ServerOAuth2AuthorizedClientExchangeFilterFunction clientRegistrationId
Issue: gh-4921
This commit is contained in:
		
							parent
							
								
									28537fa3b6
								
							
						
					
					
						commit
						158b8aa6d5
					
				| 
						 | 
				
			
			@ -27,12 +27,15 @@ import org.springframework.security.core.context.ReactiveSecurityContextHolder;
 | 
			
		|||
import org.springframework.security.core.context.SecurityContext;
 | 
			
		||||
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
 | 
			
		||||
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
 | 
			
		||||
import org.springframework.security.oauth2.client.OAuth2ClientException;
 | 
			
		||||
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
 | 
			
		||||
import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient;
 | 
			
		||||
import org.springframework.security.oauth2.client.endpoint.WebClientReactiveClientCredentialsTokenResponseClient;
 | 
			
		||||
import org.springframework.security.oauth2.client.registration.ClientRegistration;
 | 
			
		||||
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
 | 
			
		||||
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
 | 
			
		||||
import org.springframework.security.oauth2.core.AuthorizationGrantType;
 | 
			
		||||
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 | 
			
		||||
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 | 
			
		||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 | 
			
		||||
import org.springframework.util.Assert;
 | 
			
		||||
import org.springframework.web.reactive.function.BodyInserters;
 | 
			
		||||
import org.springframework.web.reactive.function.client.ClientRequest;
 | 
			
		||||
