diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/reactive/ReactiveOAuth2ClientImportSelector.java b/config/src/main/java/org/springframework/security/config/annotation/web/reactive/ReactiveOAuth2ClientImportSelector.java index 760e3aecf0..a089518a56 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/reactive/ReactiveOAuth2ClientImportSelector.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/reactive/ReactiveOAuth2ClientImportSelector.java @@ -21,6 +21,7 @@ import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.ImportSelector; import org.springframework.core.type.AnnotationMetadata; import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService; +import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.web.server.AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.reactive.result.method.annotation.OAuth2AuthorizedClientArgumentResolver; @@ -53,17 +54,25 @@ final class ReactiveOAuth2ClientImportSelector implements ImportSelector { @Configuration static class OAuth2ClientWebFluxSecurityConfiguration implements WebFluxConfigurer { + private ReactiveClientRegistrationRepository clientRegistrationRepository; + private ServerOAuth2AuthorizedClientRepository authorizedClientRepository; private ReactiveOAuth2AuthorizedClientService authorizedClientService; @Override public void configureArgumentResolvers(ArgumentResolverConfigurer configurer) { - if (this.authorizedClientRepository != null) { - configurer.addCustomResolver(new OAuth2AuthorizedClientArgumentResolver(getAuthorizedClientRepository())); + if (this.authorizedClientRepository != null && this.clientRegistrationRepository != null) { + configurer.addCustomResolver(new OAuth2AuthorizedClientArgumentResolver(this.clientRegistrationRepository, getAuthorizedClientRepository())); } } + @Autowired(required = false) + public void setClientRegistrationRepository( + ReactiveClientRegistrationRepository clientRegistrationRepository) { + this.clientRegistrationRepository = clientRegistrationRepository; + } + @Autowired(required = false) public void setAuthorizedClientRepository(ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { this.authorizedClientRepository = authorizedClientRepository; diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/OAuth2AuthorizedClientResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/OAuth2AuthorizedClientResolver.java new file mode 100644 index 0000000000..6b381c6d2b --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/OAuth2AuthorizedClientResolver.java @@ -0,0 +1,185 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.web.reactive.function.client; + +import org.springframework.security.authentication.AnonymousAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.core.context.ReactiveSecurityContextHolder; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; +import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.WebClientReactiveClientCredentialsTokenResponseClient; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; +import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.util.Assert; +import org.springframework.web.server.ServerWebExchange; +import reactor.core.publisher.Mono; + +import java.util.Optional; + +/** + * @author Rob Winch + * @since 5.1 + */ +class OAuth2AuthorizedClientResolver { + + private static final AnonymousAuthenticationToken ANONYMOUS_USER_TOKEN = new AnonymousAuthenticationToken("anonymous", "anonymousUser", + AuthorityUtils.createAuthorityList("ROLE_USER")); + + private final ReactiveClientRegistrationRepository clientRegistrationRepository; + + private final ServerOAuth2AuthorizedClientRepository authorizedClientRepository; + + private ReactiveOAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient = new WebClientReactiveClientCredentialsTokenResponseClient(); + + private boolean defaultOAuth2AuthorizedClient; + + public OAuth2AuthorizedClientResolver( + ReactiveClientRegistrationRepository clientRegistrationRepository, + ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { + Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); + Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); + this.clientRegistrationRepository = clientRegistrationRepository; + this.authorizedClientRepository = authorizedClientRepository; + } + + /** + * If true, a default {@link OAuth2AuthorizedClient} can be discovered from the current Authentication. It is + * recommended to be cautious with this feature since all HTTP requests will receive the access token if it can be + * resolved from the current Authentication. + * @param defaultOAuth2AuthorizedClient true if a default {@link OAuth2AuthorizedClient} should be used, else false. + * Default is false. + */ + public void setDefaultOAuth2AuthorizedClient(boolean defaultOAuth2AuthorizedClient) { + this.defaultOAuth2AuthorizedClient = defaultOAuth2AuthorizedClient; + } + + /** + * Sets the {@link ReactiveOAuth2AccessTokenResponseClient} to be used for getting an {@link OAuth2AuthorizedClient} for + * client_credentials grant. + * @param clientCredentialsTokenResponseClient the client to use + */ + public void setClientCredentialsTokenResponseClient( + ReactiveOAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient) { + Assert.notNull(clientCredentialsTokenResponseClient, "clientCredentialsTokenResponseClient cannot be null"); + this.clientCredentialsTokenResponseClient = clientCredentialsTokenResponseClient; + } + + Mono createDefaultedRequest(String clientRegistrationId, + Authentication authentication, ServerWebExchange exchange) { + Mono defaultedAuthentication = Mono.justOrEmpty(authentication) + .switchIfEmpty(currentAuthentication()); + + Mono defaultedRegistrationId = Mono.justOrEmpty(clientRegistrationId) + .switchIfEmpty(clientRegistrationId(defaultedAuthentication)); + + Mono> defaultedExchange = Mono.justOrEmpty(exchange) + .switchIfEmpty(currentServerWebExchange()).map(Optional::of) + .defaultIfEmpty(Optional.empty()); + + return Mono.zip(defaultedRegistrationId, defaultedAuthentication, defaultedExchange) + .map(t3 -> new Request(t3.getT1(), t3.getT2(), t3.getT3().orElse(null))); + } + + Mono loadAuthorizedClient(Request request) { + String clientRegistrationId = request.getClientRegistrationId(); + Authentication authentication = request.getAuthentication(); + ServerWebExchange exchange = request.getExchange(); + return this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, authentication, exchange) + .switchIfEmpty(authorizedClientNotLoaded(clientRegistrationId, authentication, exchange)); + } + + private Mono authorizedClientNotLoaded(String clientRegistrationId, Authentication authentication, ServerWebExchange exchange) { + return this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId) + .switchIfEmpty(Mono.error(() -> new IllegalArgumentException("Client Registration with id " + clientRegistrationId + " was not found"))) + .flatMap(clientRegistration -> { + if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) { + return clientCredentials(clientRegistration, authentication, exchange); + } + return Mono.error(() -> new ClientAuthorizationRequiredException(clientRegistrationId)); + }); +} + + private Mono clientCredentials( + ClientRegistration clientRegistration, Authentication authentication, ServerWebExchange exchange) { + OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration); + return this.clientCredentialsTokenResponseClient.getTokenResponse(grantRequest) + .flatMap(tokenResponse -> clientCredentialsResponse(clientRegistration, authentication, exchange, tokenResponse)); + } + + private Mono clientCredentialsResponse(ClientRegistration clientRegistration, Authentication authentication, ServerWebExchange exchange, OAuth2AccessTokenResponse tokenResponse) { + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + clientRegistration, authentication.getName(), tokenResponse.getAccessToken()); + return this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, authentication, exchange) + .thenReturn(authorizedClient); + } + + /** + * Attempts to load the client registration id from the current {@link Authentication} + * @return + */ + private Mono clientRegistrationId(Mono authentication) { + return authentication + .filter(t -> this.defaultOAuth2AuthorizedClient && t instanceof OAuth2AuthenticationToken) + .cast(OAuth2AuthenticationToken.class) + .map(OAuth2AuthenticationToken::getAuthorizedClientRegistrationId); + } + + private Mono currentAuthentication() { + return ReactiveSecurityContextHolder.getContext() + .map(SecurityContext::getAuthentication) + .defaultIfEmpty(ANONYMOUS_USER_TOKEN); + } + + private Mono currentServerWebExchange() { + return Mono.subscriberContext() + .filter(c -> c.hasKey(ServerWebExchange.class)) + .map(c -> c.get(ServerWebExchange.class)); + } + + static class Request { + private final String clientRegistrationId; + private final Authentication authentication; + private final ServerWebExchange exchange; + + public Request(String clientRegistrationId, Authentication authentication, + ServerWebExchange exchange) { + this.clientRegistrationId = clientRegistrationId; + this.authentication = authentication; + this.exchange = exchange; + } + + public String getClientRegistrationId() { + return this.clientRegistrationId; + } + + public Authentication getAuthentication() { + return this.authentication; + } + + public ServerWebExchange getExchange() { + return this.exchange; + } + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java index ce915f120a..3d66f295a8 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java @@ -16,28 +16,21 @@ package org.springframework.security.oauth2.client.web.reactive.function.client; -import com.sun.security.ntlm.Server; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.core.Authentication; -import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.ReactiveSecurityContextHolder; -import org.springframework.security.core.context.SecurityContext; -import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; -import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient; -import org.springframework.security.oauth2.client.endpoint.WebClientReactiveClientCredentialsTokenResponseClient; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2RefreshToken; -import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.util.Assert; import org.springframework.web.reactive.function.BodyInserters; import org.springframework.web.reactive.function.client.ClientRequest; @@ -51,9 +44,7 @@ import java.net.URI; import java.time.Clock; import java.time.Duration; import java.time.Instant; -import java.util.Collection; import java.util.Map; -import java.util.Optional; import java.util.function.Consumer; import static org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors.oauth2AccessTokenResponse; @@ -88,20 +79,13 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements private Duration accessTokenExpiresSkew = Duration.ofMinutes(1); - private boolean defaultOAuth2AuthorizedClient; - - private ReactiveOAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient = - new WebClientReactiveClientCredentialsTokenResponseClient(); - - private ReactiveClientRegistrationRepository clientRegistrationRepository; - private ServerOAuth2AuthorizedClientRepository authorizedClientRepository; - public ServerOAuth2AuthorizedClientExchangeFilterFunction() {} + private final OAuth2AuthorizedClientResolver authorizedClientResolver; public ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveClientRegistrationRepository clientRegistrationRepository, ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { - this.clientRegistrationRepository = clientRegistrationRepository; this.authorizedClientRepository = authorizedClientRepository; + this.authorizedClientResolver = new OAuth2AuthorizedClientResolver(clientRegistrationRepository, authorizedClientRepository); } /** @@ -142,6 +126,9 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements return attributes -> attributes.put(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME, authorizedClient); } + private static OAuth2AuthorizedClient oauth2AuthorizedClient(ClientRequest request) { + return (OAuth2AuthorizedClient) request.attributes().get(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME); + } /** * Modifies the {@link ClientRequest#attributes()} to include the {@link OAuth2AuthorizedClient} to be used for @@ -166,6 +153,10 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements return attributes -> attributes.put(SERVER_WEB_EXCHANGE_ATTR_NAME, serverWebExchange); } + private static ServerWebExchange serverWebExchange(ClientRequest request) { + return (ServerWebExchange) request.attributes().get(SERVER_WEB_EXCHANGE_ATTR_NAME); + } + /** * Modifies the {@link ClientRequest#attributes()} to include the {@link ClientRegistration#getRegistrationId()} to * be used to look up the {@link OAuth2AuthorizedClient}. @@ -178,6 +169,14 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements return attributes -> attributes.put(CLIENT_REGISTRATION_ID_ATTR_NAME, clientRegistrationId); } + private static String clientRegistrationId(ClientRequest request) { + OAuth2AuthorizedClient authorizedClient = oauth2AuthorizedClient(request); + if (authorizedClient != null) { + return authorizedClient.getClientRegistration().getRegistrationId(); + } + return (String) request.attributes().get(CLIENT_REGISTRATION_ID_ATTR_NAME); + } + /** * If true, a default {@link OAuth2AuthorizedClient} can be discovered from the current Authentication. It is * recommended to be cautious with this feature since all HTTP requests will receive the access token if it can be @@ -186,7 +185,7 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements * Default is false. */ public void setDefaultOAuth2AuthorizedClient(boolean defaultOAuth2AuthorizedClient) { - this.defaultOAuth2AuthorizedClient = defaultOAuth2AuthorizedClient; + this.authorizedClientResolver.setDefaultOAuth2AuthorizedClient(defaultOAuth2AuthorizedClient); } /** @@ -196,8 +195,7 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements */ public void setClientCredentialsTokenResponseClient( ReactiveOAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient) { - Assert.notNull(clientCredentialsTokenResponseClient, "clientCredentialsTokenResponseClient cannot be null"); - this.clientCredentialsTokenResponseClient = clientCredentialsTokenResponseClient; + this.authorizedClientResolver.setClientCredentialsTokenResponseClient(clientCredentialsTokenResponseClient); } /** @@ -212,128 +210,59 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements @Override public Mono filter(ClientRequest request, ExchangeFunction next) { - return authorizedClient(request) - .flatMap(authorizedClient -> refreshIfNecessary(next, authorizedClient, request)) + return authorizedClient(request, next) .map(authorizedClient -> bearer(request, authorizedClient)) .flatMap(next::exchange) .switchIfEmpty(next.exchange(request)); } - private Mono serverWebExchange(ClientRequest request) { - ServerWebExchange exchange = (ServerWebExchange) request.attributes().get(SERVER_WEB_EXCHANGE_ATTR_NAME); - return Mono.justOrEmpty(exchange) - .switchIfEmpty(serverWebExchange()); + private Mono authorizedClient(ClientRequest request, ExchangeFunction next) { + OAuth2AuthorizedClient authorizedClientFromAttrs = oauth2AuthorizedClient(request); + return Mono.justOrEmpty(authorizedClientFromAttrs) + .switchIfEmpty(Mono.defer(() -> loadAuthorizedClient(request))) + .flatMap(authorizedClient -> refreshIfNecessary(request, next, authorizedClient)); } - private Mono serverWebExchange() { - return Mono.subscriberContext() - .filter(c -> c.hasKey(ServerWebExchange.class)) - .map(c -> c.get(ServerWebExchange.class)); + private Mono loadAuthorizedClient(ClientRequest request) { + return createRequest(request) + .flatMap(r -> this.authorizedClientResolver.loadAuthorizedClient(r)); } - private Mono authorizedClient(ClientRequest request) { - Optional attribute = request.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME) - .map(OAuth2AuthorizedClient.class::cast); - return Mono.justOrEmpty(attribute) - .switchIfEmpty(findAuthorizedClientByRegistrationId(request)); + private Mono createRequest(ClientRequest request) { + String clientRegistrationId = clientRegistrationId(request); + Authentication authentication = null; + ServerWebExchange exchange = serverWebExchange(request); + return this.authorizedClientResolver.createDefaultedRequest(clientRegistrationId, authentication, exchange); } - private Mono findAuthorizedClientByRegistrationId(ClientRequest request) { - if (this.authorizedClientRepository == null) { - return Mono.empty(); - } - - return currentAuthentication() - .flatMap(principal -> clientRegistrationId(request, principal) - .flatMap(clientRegistrationId -> serverWebExchange(request).flatMap(exchange -> loadAuthorizedClient(clientRegistrationId, exchange, principal))) - ); - } - - private Mono clientRegistrationId(ClientRequest request, Authentication authentication) { - return Mono.justOrEmpty(request.attributes().get(CLIENT_REGISTRATION_ID_ATTR_NAME)) - .cast(String.class) - .switchIfEmpty(clientRegistrationId(authentication)); - } - - private Mono clientRegistrationId(Authentication authentication) { - return Mono.justOrEmpty(authentication) - .filter(t -> this.defaultOAuth2AuthorizedClient && t instanceof OAuth2AuthenticationToken) - .cast(OAuth2AuthenticationToken.class) - .map(OAuth2AuthenticationToken::getAuthorizedClientRegistrationId); - } - - private Mono loadAuthorizedClient(String clientRegistrationId, - ServerWebExchange exchange, Authentication principal) { - return this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, exchange) - .switchIfEmpty(authorizedClientNotFound(clientRegistrationId, exchange)); - } - - private Mono authorizedClientNotFound(String clientRegistrationId, ServerWebExchange exchange) { - return this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId) - .switchIfEmpty(Mono.error(() -> new IllegalArgumentException("Client Registration with id " + clientRegistrationId + " was not found"))) - .flatMap(clientRegistration -> { - if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) { - return clientCredentials(clientRegistration, exchange); - } - return Mono.error(() -> new ClientAuthorizationRequiredException(clientRegistrationId)); - }); - } - - private Mono clientCredentials( - ClientRegistration clientRegistration, ServerWebExchange exchange) { - OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration); - return this.clientCredentialsTokenResponseClient.getTokenResponse(grantRequest) - .flatMap(tokenResponse -> clientCredentialsResponse(clientRegistration, tokenResponse, exchange)); - } - - private Mono clientCredentialsResponse(ClientRegistration clientRegistration, OAuth2AccessTokenResponse tokenResponse, ServerWebExchange exchange) { - return currentAuthentication() - .flatMap(principal -> { - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - clientRegistration, (principal != null ? - principal.getName() : - "anonymousUser"), - tokenResponse.getAccessToken()); - - return this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, null) - .thenReturn(authorizedClient); - }); - } - - private Mono refreshIfNecessary(ExchangeFunction next, OAuth2AuthorizedClient authorizedClient, ClientRequest request) { + private Mono refreshIfNecessary(ClientRequest request, ExchangeFunction next, OAuth2AuthorizedClient authorizedClient) { if (shouldRefresh(authorizedClient)) { - return serverWebExchange(request) - .flatMap(exchange -> refreshAuthorizedClient(next, authorizedClient, exchange)); + return createRequest(request) + .flatMap(r -> refreshAuthorizedClient(next, authorizedClient, r)); } return Mono.just(authorizedClient); } private Mono refreshAuthorizedClient(ExchangeFunction next, - OAuth2AuthorizedClient authorizedClient, ServerWebExchange exchange) { + OAuth2AuthorizedClient authorizedClient, OAuth2AuthorizedClientResolver.Request r) { + ServerWebExchange exchange = r.getExchange(); + Authentication authentication = r.getAuthentication(); ClientRegistration clientRegistration = authorizedClient .getClientRegistration(); String tokenUri = clientRegistration .getProviderDetails().getTokenUri(); - ClientRequest request = ClientRequest.create(HttpMethod.POST, URI.create(tokenUri)) + ClientRequest refreshRequest = ClientRequest.create(HttpMethod.POST, URI.create(tokenUri)) .header(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) .headers(headers -> headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret())) .body(refreshTokenBody(authorizedClient.getRefreshToken().getTokenValue())) .build(); - return next.exchange(request) - .flatMap(response -> response.body(oauth2AccessTokenResponse())) + return next.exchange(refreshRequest) + .flatMap(refreshResponse -> refreshResponse.body(oauth2AccessTokenResponse())) .map(accessTokenResponse -> new OAuth2AuthorizedClient(authorizedClient.getClientRegistration(), authorizedClient.getPrincipalName(), accessTokenResponse.getAccessToken(), accessTokenResponse.getRefreshToken())) - .flatMap(result -> currentAuthentication() - .defaultIfEmpty(new PrincipalNameAuthentication(authorizedClient.getPrincipalName())) - .flatMap(principal -> this.authorizedClientRepository.saveAuthorizedClient(result, principal, exchange)) + .flatMap(result -> this.authorizedClientRepository.saveAuthorizedClient(result, authentication, exchange) .thenReturn(result)); } - private Mono currentAuthentication() { - return ReactiveSecurityContextHolder.getContext() - .map(SecurityContext::getAuthentication) - .defaultIfEmpty(ANONYMOUS_USER_TOKEN); - } - private boolean shouldRefresh(OAuth2AuthorizedClient authorizedClient) { if (this.authorizedClientRepository == null) { return false; @@ -361,52 +290,4 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements .fromFormData("grant_type", AuthorizationGrantType.REFRESH_TOKEN.getValue()) .with("refresh_token", refreshToken); } - - private static class PrincipalNameAuthentication implements Authentication { - private final String username; - - private PrincipalNameAuthentication(String username) { - this.username = username; - } - - @Override - public Collection getAuthorities() { - throw unsupported(); - } - - @Override - public Object getCredentials() { - throw unsupported(); - } - - @Override - public Object getDetails() { - throw unsupported(); - } - - @Override - public Object getPrincipal() { - throw unsupported(); - } - - @Override - public boolean isAuthenticated() { - throw unsupported(); - } - - @Override - public void setAuthenticated(boolean isAuthenticated) - throws IllegalArgumentException { - throw unsupported(); - } - - @Override - public String getName() { - return this.username; - } - - private UnsupportedOperationException unsupported() { - return new UnsupportedOperationException("Not Supported"); - } - } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolver.java index 0012353b82..35c1e2cdd0 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolver.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolver.java @@ -18,15 +18,9 @@ package org.springframework.security.oauth2.client.web.reactive.result.method.an import org.springframework.core.MethodParameter; import org.springframework.core.annotation.AnnotatedElementUtils; -import org.springframework.security.authentication.AnonymousAuthenticationToken; -import org.springframework.security.core.Authentication; -import org.springframework.security.core.authority.AuthorityUtils; -import org.springframework.security.core.context.ReactiveSecurityContextHolder; -import org.springframework.security.core.context.SecurityContext; -import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient; -import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.util.Assert; import org.springframework.util.StringUtils; @@ -56,16 +50,18 @@ import reactor.core.publisher.Mono; * @see RegisteredOAuth2AuthorizedClient */ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMethodArgumentResolver { - private final ServerOAuth2AuthorizedClientRepository authorizedClientRepository; + + private final OAuth2AuthorizedClientResolver authorizedClientResolver; /** * Constructs an {@code OAuth2AuthorizedClientArgumentResolver} using the provided parameters. * * @param authorizedClientRepository the authorized client repository */ - public OAuth2AuthorizedClientArgumentResolver(ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { + public OAuth2AuthorizedClientArgumentResolver(ReactiveClientRegistrationRepository clientRegistrationRepository, ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); - this.authorizedClientRepository = authorizedClientRepository; + this.authorizedClientResolver = new OAuth2AuthorizedClientResolver(clientRegistrationRepository, authorizedClientRepository); + this.authorizedClientResolver.setDefaultOAuth2AuthorizedClient(true); } @Override @@ -80,41 +76,11 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth RegisteredOAuth2AuthorizedClient authorizedClientAnnotation = AnnotatedElementUtils .findMergedAnnotation(parameter.getParameter(), RegisteredOAuth2AuthorizedClient.class); - Mono clientRegistrationId = Mono.justOrEmpty(authorizedClientAnnotation.registrationId()) - .filter(id -> !StringUtils.isEmpty(id)) - .switchIfEmpty(clientRegistrationId()) - .switchIfEmpty(Mono.defer(() -> Mono.error(new IllegalArgumentException( - "Unable to resolve the Client Registration Identifier. It must be provided via @RegisteredOAuth2AuthorizedClient(\"client1\") or @RegisteredOAuth2AuthorizedClient(registrationId = \"client1\").")))); + String clientRegistrationId = StringUtils.hasLength(authorizedClientAnnotation.registrationId()) ? + authorizedClientAnnotation.registrationId() : null; - Mono principal = ReactiveSecurityContextHolder.getContext() - .map(SecurityContext::getAuthentication) - .defaultIfEmpty(new AnonymousAuthenticationToken("key", "anonymous", - AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"))); - - Mono authorizedClient = Mono - .zip(clientRegistrationId, principal).switchIfEmpty( - clientRegistrationId.flatMap(id -> Mono.error(new IllegalStateException( - "Unable to resolve the Authorized Client with registration identifier \"" - + id - + "\". An \"authenticated\" or \"unauthenticated\" session is required. To allow for unauthenticated access, ensure ServerHttpSecurity.anonymous() is configured.")))) - .flatMap(zipped -> { - String registrationId = zipped.getT1(); - Authentication authentication = zipped.getT2(); - return this.authorizedClientRepository - .loadAuthorizedClient(registrationId, authentication, exchange).switchIfEmpty(Mono.defer(() -> Mono - .error(new ClientAuthorizationRequiredException( - registrationId)))); - }).cast(OAuth2AuthorizedClient.class); - - return authorizedClient.cast(Object.class); + return this.authorizedClientResolver.createDefaultedRequest(clientRegistrationId, null, exchange) + .flatMap(this.authorizedClientResolver::loadAuthorizedClient); }); } - - private Mono clientRegistrationId() { - return ReactiveSecurityContextHolder.getContext() - .map(SecurityContext::getAuthentication) - .filter(authentication -> authentication instanceof OAuth2AuthenticationToken) - .cast(OAuth2AuthenticationToken.class) - .map(OAuth2AuthenticationToken::getAuthorizedClientRegistrationId); - } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientResolver.java new file mode 100644 index 0000000000..a90f65e9c5 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientResolver.java @@ -0,0 +1,186 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.web.reactive.result.method.annotation; + +import org.springframework.security.authentication.AnonymousAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.core.context.ReactiveSecurityContextHolder; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; +import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.WebClientReactiveClientCredentialsTokenResponseClient; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; +import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.util.Assert; +import org.springframework.web.server.ServerWebExchange; +import reactor.core.publisher.Mono; + +import java.util.Optional; + +/** + * @author Rob Winch + * @since 5.1 + */ +class OAuth2AuthorizedClientResolver { + + private static final AnonymousAuthenticationToken ANONYMOUS_USER_TOKEN = new AnonymousAuthenticationToken("anonymous", "anonymousUser", + AuthorityUtils.createAuthorityList("ROLE_USER")); + + private final ReactiveClientRegistrationRepository clientRegistrationRepository; + + private final ServerOAuth2AuthorizedClientRepository authorizedClientRepository; + + private ReactiveOAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient = new WebClientReactiveClientCredentialsTokenResponseClient(); + + private boolean defaultOAuth2AuthorizedClient; + + public OAuth2AuthorizedClientResolver( + ReactiveClientRegistrationRepository clientRegistrationRepository, + ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { + Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); + Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); + this.clientRegistrationRepository = clientRegistrationRepository; + this.authorizedClientRepository = authorizedClientRepository; + } + + /** + * If true, a default {@link OAuth2AuthorizedClient} can be discovered from the current Authentication. It is + * recommended to be cautious with this feature since all HTTP requests will receive the access token if it can be + * resolved from the current Authentication. + * @param defaultOAuth2AuthorizedClient true if a default {@link OAuth2AuthorizedClient} should be used, else false. + * Default is false. + */ + public void setDefaultOAuth2AuthorizedClient(boolean defaultOAuth2AuthorizedClient) { + this.defaultOAuth2AuthorizedClient = defaultOAuth2AuthorizedClient; + } + + /** + * Sets the {@link ReactiveOAuth2AccessTokenResponseClient} to be used for getting an {@link OAuth2AuthorizedClient} for + * client_credentials grant. + * @param clientCredentialsTokenResponseClient the client to use + */ + public void setClientCredentialsTokenResponseClient( + ReactiveOAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient) { + Assert.notNull(clientCredentialsTokenResponseClient, "clientCredentialsTokenResponseClient cannot be null"); + this.clientCredentialsTokenResponseClient = clientCredentialsTokenResponseClient; + } + + Mono createDefaultedRequest(String clientRegistrationId, + Authentication authentication, ServerWebExchange exchange) { + Mono defaultedAuthentication = Mono.justOrEmpty(authentication) + .switchIfEmpty(currentAuthentication()); + + Mono defaultedRegistrationId = Mono.justOrEmpty(clientRegistrationId) + .switchIfEmpty(clientRegistrationId(defaultedAuthentication)) + .switchIfEmpty(Mono.error(() -> new IllegalArgumentException("The clientRegistrationId could not be resolved. Please provide one"))); + + Mono> defaultedExchange = Mono.justOrEmpty(exchange) + .switchIfEmpty(currentServerWebExchange()).map(Optional::of) + .defaultIfEmpty(Optional.empty()); + + return Mono.zip(defaultedRegistrationId, defaultedAuthentication, defaultedExchange) + .map(t3 -> new Request(t3.getT1(), t3.getT2(), t3.getT3().orElse(null))); + } + + Mono loadAuthorizedClient(Request request) { + String clientRegistrationId = request.getClientRegistrationId(); + Authentication authentication = request.getAuthentication(); + ServerWebExchange exchange = request.getExchange(); + return this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, authentication, exchange) + .switchIfEmpty(authorizedClientNotLoaded(clientRegistrationId, authentication, exchange)); + } + + private Mono authorizedClientNotLoaded(String clientRegistrationId, Authentication authentication, ServerWebExchange exchange) { + return this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId) + .switchIfEmpty(Mono.error(() -> new IllegalArgumentException("Client Registration with id " + clientRegistrationId + " was not found"))) + .flatMap(clientRegistration -> { + if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) { + return clientCredentials(clientRegistration, authentication, exchange); + } + return Mono.error(() -> new ClientAuthorizationRequiredException(clientRegistrationId)); + }); +} + + private Mono clientCredentials( + ClientRegistration clientRegistration, Authentication authentication, ServerWebExchange exchange) { + OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration); + return this.clientCredentialsTokenResponseClient.getTokenResponse(grantRequest) + .flatMap(tokenResponse -> clientCredentialsResponse(clientRegistration, authentication, exchange, tokenResponse)); + } + + private Mono clientCredentialsResponse(ClientRegistration clientRegistration, Authentication authentication, ServerWebExchange exchange, OAuth2AccessTokenResponse tokenResponse) { + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + clientRegistration, authentication.getName(), tokenResponse.getAccessToken()); + return this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, authentication, exchange) + .thenReturn(authorizedClient); + } + + /** + * Attempts to load the client registration id from the current {@link Authentication} + * @return + */ + private Mono clientRegistrationId(Mono authentication) { + return authentication + .filter(t -> this.defaultOAuth2AuthorizedClient && t instanceof OAuth2AuthenticationToken) + .cast(OAuth2AuthenticationToken.class) + .map(OAuth2AuthenticationToken::getAuthorizedClientRegistrationId); + } + + private Mono currentAuthentication() { + return ReactiveSecurityContextHolder.getContext() + .map(SecurityContext::getAuthentication) + .defaultIfEmpty(ANONYMOUS_USER_TOKEN); + } + + private Mono currentServerWebExchange() { + return Mono.subscriberContext() + .filter(c -> c.hasKey(ServerWebExchange.class)) + .map(c -> c.get(ServerWebExchange.class)); + } + + static class Request { + private final String clientRegistrationId; + private final Authentication authentication; + private final ServerWebExchange exchange; + + public Request(String clientRegistrationId, Authentication authentication, + ServerWebExchange exchange) { + this.clientRegistrationId = clientRegistrationId; + this.authentication = authentication; + this.exchange = exchange; + } + + public String getClientRegistrationId() { + return this.clientRegistrationId; + } + + public Authentication getAuthentication() { + return this.authentication; + } + + public ServerWebExchange getExchange() { + return this.exchange; + } + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java index 74ae319a8b..9faab5ba88 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java @@ -16,6 +16,7 @@ package org.springframework.security.oauth2.client.web.reactive.function.client; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; @@ -88,7 +89,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Mock private ServerWebExchange serverWebExchange; - private ServerOAuth2AuthorizedClientExchangeFilterFunction function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(); + private ServerOAuth2AuthorizedClientExchangeFilterFunction function; private MockExchangeFunction exchange = new MockExchangeFunction(); @@ -100,6 +101,11 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { Instant.now(), Instant.now().plus(Duration.ofDays(1))); + @Before + public void setup() { + this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository); + } + @Test public void filterWhenAuthorizedClientNullThenAuthorizationHeaderNull() { ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) @@ -155,7 +161,6 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { this.accessToken.getTokenValue(), issuedAt, accessTokenExpiresAt); - this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository); OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, @@ -204,7 +209,6 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { this.accessToken.getTokenValue(), issuedAt, accessTokenExpiresAt); - this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository); OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, @@ -236,8 +240,6 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenRefreshTokenNullThenShouldRefreshFalse() { - this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken); ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) @@ -258,8 +260,6 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenNotExpiredThenShouldRefreshFalse() { - this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository); - OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt()); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken, refreshToken); @@ -281,8 +281,6 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenClientRegistrationIdThenAuthorizedClientResolved() { - this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository); - OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt()); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken, refreshToken); @@ -306,7 +304,6 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenClientRegistrationIdFromAuthenticationThenAuthorizedClientResolved() { - this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository); this.function.setDefaultOAuth2AuthorizedClient(true); OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt()); @@ -337,8 +334,6 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenDefaultOAuth2AuthorizedClientFalseThenEmpty() { - this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository); - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .build(); @@ -359,8 +354,6 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenClientRegistrationIdAndServerWebExchangeFromContextThenServerWebExchangeFromContext() { - this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository); - OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt()); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken, refreshToken); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java index 762ef63a57..54b3d45fba 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java @@ -29,9 +29,10 @@ import org.springframework.security.oauth2.client.ClientAuthorizationRequiredExc import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.util.ReflectionUtils; -import reactor.core.publisher.Hooks; import reactor.core.publisher.Mono; import reactor.util.context.Context; @@ -50,6 +51,8 @@ import static org.mockito.Mockito.when; */ @RunWith(MockitoJUnitRunner.class) public class OAuth2AuthorizedClientArgumentResolverTests { + @Mock + private ReactiveClientRegistrationRepository clientRegistrationRepository; @Mock private ServerOAuth2AuthorizedClientRepository authorizedClientRepository; private OAuth2AuthorizedClientArgumentResolver argumentResolver; @@ -59,15 +62,14 @@ public class OAuth2AuthorizedClientArgumentResolverTests { @Before public void setUp() { - this.argumentResolver = new OAuth2AuthorizedClientArgumentResolver(this.authorizedClientRepository); + this.argumentResolver = new OAuth2AuthorizedClientArgumentResolver(this.clientRegistrationRepository, this.authorizedClientRepository); this.authorizedClient = mock(OAuth2AuthorizedClient.class); when(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any())).thenReturn(Mono.just(this.authorizedClient)); - Hooks.onOperatorDebug(); } @Test public void constructorWhenOAuth2AuthorizedClientServiceIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2AuthorizedClientArgumentResolver(null)) + assertThatThrownBy(() -> new OAuth2AuthorizedClientArgumentResolver(this.clientRegistrationRepository, null)) .isInstanceOf(IllegalArgumentException.class); } @@ -94,11 +96,13 @@ public class OAuth2AuthorizedClientArgumentResolverTests { MethodParameter methodParameter = this.getMethodParameter("registrationIdEmpty", OAuth2AuthorizedClient.class); assertThatThrownBy(() -> resolveArgument(methodParameter)) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Unable to resolve the Client Registration Identifier. It must be provided via @RegisteredOAuth2AuthorizedClient(\"client1\") or @RegisteredOAuth2AuthorizedClient(registrationId = \"client1\")."); + .hasMessage("The clientRegistrationId could not be resolved. Please provide one"); } @Test public void resolveArgumentWhenRegistrationIdEmptyAndOAuth2AuthenticationThenResolves() { + when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just( + TestClientRegistrations.clientRegistration().build())); this.authentication = mock(OAuth2AuthenticationToken.class); when(((OAuth2AuthenticationToken) this.authentication).getAuthorizedClientRegistrationId()).thenReturn("client1"); MethodParameter methodParameter = this.getMethodParameter("registrationIdEmpty", OAuth2AuthorizedClient.class); @@ -108,18 +112,24 @@ public class OAuth2AuthorizedClientArgumentResolverTests { @Test public void resolveArgumentWhenParameterTypeOAuth2AuthorizedClientAndCurrentAuthenticationNullThenResolves() { this.authentication = null; + when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just( + TestClientRegistrations.clientRegistration().build())); MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class); assertThat(resolveArgument(methodParameter)).isSameAs(this.authorizedClient); } @Test public void resolveArgumentWhenOAuth2AuthorizedClientFoundThenResolves() { + when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just( + TestClientRegistrations.clientRegistration().build())); MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class); assertThat(resolveArgument(methodParameter)).isSameAs(this.authorizedClient); } @Test public void resolveArgumentWhenOAuth2AuthorizedClientNotFoundThenThrowClientAuthorizationRequiredException() { + when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just( + TestClientRegistrations.clientRegistration().build())); when(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any())).thenReturn(Mono.empty()); MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class); assertThatThrownBy(() -> resolveArgument(methodParameter))