From 69156b741d89dc3e7e7e5e952eddf963639ed226 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Thu, 13 Feb 2020 05:36:17 -0500 Subject: [PATCH] Add OAuth2Authorization success/failure handlers Fixes gh-7840 --- ...tServiceOAuth2AuthorizedClientManager.java | 85 ++++- .../OAuth2AuthorizationFailureHandler.java | 48 +++ .../OAuth2AuthorizationSuccessHandler.java | 47 +++ ...tAuthorizationCodeTokenResponseClient.java | 24 +- ...tClientCredentialsTokenResponseClient.java | 24 +- .../DefaultPasswordTokenResponseClient.java | 24 +- ...efaultRefreshTokenTokenResponseClient.java | 24 +- ...sAuthorizationCodeTokenResponseClient.java | 22 +- .../DefaultOAuth2AuthorizedClientManager.java | 118 ++++++- ...ientOAuth2AuthorizationFailureHandler.java | 169 +++++++++ ...ientOAuth2AuthorizationSuccessHandler.java | 77 ++++ ...uthorizedClientExchangeFilterFunction.java | 330 ++++++++++++++++-- ...iceOAuth2AuthorizedClientManagerTests.java | 93 ++++- ...ultOAuth2AuthorizedClientManagerTests.java | 99 +++++- ...izedClientExchangeFilterFunctionTests.java | 266 +++++++++++++- 15 files changed, 1349 insertions(+), 101 deletions(-) create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationFailureHandler.java create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationSuccessHandler.java create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/RemoveAuthorizedClientOAuth2AuthorizationFailureHandler.java create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/SaveAuthorizedClientOAuth2AuthorizationSuccessHandler.java diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizedClientServiceOAuth2AuthorizedClientManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizedClientServiceOAuth2AuthorizedClientManager.java index 0ceabd2cad..1eebc93a28 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizedClientServiceOAuth2AuthorizedClientManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizedClientServiceOAuth2AuthorizedClientManager.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,11 @@ import org.springframework.lang.Nullable; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager; +import org.springframework.security.oauth2.client.web.RemoveAuthorizedClientOAuth2AuthorizationFailureHandler; +import org.springframework.security.oauth2.client.web.SaveAuthorizedClientOAuth2AuthorizationSuccessHandler; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -31,20 +36,50 @@ import java.util.function.Function; /** * An implementation of an {@link OAuth2AuthorizedClientManager} - * that is capable of operating outside of a {@code HttpServletRequest} context, + * that is capable of operating outside of the context of a {@code HttpServletRequest}, * e.g. in a scheduled/background thread and/or in the service-tier. * + *

+ * (When operating within the context of a {@code HttpServletRequest}, + * use {@link DefaultOAuth2AuthorizedClientManager} instead.) + * + *

Authorized Client Persistence

+ * + *

+ * This manager utilizes an {@link OAuth2AuthorizedClientService} + * to persist {@link OAuth2AuthorizedClient}s. + * + *

+ * By default, when an authorization attempt succeeds, the {@link OAuth2AuthorizedClient} + * will be saved in the {@link OAuth2AuthorizedClientService}. + * This functionality can be changed by configuring a custom {@link OAuth2AuthorizationSuccessHandler} + * via {@link #setAuthorizationSuccessHandler(OAuth2AuthorizationSuccessHandler)}. + * + *

+ * By default, when an authorization attempt fails due to an + * {@value OAuth2ErrorCodes#INVALID_GRANT} error, + * the previously saved {@link OAuth2AuthorizedClient} + * will be removed from the {@link OAuth2AuthorizedClientService}. + * (The {@value OAuth2ErrorCodes#INVALID_GRANT} error can occur + * when a refresh token that is no longer valid is used to retrieve a new access token.) + * This functionality can be changed by configuring a custom {@link OAuth2AuthorizationFailureHandler} + * via {@link #setAuthorizationFailureHandler(OAuth2AuthorizationFailureHandler)}. + * * @author Joe Grandja * @since 5.2 * @see OAuth2AuthorizedClientManager * @see OAuth2AuthorizedClientProvider * @see OAuth2AuthorizedClientService + * @see OAuth2AuthorizationSuccessHandler + * @see OAuth2AuthorizationFailureHandler */ public final class AuthorizedClientServiceOAuth2AuthorizedClientManager implements OAuth2AuthorizedClientManager { private final ClientRegistrationRepository clientRegistrationRepository; private final OAuth2AuthorizedClientService authorizedClientService; private OAuth2AuthorizedClientProvider authorizedClientProvider = context -> null; - private Function> contextAttributesMapper = new DefaultContextAttributesMapper(); + private Function> contextAttributesMapper; + private OAuth2AuthorizationSuccessHandler authorizationSuccessHandler; + private OAuth2AuthorizationFailureHandler authorizationFailureHandler; /** * Constructs an {@code AuthorizedClientServiceOAuth2AuthorizedClientManager} using the provided parameters. @@ -58,6 +93,9 @@ public final class AuthorizedClientServiceOAuth2AuthorizedClientManager implemen Assert.notNull(authorizedClientService, "authorizedClientService cannot be null"); this.clientRegistrationRepository = clientRegistrationRepository; this.authorizedClientService = authorizedClientService; + this.contextAttributesMapper = new DefaultContextAttributesMapper(); + this.authorizationSuccessHandler = new SaveAuthorizedClientOAuth2AuthorizationSuccessHandler(authorizedClientService); + this.authorizationFailureHandler = new RemoveAuthorizedClientOAuth2AuthorizationFailureHandler(authorizedClientService); } @Nullable @@ -92,9 +130,16 @@ public final class AuthorizedClientServiceOAuth2AuthorizedClientManager implemen }) .build(); - authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); + try { + authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); + } catch (OAuth2AuthorizationException ex) { + this.authorizationFailureHandler.onAuthorizationFailure(ex, principal, Collections.emptyMap()); + throw ex; + } + if (authorizedClient != null) { - this.authorizedClientService.saveAuthorizedClient(authorizedClient, principal); + this.authorizationSuccessHandler.onAuthorizationSuccess( + authorizedClient, principal, Collections.emptyMap()); } else { // In the case of re-authorization, the returned `authorizedClient` may be null if re-authorization is not supported. // For these cases, return the provided `authorizationContext.authorizedClient`. @@ -128,6 +173,36 @@ public final class AuthorizedClientServiceOAuth2AuthorizedClientManager implemen this.contextAttributesMapper = contextAttributesMapper; } + /** + * Sets the {@link OAuth2AuthorizationSuccessHandler} that handles successful authorizations. + * + *

+ * A {@link SaveAuthorizedClientOAuth2AuthorizationSuccessHandler} is used by default. + * + * @param authorizationSuccessHandler the {@link OAuth2AuthorizationSuccessHandler} that handles successful authorizations + * @see SaveAuthorizedClientOAuth2AuthorizationSuccessHandler + * @since 5.3 + */ + public void setAuthorizationSuccessHandler(OAuth2AuthorizationSuccessHandler authorizationSuccessHandler) { + Assert.notNull(authorizationSuccessHandler, "authorizationSuccessHandler cannot be null"); + this.authorizationSuccessHandler = authorizationSuccessHandler; + } + + /** + * Sets the {@link OAuth2AuthorizationFailureHandler} that handles authorization failures. + * + *

+ * A {@link RemoveAuthorizedClientOAuth2AuthorizationFailureHandler} is used by default. + * + * @param authorizationFailureHandler the {@link OAuth2AuthorizationFailureHandler} that handles authorization failures + * @see RemoveAuthorizedClientOAuth2AuthorizationFailureHandler + * @since 5.3 + */ + public void setAuthorizationFailureHandler(OAuth2AuthorizationFailureHandler authorizationFailureHandler) { + Assert.notNull(authorizationFailureHandler, "authorizationFailureHandler cannot be null"); + this.authorizationFailureHandler = authorizationFailureHandler; + } + /** * The default implementation of the {@link #setContextAttributesMapper(Function) contextAttributesMapper}. */ diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationFailureHandler.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationFailureHandler.java new file mode 100644 index 0000000000..c24141d8b4 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationFailureHandler.java @@ -0,0 +1,48 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; + +import java.util.Map; + +/** + * Handles when an OAuth 2.0 Client fails to authorize (or re-authorize) + * via the Authorization Server or Resource Server. + * + * @author Joe Grandja + * @since 5.3 + * @see OAuth2AuthorizedClient + * @see OAuth2AuthorizedClientManager + */ +@FunctionalInterface +public interface OAuth2AuthorizationFailureHandler { + + /** + * Called when an OAuth 2.0 Client fails to authorize (or re-authorize) + * via the Authorization Server or Resource Server. + * + * @param authorizationException the exception that contains details about what failed + * @param principal the {@code Principal} associated with the attempted authorization + * @param attributes an immutable {@code Map} of (optional) attributes present under certain conditions. + * For example, this might contain a {@code javax.servlet.http.HttpServletRequest} + * and {@code javax.servlet.http.HttpServletResponse} if the authorization was performed + * within the context of a {@code javax.servlet.ServletContext}. + */ + void onAuthorizationFailure(OAuth2AuthorizationException authorizationException, + Authentication principal, Map attributes); +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationSuccessHandler.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationSuccessHandler.java new file mode 100644 index 0000000000..b350924ab5 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationSuccessHandler.java @@ -0,0 +1,47 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.springframework.security.core.Authentication; + +import java.util.Map; + +/** + * Handles when an OAuth 2.0 Client has been successfully + * authorized (or re-authorized) via the Authorization Server. + * + * @author Joe Grandja + * @since 5.3 + * @see OAuth2AuthorizedClient + * @see OAuth2AuthorizedClientManager + */ +@FunctionalInterface +public interface OAuth2AuthorizationSuccessHandler { + + /** + * Called when an OAuth 2.0 Client has been successfully + * authorized (or re-authorized) via the Authorization Server. + * + * @param authorizedClient the client that was successfully authorized (or re-authorized) + * @param principal the {@code Principal} associated with the authorized client + * @param attributes an immutable {@code Map} of (optional) attributes present under certain conditions. + * For example, this might contain a {@code javax.servlet.http.HttpServletRequest} + * and {@code javax.servlet.http.HttpServletResponse} if the authorization was performed + * within the context of a {@code javax.servlet.ServletContext}. + */ + void onAuthorizationSuccess(OAuth2AuthorizedClient authorizedClient, + Authentication principal, Map attributes); +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultAuthorizationCodeTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultAuthorizationCodeTokenResponseClient.java index 174dc75e43..a8fd73626e 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultAuthorizationCodeTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultAuthorizationCodeTokenResponseClient.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,9 +20,9 @@ import org.springframework.http.RequestEntity; import org.springframework.http.ResponseEntity; import org.springframework.http.converter.FormHttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.security.oauth2.client.ClientAuthorizationException; import org.springframework.security.oauth2.client.http.OAuth2ErrorResponseErrorHandler; import org.springframework.security.oauth2.core.AuthorizationGrantType; -import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter; @@ -30,6 +30,7 @@ import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClientException; +import org.springframework.web.client.RestClientResponseException; import org.springframework.web.client.RestOperations; import org.springframework.web.client.RestTemplate; @@ -74,9 +75,22 @@ public final class DefaultAuthorizationCodeTokenResponseClient implements OAuth2 try { response = this.restOperations.exchange(request, OAuth2AccessTokenResponse.class); } catch (RestClientException ex) { - OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE, - "An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + ex.getMessage(), null); - throw new OAuth2AuthorizationException(oauth2Error, ex); + int statusCode = 500; + if (ex instanceof RestClientResponseException) { + statusCode = ((RestClientResponseException) ex).getRawStatusCode(); + } + OAuth2Error oauth2Error = new OAuth2Error( + INVALID_TOKEN_RESPONSE_ERROR_CODE, + "An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + ex.getMessage(), + null); + String message = String.format("Error retrieving OAuth 2.0 Access Token (HTTP Status Code: %s) %s", + statusCode, + oauth2Error); + throw new ClientAuthorizationException( + oauth2Error, + authorizationCodeGrantRequest.getClientRegistration().getRegistrationId(), + message, + ex); } OAuth2AccessTokenResponse tokenResponse = response.getBody(); diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClient.java index fdd5eb1e75..0137a6c269 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClient.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,9 +20,9 @@ import org.springframework.http.RequestEntity; import org.springframework.http.ResponseEntity; import org.springframework.http.converter.FormHttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.security.oauth2.client.ClientAuthorizationException; import org.springframework.security.oauth2.client.http.OAuth2ErrorResponseErrorHandler; import org.springframework.security.oauth2.core.AuthorizationGrantType; -import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter; @@ -30,6 +30,7 @@ import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClientException; +import org.springframework.web.client.RestClientResponseException; import org.springframework.web.client.RestOperations; import org.springframework.web.client.RestTemplate; @@ -74,9 +75,22 @@ public final class DefaultClientCredentialsTokenResponseClient implements OAuth2 try { response = this.restOperations.exchange(request, OAuth2AccessTokenResponse.class); } catch (RestClientException ex) { - OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE, - "An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + ex.getMessage(), null); - throw new OAuth2AuthorizationException(oauth2Error, ex); + int statusCode = 500; + if (ex instanceof RestClientResponseException) { + statusCode = ((RestClientResponseException) ex).getRawStatusCode(); + } + OAuth2Error oauth2Error = new OAuth2Error( + INVALID_TOKEN_RESPONSE_ERROR_CODE, + "An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + ex.getMessage(), + null); + String message = String.format("Error retrieving OAuth 2.0 Access Token (HTTP Status Code: %s) %s", + statusCode, + oauth2Error); + throw new ClientAuthorizationException( + oauth2Error, + clientCredentialsGrantRequest.getClientRegistration().getRegistrationId(), + message, + ex); } OAuth2AccessTokenResponse tokenResponse = response.getBody(); diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultPasswordTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultPasswordTokenResponseClient.java index e2f7180d2e..11a78f51ff 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultPasswordTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultPasswordTokenResponseClient.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,9 +20,9 @@ import org.springframework.http.RequestEntity; import org.springframework.http.ResponseEntity; import org.springframework.http.converter.FormHttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.security.oauth2.client.ClientAuthorizationException; import org.springframework.security.oauth2.client.http.OAuth2ErrorResponseErrorHandler; import org.springframework.security.oauth2.core.AuthorizationGrantType; -import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter; @@ -30,6 +30,7 @@ import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClientException; +import org.springframework.web.client.RestClientResponseException; import org.springframework.web.client.RestOperations; import org.springframework.web.client.RestTemplate; @@ -74,9 +75,22 @@ public final class DefaultPasswordTokenResponseClient implements OAuth2AccessTok try { response = this.restOperations.exchange(request, OAuth2AccessTokenResponse.class); } catch (RestClientException ex) { - OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE, - "An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + ex.getMessage(), null); - throw new OAuth2AuthorizationException(oauth2Error, ex); + int statusCode = 500; + if (ex instanceof RestClientResponseException) { + statusCode = ((RestClientResponseException) ex).getRawStatusCode(); + } + OAuth2Error oauth2Error = new OAuth2Error( + INVALID_TOKEN_RESPONSE_ERROR_CODE, + "An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + ex.getMessage(), + null); + String message = String.format("Error retrieving OAuth 2.0 Access Token (HTTP Status Code: %s) %s", + statusCode, + oauth2Error); + throw new ClientAuthorizationException( + oauth2Error, + passwordGrantRequest.getClientRegistration().getRegistrationId(), + message, + ex); } OAuth2AccessTokenResponse tokenResponse = response.getBody(); diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClient.java index 0efd37d8eb..4c0dd961ab 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClient.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,9 +20,9 @@ import org.springframework.http.RequestEntity; import org.springframework.http.ResponseEntity; import org.springframework.http.converter.FormHttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.security.oauth2.client.ClientAuthorizationException; import org.springframework.security.oauth2.client.http.OAuth2ErrorResponseErrorHandler; import org.springframework.security.oauth2.core.AuthorizationGrantType; -import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter; @@ -30,6 +30,7 @@ import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClientException; +import org.springframework.web.client.RestClientResponseException; import org.springframework.web.client.RestOperations; import org.springframework.web.client.RestTemplate; @@ -73,9 +74,22 @@ public final class DefaultRefreshTokenTokenResponseClient implements OAuth2Acces try { response = this.restOperations.exchange(request, OAuth2AccessTokenResponse.class); } catch (RestClientException ex) { - OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE, - "An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + ex.getMessage(), null); - throw new OAuth2AuthorizationException(oauth2Error, ex); + int statusCode = 500; + if (ex instanceof RestClientResponseException) { + statusCode = ((RestClientResponseException) ex).getRawStatusCode(); + } + OAuth2Error oauth2Error = new OAuth2Error( + INVALID_TOKEN_RESPONSE_ERROR_CODE, + "An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + ex.getMessage(), + null); + String message = String.format("Error retrieving OAuth 2.0 Access Token (HTTP Status Code: %s) %s", + statusCode, + oauth2Error); + throw new ClientAuthorizationException( + oauth2Error, + refreshTokenGrantRequest.getClientRegistration().getRegistrationId(), + message, + ex); } OAuth2AccessTokenResponse tokenResponse = response.getBody(); diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusAuthorizationCodeTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusAuthorizationCodeTokenResponseClient.java index cac276aa67..83d63c53d9 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusAuthorizationCodeTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusAuthorizationCodeTokenResponseClient.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,10 +31,10 @@ import com.nimbusds.oauth2.sdk.auth.Secret; import com.nimbusds.oauth2.sdk.http.HTTPRequest; import com.nimbusds.oauth2.sdk.id.ClientID; import org.springframework.http.MediaType; +import org.springframework.security.oauth2.client.ClientAuthorizationException; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AccessToken; -import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; @@ -100,9 +100,19 @@ public class NimbusAuthorizationCodeTokenResponseClient implements OAuth2AccessT httpRequest.setReadTimeout(30000); tokenResponse = com.nimbusds.oauth2.sdk.TokenResponse.parse(httpRequest.send()); } catch (ParseException | IOException ex) { - OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE, - "An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + ex.getMessage(), null); - throw new OAuth2AuthorizationException(oauth2Error, ex); + int statusCode = 500; + OAuth2Error oauth2Error = new OAuth2Error( + INVALID_TOKEN_RESPONSE_ERROR_CODE, + "An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + ex.getMessage(), + null); + String message = String.format("Error retrieving OAuth 2.0 Access Token (HTTP Status Code: %s) %s", + statusCode, + oauth2Error); + throw new ClientAuthorizationException( + oauth2Error, + clientRegistration.getRegistrationId(), + message, + ex); } if (!tokenResponse.indicatesSuccess()) { @@ -117,7 +127,7 @@ public class NimbusAuthorizationCodeTokenResponseClient implements OAuth2AccessT errorObject.getDescription(), errorObject.getURI() != null ? errorObject.getURI().toString() : null); } - throw new OAuth2AuthorizationException(oauth2Error); + throw new ClientAuthorizationException(oauth2Error, clientRegistration.getRegistrationId()); } AccessTokenResponse accessTokenResponse = (AccessTokenResponse) tokenResponse; diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java index e9d875a693..ad00e66a86 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java @@ -15,22 +15,20 @@ */ package org.springframework.security.oauth2.client.web; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; -import java.util.function.Function; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; - import org.springframework.lang.Nullable; import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.AuthorizedClientServiceOAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.OAuth2AuthorizationContext; +import org.springframework.security.oauth2.client.OAuth2AuthorizationFailureHandler; +import org.springframework.security.oauth2.client.OAuth2AuthorizationSuccessHandler; import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -39,19 +37,57 @@ import org.springframework.web.context.request.RequestAttributes; import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.ServletRequestAttributes; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Function; + /** - * The default implementation of an {@link OAuth2AuthorizedClientManager}. + * The default implementation of an {@link OAuth2AuthorizedClientManager} + * for use within the context of a {@code HttpServletRequest}. + * + *

+ * (When operating outside of the context of a {@code HttpServletRequest}, + * use {@link AuthorizedClientServiceOAuth2AuthorizedClientManager} instead.) + * + *

Authorized Client Persistence

+ * + *

+ * This manager utilizes an {@link OAuth2AuthorizedClientRepository} + * to persist {@link OAuth2AuthorizedClient}s. + * + *

+ * By default, when an authorization attempt succeeds, the {@link OAuth2AuthorizedClient} + * will be saved in the {@link OAuth2AuthorizedClientRepository}. + * This functionality can be changed by configuring a custom {@link OAuth2AuthorizationSuccessHandler} + * via {@link #setAuthorizationSuccessHandler(OAuth2AuthorizationSuccessHandler)}. + * + *

+ * By default, when an authorization attempt fails due to an + * {@value OAuth2ErrorCodes#INVALID_GRANT} error, + * the previously saved {@link OAuth2AuthorizedClient} + * will be removed from the {@link OAuth2AuthorizedClientRepository}. + * (The {@value OAuth2ErrorCodes#INVALID_GRANT} error can occur + * when a refresh token that is no longer valid is used to retrieve a new access token.) + * This functionality can be changed by configuring a custom {@link OAuth2AuthorizationFailureHandler} + * via {@link #setAuthorizationFailureHandler(OAuth2AuthorizationFailureHandler)}. * * @author Joe Grandja * @since 5.2 * @see OAuth2AuthorizedClientManager * @see OAuth2AuthorizedClientProvider + * @see OAuth2AuthorizationSuccessHandler + * @see OAuth2AuthorizationFailureHandler */ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2AuthorizedClientManager { private final ClientRegistrationRepository clientRegistrationRepository; private final OAuth2AuthorizedClientRepository authorizedClientRepository; private OAuth2AuthorizedClientProvider authorizedClientProvider = context -> null; - private Function> contextAttributesMapper = new DefaultContextAttributesMapper(); + private Function> contextAttributesMapper; + private OAuth2AuthorizationSuccessHandler authorizationSuccessHandler; + private OAuth2AuthorizationFailureHandler authorizationFailureHandler; /** * Constructs a {@code DefaultOAuth2AuthorizedClientManager} using the provided parameters. @@ -65,6 +101,9 @@ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2Authori Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); this.clientRegistrationRepository = clientRegistrationRepository; this.authorizedClientRepository = authorizedClientRepository; + this.contextAttributesMapper = new DefaultContextAttributesMapper(); + this.authorizationSuccessHandler = new SaveAuthorizedClientOAuth2AuthorizationSuccessHandler(authorizedClientRepository); + this.authorizationFailureHandler = new RemoveAuthorizedClientOAuth2AuthorizationFailureHandler(authorizedClientRepository); } @Nullable @@ -105,9 +144,17 @@ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2Authori }) .build(); - authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); + try { + authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); + } catch (OAuth2AuthorizationException ex) { + this.authorizationFailureHandler.onAuthorizationFailure( + ex, principal, createAttributes(servletRequest, servletResponse)); + throw ex; + } + if (authorizedClient != null) { - this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, servletRequest, servletResponse); + this.authorizationSuccessHandler.onAuthorizationSuccess( + authorizedClient, principal, createAttributes(servletRequest, servletResponse)); } else { // In the case of re-authorization, the returned `authorizedClient` may be null if re-authorization is not supported. // For these cases, return the provided `authorizationContext.authorizedClient`. @@ -119,12 +166,19 @@ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2Authori return authorizedClient; } + private static Map createAttributes(HttpServletRequest servletRequest, HttpServletResponse servletResponse) { + Map attributes = new HashMap<>(); + attributes.put(HttpServletRequest.class.getName(), servletRequest); + attributes.put(HttpServletResponse.class.getName(), servletResponse); + return attributes; + } + private static HttpServletRequest getHttpServletRequestOrDefault(Map attributes) { HttpServletRequest servletRequest = (HttpServletRequest) attributes.get(HttpServletRequest.class.getName()); if (servletRequest == null) { - RequestAttributes context = RequestContextHolder.getRequestAttributes(); - if (context instanceof ServletRequestAttributes) { - servletRequest = ((ServletRequestAttributes) context).getRequest(); + RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes(); + if (requestAttributes instanceof ServletRequestAttributes) { + servletRequest = ((ServletRequestAttributes) requestAttributes).getRequest(); } } return servletRequest; @@ -133,9 +187,9 @@ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2Authori private static HttpServletResponse getHttpServletResponseOrDefault(Map attributes) { HttpServletResponse servletResponse = (HttpServletResponse) attributes.get(HttpServletResponse.class.getName()); if (servletResponse == null) { - RequestAttributes context = RequestContextHolder.getRequestAttributes(); - if (context instanceof ServletRequestAttributes) { - servletResponse = ((ServletRequestAttributes) context).getResponse(); + RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes(); + if (requestAttributes instanceof ServletRequestAttributes) { + servletResponse = ((ServletRequestAttributes) requestAttributes).getResponse(); } } return servletResponse; @@ -163,6 +217,36 @@ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2Authori this.contextAttributesMapper = contextAttributesMapper; } + /** + * Sets the {@link OAuth2AuthorizationSuccessHandler} that handles successful authorizations. + * + *

+ * A {@link SaveAuthorizedClientOAuth2AuthorizationSuccessHandler} is used by default. + * + * @param authorizationSuccessHandler the {@link OAuth2AuthorizationSuccessHandler} that handles successful authorizations + * @see SaveAuthorizedClientOAuth2AuthorizationSuccessHandler + * @since 5.3 + */ + public void setAuthorizationSuccessHandler(OAuth2AuthorizationSuccessHandler authorizationSuccessHandler) { + Assert.notNull(authorizationSuccessHandler, "authorizationSuccessHandler cannot be null"); + this.authorizationSuccessHandler = authorizationSuccessHandler; + } + + /** + * Sets the {@link OAuth2AuthorizationFailureHandler} that handles authorization failures. + * + *

+ * A {@link RemoveAuthorizedClientOAuth2AuthorizationFailureHandler} is used by default. + * + * @param authorizationFailureHandler the {@link OAuth2AuthorizationFailureHandler} that handles authorization failures + * @see RemoveAuthorizedClientOAuth2AuthorizationFailureHandler + * @since 5.3 + */ + public void setAuthorizationFailureHandler(OAuth2AuthorizationFailureHandler authorizationFailureHandler) { + Assert.notNull(authorizationFailureHandler, "authorizationFailureHandler cannot be null"); + this.authorizationFailureHandler = authorizationFailureHandler; + } + /** * The default implementation of the {@link #setContextAttributesMapper(Function) contextAttributesMapper}. */ diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/RemoveAuthorizedClientOAuth2AuthorizationFailureHandler.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/RemoveAuthorizedClientOAuth2AuthorizationFailureHandler.java new file mode 100644 index 0000000000..f2c424b9dd --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/RemoveAuthorizedClientOAuth2AuthorizationFailureHandler.java @@ -0,0 +1,169 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.web; + +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.ClientAuthorizationException; +import org.springframework.security.oauth2.client.OAuth2AuthorizationFailureHandler; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.util.Assert; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +/** + * An {@link OAuth2AuthorizationFailureHandler} that removes an {@link OAuth2AuthorizedClient} + * from an {@link OAuth2AuthorizedClientRepository} or {@link OAuth2AuthorizedClientService} + * for a specific set of OAuth 2.0 error codes. + * + * @author Joe Grandja + * @since 5.3 + * @see OAuth2AuthorizedClient + * @see OAuth2AuthorizedClientRepository + * @see OAuth2AuthorizedClientService + */ +public class RemoveAuthorizedClientOAuth2AuthorizationFailureHandler implements OAuth2AuthorizationFailureHandler { + + /** + * The default OAuth 2.0 error codes that will trigger removal of an {@link OAuth2AuthorizedClient}. + * @see OAuth2ErrorCodes + */ + public static final Set DEFAULT_REMOVE_AUTHORIZED_CLIENT_ERROR_CODES = Collections.unmodifiableSet(new HashSet<>(Arrays.asList( + /* + * Returned from Resource Servers when an access token provided is expired, revoked, + * malformed, or invalid for other reasons. + * + * Note that this is needed because ServletOAuth2AuthorizedClientExchangeFilterFunction + * delegates this type of failure received from a Resource Server + * to this failure handler. + */ + OAuth2ErrorCodes.INVALID_TOKEN, + + /* + * Returned from Authorization Servers when the authorization grant or refresh token is invalid, expired, revoked, + * does not match the redirection URI used in the authorization request, or was issued to another client. + */ + OAuth2ErrorCodes.INVALID_GRANT + ))); + + /** + * The OAuth 2.0 error codes which will trigger removal of an {@link OAuth2AuthorizedClient}. + * @see OAuth2ErrorCodes + */ + private final Set removeAuthorizedClientErrorCodes; + + /** + * A delegate that removes an {@link OAuth2AuthorizedClient} from a + * {@link OAuth2AuthorizedClientRepository} or {@link OAuth2AuthorizedClientService} + * if the error code is one of the {@link #removeAuthorizedClientErrorCodes}. + */ + private final OAuth2AuthorizedClientRemover delegate; + + @FunctionalInterface + private interface OAuth2AuthorizedClientRemover { + void removeAuthorizedClient(String clientRegistrationId, Authentication principal, Map attributes); + } + + /** + * Constructs a {@code RemoveAuthorizedClientOAuth2AuthorizationFailureHandler} using the provided parameters. + * + * @param authorizedClientRepository the repository from which authorized clients will be removed + * if the error code is one of the {@link #DEFAULT_REMOVE_AUTHORIZED_CLIENT_ERROR_CODES}. + */ + public RemoveAuthorizedClientOAuth2AuthorizationFailureHandler(OAuth2AuthorizedClientRepository authorizedClientRepository) { + this(authorizedClientRepository, DEFAULT_REMOVE_AUTHORIZED_CLIENT_ERROR_CODES); + } + + /** + * Constructs a {@code RemoveAuthorizedClientOAuth2AuthorizationFailureHandler} using the provided parameters. + * + * @param authorizedClientRepository the repository from which authorized clients will be removed + * if the error code is one of the {@link #removeAuthorizedClientErrorCodes}. + * @param removeAuthorizedClientErrorCodes the OAuth 2.0 error codes which will trigger removal of an authorized client. + * @see OAuth2ErrorCodes + */ + public RemoveAuthorizedClientOAuth2AuthorizationFailureHandler( + OAuth2AuthorizedClientRepository authorizedClientRepository, + Set removeAuthorizedClientErrorCodes) { + Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); + Assert.notNull(removeAuthorizedClientErrorCodes, "removeAuthorizedClientErrorCodes cannot be null"); + this.removeAuthorizedClientErrorCodes = Collections.unmodifiableSet(new HashSet<>(removeAuthorizedClientErrorCodes)); + this.delegate = (clientRegistrationId, principal, attributes) -> + authorizedClientRepository.removeAuthorizedClient(clientRegistrationId, principal, + (HttpServletRequest) attributes.get(HttpServletRequest.class.getName()), + (HttpServletResponse) attributes.get(HttpServletResponse.class.getName())); + } + + /** + * Constructs a {@code RemoveAuthorizedClientOAuth2AuthorizationFailureHandler} using the provided parameters. + * + * @param authorizedClientService the service from which authorized clients will be removed + * if the error code is one of the {@link #DEFAULT_REMOVE_AUTHORIZED_CLIENT_ERROR_CODES}. + */ + public RemoveAuthorizedClientOAuth2AuthorizationFailureHandler(OAuth2AuthorizedClientService authorizedClientService) { + this(authorizedClientService, DEFAULT_REMOVE_AUTHORIZED_CLIENT_ERROR_CODES); + } + + /** + * Constructs a {@code RemoveAuthorizedClientOAuth2AuthorizationFailureHandler} using the provided parameters. + * + * @param authorizedClientService the service from which authorized clients will be removed + * if the error code is one of the {@link #removeAuthorizedClientErrorCodes}. + * @param removeAuthorizedClientErrorCodes the OAuth 2.0 error codes which will trigger removal of an authorized client. + * @see OAuth2ErrorCodes + */ + public RemoveAuthorizedClientOAuth2AuthorizationFailureHandler( + OAuth2AuthorizedClientService authorizedClientService, + Set removeAuthorizedClientErrorCodes) { + Assert.notNull(authorizedClientService, "authorizedClientService cannot be null"); + Assert.notNull(removeAuthorizedClientErrorCodes, "removeAuthorizedClientErrorCodes cannot be null"); + this.removeAuthorizedClientErrorCodes = Collections.unmodifiableSet(new HashSet<>(removeAuthorizedClientErrorCodes)); + this.delegate = (clientRegistrationId, principal, attributes) -> + authorizedClientService.removeAuthorizedClient(clientRegistrationId, principal.getName()); + } + + @Override + public void onAuthorizationFailure(OAuth2AuthorizationException authorizationException, + Authentication principal, Map attributes) { + + if (authorizationException instanceof ClientAuthorizationException && + hasRemovalErrorCode(authorizationException)) { + ClientAuthorizationException clientAuthorizationException = (ClientAuthorizationException) authorizationException; + this.delegate.removeAuthorizedClient( + clientAuthorizationException.getClientRegistrationId(), principal, attributes); + } + } + + /** + * Returns true if the given exception has an error code that + * indicates that the authorized client should be removed. + * + * @param authorizationException the exception that caused the authorization failure + * @return true if the given exception has an error code that + * indicates that the authorized client should be removed. + */ + private boolean hasRemovalErrorCode(OAuth2AuthorizationException authorizationException) { + return this.removeAuthorizedClientErrorCodes.contains(authorizationException.getError().getErrorCode()); + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/SaveAuthorizedClientOAuth2AuthorizationSuccessHandler.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/SaveAuthorizedClientOAuth2AuthorizationSuccessHandler.java new file mode 100644 index 0000000000..91647f568e --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/SaveAuthorizedClientOAuth2AuthorizationSuccessHandler.java @@ -0,0 +1,77 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.web; + +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.OAuth2AuthorizationSuccessHandler; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService; +import org.springframework.util.Assert; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.util.Map; + +/** + * An {@link OAuth2AuthorizationSuccessHandler} that saves an {@link OAuth2AuthorizedClient} + * in an {@link OAuth2AuthorizedClientRepository} or {@link OAuth2AuthorizedClientService}. + * + * @author Joe Grandja + * @since 5.3 + * @see OAuth2AuthorizedClient + * @see OAuth2AuthorizedClientRepository + * @see OAuth2AuthorizedClientService + */ +public class SaveAuthorizedClientOAuth2AuthorizationSuccessHandler implements OAuth2AuthorizationSuccessHandler { + + /** + * A delegate that saves an {@link OAuth2AuthorizedClient} in an + * {@link OAuth2AuthorizedClientRepository} or {@link OAuth2AuthorizedClientService}. + */ + private final OAuth2AuthorizationSuccessHandler delegate; + + /** + * Constructs a {@code SaveAuthorizedClientOAuth2AuthorizationSuccessHandler} using the provided parameters. + * + * @param authorizedClientRepository The repository in which authorized clients will be saved. + */ + public SaveAuthorizedClientOAuth2AuthorizationSuccessHandler( + OAuth2AuthorizedClientRepository authorizedClientRepository) { + Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); + this.delegate = (authorizedClient, principal, attributes) -> + authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, + (HttpServletRequest) attributes.get(HttpServletRequest.class.getName()), + (HttpServletResponse) attributes.get(HttpServletResponse.class.getName())); + } + + /** + * Constructs a {@code SaveAuthorizedClientOAuth2AuthorizationSuccessHandler} using the provided parameters. + * + * @param authorizedClientService The service in which authorized clients will be saved. + */ + public SaveAuthorizedClientOAuth2AuthorizationSuccessHandler( + OAuth2AuthorizedClientService authorizedClientService) { + Assert.notNull(authorizedClientService, "authorizedClientService cannot be null"); + this.delegate = (authorizedClient, principal, attributes) -> + authorizedClientService.saveAuthorizedClient(authorizedClient, principal); + } + + @Override + public void onAuthorizationSuccess(OAuth2AuthorizedClient authorizedClient, + Authentication principal, Map attributes) { + this.delegate.onAuthorizationSuccess(authorizedClient, principal, attributes); + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java index 22d488c965..d6b6371a87 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java @@ -13,15 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.springframework.security.oauth2.client.web.reactive.function.client; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.oauth2.client.ClientAuthorizationException; import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.OAuth2AuthorizationFailureHandler; import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager; @@ -35,7 +38,13 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.client.web.RemoveAuthorizedClientOAuth2AuthorizationFailureHandler; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.util.Assert; +import org.springframework.util.StringUtils; import org.springframework.web.context.request.RequestAttributes; import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.ServletRequestAttributes; @@ -44,6 +53,7 @@ import org.springframework.web.reactive.function.client.ClientResponse; import org.springframework.web.reactive.function.client.ExchangeFilterFunction; import org.springframework.web.reactive.function.client.ExchangeFunction; import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.client.WebClientResponseException; import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; import reactor.util.context.Context; @@ -52,18 +62,25 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import java.time.Duration; import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; import java.util.Map; import java.util.function.Consumer; +import java.util.stream.Collectors; +import java.util.stream.Stream; /** - * Provides an easy mechanism for using an {@link OAuth2AuthorizedClient} to make OAuth2 requests by including the - * token as a Bearer Token. It also provides mechanisms for looking up the {@link OAuth2AuthorizedClient}. This class is - * intended to be used in a servlet environment. + * Provides an easy mechanism for using an {@link OAuth2AuthorizedClient} to make OAuth 2.0 requests + * by including the {@link OAuth2AuthorizedClient#getAccessToken() access token} as a bearer token. * + *

+ * NOTE:This class is intended to be used in a {@code Servlet} environment. + * + *

* Example usage: * *

- * ServletOAuth2AuthorizedClientExchangeFilterFunction oauth2 = new ServletOAuth2AuthorizedClientExchangeFilterFunction(clientRegistrationRepository, authorizedClientRepository);
+ * ServletOAuth2AuthorizedClientExchangeFilterFunction oauth2 = new ServletOAuth2AuthorizedClientExchangeFilterFunction(authorizedClientManager);
  * WebClient webClient = WebClient.builder()
  *    .apply(oauth2.oauth2Configuration())
  *    .build();
@@ -76,23 +93,35 @@ import java.util.function.Consumer;
  *    .bodyToMono(String.class);
  * 
* - * An attempt to automatically refresh the token will be made if all of the following - * are true: + *

Authentication and Authorization Failures

* - * + *

+ * Since 5.3, this filter function has the ability to forward authentication (HTTP 401 Unauthorized) + * and authorization (HTTP 403 Forbidden) failures from an OAuth 2.0 Resource Server + * to a {@link OAuth2AuthorizationFailureHandler}. + * A {@link RemoveAuthorizedClientOAuth2AuthorizationFailureHandler} can be used + * to remove the cached {@link OAuth2AuthorizedClient}, so that future requests will result + * in a new token being retrieved from an Authorization Server, and sent to the Resource Server. + * + *

+ * If the {@link #ServletOAuth2AuthorizedClientExchangeFilterFunction(ClientRegistrationRepository, OAuth2AuthorizedClientRepository)} + * constructor is used, a {@link RemoveAuthorizedClientOAuth2AuthorizationFailureHandler} + * will be configured automatically. + * + *

+ * If the {@link #ServletOAuth2AuthorizedClientExchangeFilterFunction(OAuth2AuthorizedClientManager)} + * constructor is used, a {@link RemoveAuthorizedClientOAuth2AuthorizationFailureHandler} + * will NOT be configured automatically. + * It is recommended that you configure one via {@link #setAuthorizationFailureHandler(OAuth2AuthorizationFailureHandler)}. * * @author Rob Winch * @author Joe Grandja * @author Roman Matiushchenko * @since 5.1 * @see OAuth2AuthorizedClientManager + * @see DefaultOAuth2AuthorizedClientManager + * @see OAuth2AuthorizedClientProvider + * @see OAuth2AuthorizedClientProviderBuilder */ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implements ExchangeFilterFunction { @@ -103,6 +132,7 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement * The request attribute name used to locate the {@link OAuth2AuthorizedClient}. */ private static final String OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME = OAuth2AuthorizedClient.class.getName(); + private static final String CLIENT_REGISTRATION_ID_ATTR_NAME = OAuth2AuthorizedClient.class.getName().concat(".CLIENT_REGISTRATION_ID"); private static final String AUTHENTICATION_ATTR_NAME = Authentication.class.getName(); private static final String HTTP_SERVLET_REQUEST_ATTR_NAME = HttpServletRequest.class.getName(); @@ -125,35 +155,75 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement private String defaultClientRegistrationId; + private ClientResponseHandler clientResponseHandler; + + @FunctionalInterface + private interface ClientResponseHandler { + Mono handleResponse(ClientRequest request, Mono response); + } + public ServletOAuth2AuthorizedClientExchangeFilterFunction() { } /** * Constructs a {@code ServletOAuth2AuthorizedClientExchangeFilterFunction} using the provided parameters. * + *

+ * When this constructor is used, authentication (HTTP 401) and authorization (HTTP 403) + * failures returned from an OAuth 2.0 Resource Server will NOT be forwarded to an + * {@link OAuth2AuthorizationFailureHandler}. + * Therefore, future requests to the Resource Server will most likely use the same (likely invalid) token, + * resulting in the same errors returned from the Resource Server. + * It is recommended to configure a {@link RemoveAuthorizedClientOAuth2AuthorizationFailureHandler} + * via {@link #setAuthorizationFailureHandler(OAuth2AuthorizationFailureHandler)} + * so that authentication and authorization failures returned from a Resource Server + * will result in removing the authorized client, so that a new token is retrieved for future requests. + * * @since 5.2 * @param authorizedClientManager the {@link OAuth2AuthorizedClientManager} which manages the authorized client(s) */ public ServletOAuth2AuthorizedClientExchangeFilterFunction(OAuth2AuthorizedClientManager authorizedClientManager) { Assert.notNull(authorizedClientManager, "authorizedClientManager cannot be null"); this.authorizedClientManager = authorizedClientManager; + this.clientResponseHandler = (request, responseMono) -> responseMono; } /** * Constructs a {@code ServletOAuth2AuthorizedClientExchangeFilterFunction} using the provided parameters. * + *

+ * Since 5.3, when this constructor is used, authentication (HTTP 401) + * and authorization (HTTP 403) failures returned from an OAuth 2.0 Resource Server + * will be forwarded to a {@link RemoveAuthorizedClientOAuth2AuthorizationFailureHandler}, + * which will potentially remove the {@link OAuth2AuthorizedClient} from the given + * {@link OAuth2AuthorizedClientRepository}, depending on the OAuth 2.0 error code returned. + * Authentication failures returned from an OAuth 2.0 Resource Server typically indicate + * that the token is invalid, and should not be used in future requests. + * Removing the authorized client from the repository will ensure that the existing + * token will not be sent for future requests to the Resource Server, + * and a new token is retrieved from the Authorization Server and used for + * future requests to the Resource Server. + * * @param clientRegistrationRepository the repository of client registrations * @param authorizedClientRepository the repository of authorized clients */ public ServletOAuth2AuthorizedClientExchangeFilterFunction( ClientRegistrationRepository clientRegistrationRepository, OAuth2AuthorizedClientRepository authorizedClientRepository) { - this.authorizedClientManager = createDefaultAuthorizedClientManager(clientRegistrationRepository, authorizedClientRepository); + + OAuth2AuthorizationFailureHandler authorizationFailureHandler = + new RemoveAuthorizedClientOAuth2AuthorizationFailureHandler(authorizedClientRepository); + + this.authorizedClientManager = createDefaultAuthorizedClientManager( + clientRegistrationRepository, authorizedClientRepository, authorizationFailureHandler); this.defaultAuthorizedClientManager = true; + this.clientResponseHandler = new AuthorizationFailureForwarder(authorizationFailureHandler); } private static OAuth2AuthorizedClientManager createDefaultAuthorizedClientManager( - ClientRegistrationRepository clientRegistrationRepository, OAuth2AuthorizedClientRepository authorizedClientRepository) { + ClientRegistrationRepository clientRegistrationRepository, + OAuth2AuthorizedClientRepository authorizedClientRepository, + OAuth2AuthorizationFailureHandler authorizationFailureHandler) { OAuth2AuthorizedClientProvider authorizedClientProvider = OAuth2AuthorizedClientProviderBuilder.builder() @@ -165,6 +235,7 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement DefaultOAuth2AuthorizedClientManager authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( clientRegistrationRepository, authorizedClientRepository); authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); + authorizedClientManager.setAuthorizationFailureHandler(authorizationFailureHandler); return authorizedClientManager; } @@ -333,19 +404,47 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement updateDefaultAuthorizedClientManager(); } + /** + * Sets the {@link OAuth2AuthorizationFailureHandler} that handles + * authentication and authorization failures when communicating + * to the OAuth 2.0 Resource Server. + * + *

+ * For example, a {@link RemoveAuthorizedClientOAuth2AuthorizationFailureHandler} + * is typically used to remove the cached {@link OAuth2AuthorizedClient}, + * so that the same token is no longer used in future requests to the Resource Server. + * + *

+ * The failure handler used by default depends on which constructor was used + * to construct this {@link ServletOAuth2AuthorizedClientExchangeFilterFunction}. + * See the constructors for more details. + * + * @param authorizationFailureHandler the {@link OAuth2AuthorizationFailureHandler} that handles authentication and authorization failures + * @since 5.3 + */ + public void setAuthorizationFailureHandler(OAuth2AuthorizationFailureHandler authorizationFailureHandler) { + Assert.notNull(authorizationFailureHandler, "authorizationFailureHandler cannot be null"); + this.clientResponseHandler = new AuthorizationFailureForwarder(authorizationFailureHandler); + } + @Override public Mono filter(ClientRequest request, ExchangeFunction next) { return mergeRequestAttributesIfNecessary(request) .filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent()) - .flatMap(req -> authorizedClient(getOAuth2AuthorizedClient(req.attributes()), req)) + .flatMap(req -> reauthorizeClient(getOAuth2AuthorizedClient(req.attributes()), req)) .switchIfEmpty(Mono.defer(() -> mergeRequestAttributesIfNecessary(request) .filter(req -> resolveClientRegistrationId(req) != null) .flatMap(req -> authorizeClient(resolveClientRegistrationId(req), req)) )) .map(authorizedClient -> bearer(request, authorizedClient)) - .flatMap(next::exchange) - .switchIfEmpty(Mono.defer(() -> next.exchange(request))); + .flatMap(requestWithBearer -> exchangeAndHandleResponse(requestWithBearer, next)) + .switchIfEmpty(Mono.defer(() -> exchangeAndHandleResponse(request, next))); + } + + private Mono exchangeAndHandleResponse(ClientRequest request, ExchangeFunction next) { + return next.exchange(request) + .transform(responseMono -> this.clientResponseHandler.handleResponse(request, responseMono)); } private Mono mergeRequestAttributesIfNecessary(ClientRequest request) { @@ -443,13 +542,14 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement }); OAuth2AuthorizeRequest authorizeRequest = builder.build(); - // NOTE: 'authorizedClientManager.authorize()' needs to be executed on a dedicated thread via subscribeOn(Schedulers.boundedElastic()) - // NOTE: 'authorizedClientManager.authorize()' needs to be executed on a dedicated thread via subscribeOn(Schedulers.boundedElastic()) + // NOTE: + // 'authorizedClientManager.authorize()' needs to be executed + // on a dedicated thread via subscribeOn(Schedulers.boundedElastic()) // since it performs a blocking I/O operation using RestTemplate internally return Mono.fromSupplier(() -> this.authorizedClientManager.authorize(authorizeRequest)).subscribeOn(Schedulers.boundedElastic()); } - private Mono authorizedClient(OAuth2AuthorizedClient authorizedClient, ClientRequest request) { + private Mono reauthorizeClient(OAuth2AuthorizedClient authorizedClient, ClientRequest request) { if (this.authorizedClientManager == null) { return Mono.just(authorizedClient); } @@ -472,7 +572,9 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement }); OAuth2AuthorizeRequest reauthorizeRequest = builder.build(); - // NOTE: 'authorizedClientManager.authorize()' needs to be executed on a dedicated thread via subscribeOn(Schedulers.boundedElastic()) + // NOTE: + // 'authorizedClientManager.authorize()' needs to be executed + // on a dedicated thread via subscribeOn(Schedulers.boundedElastic()) // since it performs a blocking I/O operation using RestTemplate internally return Mono.fromSupplier(() -> this.authorizedClientManager.authorize(reauthorizeRequest)).subscribeOn(Schedulers.boundedElastic()); } @@ -480,6 +582,7 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement private ClientRequest bearer(ClientRequest request, OAuth2AuthorizedClient authorizedClient) { return ClientRequest.from(request) .headers(headers -> headers.setBearerAuth(authorizedClient.getAccessToken().getTokenValue())) + .attributes(oauth2AuthorizedClient(authorizedClient)) .build(); } @@ -550,4 +653,183 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement return new UnsupportedOperationException("Not Supported"); } } + + /** + * Forwards authentication and authorization failures to an + * {@link OAuth2AuthorizationFailureHandler}. + * + * @since 5.3 + */ + private static class AuthorizationFailureForwarder implements ClientResponseHandler { + + /** + * A map of HTTP status code to OAuth 2.0 error code for + * HTTP status codes that should be interpreted as + * authentication or authorization failures. + */ + private final Map httpStatusToOAuth2ErrorCodeMap; + + /** + * The {@link OAuth2AuthorizationFailureHandler} to notify + * when an authentication/authorization failure occurs. + */ + private final OAuth2AuthorizationFailureHandler authorizationFailureHandler; + + private AuthorizationFailureForwarder(OAuth2AuthorizationFailureHandler authorizationFailureHandler) { + Assert.notNull(authorizationFailureHandler, "authorizationFailureHandler cannot be null"); + this.authorizationFailureHandler = authorizationFailureHandler; + + Map httpStatusToOAuth2Error = new HashMap<>(); + httpStatusToOAuth2Error.put(HttpStatus.UNAUTHORIZED.value(), OAuth2ErrorCodes.INVALID_TOKEN); + httpStatusToOAuth2Error.put(HttpStatus.FORBIDDEN.value(), OAuth2ErrorCodes.INSUFFICIENT_SCOPE); + this.httpStatusToOAuth2ErrorCodeMap = Collections.unmodifiableMap(httpStatusToOAuth2Error); + } + + @Override + public Mono handleResponse(ClientRequest request, Mono responseMono) { + return responseMono + .flatMap(response -> handleResponse(request, response) + .thenReturn(response)) + .onErrorResume(WebClientResponseException.class, e -> handleWebClientResponseException(request, e) + .then(Mono.error(e))) + .onErrorResume(OAuth2AuthorizationException.class, e -> handleAuthorizationException(request, e) + .then(Mono.error(e))); + } + + private Mono handleResponse(ClientRequest request, ClientResponse response) { + return Mono.justOrEmpty(resolveErrorIfPossible(response)) + .flatMap(oauth2Error -> { + Map attrs = request.attributes(); + OAuth2AuthorizedClient authorizedClient = getOAuth2AuthorizedClient(attrs); + if (authorizedClient == null) { + return Mono.empty(); + } + + ClientAuthorizationException authorizationException = new ClientAuthorizationException( + oauth2Error, authorizedClient.getClientRegistration().getRegistrationId()); + + Authentication principal = new PrincipalNameAuthentication(authorizedClient.getPrincipalName()); + HttpServletRequest servletRequest = getRequest(attrs); + HttpServletResponse servletResponse = getResponse(attrs); + + return handleAuthorizationFailure(authorizationException, principal, servletRequest, servletResponse); + }); + } + + private OAuth2Error resolveErrorIfPossible(ClientResponse response) { + // Try to resolve from 'WWW-Authenticate' header + if (!response.headers().header(HttpHeaders.WWW_AUTHENTICATE).isEmpty()) { + String wwwAuthenticateHeader = response.headers().header(HttpHeaders.WWW_AUTHENTICATE).get(0); + Map authParameters = parseAuthParameters(wwwAuthenticateHeader); + if (authParameters.containsKey(OAuth2ParameterNames.ERROR)) { + return new OAuth2Error( + authParameters.get(OAuth2ParameterNames.ERROR), + authParameters.get(OAuth2ParameterNames.ERROR_DESCRIPTION), + authParameters.get(OAuth2ParameterNames.ERROR_URI)); + } + } + return resolveErrorIfPossible(response.rawStatusCode()); + } + + private OAuth2Error resolveErrorIfPossible(int statusCode) { + if (this.httpStatusToOAuth2ErrorCodeMap.containsKey(statusCode)) { + return new OAuth2Error( + this.httpStatusToOAuth2ErrorCodeMap.get(statusCode), + null, + "https://tools.ietf.org/html/rfc6750#section-3.1"); + } + return null; + } + + private Map parseAuthParameters(String wwwAuthenticateHeader) { + return Stream.of(wwwAuthenticateHeader) + .filter(header -> !StringUtils.isEmpty(header)) + .filter(header -> header.toLowerCase().startsWith("bearer")) + .map(header -> header.substring("bearer".length())) + .map(header -> header.split(",")) + .flatMap(Stream::of) + .map(parameter -> parameter.split("=")) + .filter(parameter -> parameter.length > 1) + .collect(Collectors.toMap( + parameters -> parameters[0].trim(), + parameters -> parameters[1].trim().replace("\"", ""))); + } + + /** + * Handles the given http status code returned from a resource server + * by notifying the authorization failure handler if the http status + * code is in the {@link #httpStatusToOAuth2ErrorCodeMap}. + * + * @param request the request being processed + * @param exception The root cause exception for the failure + * @return a {@link Mono} that completes empty after the authorization failure handler completes + */ + private Mono handleWebClientResponseException(ClientRequest request, WebClientResponseException exception) { + return Mono.justOrEmpty(resolveErrorIfPossible(exception.getRawStatusCode())) + .flatMap(oauth2Error -> { + Map attrs = request.attributes(); + OAuth2AuthorizedClient authorizedClient = getOAuth2AuthorizedClient(attrs); + if (authorizedClient == null) { + return Mono.empty(); + } + + ClientAuthorizationException authorizationException = new ClientAuthorizationException( + oauth2Error, authorizedClient.getClientRegistration().getRegistrationId(), exception); + + Authentication principal = new PrincipalNameAuthentication(authorizedClient.getPrincipalName()); + HttpServletRequest servletRequest = getRequest(attrs); + HttpServletResponse servletResponse = getResponse(attrs); + + return handleAuthorizationFailure(authorizationException, principal, servletRequest, servletResponse); + }); + } + + /** + * Handles the given {@link OAuth2AuthorizationException} that occurred downstream + * by notifying the authorization failure handler. + * + * @param request the request being processed + * @param authorizationException the authorization exception to include in the failure event + * @return a {@link Mono} that completes empty after the authorization failure handler completes + */ + private Mono handleAuthorizationException(ClientRequest request, OAuth2AuthorizationException authorizationException) { + return Mono.justOrEmpty(request) + .flatMap(req -> { + Map attrs = req.attributes(); + OAuth2AuthorizedClient authorizedClient = getOAuth2AuthorizedClient(attrs); + if (authorizedClient == null) { + return Mono.empty(); + } + + Authentication principal = new PrincipalNameAuthentication(authorizedClient.getPrincipalName()); + HttpServletRequest servletRequest = getRequest(attrs); + HttpServletResponse servletResponse = getResponse(attrs); + + return handleAuthorizationFailure(authorizationException, principal, servletRequest, servletResponse); + }); + } + + /** + * Delegates the failed authorization to the {@link OAuth2AuthorizationFailureHandler}. + * + * @param exception the {@link OAuth2AuthorizationException} to include in the failure event + * @param principal the principal associated with the failed authorization attempt + * @param servletRequest the currently active {@code HttpServletRequest} + * @param servletResponse the currently active {@code HttpServletResponse} + * @return a {@link Mono} that completes empty after the {@link OAuth2AuthorizationFailureHandler} completes + */ + private Mono handleAuthorizationFailure(OAuth2AuthorizationException exception, + Authentication principal, HttpServletRequest servletRequest, HttpServletResponse servletResponse) { + Runnable runnable = () -> this.authorizationFailureHandler.onAuthorizationFailure( + exception, principal, createAttributes(servletRequest, servletResponse)); + return Mono.fromRunnable(runnable).subscribeOn(Schedulers.boundedElastic()).then(); + } + + private static Map createAttributes(HttpServletRequest servletRequest, HttpServletResponse servletResponse) { + Map attributes = new HashMap<>(); + attributes.put(HttpServletRequest.class.getName(), servletRequest); + attributes.put(HttpServletResponse.class.getName(), servletResponse); + return attributes; + } + } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizedClientServiceOAuth2AuthorizedClientManagerTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizedClientServiceOAuth2AuthorizedClientManagerTests.java index 6d70e5e880..31830190df 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizedClientServiceOAuth2AuthorizedClientManagerTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizedClientServiceOAuth2AuthorizedClientManagerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,6 +23,10 @@ import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.client.web.RemoveAuthorizedClientOAuth2AuthorizationFailureHandler; +import org.springframework.security.oauth2.client.web.SaveAuthorizedClientOAuth2AuthorizationSuccessHandler; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; @@ -30,10 +34,16 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import java.util.function.Function; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; /** * Tests for {@link AuthorizedClientServiceOAuth2AuthorizedClientManager}. @@ -45,6 +55,8 @@ public class AuthorizedClientServiceOAuth2AuthorizedClientManagerTests { private OAuth2AuthorizedClientService authorizedClientService; private OAuth2AuthorizedClientProvider authorizedClientProvider; private Function contextAttributesMapper; + private OAuth2AuthorizationSuccessHandler authorizationSuccessHandler; + private OAuth2AuthorizationFailureHandler authorizationFailureHandler; private AuthorizedClientServiceOAuth2AuthorizedClientManager authorizedClientManager; private ClientRegistration clientRegistration; private Authentication principal; @@ -58,10 +70,14 @@ public class AuthorizedClientServiceOAuth2AuthorizedClientManagerTests { this.authorizedClientService = mock(OAuth2AuthorizedClientService.class); this.authorizedClientProvider = mock(OAuth2AuthorizedClientProvider.class); this.contextAttributesMapper = mock(Function.class); + this.authorizationSuccessHandler = spy(new SaveAuthorizedClientOAuth2AuthorizationSuccessHandler(this.authorizedClientService)); + this.authorizationFailureHandler = spy(new RemoveAuthorizedClientOAuth2AuthorizationFailureHandler(this.authorizedClientService)); this.authorizedClientManager = new AuthorizedClientServiceOAuth2AuthorizedClientManager( this.clientRegistrationRepository, this.authorizedClientService); this.authorizedClientManager.setAuthorizedClientProvider(this.authorizedClientProvider); this.authorizedClientManager.setContextAttributesMapper(this.contextAttributesMapper); + this.authorizedClientManager.setAuthorizationSuccessHandler(this.authorizationSuccessHandler); + this.authorizedClientManager.setAuthorizationFailureHandler(this.authorizationFailureHandler); this.clientRegistration = TestClientRegistrations.clientRegistration().build(); this.principal = new TestingAuthenticationToken("principal", "password"); this.authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.principal.getName(), @@ -97,6 +113,20 @@ public class AuthorizedClientServiceOAuth2AuthorizedClientManagerTests { .hasMessage("contextAttributesMapper cannot be null"); } + @Test + public void setAuthorizationSuccessHandlerWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizationSuccessHandler(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizationSuccessHandler cannot be null"); + } + + @Test + public void setAuthorizationFailureHandlerWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizationFailureHandler(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizationFailureHandler cannot be null"); + } + @Test public void authorizeWhenRequestIsNullThenThrowIllegalArgumentException() { assertThatThrownBy(() -> this.authorizedClientManager.authorize(null)) @@ -134,8 +164,8 @@ public class AuthorizedClientServiceOAuth2AuthorizedClientManagerTests { assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); assertThat(authorizedClient).isNull(); - verify(this.authorizedClientService, never()).saveAuthorizedClient( - any(OAuth2AuthorizedClient.class), eq(this.principal)); + verifyNoInteractions(this.authorizationSuccessHandler); + verify(this.authorizedClientService, never()).saveAuthorizedClient(any(), any()); } @SuppressWarnings("unchecked") @@ -160,6 +190,8 @@ public class AuthorizedClientServiceOAuth2AuthorizedClientManagerTests { assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); assertThat(authorizedClient).isSameAs(this.authorizedClient); + verify(this.authorizationSuccessHandler).onAuthorizationSuccess( + eq(this.authorizedClient), eq(this.principal), any()); verify(this.authorizedClientService).saveAuthorizedClient( eq(this.authorizedClient), eq(this.principal)); } @@ -192,6 +224,8 @@ public class AuthorizedClientServiceOAuth2AuthorizedClientManagerTests { assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); assertThat(authorizedClient).isSameAs(reauthorizedClient); + verify(this.authorizationSuccessHandler).onAuthorizationSuccess( + eq(reauthorizedClient), eq(this.principal), any()); verify(this.authorizedClientService).saveAuthorizedClient( eq(reauthorizedClient), eq(this.principal)); } @@ -213,8 +247,8 @@ public class AuthorizedClientServiceOAuth2AuthorizedClientManagerTests { assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); assertThat(authorizedClient).isSameAs(this.authorizedClient); - verify(this.authorizedClientService, never()).saveAuthorizedClient( - any(OAuth2AuthorizedClient.class), eq(this.principal)); + verifyNoInteractions(this.authorizationSuccessHandler); + verify(this.authorizedClientService, never()).saveAuthorizedClient(any(), any()); } @SuppressWarnings("unchecked") @@ -240,6 +274,8 @@ public class AuthorizedClientServiceOAuth2AuthorizedClientManagerTests { assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); assertThat(authorizedClient).isSameAs(reauthorizedClient); + verify(this.authorizationSuccessHandler).onAuthorizationSuccess( + eq(reauthorizedClient), eq(this.principal), any()); verify(this.authorizedClientService).saveAuthorizedClient( eq(reauthorizedClient), eq(this.principal)); } @@ -274,7 +310,52 @@ public class AuthorizedClientServiceOAuth2AuthorizedClientManagerTests { assertThat(requestScopeAttribute).contains("read", "write"); assertThat(authorizedClient).isSameAs(reauthorizedClient); + verify(this.authorizationSuccessHandler).onAuthorizationSuccess( + eq(reauthorizedClient), eq(this.principal), any()); verify(this.authorizedClientService).saveAuthorizedClient( eq(reauthorizedClient), eq(this.principal)); } + + @Test + public void reauthorizeWhenErrorCodeMatchThenRemoveAuthorizedClient() { + ClientAuthorizationException authorizationException = new ClientAuthorizationException( + new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT, null, null), + this.clientRegistration.getRegistrationId()); + + when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) + .thenThrow(authorizationException); + + OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) + .principal(this.principal) + .build(); + + assertThatCode(() -> this.authorizedClientManager.authorize(reauthorizeRequest)) + .isEqualTo(authorizationException); + + verify(this.authorizationFailureHandler).onAuthorizationFailure( + eq(authorizationException), eq(this.principal), any()); + verify(this.authorizedClientService).removeAuthorizedClient( + eq(this.clientRegistration.getRegistrationId()), eq(this.principal.getName())); + } + + @Test + public void reauthorizeWhenErrorCodeDoesNotMatchThenDoNotRemoveAuthorizedClient() { + ClientAuthorizationException authorizationException = new ClientAuthorizationException( + new OAuth2Error("non-matching-error-code", null, null), + this.clientRegistration.getRegistrationId()); + + when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) + .thenThrow(authorizationException); + + OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) + .principal(this.principal) + .build(); + + assertThatCode(() -> this.authorizedClientManager.authorize(reauthorizeRequest)) + .isEqualTo(authorizationException); + + verify(this.authorizationFailureHandler).onAuthorizationFailure( + eq(authorizationException), eq(this.principal), any()); + verifyNoInteractions(this.authorizedClientService); + } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManagerTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManagerTests.java index 627e88ab7d..03b2cfc063 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManagerTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManagerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,13 +22,18 @@ import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.ClientAuthorizationException; import org.springframework.security.oauth2.client.OAuth2AuthorizationContext; +import org.springframework.security.oauth2.client.OAuth2AuthorizationFailureHandler; +import org.springframework.security.oauth2.client.OAuth2AuthorizationSuccessHandler; import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; @@ -41,8 +46,16 @@ import java.util.Map; import java.util.function.Function; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; /** * Tests for {@link DefaultOAuth2AuthorizedClientManager}. @@ -54,6 +67,8 @@ public class DefaultOAuth2AuthorizedClientManagerTests { private OAuth2AuthorizedClientRepository authorizedClientRepository; private OAuth2AuthorizedClientProvider authorizedClientProvider; private Function contextAttributesMapper; + private OAuth2AuthorizationSuccessHandler authorizationSuccessHandler; + private OAuth2AuthorizationFailureHandler authorizationFailureHandler; private DefaultOAuth2AuthorizedClientManager authorizedClientManager; private ClientRegistration clientRegistration; private Authentication principal; @@ -69,10 +84,14 @@ public class DefaultOAuth2AuthorizedClientManagerTests { this.authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class); this.authorizedClientProvider = mock(OAuth2AuthorizedClientProvider.class); this.contextAttributesMapper = mock(Function.class); + this.authorizationSuccessHandler = spy(new SaveAuthorizedClientOAuth2AuthorizationSuccessHandler(this.authorizedClientRepository)); + this.authorizationFailureHandler = spy(new RemoveAuthorizedClientOAuth2AuthorizationFailureHandler(this.authorizedClientRepository)); this.authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( this.clientRegistrationRepository, this.authorizedClientRepository); this.authorizedClientManager.setAuthorizedClientProvider(this.authorizedClientProvider); this.authorizedClientManager.setContextAttributesMapper(this.contextAttributesMapper); + this.authorizedClientManager.setAuthorizationSuccessHandler(this.authorizationSuccessHandler); + this.authorizedClientManager.setAuthorizationFailureHandler(this.authorizationFailureHandler); this.clientRegistration = TestClientRegistrations.clientRegistration().build(); this.principal = new TestingAuthenticationToken("principal", "password"); this.authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.principal.getName(), @@ -110,6 +129,20 @@ public class DefaultOAuth2AuthorizedClientManagerTests { .hasMessage("contextAttributesMapper cannot be null"); } + @Test + public void setAuthorizationSuccessHandlerWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizationSuccessHandler(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizationSuccessHandler cannot be null"); + } + + @Test + public void setAuthorizationFailureHandlerWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizationFailureHandler(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizationFailureHandler cannot be null"); + } + @Test public void authorizeWhenRequestIsNullThenThrowIllegalArgumentException() { assertThatThrownBy(() -> this.authorizedClientManager.authorize(null)) @@ -176,8 +209,8 @@ public class DefaultOAuth2AuthorizedClientManagerTests { assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); assertThat(authorizedClient).isNull(); - verify(this.authorizedClientRepository, never()).saveAuthorizedClient( - any(OAuth2AuthorizedClient.class), eq(this.principal), eq(this.request), eq(this.response)); + verifyNoInteractions(this.authorizationSuccessHandler); + verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any(), any()); } @SuppressWarnings("unchecked") @@ -206,6 +239,8 @@ public class DefaultOAuth2AuthorizedClientManagerTests { assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); assertThat(authorizedClient).isSameAs(this.authorizedClient); + verify(this.authorizationSuccessHandler).onAuthorizationSuccess( + eq(this.authorizedClient), eq(this.principal), any()); verify(this.authorizedClientRepository).saveAuthorizedClient( eq(this.authorizedClient), eq(this.principal), eq(this.request), eq(this.response)); } @@ -242,6 +277,8 @@ public class DefaultOAuth2AuthorizedClientManagerTests { assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); assertThat(authorizedClient).isSameAs(reauthorizedClient); + verify(this.authorizationSuccessHandler).onAuthorizationSuccess( + eq(reauthorizedClient), eq(this.principal), any()); verify(this.authorizedClientRepository).saveAuthorizedClient( eq(reauthorizedClient), eq(this.principal), eq(this.request), eq(this.response)); } @@ -308,6 +345,7 @@ public class DefaultOAuth2AuthorizedClientManagerTests { assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); assertThat(authorizedClient).isSameAs(this.authorizedClient); + verifyNoInteractions(this.authorizationSuccessHandler); verify(this.authorizedClientRepository, never()).saveAuthorizedClient( any(OAuth2AuthorizedClient.class), eq(this.principal), eq(this.request), eq(this.response)); } @@ -339,6 +377,8 @@ public class DefaultOAuth2AuthorizedClientManagerTests { assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); assertThat(authorizedClient).isSameAs(reauthorizedClient); + verify(this.authorizationSuccessHandler).onAuthorizationSuccess( + eq(reauthorizedClient), eq(this.principal), any()); verify(this.authorizedClientRepository).saveAuthorizedClient( eq(reauthorizedClient), eq(this.principal), eq(this.request), eq(this.response)); } @@ -372,4 +412,55 @@ public class DefaultOAuth2AuthorizedClientManagerTests { String[] requestScopeAttribute = authorizationContext.getAttribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME); assertThat(requestScopeAttribute).contains("read", "write"); } + + @Test + public void reauthorizeWhenErrorCodeMatchThenRemoveAuthorizedClient() { + ClientAuthorizationException authorizationException = new ClientAuthorizationException( + new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT, null, null), + this.clientRegistration.getRegistrationId()); + + when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) + .thenThrow(authorizationException); + + OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) + .principal(this.principal) + .attributes(attrs -> { + attrs.put(HttpServletRequest.class.getName(), this.request); + attrs.put(HttpServletResponse.class.getName(), this.response); + }) + .build(); + + assertThatCode(() -> this.authorizedClientManager.authorize(reauthorizeRequest)) + .isEqualTo(authorizationException); + + verify(this.authorizationFailureHandler).onAuthorizationFailure( + eq(authorizationException), eq(this.principal), any()); + verify(this.authorizedClientRepository).removeAuthorizedClient( + eq(this.clientRegistration.getRegistrationId()), eq(this.principal), eq(this.request), eq(this.response)); + } + + @Test + public void reauthorizeWhenErrorCodeDoesNotMatchThenDoNotRemoveAuthorizedClient() { + ClientAuthorizationException authorizationException = new ClientAuthorizationException( + new OAuth2Error("non-matching-error-code", null, null), + this.clientRegistration.getRegistrationId()); + + when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) + .thenThrow(authorizationException); + + OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) + .principal(this.principal) + .attributes(attrs -> { + attrs.put(HttpServletRequest.class.getName(), this.request); + attrs.put(HttpServletResponse.class.getName(), this.response); + }) + .build(); + + assertThatCode(() -> this.authorizedClientManager.authorize(reauthorizeRequest)) + .isEqualTo(authorizationException); + + verify(this.authorizationFailureHandler).onAuthorizationFailure( + eq(authorizationException), eq(this.principal), any()); + verifyNoInteractions(this.authorizedClientRepository); + } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java index d60031b043..2a0c9cf841 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,18 +15,6 @@ */ package org.springframework.security.oauth2.client.web.reactive.function.client; -import java.net.URI; -import java.time.Duration; -import java.time.Instant; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.function.Consumer; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; - import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -35,8 +23,6 @@ import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; -import reactor.util.context.Context; - import org.springframework.core.codec.ByteBufferEncoder; import org.springframework.core.codec.CharSequenceEncoder; import org.springframework.http.HttpHeaders; @@ -60,7 +46,9 @@ import org.springframework.security.core.Authentication; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.oauth2.client.ClientAuthorizationException; import org.springframework.security.oauth2.client.OAuth2AuthorizationContext; +import org.springframework.security.oauth2.client.OAuth2AuthorizationFailureHandler; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProviderBuilder; @@ -78,6 +66,9 @@ import org.springframework.security.oauth2.client.registration.TestClientRegistr import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; @@ -89,16 +80,37 @@ import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.ServletRequestAttributes; import org.springframework.web.reactive.function.BodyInserter; import org.springframework.web.reactive.function.client.ClientRequest; +import org.springframework.web.reactive.function.client.ClientResponse; +import org.springframework.web.reactive.function.client.ExchangeFunction; import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.client.WebClientResponseException; +import reactor.core.publisher.Mono; +import reactor.util.context.Context; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Consumer; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.entry; import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; import static org.mockito.Mockito.any; import static org.mockito.Mockito.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyZeroInteractions; +import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; import static org.springframework.http.HttpMethod.GET; import static org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId; @@ -128,6 +140,14 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { @Mock private OAuth2AccessTokenResponseClient passwordTokenResponseClient; @Mock + private OAuth2AuthorizationFailureHandler authorizationFailureHandler; + @Captor + private ArgumentCaptor authorizationExceptionCaptor; + @Captor + private ArgumentCaptor authenticationCaptor; + @Captor + private ArgumentCaptor> attributesCaptor; + @Mock private WebClient.RequestHeadersSpec spec; @Captor private ArgumentCaptor>> attrs; @@ -167,7 +187,7 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { this.authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( this.clientRegistrationRepository, this.authorizedClientRepository); this.authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(authorizedClientManager); + this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientManager); } @After @@ -233,7 +253,7 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { SecurityContextHolder.getContext().setAuthentication(this.authentication); Map attrs = getDefaultRequestAttributes(); assertThat(getAuthentication(attrs)).isEqualTo(this.authentication); - verifyZeroInteractions(this.authorizedClientRepository); + verifyNoInteractions(this.authorizedClientRepository); } private Map getDefaultRequestAttributes() { @@ -647,6 +667,215 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { assertThat(getBody(request)).isEmpty(); } + @Test + public void filterWhenUnauthorizedThenInvokeFailureHandler() { + assertHttpStatusInvokesFailureHandler(HttpStatus.UNAUTHORIZED, OAuth2ErrorCodes.INVALID_TOKEN); + } + + @Test + public void filterWhenForbiddenThenInvokeFailureHandler() { + assertHttpStatusInvokesFailureHandler(HttpStatus.FORBIDDEN, OAuth2ErrorCodes.INSUFFICIENT_SCOPE); + } + + private void assertHttpStatusInvokesFailureHandler(HttpStatus httpStatus, String expectedErrorCode) { + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + this.registration, "principalName", this.accessToken); + MockHttpServletRequest servletRequest = new MockHttpServletRequest(); + MockHttpServletResponse servletResponse = new MockHttpServletResponse(); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) + .attributes(oauth2AuthorizedClient(authorizedClient)) + .attributes(httpServletRequest(servletRequest)) + .attributes(httpServletResponse(servletResponse)) + .build(); + + when(this.exchange.getResponse().rawStatusCode()).thenReturn(httpStatus.value()); + when(this.exchange.getResponse().headers()).thenReturn(mock(ClientResponse.Headers.class)); + this.function.setAuthorizationFailureHandler(this.authorizationFailureHandler); + + this.function.filter(request, this.exchange).block(); + + verify(this.authorizationFailureHandler).onAuthorizationFailure( + this.authorizationExceptionCaptor.capture(), + this.authenticationCaptor.capture(), + this.attributesCaptor.capture()); + + assertThat(this.authorizationExceptionCaptor.getValue()) + .isInstanceOfSatisfying(ClientAuthorizationException.class, e -> { + assertThat(e.getClientRegistrationId()).isEqualTo(this.registration.getRegistrationId()); + assertThat(e.getError().getErrorCode()).isEqualTo(expectedErrorCode); + assertThat(e).hasNoCause(); + assertThat(e).hasMessageContaining(expectedErrorCode); + }); + assertThat(this.authenticationCaptor.getValue().getName()) + .isEqualTo(authorizedClient.getPrincipalName()); + assertThat(this.attributesCaptor.getValue()) + .containsExactly( + entry(HttpServletRequest.class.getName(), servletRequest), + entry(HttpServletResponse.class.getName(), servletResponse)); + } + + @Test + public void filterWhenWWWAuthenticateHeaderIncludesErrorThenInvokeFailureHandler() { + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + this.registration, "principalName", this.accessToken); + MockHttpServletRequest servletRequest = new MockHttpServletRequest(); + MockHttpServletResponse servletResponse = new MockHttpServletResponse(); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) + .attributes(oauth2AuthorizedClient(authorizedClient)) + .attributes(httpServletRequest(servletRequest)) + .attributes(httpServletResponse(servletResponse)) + .build(); + + String wwwAuthenticateHeader = "Bearer error=\"insufficient_scope\", " + + "error_description=\"The request requires higher privileges than provided by the access token.\", " + + "error_uri=\"https://tools.ietf.org/html/rfc6750#section-3.1\""; + ClientResponse.Headers headers = mock(ClientResponse.Headers.class); + when(headers.header(eq(HttpHeaders.WWW_AUTHENTICATE))) + .thenReturn(Collections.singletonList(wwwAuthenticateHeader)); + when(this.exchange.getResponse().headers()).thenReturn(headers); + this.function.setAuthorizationFailureHandler(this.authorizationFailureHandler); + + this.function.filter(request, this.exchange).block(); + + verify(this.authorizationFailureHandler).onAuthorizationFailure( + this.authorizationExceptionCaptor.capture(), + this.authenticationCaptor.capture(), + this.attributesCaptor.capture()); + + assertThat(this.authorizationExceptionCaptor.getValue()) + .isInstanceOfSatisfying(ClientAuthorizationException.class, e -> { + assertThat(e.getClientRegistrationId()).isEqualTo(this.registration.getRegistrationId()); + assertThat(e.getError().getErrorCode()).isEqualTo(OAuth2ErrorCodes.INSUFFICIENT_SCOPE); + assertThat(e.getError().getDescription()).isEqualTo("The request requires higher privileges than provided by the access token."); + assertThat(e.getError().getUri()).isEqualTo("https://tools.ietf.org/html/rfc6750#section-3.1"); + assertThat(e).hasNoCause(); + assertThat(e).hasMessageContaining(OAuth2ErrorCodes.INSUFFICIENT_SCOPE); + }); + assertThat(this.authenticationCaptor.getValue().getName()) + .isEqualTo(authorizedClient.getPrincipalName()); + assertThat(this.attributesCaptor.getValue()) + .containsExactly( + entry(HttpServletRequest.class.getName(), servletRequest), + entry(HttpServletResponse.class.getName(), servletResponse)); + } + + @Test + public void filterWhenUnauthorizedWithWebClientExceptionThenInvokeFailureHandler() { + assertHttpStatusWithWebClientExceptionInvokesFailureHandler( + HttpStatus.UNAUTHORIZED, OAuth2ErrorCodes.INVALID_TOKEN); + } + + @Test + public void filterWhenForbiddenWithWebClientExceptionThenInvokeFailureHandler() { + assertHttpStatusWithWebClientExceptionInvokesFailureHandler( + HttpStatus.FORBIDDEN, OAuth2ErrorCodes.INSUFFICIENT_SCOPE); + } + + private void assertHttpStatusWithWebClientExceptionInvokesFailureHandler( + HttpStatus httpStatus, String expectedErrorCode) { + + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + this.registration, "principalName", this.accessToken); + MockHttpServletRequest servletRequest = new MockHttpServletRequest(); + MockHttpServletResponse servletResponse = new MockHttpServletResponse(); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) + .attributes(oauth2AuthorizedClient(authorizedClient)) + .attributes(httpServletRequest(servletRequest)) + .attributes(httpServletResponse(servletResponse)) + .build(); + + WebClientResponseException exception = WebClientResponseException.create( + httpStatus.value(), + httpStatus.getReasonPhrase(), + HttpHeaders.EMPTY, + new byte[0], + StandardCharsets.UTF_8); + ExchangeFunction throwingExchangeFunction = r -> Mono.error(exception); + this.function.setAuthorizationFailureHandler(this.authorizationFailureHandler); + + assertThatCode(() -> this.function.filter(request, throwingExchangeFunction).block()) + .isEqualTo(exception); + + verify(this.authorizationFailureHandler).onAuthorizationFailure( + this.authorizationExceptionCaptor.capture(), + this.authenticationCaptor.capture(), + this.attributesCaptor.capture()); + + assertThat(this.authorizationExceptionCaptor.getValue()) + .isInstanceOfSatisfying(ClientAuthorizationException.class, e -> { + assertThat(e.getClientRegistrationId()).isEqualTo(this.registration.getRegistrationId()); + assertThat(e.getError().getErrorCode()).isEqualTo(expectedErrorCode); + assertThat(e).hasCause(exception); + assertThat(e).hasMessageContaining(expectedErrorCode); + }); + assertThat(this.authenticationCaptor.getValue().getName()) + .isEqualTo(authorizedClient.getPrincipalName()); + assertThat(this.attributesCaptor.getValue()) + .containsExactly( + entry(HttpServletRequest.class.getName(), servletRequest), + entry(HttpServletResponse.class.getName(), servletResponse)); + } + + @Test + public void filterWhenAuthorizationExceptionThenInvokeFailureHandler() { + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + this.registration, "principalName", this.accessToken); + MockHttpServletRequest servletRequest = new MockHttpServletRequest(); + MockHttpServletResponse servletResponse = new MockHttpServletResponse(); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) + .attributes(oauth2AuthorizedClient(authorizedClient)) + .attributes(httpServletRequest(servletRequest)) + .attributes(httpServletResponse(servletResponse)) + .build(); + + OAuth2AuthorizationException authorizationException = new OAuth2AuthorizationException( + new OAuth2Error(OAuth2ErrorCodes.INVALID_TOKEN)); + ExchangeFunction throwingExchangeFunction = r -> Mono.error(authorizationException); + this.function.setAuthorizationFailureHandler(this.authorizationFailureHandler); + + assertThatCode(() -> this.function.filter(request, throwingExchangeFunction).block()) + .isEqualTo(authorizationException); + + verify(this.authorizationFailureHandler).onAuthorizationFailure( + this.authorizationExceptionCaptor.capture(), + this.authenticationCaptor.capture(), + this.attributesCaptor.capture()); + + assertThat(this.authorizationExceptionCaptor.getValue()) + .isInstanceOfSatisfying(OAuth2AuthorizationException.class, e -> { + assertThat(e.getError().getErrorCode()).isEqualTo(authorizationException.getError().getErrorCode()); + assertThat(e).hasNoCause(); + assertThat(e).hasMessageContaining(OAuth2ErrorCodes.INVALID_TOKEN); + }); + assertThat(this.authenticationCaptor.getValue().getName()) + .isEqualTo(authorizedClient.getPrincipalName()); + assertThat(this.attributesCaptor.getValue()) + .containsExactly( + entry(HttpServletRequest.class.getName(), servletRequest), + entry(HttpServletResponse.class.getName(), servletResponse)); + } + + @Test + public void filterWhenOtherHttpStatusThenDoesNotInvokeFailureHandler() { + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + this.registration, "principalName", this.accessToken); + MockHttpServletRequest servletRequest = new MockHttpServletRequest(); + MockHttpServletResponse servletResponse = new MockHttpServletResponse(); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) + .attributes(oauth2AuthorizedClient(authorizedClient)) + .attributes(httpServletRequest(servletRequest)) + .attributes(httpServletResponse(servletResponse)) + .build(); + + when(this.exchange.getResponse().rawStatusCode()).thenReturn(HttpStatus.BAD_REQUEST.value()); + when(this.exchange.getResponse().headers()).thenReturn(mock(ClientResponse.Headers.class)); + this.function.setAuthorizationFailureHandler(this.authorizationFailureHandler); + + this.function.filter(request, this.exchange).block(); + + verifyNoInteractions(this.authorizationFailureHandler); + } + private Context context(HttpServletRequest servletRequest, HttpServletResponse servletResponse, Authentication authentication) { Map contextAttributes = new HashMap<>(); contextAttributes.put(HttpServletRequest.class, servletRequest); @@ -688,5 +917,4 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { request.body().insert(body, context).block(); return body.getBodyAsString().block(); } - }