diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java index 10219b2b2d..0725bcf890 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java @@ -86,7 +86,7 @@ public final class CsrfFilter extends OncePerRequestFilter { private AccessDeniedHandler accessDeniedHandler = new AccessDeniedHandlerImpl(); - private String csrfRequestAttributeName; + private String csrfRequestAttributeName = "_csrf"; public CsrfFilter(CsrfTokenRepository csrfTokenRepository) { Assert.notNull(csrfTokenRepository, "csrfTokenRepository cannot be null"); diff --git a/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java b/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java index f5789804f2..4844b2d6b5 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java @@ -75,6 +75,8 @@ public class CsrfFilterTests { private CsrfToken token; + private String csrfAttrName = "_csrf"; + private CsrfFilter filter; @BeforeEach @@ -108,7 +110,7 @@ public class CsrfFilterTests { given(this.requestMatcher.matches(this.request)).willReturn(false); given(this.tokenRepository.generateToken(this.request)).willReturn(this.token); this.filter.doFilter(this.request, this.response, this.filterChain); - CsrfToken attrToken = (CsrfToken) this.request.getAttribute(this.token.getParameterName()); + CsrfToken attrToken = (CsrfToken) this.request.getAttribute(this.csrfAttrName); // no CsrfToken should have been saved yet verify(this.tokenRepository, times(0)).saveToken(any(CsrfToken.class), any(HttpServletRequest.class), any(HttpServletResponse.class)); @@ -125,7 +127,7 @@ public class CsrfFilterTests { given(this.requestMatcher.matches(this.request)).willReturn(true); given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); + assertThat(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class)); verifyNoMoreInteractions(this.filterChain); @@ -137,7 +139,7 @@ public class CsrfFilterTests { given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID"); this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); + assertThat(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class)); verifyNoMoreInteractions(this.filterChain); @@ -149,7 +151,7 @@ public class CsrfFilterTests { given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID"); this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); + assertThat(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class)); verifyNoMoreInteractions(this.filterChain); @@ -163,7 +165,7 @@ public class CsrfFilterTests { this.request.setParameter(this.token.getParameterName(), this.token.getToken()); this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID"); this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); + assertThat(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class)); verifyNoMoreInteractions(this.filterChain); @@ -174,7 +176,7 @@ public class CsrfFilterTests { given(this.requestMatcher.matches(this.request)).willReturn(false); given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); + assertThat(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); verify(this.filterChain).doFilter(this.request, this.response); verifyNoMoreInteractions(this.deniedHandler); @@ -185,7 +187,7 @@ public class CsrfFilterTests { given(this.requestMatcher.matches(this.request)).willReturn(false); given(this.tokenRepository.generateToken(this.request)).willReturn(this.token); this.filter.doFilter(this.request, this.response, this.filterChain); - assertToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); + assertToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); assertToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); verify(this.filterChain).doFilter(this.request, this.response); verifyNoMoreInteractions(this.deniedHandler); @@ -197,7 +199,7 @@ public class CsrfFilterTests { given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); this.request.addHeader(this.token.getHeaderName(), this.token.getToken()); this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); + assertThat(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); verify(this.filterChain).doFilter(this.request, this.response); verifyNoMoreInteractions(this.deniedHandler); @@ -211,7 +213,7 @@ public class CsrfFilterTests { this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID"); this.request.addHeader(this.token.getHeaderName(), this.token.getToken()); this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); + assertThat(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); verify(this.filterChain).doFilter(this.request, this.response); verifyNoMoreInteractions(this.deniedHandler); @@ -223,7 +225,7 @@ public class CsrfFilterTests { given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); this.request.setParameter(this.token.getParameterName(), this.token.getToken()); this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); + assertThat(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); verify(this.filterChain).doFilter(this.request, this.response); verifyNoMoreInteractions(this.deniedHandler); @@ -237,7 +239,7 @@ public class CsrfFilterTests { given(this.tokenRepository.generateToken(this.request)).willReturn(this.token); this.request.setParameter(this.token.getParameterName(), this.token.getToken()); this.filter.doFilter(this.request, this.response, this.filterChain); - assertToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); + assertToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); assertToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); // LazyCsrfTokenRepository requires the response as an attribute assertThat(this.request.getAttribute(HttpServletResponse.class.getName())).isEqualTo(this.response); @@ -303,7 +305,7 @@ public class CsrfFilterTests { given(this.requestMatcher.matches(this.request)).willReturn(true); given(this.tokenRepository.loadToken(this.request)).willReturn(this.token); this.filter.doFilter(this.request, this.response, this.filterChain); - assertThat(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); + assertThat(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN); verifyNoMoreInteractions(this.filterChain);