diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java index 1891dd4e3b..309042a772 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java @@ -24,8 +24,8 @@ import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.context.ReactiveSecurityContextHolder; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; -import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService; import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.util.Assert; @@ -34,6 +34,7 @@ import org.springframework.web.reactive.function.client.ClientRequest; import org.springframework.web.reactive.function.client.ClientResponse; import org.springframework.web.reactive.function.client.ExchangeFilterFunction; import org.springframework.web.reactive.function.client.ExchangeFunction; +import org.springframework.web.server.ServerWebExchange; import reactor.core.publisher.Mono; import java.net.URI; @@ -60,16 +61,22 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements */ private static final String OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME = OAuth2AuthorizedClient.class.getName(); + /** + * The request attribute name used to locate the {@link org.springframework.web.server.ServerWebExchange}. + */ + private static final String SERVER_WEB_EXCHANGE_ATTR_NAME = ServerWebExchange.class.getName(); + private Clock clock = Clock.systemUTC(); private Duration accessTokenExpiresSkew = Duration.ofMinutes(1); - private ReactiveOAuth2AuthorizedClientService authorizedClientService; + private ServerOAuth2AuthorizedClientRepository authorizedClientRepository; public ServerOAuth2AuthorizedClientExchangeFilterFunction() {} - public ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveOAuth2AuthorizedClientService authorizedClientService) { - this.authorizedClientService = authorizedClientService; + public ServerOAuth2AuthorizedClientExchangeFilterFunction( + ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { + this.authorizedClientRepository = authorizedClientRepository; } /** @@ -78,7 +85,7 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements * *
 	 * WebClient webClient = WebClient.builder()
-	 *    .filter(new OAuth2AuthorizedClientExchangeFilterFunction(authorizedClientService))
+	 *    .filter(new OAuth2AuthorizedClientExchangeFilterFunction(authorizedClientRepository))
 	 *    .build();
 	 * Mono response = webClient
 	 *    .get()
@@ -110,6 +117,30 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
 		return attributes -> attributes.put(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME, authorizedClient);
 	}
 
+
+	/**
+	 * Modifies the {@link ClientRequest#attributes()} to include the {@link OAuth2AuthorizedClient} to be used for
+	 * providing the Bearer Token. Example usage:
+	 *
+	 * 
+	 * WebClient webClient = WebClient.builder()
+	 *    .filter(new OAuth2AuthorizedClientExchangeFilterFunction(authorizedClientRepository))
+	 *    .build();
+	 * Mono response = webClient
+	 *    .get()
+	 *    .uri(uri)
+	 *    .attributes(serverWebExchange(serverWebExchange))
+	 *    // ...
+	 *    .retrieve()
+	 *    .bodyToMono(String.class);
+	 * 
+ * @param serverWebExchange the {@link ServerWebExchange} to use + * @return the {@link Consumer} to populate the client request attributes + */ + public static Consumer> serverWebExchange(ServerWebExchange serverWebExchange) { + return attributes -> attributes.put(SERVER_WEB_EXCHANGE_ATTR_NAME, serverWebExchange); + } + /** * An access token will be considered expired by comparing its expiration to now + * this skewed Duration. The default is 1 minute. @@ -124,22 +155,23 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements public Mono filter(ClientRequest request, ExchangeFunction next) { Optional attribute = request.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME) .map(OAuth2AuthorizedClient.class::cast); + ServerWebExchange exchange = (ServerWebExchange) request.attributes().get(SERVER_WEB_EXCHANGE_ATTR_NAME); return Mono.justOrEmpty(attribute) - .flatMap(authorizedClient -> authorizedClient(next, authorizedClient)) + .flatMap(authorizedClient -> authorizedClient(next, authorizedClient, exchange)) .map(authorizedClient -> bearer(request, authorizedClient)) .flatMap(next::exchange) .switchIfEmpty(next.exchange(request)); } - private Mono authorizedClient(ExchangeFunction next, OAuth2AuthorizedClient authorizedClient) { + private Mono authorizedClient(ExchangeFunction next, OAuth2AuthorizedClient authorizedClient, ServerWebExchange exchange) { if (shouldRefresh(authorizedClient)) { - return refreshAuthorizedClient(next, authorizedClient); + return refreshAuthorizedClient(next, authorizedClient, exchange); } return Mono.just(authorizedClient); } private Mono refreshAuthorizedClient(ExchangeFunction next, - OAuth2AuthorizedClient authorizedClient) { + OAuth2AuthorizedClient authorizedClient, ServerWebExchange exchange) { ClientRegistration clientRegistration = authorizedClient .getClientRegistration(); String tokenUri = clientRegistration @@ -155,12 +187,12 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements .flatMap(result -> ReactiveSecurityContextHolder.getContext() .map(SecurityContext::getAuthentication) .defaultIfEmpty(new PrincipalNameAuthentication(authorizedClient.getPrincipalName())) - .flatMap(principal -> this.authorizedClientService.saveAuthorizedClient(result, principal)) + .flatMap(principal -> this.authorizedClientRepository.saveAuthorizedClient(result, principal, exchange)) .thenReturn(result)); } private boolean shouldRefresh(OAuth2AuthorizedClient authorizedClient) { - if (this.authorizedClientService == null) { + if (this.authorizedClientRepository == null) { return false; } OAuth2RefreshToken refreshToken = authorizedClient.getRefreshToken(); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java index 99df5b92b2..595ea57f9f 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java @@ -36,9 +36,9 @@ import org.springframework.mock.http.client.reactive.MockClientHttpRequest; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.context.ReactiveSecurityContextHolder; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; -import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; @@ -70,7 +70,7 @@ import static org.springframework.security.oauth2.client.web.reactive.function.c @RunWith(MockitoJUnitRunner.class) public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Mock - private ReactiveOAuth2AuthorizedClientService authorizedClientService; + private ServerOAuth2AuthorizedClientRepository auth2AuthorizedClientRepository; private ServerOAuth2AuthorizedClientExchangeFilterFunction function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(); @@ -124,7 +124,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenRefreshRequiredThenRefresh() { - when(this.authorizedClientService.saveAuthorizedClient(any(), any())).thenReturn(Mono.empty()); + when(this.auth2AuthorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty()); OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1") .tokenType(OAuth2AccessToken.TokenType.BEARER) .expiresIn(3600) @@ -139,7 +139,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { this.accessToken.getTokenValue(), issuedAt, accessTokenExpiresAt); - this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientService); + this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository); OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, @@ -153,7 +153,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) .block(); - verify(this.authorizedClientService).saveAuthorizedClient(any(), eq(authentication)); + verify(this.auth2AuthorizedClientRepository).saveAuthorizedClient(any(), eq(authentication), any()); List requests = this.exchange.getRequests(); assertThat(requests).hasSize(2); @@ -173,7 +173,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved() { - when(this.authorizedClientService.saveAuthorizedClient(any(), any())).thenReturn(Mono.empty()); + when(this.auth2AuthorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty()); OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1") .tokenType(OAuth2AccessToken.TokenType.BEARER) .expiresIn(3600) @@ -188,7 +188,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { this.accessToken.getTokenValue(), issuedAt, accessTokenExpiresAt); - this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientService); + this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository); OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, @@ -200,7 +200,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { this.function.filter(request, this.exchange) .block(); - verify(this.authorizedClientService).saveAuthorizedClient(any(), any()); + verify(this.auth2AuthorizedClientRepository).saveAuthorizedClient(any(), any(), any()); List requests = this.exchange.getRequests(); assertThat(requests).hasSize(2); @@ -220,7 +220,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenRefreshTokenNullThenShouldRefreshFalse() { - this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientService); + this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken); @@ -242,7 +242,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenNotExpiredThenShouldRefreshFalse() { - this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientService); + this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository); OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt()); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,