From 08d2c93713caac96217e72caa13009490fda9bb5 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Thu, 26 Sep 2019 17:18:47 -0400 Subject: [PATCH] Polish gh-7466 --- .../config/web/server/ServerHttpSecurity.java | 17 ++-- .../config/web/server/OAuth2LoginTests.java | 85 ++++++++++++++++--- .../web/server/ServerHttpSecurityTests.java | 82 +----------------- 3 files changed, 85 insertions(+), 99 deletions(-) diff --git a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java index 46abc0e7bc..2a1f43cb2a 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java +++ b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java @@ -76,7 +76,6 @@ import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserSer import org.springframework.security.oauth2.client.web.server.AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationCodeGrantWebFilter; import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationRequestRedirectWebFilter; -import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationCodeAuthenticationTokenConverter; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationRequestResolver; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; @@ -1106,13 +1105,14 @@ public class ServerHttpSecurity { } /** - * Sets authorization request repository for {@link OAuth2AuthorizationRequestRedirectWebFilter}. + * Sets the repository to use for storing {@link OAuth2AuthorizationRequest}'s. * - * @param authorizationRequestRepository authorization request repository, must not be null + * @since 5.2 + * @param authorizationRequestRepository the repository to use for storing {@link OAuth2AuthorizationRequest}'s * @return the {@link OAuth2LoginSpec} for further configuration */ - public OAuth2LoginSpec authorizationRequestRepository(ServerAuthorizationRequestRepository authorizationRequestRepository) { - Assert.notNull(authorizationRequestRepository, "authorizationRequestRepository cannot be null"); + public OAuth2LoginSpec authorizationRequestRepository( + ServerAuthorizationRequestRepository authorizationRequestRepository) { this.authorizationRequestRepository = authorizationRequestRepository; return this; } @@ -1163,9 +1163,7 @@ public class ServerHttpSecurity { OAuth2AuthorizationRequestRedirectWebFilter oauthRedirectFilter = getRedirectWebFilter(); ServerAuthorizationRequestRepository authorizationRequestRepository = getAuthorizationRequestRepository(); - if (authorizationRequestRepository != null) { - oauthRedirectFilter.setAuthorizationRequestRepository(authorizationRequestRepository); - } + oauthRedirectFilter.setAuthorizationRequestRepository(authorizationRequestRepository); oauthRedirectFilter.setRequestCache(http.requestCache.requestCache); ReactiveAuthenticationManager manager = getAuthenticationManager(); @@ -1267,10 +1265,9 @@ public class ServerHttpSecurity { return result; } - @SuppressWarnings("unchecked") private ServerAuthorizationRequestRepository getAuthorizationRequestRepository() { if (this.authorizationRequestRepository == null) { - this.authorizationRequestRepository = getBeanOrNull(ServerAuthorizationRequestRepository.class); + this.authorizationRequestRepository = new WebSessionOAuth2ServerAuthorizationRequestRepository(); } return this.authorizationRequestRepository; } diff --git a/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java b/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java index 96f1113daf..046675934b 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java @@ -16,16 +16,10 @@ package org.springframework.security.config.web.server; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; - import org.junit.Rule; import org.junit.Test; import org.mockito.stubbing.Answer; import org.openqa.selenium.WebDriver; -import reactor.core.publisher.Mono; - import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; @@ -41,6 +35,8 @@ import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextImpl; import org.springframework.security.htmlunit.server.WebTestClientHtmlUnitDriverBuilder; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken; import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken; @@ -53,7 +49,9 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio import org.springframework.security.oauth2.client.registration.InMemoryReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService; +import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationRequestResolver; +import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; @@ -84,20 +82,25 @@ import org.springframework.security.web.server.authentication.ServerAuthenticati import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler; import org.springframework.security.web.server.authentication.ServerAuthenticationSuccessHandler; import org.springframework.security.web.server.context.ServerSecurityContextRepository; +import org.springframework.security.web.server.savedrequest.ServerRequestCache; import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher; import org.springframework.test.web.reactive.server.WebTestClient; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.RestController; import org.springframework.web.reactive.config.EnableWebFlux; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebFilter; import org.springframework.web.server.WebFilterChain; import org.springframework.web.server.WebHandler; +import reactor.core.publisher.Mono; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; +import static org.mockito.Mockito.*; import static org.springframework.security.oauth2.jwt.TestJwts.jwt; /** @@ -189,6 +192,68 @@ public class OAuth2LoginTests { } } + @Test + public void oauth2AuthorizeWhenCustomObjectsThenUsed() { + this.spring.register(OAuth2LoginWithSingleClientRegistrations.class, + OAuth2AuthorizeWithMockObjectsConfig.class, + AuthorizedClientController.class).autowire(); + + OAuth2AuthorizeWithMockObjectsConfig config = this.spring.getContext().getBean(OAuth2AuthorizeWithMockObjectsConfig.class); + + ServerOAuth2AuthorizedClientRepository authorizedClientRepository = config.authorizedClientRepository; + ServerAuthorizationRequestRepository authorizationRequestRepository = config.authorizationRequestRepository; + ServerRequestCache requestCache = config.requestCache; + + when(authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty()); + when(authorizationRequestRepository.saveAuthorizationRequest(any(), any())).thenReturn(Mono.empty()); + when(requestCache.removeMatchingRequest(any())).thenReturn(Mono.empty()); + when(requestCache.saveRequest(any())).thenReturn(Mono.empty()); + + this.client.get() + .uri("/") + .exchange() + .expectStatus().is3xxRedirection(); + + verify(authorizedClientRepository).loadAuthorizedClient(any(), any(), any()); + verify(authorizationRequestRepository).saveAuthorizationRequest(any(), any()); + verify(requestCache).saveRequest(any()); + } + + @EnableWebFlux + static class OAuth2AuthorizeWithMockObjectsConfig { + ServerOAuth2AuthorizedClientRepository authorizedClientRepository = + mock(ServerOAuth2AuthorizedClientRepository.class); + + ServerAuthorizationRequestRepository authorizationRequestRepository = + mock(ServerAuthorizationRequestRepository.class); + + ServerRequestCache requestCache = mock(ServerRequestCache.class); + + @Bean + SecurityWebFilterChain springSecurity(ServerHttpSecurity http) { + http + .requestCache() + .requestCache(this.requestCache) + .and() + .oauth2Login() + .authorizationRequestRepository(this.authorizationRequestRepository); + return http.build(); + } + + @Bean + ServerOAuth2AuthorizedClientRepository authorizedClientRepository() { + return this.authorizedClientRepository; + } + } + + @RestController + static class AuthorizedClientController { + @GetMapping("/") + String home(@RegisteredOAuth2AuthorizedClient("github") OAuth2AuthorizedClient authorizedClient) { + return "home"; + } + } + @Test public void oauth2LoginWhenCustomObjectsThenUsed() { this.spring.register(OAuth2LoginWithSingleClientRegistrations.class, diff --git a/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java b/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java index 01be4c914d..c95f8bd17d 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java @@ -20,14 +20,12 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.BDDMockito.given; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.when; import static org.springframework.security.config.Customizer.withDefaults; import java.util.Arrays; -import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.Optional; @@ -43,7 +41,6 @@ import org.mockito.junit.MockitoJUnitRunner; import org.springframework.security.core.Authentication; import org.springframework.security.web.authentication.preauth.x509.X509PrincipalExtractor; import org.springframework.security.web.server.authentication.ServerX509AuthenticationConverter; -import org.springframework.web.server.handler.FilteringWebHandler; import reactor.core.publisher.Mono; import reactor.test.publisher.TestPublisher; @@ -51,29 +48,18 @@ import org.springframework.security.authentication.ReactiveAuthenticationManager import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.config.annotation.web.reactive.ServerHttpSecurityConfigurationBuilder; import org.springframework.security.core.context.SecurityContext; -import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException; -import org.springframework.security.oauth2.client.registration.ClientRegistration; -import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; -import org.springframework.security.oauth2.client.registration.TestClientRegistrations; -import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationRequestRedirectWebFilter; -import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository; -import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.test.web.reactive.server.WebTestClientBuilder; import org.springframework.security.web.server.SecurityWebFilterChain; import org.springframework.security.web.server.WebFilterChainProxy; -import org.springframework.security.web.server.authentication.AnonymousAuthenticationWebFilterTests; -import org.springframework.security.web.server.authentication.HttpBasicServerAuthenticationEntryPoint; import org.springframework.security.web.server.authentication.logout.DelegatingServerLogoutHandler; import org.springframework.security.web.server.authentication.logout.LogoutWebFilter; import org.springframework.security.web.server.authentication.logout.SecurityContextServerLogoutHandler; import org.springframework.security.web.server.authentication.logout.ServerLogoutHandler; -import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter; import org.springframework.security.web.server.context.ServerSecurityContextRepository; import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository; import org.springframework.security.web.server.csrf.CsrfServerLogoutHandler; import org.springframework.security.web.server.csrf.CsrfWebFilter; import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository; -import org.springframework.security.web.server.savedrequest.ServerRequestCache; import org.springframework.test.util.ReflectionTestUtils; import org.springframework.test.web.reactive.server.EntityExchangeResult; import org.springframework.test.web.reactive.server.FluxExchangeResult; @@ -82,7 +68,10 @@ import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RestController; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebFilter; +import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter; import org.springframework.web.server.WebFilterChain; +import org.springframework.security.web.server.authentication.AnonymousAuthenticationWebFilterTests; +import org.springframework.security.web.server.authentication.HttpBasicServerAuthenticationEntryPoint; /** * @author Rob Winch @@ -486,71 +475,6 @@ public class ServerHttpSecurityTests { verify(customServerCsrfTokenRepository).loadToken(any()); } - @SuppressWarnings("UnassignedFluxMonoInstance") - @Test - public void configureOAuth2LoginUsingCustomCommonServerRequestCache() { - ServerRequestCache requestCacheMock = mock(ServerRequestCache.class); - when(requestCacheMock.saveRequest(any(ServerWebExchange.class))).thenReturn(Mono.empty()); - - ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); - String registrationId = clientRegistration.getRegistrationId(); - - ReactiveClientRegistrationRepository clientRegistrationRepositoryMock = - mock(ReactiveClientRegistrationRepository.class); - when(clientRegistrationRepositoryMock.findByRegistrationId(registrationId)) - .thenReturn(Mono.just(clientRegistration)); - - SecurityWebFilterChain filterChain = http.requestCache().requestCache(requestCacheMock) - .and().oauth2Login().clientRegistrationRepository(clientRegistrationRepositoryMock) - .and().build(); - - Optional redirectWebFilter = - getWebFilter(filterChain, OAuth2AuthorizationRequestRedirectWebFilter.class); - assertThat(redirectWebFilter.isPresent()).isTrue(); - - FilteringWebHandler webHandler = new FilteringWebHandler( - e -> Mono.error(new ClientAuthorizationRequiredException(registrationId)), - Collections.singletonList(redirectWebFilter.get()) - ); - WebTestClient client = WebTestClient.bindToWebHandler(webHandler).build(); - client.get().uri("/foo/bar").exchange(); - verify(requestCacheMock, times(1)).saveRequest(any(ServerWebExchange.class)); - } - - @Test(expected = IllegalArgumentException.class) - public void throwExceptionWhenNullPassedForOAuth2LoginAuthorizationRequestRepository() { - http.oauth2Login().authorizationRequestRepository(null).and().build(); - } - - @SuppressWarnings({"UnassignedFluxMonoInstance", "unchecked"}) - @Test - public void configureOAuth2LoginUsingCustomAuthorizationRequestRepository() { - ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); - String registrationId = clientRegistration.getRegistrationId(); - - ReactiveClientRegistrationRepository clientRegistrationRepositoryMock = - mock(ReactiveClientRegistrationRepository.class); - when(clientRegistrationRepositoryMock.findByRegistrationId(registrationId)) - .thenReturn(Mono.just(clientRegistration)); - - ServerAuthorizationRequestRepository requestRepositoryMock = mock(ServerAuthorizationRequestRepository.class); - SecurityWebFilterChain filterChain = http.oauth2Login() - .clientRegistrationRepository(clientRegistrationRepositoryMock) - .authorizationRequestRepository(requestRepositoryMock) - .and().build(); - - Optional redirectWebFilter = - getWebFilter(filterChain, OAuth2AuthorizationRequestRedirectWebFilter.class); - assertThat(redirectWebFilter.isPresent()).isTrue(); - - WebTestClient client = WebTestClient.bindToController(new SubscriberContextController()) - .webFilter(redirectWebFilter.get()) - .build(); - client.get().uri("/oauth2/authorization/" + registrationId).exchange(); - verify(requestRepositoryMock, times(1)).saveAuthorizationRequest(any(OAuth2AuthorizationRequest.class), - any(ServerWebExchange.class)); - } - private boolean isX509Filter(WebFilter filter) { try { Object converter = ReflectionTestUtils.getField(filter, "authenticationConverter");