diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerTests.java index 5e5eb58a95..ca6ab9314c 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerTests.java @@ -17,7 +17,10 @@ package org.springframework.security.config.annotation.web.configurers; import java.net.URI; +import java.util.Arrays; +import java.util.List; +import jakarta.servlet.http.Cookie; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; import org.junit.jupiter.api.Test; @@ -27,6 +30,7 @@ import org.springframework.beans.factory.BeanCreationException; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.mock.web.MockHttpSession; import org.springframework.security.config.Customizer; @@ -42,6 +46,7 @@ import org.springframework.security.provisioning.InMemoryUserDetailsManager; import org.springframework.security.web.SecurityFilterChain; import org.springframework.security.web.access.AccessDeniedHandler; import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy; +import org.springframework.security.web.csrf.CookieCsrfTokenRepository; import org.springframework.security.web.csrf.CsrfToken; import org.springframework.security.web.csrf.CsrfTokenRepository; import org.springframework.security.web.csrf.CsrfTokenRequestAttributeHandler; @@ -308,6 +313,7 @@ public class CsrfConfigurerTests { public void loginWhenCustomCsrfTokenRepositoryThenCsrfTokenIsCleared() throws Exception { CsrfTokenRepositoryConfig.REPO = mock(CsrfTokenRepository.class); DefaultCsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token"); + given(CsrfTokenRepositoryConfig.REPO.loadToken(any())).willReturn(csrfToken); given(CsrfTokenRepositoryConfig.REPO.loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class))).willReturn(new TestDeferredCsrfToken(csrfToken)); this.spring.register(CsrfTokenRepositoryConfig.class, BasicController.class).autowire(); @@ -318,6 +324,7 @@ public class CsrfConfigurerTests { .param("password", "password"); // @formatter:on this.mvc.perform(loginRequest).andExpect(redirectedUrl("/")); + verify(CsrfTokenRepositoryConfig.REPO).loadToken(any(HttpServletRequest.class)); verify(CsrfTokenRepositoryConfig.REPO).saveToken(isNull(), any(HttpServletRequest.class), any(HttpServletResponse.class)); } @@ -449,6 +456,7 @@ public class CsrfConfigurerTests { public void loginWhenCsrfTokenRequestAttributeHandlerSetAndNormalCsrfTokenThenSuccess() throws Exception { CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token"); CsrfTokenRepository csrfTokenRepository = mock(CsrfTokenRepository.class); + given(csrfTokenRepository.loadToken(any(HttpServletRequest.class))).willReturn(csrfToken); given(csrfTokenRepository.loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class))) .willReturn(new TestDeferredCsrfToken(csrfToken)); CsrfTokenRequestHandlerConfig.REPO = csrfTokenRepository; @@ -462,6 +470,7 @@ public class CsrfConfigurerTests { .param("password", "password"); // @formatter:on this.mvc.perform(loginRequest).andExpect(redirectedUrl("/")); + verify(csrfTokenRepository).loadToken(any(HttpServletRequest.class)); verify(csrfTokenRepository).saveToken(isNull(), any(HttpServletRequest.class), any(HttpServletResponse.class)); verify(csrfTokenRepository, times(2)).loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class)); @@ -487,6 +496,7 @@ public class CsrfConfigurerTests { public void loginWhenXorCsrfTokenRequestAttributeHandlerSetAndMaskedCsrfTokenThenSuccess() throws Exception { CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token"); CsrfTokenRepository csrfTokenRepository = mock(CsrfTokenRepository.class); + given(csrfTokenRepository.loadToken(any(HttpServletRequest.class))).willReturn(csrfToken); given(csrfTokenRepository.loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class))) .willReturn(new TestDeferredCsrfToken(csrfToken)); CsrfTokenRequestHandlerConfig.REPO = csrfTokenRepository; @@ -503,12 +513,93 @@ public class CsrfConfigurerTests { .param("password", "password"); // @formatter:on this.mvc.perform(loginRequest).andExpect(redirectedUrl("/")); + verify(csrfTokenRepository).loadToken(any(HttpServletRequest.class)); verify(csrfTokenRepository).saveToken(isNull(), any(HttpServletRequest.class), any(HttpServletResponse.class)); verify(csrfTokenRepository, times(3)).loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class)); verifyNoMoreInteractions(csrfTokenRepository); } + @Test + public void loginWhenFormLoginAndCookieCsrfTokenRepositorySetAndExistingTokenThenRemovesAndGeneratesNewToken() + throws Exception { + CsrfToken csrfToken = new DefaultCsrfToken("X-XSRF-TOKEN", "_csrf", "token"); + Cookie existingCookie = new Cookie("XSRF-TOKEN", csrfToken.getToken()); + CookieCsrfTokenRepository csrfTokenRepository = CookieCsrfTokenRepository.withHttpOnlyFalse(); + csrfTokenRepository.setCookieName(existingCookie.getName()); + CsrfTokenRequestHandlerConfig.REPO = csrfTokenRepository; + CsrfTokenRequestHandlerConfig.HANDLER = new CsrfTokenRequestAttributeHandler(); + this.spring.register(CsrfTokenRequestHandlerConfig.class, BasicController.class).autowire(); + + // @formatter:off + MockHttpServletRequestBuilder loginRequest = post("/login") + .cookie(existingCookie) + .header(csrfToken.getHeaderName(), csrfToken.getToken()) + .param("username", "user") + .param("password", "password"); + // @formatter:on + MvcResult mvcResult = this.mvc.perform(loginRequest).andExpect(redirectedUrl("/")).andReturn(); + List cookies = Arrays.asList(mvcResult.getResponse().getCookies()); + cookies.removeIf((cookie) -> !cookie.getName().equalsIgnoreCase(existingCookie.getName())); + assertThat(cookies).hasSize(2); + assertThat(cookies.get(0).getValue()).isEmpty(); + assertThat(cookies.get(1).getValue()).isNotEmpty(); + } + + @Test + public void postWhenHttpBasicAndCookieCsrfTokenRepositorySetAndExistingTokenThenRemovesAndGeneratesNewToken() + throws Exception { + CsrfToken csrfToken = new DefaultCsrfToken("X-XSRF-TOKEN", "_csrf", "token"); + Cookie existingCookie = new Cookie("XSRF-TOKEN", csrfToken.getToken()); + CookieCsrfTokenRepository csrfTokenRepository = CookieCsrfTokenRepository.withHttpOnlyFalse(); + csrfTokenRepository.setCookieName(existingCookie.getName()); + HttpBasicCsrfTokenRequestHandlerConfig.REPO = csrfTokenRepository; + HttpBasicCsrfTokenRequestHandlerConfig.HANDLER = new CsrfTokenRequestAttributeHandler(); + this.spring.register(HttpBasicCsrfTokenRequestHandlerConfig.class, BasicController.class).autowire(); + + HttpHeaders headers = new HttpHeaders(); + headers.set(csrfToken.getHeaderName(), csrfToken.getToken()); + headers.setBasicAuth("user", "password"); + // @formatter:off + MvcResult mvcResult = this.mvc.perform(post("/") + .cookie(existingCookie) + .headers(headers)) + .andExpect(status().isOk()) + .andReturn(); + // @formatter:on + List cookies = Arrays.asList(mvcResult.getResponse().getCookies()); + cookies.removeIf((cookie) -> !cookie.getName().equalsIgnoreCase(existingCookie.getName())); + assertThat(cookies).hasSize(2); + assertThat(cookies.get(0).getValue()).isEmpty(); + assertThat(cookies.get(1).getValue()).isNotEmpty(); + } + + @Test + public void getWhenHttpBasicAndCookieCsrfTokenRepositorySetAndNoExistingCookieThenGeneratesNewToken() + throws Exception { + CsrfToken csrfToken = new DefaultCsrfToken("X-XSRF-TOKEN", "_csrf", "token"); + Cookie expectedCookie = new Cookie("XSRF-TOKEN", csrfToken.getToken()); + CookieCsrfTokenRepository csrfTokenRepository = CookieCsrfTokenRepository.withHttpOnlyFalse(); + csrfTokenRepository.setCookieName(expectedCookie.getName()); + HttpBasicCsrfTokenRequestHandlerConfig.REPO = csrfTokenRepository; + HttpBasicCsrfTokenRequestHandlerConfig.HANDLER = new CsrfTokenRequestAttributeHandler(); + this.spring.register(HttpBasicCsrfTokenRequestHandlerConfig.class, BasicController.class).autowire(); + + HttpHeaders headers = new HttpHeaders(); + headers.set(csrfToken.getHeaderName(), csrfToken.getToken()); + headers.setBasicAuth("user", "password"); + // @formatter:off + MvcResult mvcResult = this.mvc.perform(get("/") + .headers(headers)) + .andExpect(status().isOk()) + .andReturn(); + // @formatter:on + List cookies = Arrays.asList(mvcResult.getResponse().getCookies()); + cookies.removeIf((cookie) -> !cookie.getName().equalsIgnoreCase(expectedCookie.getName())); + assertThat(cookies).hasSize(1); + assertThat(cookies.get(0).getValue()).isNotEmpty(); + } + @Configuration static class AllowHttpMethodsFirewallConfig { @@ -902,6 +993,42 @@ public class CsrfConfigurerTests { } + @Configuration + @EnableWebSecurity + static class HttpBasicCsrfTokenRequestHandlerConfig { + + static CsrfTokenRepository REPO; + + static CsrfTokenRequestHandler HANDLER; + + @Bean + SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeHttpRequests((authorize) -> authorize + .anyRequest().authenticated() + ) + .httpBasic(Customizer.withDefaults()) + .csrf((csrf) -> csrf + .csrfTokenRepository(REPO) + .csrfTokenRequestHandler(HANDLER) + ); + // @formatter:on + + return http.build(); + } + + @Autowired + void configure(AuthenticationManagerBuilder auth) throws Exception { + // @formatter:off + auth + .inMemoryAuthentication() + .withUser(PasswordEncodedUser.user()); + // @formatter:on + } + + } + @RestController static class BasicController { diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategy.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategy.java index ebfaaa64aa..87862d56a2 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategy.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategy.java @@ -32,6 +32,7 @@ import org.springframework.util.Assert; * the next request. * * @author Rob Winch + * @author Steve Riesenberg * @since 3.2 */ public final class CsrfAuthenticationStrategy implements SessionAuthenticationStrategy { @@ -64,10 +65,13 @@ public final class CsrfAuthenticationStrategy implements SessionAuthenticationSt @Override public void onAuthentication(Authentication authentication, HttpServletRequest request, HttpServletResponse response) throws SessionAuthenticationException { - this.tokenRepository.saveToken(null, request, response); - DeferredCsrfToken deferredCsrfToken = this.tokenRepository.loadDeferredToken(request, response); - this.requestHandler.handle(request, response, deferredCsrfToken::get); - this.logger.debug("Replaced CSRF Token"); + boolean containsToken = this.tokenRepository.loadToken(request) != null; + if (containsToken) { + this.tokenRepository.saveToken(null, request, response); + DeferredCsrfToken deferredCsrfToken = this.tokenRepository.loadDeferredToken(request, response); + this.requestHandler.handle(request, response, deferredCsrfToken::get); + this.logger.debug("Replaced CSRF Token"); + } } } diff --git a/web/src/test/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategyTests.java b/web/src/test/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategyTests.java index df3497486e..94bae781b3 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategyTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategyTests.java @@ -35,6 +35,7 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; @@ -81,6 +82,7 @@ public class CsrfAuthenticationStrategyTests { @Test public void onAuthenticationWhenCustomRequestHandlerThenUsed() { + given(this.csrfTokenRepository.loadToken(this.request)).willReturn(this.existingToken); given(this.csrfTokenRepository.loadDeferredToken(this.request, this.response)) .willReturn(new TestDeferredCsrfToken(this.existingToken, false)); @@ -88,16 +90,20 @@ public class CsrfAuthenticationStrategyTests { this.strategy.setRequestHandler(requestHandler); this.strategy.onAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER"), this.request, this.response); + verify(this.csrfTokenRepository).loadToken(this.request); + verify(this.csrfTokenRepository).loadDeferredToken(this.request, this.response); verify(requestHandler).handle(eq(this.request), eq(this.response), any()); verifyNoMoreInteractions(requestHandler); } @Test public void logoutRemovesCsrfTokenAndLoadsNewDeferredCsrfToken() { + given(this.csrfTokenRepository.loadToken(this.request)).willReturn(this.existingToken); given(this.csrfTokenRepository.loadDeferredToken(this.request, this.response)) .willReturn(new TestDeferredCsrfToken(this.generatedToken, false)); this.strategy.onAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER"), this.request, this.response); + verify(this.csrfTokenRepository).loadToken(this.request); verify(this.csrfTokenRepository).saveToken(null, this.request, this.response); verify(this.csrfTokenRepository).loadDeferredToken(this.request, this.response); // SEC-2404, SEC-2832 @@ -112,6 +118,7 @@ public class CsrfAuthenticationStrategyTests { @Test public void delaySavingCsrf() { this.strategy = new CsrfAuthenticationStrategy(new LazyCsrfTokenRepository(this.csrfTokenRepository)); + given(this.csrfTokenRepository.loadToken(this.request)).willReturn(this.existingToken, (CsrfToken) null); given(this.csrfTokenRepository.generateToken(this.request)).willReturn(this.generatedToken); this.strategy.onAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER"), this.request, this.response); @@ -120,7 +127,7 @@ public class CsrfAuthenticationStrategyTests { any(HttpServletResponse.class)); CsrfToken tokenInRequest = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName()); tokenInRequest.getToken(); - verify(this.csrfTokenRepository).loadToken(this.request); + verify(this.csrfTokenRepository, times(2)).loadToken(this.request); verify(this.csrfTokenRepository).generateToken(this.request); verify(this.csrfTokenRepository).saveToken(eq(this.generatedToken), any(HttpServletRequest.class), any(HttpServletResponse.class));