| 
						 | 
				
			
			@ -75,18 +78,25 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
 | 
			
		|||
	 * The request attribute name used to locate the {@link org.springframework.web.server.ServerWebExchange}.
 | 
			
		||||
	 */
 | 
			
		||||
	private static final String SERVER_WEB_EXCHANGE_ATTR_NAME = ServerWebExchange.class.getName();
 | 
			
		||||
	public static final AnonymousAuthenticationToken ANONYMOUS_USER_TOKEN = new AnonymousAuthenticationToken("anonymous", "anonymousUser",
 | 
			
		||||
 | 
			
		||||
	private static final AnonymousAuthenticationToken ANONYMOUS_USER_TOKEN = new AnonymousAuthenticationToken("anonymous", "anonymousUser",
 | 
			
		||||
			AuthorityUtils.createAuthorityList("ROLE_USER"));
 | 
			
		||||
 | 
			
		||||
	private Clock clock = Clock.systemUTC();
 | 
			
		||||
 | 
			
		||||
	private Duration accessTokenExpiresSkew = Duration.ofMinutes(1);
 | 
			
		||||
 | 
			
		||||
	private ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient =
 | 
			
		||||
			new WebClientReactiveClientCredentialsTokenResponseClient();
 | 
			
		||||
 | 
			
		||||
	private ReactiveClientRegistrationRepository clientRegistrationRepository;
 | 
			
		||||
 | 
			
		||||
	private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
 | 
			
		||||
 | 
			
		||||
	public ServerOAuth2AuthorizedClientExchangeFilterFunction() {}
 | 
			
		||||
 | 
			
		||||
	public ServerOAuth2AuthorizedClientExchangeFilterFunction(ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
 | 
			
		||||
	public ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveClientRegistrationRepository clientRegistrationRepository, ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
 | 
			
		||||
		this.clientRegistrationRepository = clientRegistrationRepository;
 | 
			
		||||
		this.authorizedClientRepository = authorizedClientRepository;
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -164,6 +174,17 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
 | 
			
		|||
		return attributes -> attributes.put(CLIENT_REGISTRATION_ID_ATTR_NAME, clientRegistrationId);
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	/**
 | 
			
		||||
	 * Sets the {@link ReactiveOAuth2AccessTokenResponseClient} to be used for getting an {@link OAuth2AuthorizedClient} for
 | 
			
		||||
	 * client_credentials grant.
 | 
			
		||||
	 * @param clientCredentialsTokenResponseClient the client to use
 | 
			
		||||
	 */
 | 
			
		||||
	public void setClientCredentialsTokenResponseClient(
 | 
			
		||||
			ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient) {
 | 
			
		||||
		Assert.notNull(clientCredentialsTokenResponseClient, "clientCredentialsTokenResponseClient cannot be null");
 | 
			
		||||
		this.clientCredentialsTokenResponseClient = clientCredentialsTokenResponseClient;
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	/**
 | 
			
		||||
	 * An access token will be considered expired by comparing its expiration to now +
 | 
			
		||||
	 * this skewed Duration. The default is 1 minute.
 | 
			
		||||
| 
						 | 
				
			
			@ -208,7 +229,39 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
 | 
			
		|||
	private Mono<OAuth2AuthorizedClient> loadAuthorizedClient(String clientRegistrationId,
 | 
			
		||||
			ServerWebExchange exchange, Authentication principal) {
 | 
			
		||||
		return this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, exchange)
 | 
			
		||||
			.switchIfEmpty(Mono.error(() -> new ClientAuthorizationRequiredException(clientRegistrationId)));
 | 
			
		||||
			.switchIfEmpty(authorizedClientNotFound(clientRegistrationId, exchange));
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	private Mono<OAuth2AuthorizedClient> authorizedClientNotFound(String clientRegistrationId, ServerWebExchange exchange) {
 | 
			
		||||
		return this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
 | 
			
		||||
			.switchIfEmpty(Mono.error(() -> new IllegalArgumentException("Client Registration with id " + clientRegistrationId + " was not found")))
 | 
			
		||||
			.flatMap(clientRegistration -> {
 | 
			
		||||
				if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) {
 | 
			
		||||
					return clientCredentials(clientRegistration, exchange);
 | 
			
		||||
				}
 | 
			
		||||
				return Mono.error(() -> new ClientAuthorizationRequiredException(clientRegistrationId));
 | 
			
		||||
			});
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	private Mono<? extends OAuth2AuthorizedClient> clientCredentials(
 | 
			
		||||
			ClientRegistration clientRegistration, ServerWebExchange exchange) {
 | 
			
		||||
		OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration);
 | 
			
		||||
		return this.clientCredentialsTokenResponseClient.getTokenResponse(grantRequest)
 | 
			
		||||
			.flatMap(tokenResponse -> clientCredentialsResponse(clientRegistration, tokenResponse, exchange));
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	private Mono<OAuth2AuthorizedClient> clientCredentialsResponse(ClientRegistration clientRegistration, OAuth2AccessTokenResponse tokenResponse, ServerWebExchange exchange) {
 | 
			
		||||
		return currentAuthentication()
 | 
			
		||||
			.flatMap(principal -> {
 | 
			
		||||
				OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
 | 
			
		||||
						clientRegistration, (principal != null ?
 | 
			
		||||
						principal.getName() :
 | 
			
		||||
						"anonymousUser"),
 | 
			
		||||
						tokenResponse.getAccessToken());
 | 
			
		||||
 | 
			
		||||
				return this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, null)
 | 
			
		||||
						.thenReturn(authorizedClient);
 | 
			
		||||
			});
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	private Mono<OAuth2AuthorizedClient> refreshIfNecessary(ExchangeFunction next, OAuth2AuthorizedClient authorizedClient, ServerWebExchange exchange) {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -37,6 +37,7 @@ import org.springframework.security.authentication.TestingAuthenticationToken;
 | 
			
		|||
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
 | 
			
		||||
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
 | 
			
		||||
import org.springframework.security.oauth2.client.registration.ClientRegistration;
 | 
			
		||||
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
 | 
			
		||||
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
 | 
			
		||||
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
 | 
			
		||||
import org.springframework.security.oauth2.core.OAuth2AccessToken;
 | 
			
		||||
| 
						 | 
				
			
			@ -71,7 +72,10 @@ import static org.springframework.security.oauth2.client.web.reactive.function.c
 | 
			
		|||
@RunWith(MockitoJUnitRunner.class)
 | 
			
		||||
public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
 | 
			
		||||
	@Mock
 | 
			
		||||
	private ServerOAuth2AuthorizedClientRepository auth2AuthorizedClientRepository;
 | 
			
		||||
	private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
 | 
			
		||||
 | 
			
		||||
	@Mock
 | 
			
		||||
	private ReactiveClientRegistrationRepository clientRegistrationRepository;
 | 
			
		||||
 | 
			
		||||
	private ServerOAuth2AuthorizedClientExchangeFilterFunction function = new ServerOAuth2AuthorizedClientExchangeFilterFunction();
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -125,7 +129,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
 | 
			
		|||
 | 
			
		||||
	@Test
 | 
			
		||||
	public void filterWhenRefreshRequiredThenRefresh() {
 | 
			
		||||
		when(this.auth2AuthorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
 | 
			
		||||
		when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
 | 
			
		||||
		OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1")
 | 
			
		||||
				.tokenType(OAuth2AccessToken.TokenType.BEARER)
 | 
			
		||||
				.expiresIn(3600)
 | 
			
		||||
| 
						 | 
				
			
			@ -140,7 +144,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
 | 
			
		|||
				this.accessToken.getTokenValue(),
 | 
			
		||||
				issuedAt,
 | 
			
		||||
				accessTokenExpiresAt);
 | 
			
		||||
		this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
 | 
			
		||||
		this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
 | 
			
		||||
 | 
			
		||||
		OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt);
 | 
			
		||||
		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
 | 
			
		||||
| 
						 | 
				
			
			@ -154,7 +158,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
 | 
			
		|||
				.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication))
 | 
			
		||||
				.block();
 | 
			
		||||
 | 
			
		||||
		verify(this.auth2AuthorizedClientRepository).saveAuthorizedClient(any(), eq(authentication), any());
 | 
			
		||||
		verify(this.authorizedClientRepository).saveAuthorizedClient(any(), eq(authentication), any());
 | 
			
		||||
 | 
			
		||||
		List<ClientRequest> requests = this.exchange.getRequests();
 | 
			
		||||
		assertThat(requests).hasSize(2);
 | 
			
		||||
| 
						 | 
				
			
			@ -174,7 +178,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
 | 
			
		|||
 | 
			
		||||
	@Test
 | 
			
		||||
	public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved() {
 | 
			
		||||
		when(this.auth2AuthorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
 | 
			
		||||
		when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
 | 
			
		||||
		OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1")
 | 
			
		||||
				.tokenType(OAuth2AccessToken.TokenType.BEARER)
 | 
			
		||||
				.expiresIn(3600)
 | 
			
		||||
| 
						 | 
				
			
			@ -189,7 +193,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
 | 
			
		|||
				this.accessToken.getTokenValue(),
 | 
			
		||||
				issuedAt,
 | 
			
		||||
				accessTokenExpiresAt);
 | 
			
		||||
		this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
 | 
			
		||||
		this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
 | 
			
		||||
 | 
			
		||||
		OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt);
 | 
			
		||||
		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
 | 
			
		||||
| 
						 | 
				
			
			@ -201,7 +205,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
 | 
			
		|||
		this.function.filter(request, this.exchange)
 | 
			
		||||
				.block();
 | 
			
		||||
 | 
			
		||||
		verify(this.auth2AuthorizedClientRepository).saveAuthorizedClient(any(), any(), any());
 | 
			
		||||
		verify(this.authorizedClientRepository).saveAuthorizedClient(any(), any(), any());
 | 
			
		||||
 | 
			
		||||
		List<ClientRequest> requests = this.exchange.getRequests();
 | 
			
		||||
		assertThat(requests).hasSize(2);
 | 
			
		||||
| 
						 | 
				
			
			@ -221,7 +225,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
 | 
			
		|||
 | 
			
		||||
	@Test
 | 
			
		||||
	public void filterWhenRefreshTokenNullThenShouldRefreshFalse() {
 | 
			
		||||
		this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
 | 
			
		||||
		this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
 | 
			
		||||
 | 
			
		||||
		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
 | 
			
		||||
				"principalName", this.accessToken);
 | 
			
		||||
| 
						 | 
				
			
			@ -243,7 +247,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
 | 
			
		|||
 | 
			
		||||
	@Test
 | 
			
		||||
	public void filterWhenNotExpiredThenShouldRefreshFalse() {
 | 
			
		||||
		this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
 | 
			
		||||
		this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
 | 
			
		||||
 | 
			
		||||
		OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt());
 | 
			
		||||
		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
 | 
			
		||||
| 
						 | 
				
			
			@ -266,12 +270,13 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
 | 
			
		|||
 | 
			
		||||
	@Test
 | 
			
		||||
	public void filterWhenClientRegistrationIdThenAuthorizedClientResolved() {
 | 
			
		||||
		this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
 | 
			
		||||
		this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
 | 
			
		||||
 | 
			
		||||
		OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt());
 | 
			
		||||
		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
 | 
			
		||||
				"principalName", this.accessToken, refreshToken);
 | 
			
		||||
		when(this.auth2AuthorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.just(authorizedClient));
 | 
			
		||||
		when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.just(authorizedClient));
 | 
			
		||||
		when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(this.registration));
 | 
			
		||||
		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
 | 
			
		||||
				.attributes(clientRegistrationId(this.registration.getRegistrationId()))
 | 
			
		||||
				.build();
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -18,7 +18,9 @@ package sample.config;
 | 
			
		|||
 | 
			
		||||
import org.springframework.context.annotation.Bean;
 | 
			
		||||
import org.springframework.context.annotation.Configuration;
 | 
			
		||||
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
 | 
			
		||||
import org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction;
 | 
			
		||||
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
 | 
			
		||||
import org.springframework.web.reactive.function.client.WebClient;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
| 
						 | 
				
			
			@ -29,9 +31,10 @@ import org.springframework.web.reactive.function.client.WebClient;
 | 
			
		|||
public class WebClientConfig {
 | 
			
		||||
 | 
			
		||||
	@Bean
 | 
			
		||||
	WebClient webClient() {
 | 
			
		||||
	WebClient webClient(ReactiveClientRegistrationRepository clientRegistrationRepository,
 | 
			
		||||
			ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
 | 
			
		||||
		return WebClient.builder()
 | 
			
		||||
				.filter(new ServerOAuth2AuthorizedClientExchangeFilterFunction())
 | 
			
		||||
				.filter(new ServerOAuth2AuthorizedClientExchangeFilterFunction(clientRegistrationRepository, authorizedClientRepository))
 | 
			
		||||
				.build();
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue