From 2a2051cd7b326167a8d060c7f655cd8a0878d088 Mon Sep 17 00:00:00 2001 From: Steve Riesenberg Date: Tue, 11 Oct 2022 14:24:10 -0500 Subject: [PATCH] Default to Xor CSRF tokens in CsrfFilter Issue gh-11960 --- .../web/configurers/DefaultFiltersTests.java | 11 ++- .../DefaultLoginPageConfigurerTests.java | 36 +++++--- ...ionManagementConfigurerServlet31Tests.java | 9 +- ...MockMvcRequestPostProcessorsCsrfTests.java | 10 ++- .../security/web/csrf/CsrfFilter.java | 2 +- .../security/web/csrf/CsrfFilterTests.java | 84 +++++++++++-------- 6 files changed, 96 insertions(+), 56 deletions(-) diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/DefaultFiltersTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/DefaultFiltersTests.java index c1a074af76..bf94a98ee3 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/DefaultFiltersTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/DefaultFiltersTests.java @@ -51,8 +51,11 @@ import org.springframework.security.web.context.SecurityContextHolderFilter; import org.springframework.security.web.context.request.async.WebAsyncManagerIntegrationFilter; import org.springframework.security.web.csrf.CsrfFilter; import org.springframework.security.web.csrf.CsrfToken; +import org.springframework.security.web.csrf.CsrfTokenRepository; +import org.springframework.security.web.csrf.CsrfTokenRequestHandler; import org.springframework.security.web.csrf.DefaultCsrfToken; import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; +import org.springframework.security.web.csrf.XorCsrfTokenRequestAttributeHandler; import org.springframework.security.web.header.HeaderWriterFilter; import org.springframework.security.web.savedrequest.RequestCacheAwareFilter; import org.springframework.security.web.servletapi.SecurityContextHolderAwareRequestFilter; @@ -121,8 +124,12 @@ public class DefaultFiltersTests { MockHttpServletRequest request = new MockHttpServletRequest("POST", ""); request.setServletPath("/logout"); CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "BaseSpringSpec_CSRFTOKEN"); - new HttpSessionCsrfTokenRepository().saveToken(csrfToken, request, response); - request.setParameter(csrfToken.getParameterName(), csrfToken.getToken()); + CsrfTokenRepository repository = new HttpSessionCsrfTokenRepository(); + repository.saveToken(csrfToken, request, response); + CsrfTokenRequestHandler handler = new XorCsrfTokenRequestAttributeHandler(); + handler.handle(request, response, () -> csrfToken); + CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName()); + request.setParameter(token.getParameterName(), token.getToken()); this.spring.getContext().getBean("springSecurityFilterChain", Filter.class).doFilter(request, response, new MockFilterChain()); assertThat(response.getRedirectedUrl()).isEqualTo("/login?logout"); diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/DefaultLoginPageConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/DefaultLoginPageConfigurerTests.java index 75746f223e..e906f89f47 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/DefaultLoginPageConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/DefaultLoginPageConfigurerTests.java @@ -85,7 +85,9 @@ public class DefaultLoginPageConfigurerTests { String csrfAttributeName = HttpSessionCsrfTokenRepository.class.getName().concat(".CSRF_TOKEN"); // @formatter:off this.mvc.perform(get("/login").sessionAttr(csrfAttributeName, csrfToken)) - .andExpect(content().string("\n" + .andExpect((result) -> { + CsrfToken token = (CsrfToken) result.getRequest().getAttribute(CsrfToken.class.getName()); + assertThat(result.getResponse().getContentAsString()).isEqualTo("\n" + "\n" + " \n" + " \n" @@ -108,11 +110,12 @@ public class DefaultLoginPageConfigurerTests { + " \n" + " \n" + "

\n" - + "\n" + + "\n" + " \n" + " \n" + "\n" - + "")); + + ""); + }); // @formatter:on } @@ -131,7 +134,9 @@ public class DefaultLoginPageConfigurerTests { // @formatter:off this.mvc.perform(get("/login?error").session((MockHttpSession) mvcResult.getRequest().getSession()) .sessionAttr(csrfAttributeName, csrfToken)) - .andExpect(content().string("\n" + .andExpect((result) -> { + CsrfToken token = (CsrfToken) result.getRequest().getAttribute(CsrfToken.class.getName()); + assertThat(result.getResponse().getContentAsString()).isEqualTo("\n" + "\n" + " \n" + " \n" @@ -153,11 +158,12 @@ public class DefaultLoginPageConfigurerTests { + " \n" + " \n" + "

\n" - + "\n" + + "\n" + " \n" + " \n" + "\n" - + "")); + + ""); + }); // @formatter:on } @@ -180,7 +186,9 @@ public class DefaultLoginPageConfigurerTests { String csrfAttributeName = HttpSessionCsrfTokenRepository.class.getName().concat(".CSRF_TOKEN"); // @formatter:off this.mvc.perform(get("/login?logout").sessionAttr(csrfAttributeName, csrfToken)) - .andExpect(content().string("\n" + .andExpect((result) -> { + CsrfToken token = (CsrfToken) result.getRequest().getAttribute(CsrfToken.class.getName()); + assertThat(result.getResponse().getContentAsString()).isEqualTo("\n" + "\n" + " \n" + " \n" @@ -203,11 +211,12 @@ public class DefaultLoginPageConfigurerTests { + " \n" + " \n" + "

\n" - + "\n" + + "\n" + " \n" + " \n" + "\n" - + "")); + + ""); + }); // @formatter:on } @@ -230,7 +239,9 @@ public class DefaultLoginPageConfigurerTests { String csrfAttributeName = HttpSessionCsrfTokenRepository.class.getName().concat(".CSRF_TOKEN"); // @formatter:off this.mvc.perform(get("/login").sessionAttr(csrfAttributeName, csrfToken)) - .andExpect(content().string("\n" + .andExpect((result) -> { + CsrfToken token = (CsrfToken) result.getRequest().getAttribute(CsrfToken.class.getName()); + assertThat(result.getResponse().getContentAsString()).isEqualTo("\n" + "\n" + " \n" + " \n" @@ -254,11 +265,12 @@ public class DefaultLoginPageConfigurerTests { + " \n" + "

\n" + "

Remember me on this computer.

\n" - + "\n" + + "\n" + " \n" + " \n" + "\n" - + "")); + + ""); + }); // @formatter:on } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerServlet31Tests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerServlet31Tests.java index 241a45308b..b17ce7635a 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerServlet31Tests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerServlet31Tests.java @@ -39,7 +39,10 @@ import org.springframework.security.web.SecurityFilterChain; import org.springframework.security.web.context.HttpRequestResponseHolder; import org.springframework.security.web.context.HttpSessionSecurityContextRepository; import org.springframework.security.web.csrf.CsrfToken; +import org.springframework.security.web.csrf.CsrfTokenRequestHandler; +import org.springframework.security.web.csrf.DeferredCsrfToken; import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; +import org.springframework.security.web.csrf.XorCsrfTokenRequestAttributeHandler; import static org.assertj.core.api.Assertions.assertThat; @@ -82,8 +85,10 @@ public class SessionManagementConfigurerServlet31Tests { request.setParameter("username", "user"); request.setParameter("password", "password"); HttpSessionCsrfTokenRepository repository = new HttpSessionCsrfTokenRepository(); - CsrfToken token = repository.generateToken(request); - repository.saveToken(token, request, this.response); + CsrfTokenRequestHandler handler = new XorCsrfTokenRequestAttributeHandler(); + DeferredCsrfToken deferredCsrfToken = repository.loadDeferredToken(request, this.response); + handler.handle(request, this.response, deferredCsrfToken::get); + CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName()); request.setParameter(token.getParameterName(), token.getToken()); request.getSession().setAttribute("attribute1", "value1"); loadConfig(SessionManagementDefaultSessionFixationServlet31Config.class); diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfTests.java index 8ce4959ef5..b5d0742a93 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfTests.java @@ -40,7 +40,10 @@ import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequ import org.springframework.security.web.FilterChainProxy; import org.springframework.security.web.SecurityFilterChain; import org.springframework.security.web.csrf.CsrfToken; +import org.springframework.security.web.csrf.CsrfTokenRequestHandler; +import org.springframework.security.web.csrf.DeferredCsrfToken; import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository; +import org.springframework.security.web.csrf.XorCsrfTokenRequestAttributeHandler; import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit.jupiter.SpringExtension; import org.springframework.test.context.web.WebAppConfiguration; @@ -157,9 +160,12 @@ public class SecurityMockMvcRequestPostProcessorsCsrfTests { // @formatter:off this.mockMvc.perform(post("/").with(csrf())); MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); HttpSessionCsrfTokenRepository repo = new HttpSessionCsrfTokenRepository(); - CsrfToken token = repo.generateToken(request); - repo.saveToken(token, request, new MockHttpServletResponse()); + CsrfTokenRequestHandler handler = new XorCsrfTokenRequestAttributeHandler(); + DeferredCsrfToken deferredCsrfToken = repo.loadDeferredToken(request, response); + handler.handle(request, response, deferredCsrfToken::get); + CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName()); MockHttpServletRequestBuilder requestWithCsrf = post("/") .param(token.getParameterName(), token.getToken()) .session((MockHttpSession) request.getSession()); diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java index 890a7cf7ba..c3294c93e7 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java @@ -87,7 +87,7 @@ public final class CsrfFilter extends OncePerRequestFilter { private AccessDeniedHandler accessDeniedHandler = new AccessDeniedHandlerImpl(); - private CsrfTokenRequestHandler requestHandler = new CsrfTokenRequestAttributeHandler(); + private CsrfTokenRequestHandler requestHandler = new XorCsrfTokenRequestAttributeHandler(); /** * Creates a new instance. diff --git a/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java b/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java index 0f4f0c8e80..c57e1bfc7d 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java @@ -130,8 +130,8 @@ public class CsrfFilterTests { given(this.tokenRepository.loadDeferredToken(this.request, this.response)) .willReturn(new TestDeferredCsrfToken(this.token, false)); this.filter.doFilter(this.request, this.response, this.filterChain); - assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); - assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull(); + assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull(); verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class)); verifyNoMoreInteractions(this.filterChain); } @@ -143,8 +143,8 @@ public class CsrfFilterTests { .willReturn(new TestDeferredCsrfToken(this.token, false)); this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID"); this.filter.doFilter(this.request, this.response, this.filterChain); - assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); - assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull(); + assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull(); verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class)); verifyNoMoreInteractions(this.filterChain); } @@ -156,8 +156,8 @@ public class CsrfFilterTests { .willReturn(new TestDeferredCsrfToken(this.token, false)); this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID"); this.filter.doFilter(this.request, this.response, this.filterChain); - assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); - assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull(); + assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull(); verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class)); verifyNoMoreInteractions(this.filterChain); } @@ -168,11 +168,14 @@ public class CsrfFilterTests { given(this.requestMatcher.matches(this.request)).willReturn(true); given(this.tokenRepository.loadDeferredToken(this.request, this.response)) .willReturn(new TestDeferredCsrfToken(this.token, false)); - this.request.setParameter(this.token.getParameterName(), this.token.getToken()); - this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID"); + CsrfTokenRequestHandler handler = new XorCsrfTokenRequestAttributeHandler(); + handler.handle(this.request, this.response, () -> this.token); + CsrfToken csrfToken = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName()); + this.request.setParameter(csrfToken.getParameterName(), csrfToken.getToken()); + this.request.addHeader(csrfToken.getHeaderName(), csrfToken.getToken() + " INVALID"); this.filter.doFilter(this.request, this.response, this.filterChain); - assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); - assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull(); + assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull(); verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class)); verifyNoMoreInteractions(this.filterChain); } @@ -183,8 +186,8 @@ public class CsrfFilterTests { given(this.tokenRepository.loadDeferredToken(this.request, this.response)) .willReturn(new TestDeferredCsrfToken(this.token, false)); this.filter.doFilter(this.request, this.response, this.filterChain); - assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); - assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull(); + assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull(); verify(this.filterChain).doFilter(this.request, this.response); verifyNoMoreInteractions(this.deniedHandler); } @@ -195,8 +198,8 @@ public class CsrfFilterTests { given(this.tokenRepository.loadDeferredToken(this.request, this.response)) .willReturn(new TestDeferredCsrfToken(this.token, true)); this.filter.doFilter(this.request, this.response, this.filterChain); - assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); - assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull(); + assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull(); verify(this.filterChain).doFilter(this.request, this.response); verifyNoMoreInteractions(this.deniedHandler); } @@ -206,10 +209,13 @@ public class CsrfFilterTests { given(this.requestMatcher.matches(this.request)).willReturn(true); given(this.tokenRepository.loadDeferredToken(this.request, this.response)) .willReturn(new TestDeferredCsrfToken(this.token, false)); - this.request.addHeader(this.token.getHeaderName(), this.token.getToken()); + CsrfTokenRequestHandler handler = new XorCsrfTokenRequestAttributeHandler(); + handler.handle(this.request, this.response, () -> this.token); + CsrfToken csrfToken = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName()); + this.request.addHeader(csrfToken.getHeaderName(), csrfToken.getToken()); this.filter.doFilter(this.request, this.response, this.filterChain); - assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); - assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull(); + assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull(); verify(this.filterChain).doFilter(this.request, this.response); verifyNoMoreInteractions(this.deniedHandler); } @@ -220,11 +226,14 @@ public class CsrfFilterTests { given(this.requestMatcher.matches(this.request)).willReturn(true); given(this.tokenRepository.loadDeferredToken(this.request, this.response)) .willReturn(new TestDeferredCsrfToken(this.token, false)); - this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID"); - this.request.addHeader(this.token.getHeaderName(), this.token.getToken()); + CsrfTokenRequestHandler handler = new XorCsrfTokenRequestAttributeHandler(); + handler.handle(this.request, this.response, () -> this.token); + CsrfToken csrfToken = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName()); + this.request.setParameter(csrfToken.getParameterName(), csrfToken.getToken() + " INVALID"); + this.request.addHeader(csrfToken.getHeaderName(), csrfToken.getToken()); this.filter.doFilter(this.request, this.response, this.filterChain); - assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); - assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull(); + assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull(); verify(this.filterChain).doFilter(this.request, this.response); verifyNoMoreInteractions(this.deniedHandler); } @@ -234,10 +243,13 @@ public class CsrfFilterTests { given(this.requestMatcher.matches(this.request)).willReturn(true); given(this.tokenRepository.loadDeferredToken(this.request, this.response)) .willReturn(new TestDeferredCsrfToken(this.token, false)); - this.request.setParameter(this.token.getParameterName(), this.token.getToken()); + CsrfTokenRequestHandler handler = new XorCsrfTokenRequestAttributeHandler(); + handler.handle(this.request, this.response, () -> this.token); + CsrfToken csrfToken = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName()); + this.request.setParameter(csrfToken.getParameterName(), csrfToken.getToken()); this.filter.doFilter(this.request, this.response, this.filterChain); - assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); - assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull(); + assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull(); verify(this.filterChain).doFilter(this.request, this.response); verifyNoMoreInteractions(this.deniedHandler); verify(this.tokenRepository, never()).saveToken(any(CsrfToken.class), any(HttpServletRequest.class), @@ -249,10 +261,13 @@ public class CsrfFilterTests { given(this.requestMatcher.matches(this.request)).willReturn(true); given(this.tokenRepository.loadDeferredToken(this.request, this.response)) .willReturn(new TestDeferredCsrfToken(this.token, true)); - this.request.setParameter(this.token.getParameterName(), this.token.getToken()); + CsrfTokenRequestHandler handler = new XorCsrfTokenRequestAttributeHandler(); + handler.handle(this.request, this.response, () -> this.token); + CsrfToken csrfToken = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName()); + this.request.setParameter(csrfToken.getParameterName(), csrfToken.getToken()); this.filter.doFilter(this.request, this.response, this.filterChain); - assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); - assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull(); + assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull(); // LazyCsrfTokenRepository requires the response as an attribute assertThat(this.request.getAttribute(HttpServletResponse.class.getName())).isEqualTo(this.response); verify(this.filterChain).doFilter(this.request, this.response); @@ -320,8 +335,8 @@ public class CsrfFilterTests { given(this.tokenRepository.loadDeferredToken(this.request, this.response)) .willReturn(new TestDeferredCsrfToken(this.token, false)); this.filter.doFilter(this.request, this.response, this.filterChain); - assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token); - assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isNotNull(); + assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isNotNull(); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN); verifyNoMoreInteractions(this.filterChain); } @@ -371,12 +386,9 @@ public class CsrfFilterTests { given(this.requestMatcher.matches(this.request)).willReturn(false); given(this.tokenRepository.loadDeferredToken(this.request, this.response)) .willReturn(new TestDeferredCsrfToken(this.token, false)); - XorCsrfTokenRequestAttributeHandler requestHandler = new XorCsrfTokenRequestAttributeHandler(); - requestHandler.setCsrfRequestAttributeName(this.token.getParameterName()); - this.filter.setRequestHandler(requestHandler); this.filter.doFilter(this.request, this.response, this.filterChain); assertThat(this.request.getAttribute(CsrfToken.class.getName())).isNotNull(); - assertThat(this.request.getAttribute(this.token.getParameterName())).isNotNull(); + assertThat(this.request.getAttribute("_csrf")).isNotNull(); verify(this.filterChain).doFilter(this.request, this.response); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); @@ -397,8 +409,6 @@ public class CsrfFilterTests { given(this.requestMatcher.matches(this.request)).willReturn(true); given(this.tokenRepository.loadDeferredToken(this.request, this.response)) .willReturn(new TestDeferredCsrfToken(this.token, false)); - XorCsrfTokenRequestAttributeHandler requestHandler = new XorCsrfTokenRequestAttributeHandler(); - this.filter.setRequestHandler(requestHandler); this.request.setParameter(this.token.getParameterName(), this.token.getToken()); this.filter.doFilter(this.request, this.response, this.filterChain); verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(AccessDeniedException.class)); @@ -421,7 +431,7 @@ public class CsrfFilterTests { throws ServletException, IOException { CsrfFilter filter = createCsrfFilter(this.tokenRepository); String csrfAttrName = "_csrf"; - CsrfTokenRequestAttributeHandler requestHandler = new CsrfTokenRequestAttributeHandler(); + CsrfTokenRequestAttributeHandler requestHandler = new XorCsrfTokenRequestAttributeHandler(); requestHandler.setCsrfRequestAttributeName(csrfAttrName); filter.setRequestHandler(requestHandler); CsrfToken expectedCsrfToken = mock(CsrfToken.class); @@ -432,7 +442,7 @@ public class CsrfFilterTests { verifyNoInteractions(expectedCsrfToken); CsrfToken tokenFromRequest = (CsrfToken) this.request.getAttribute(csrfAttrName); - assertThatCsrfToken(tokenFromRequest).isEqualTo(expectedCsrfToken); + assertThatCsrfToken(tokenFromRequest).isNotNull(); } }