From 7a715f908601a5ee72b7d62dc5eba4b0bfc47f2b Mon Sep 17 00:00:00 2001 From: Phillip Webb Date: Fri, 31 Jul 2020 19:24:40 -0700 Subject: [PATCH] Polish spring-security-oauth2-client main code Manually polish `spring-security-oauth-cleint` following the formatting and checkstyle fixes. Issue gh-8945 --- ...ionCodeOAuth2AuthorizedClientProvider.java | 1 - ...eactiveOAuth2AuthorizedClientProvider.java | 1 - ...tServiceOAuth2AuthorizedClientManager.java | 26 ++--- ...ReactiveOAuth2AuthorizedClientManager.java | 1 - ...entialsOAuth2AuthorizedClientProvider.java | 16 ++- ...eactiveOAuth2AuthorizedClientProvider.java | 8 +- .../JdbcOAuth2AuthorizedClientService.java | 18 ---- .../oauth2/client/OAuth2AuthorizeRequest.java | 3 +- ...asswordOAuth2AuthorizedClientProvider.java | 23 ++--- ...eactiveOAuth2AuthorizedClientProvider.java | 7 -- ...shTokenOAuth2AuthorizedClientProvider.java | 15 ++- ...eactiveOAuth2AuthorizedClientProvider.java | 4 - ...ientOAuth2AuthorizationFailureHandler.java | 38 +++---- ...tiveOAuth2AuthorizationFailureHandler.java | 40 +++----- ...thorizationCodeAuthenticationProvider.java | 6 -- ...tionCodeReactiveAuthenticationManager.java | 4 - .../OAuth2LoginAuthenticationProvider.java | 11 +-- .../OAuth2LoginAuthenticationToken.java | 1 - ...th2LoginReactiveAuthenticationManager.java | 8 +- ...tAuthorizationCodeTokenResponseClient.java | 30 +++--- ...tClientCredentialsTokenResponseClient.java | 30 +++--- .../DefaultPasswordTokenResponseClient.java | 30 +++--- ...efaultRefreshTokenTokenResponseClient.java | 51 +++++----- ...sAuthorizationCodeTokenResponseClient.java | 99 +++++++++---------- ...zationCodeGrantRequestEntityConverter.java | 4 - ...redentialsGrantRequestEntityConverter.java | 4 - ...h2PasswordGrantRequestEntityConverter.java | 4 - ...freshTokenGrantRequestEntityConverter.java | 4 - ...activeRefreshTokenTokenResponseClient.java | 2 - .../http/OAuth2ErrorResponseErrorHandler.java | 22 ++--- .../ClientRegistrationDeserializer.java | 6 +- .../oauth2/client/jackson2/JsonNodeUtils.java | 26 ++--- ...Auth2AuthorizationRequestDeserializer.java | 48 ++++----- .../oauth2/client/jackson2/StdConverters.java | 14 +-- ...thorizationCodeAuthenticationProvider.java | 91 ++++++++--------- ...tionCodeReactiveAuthenticationManager.java | 28 +++--- .../OidcIdTokenDecoderFactory.java | 49 +++++---- .../authentication/OidcIdTokenValidator.java | 13 --- .../ReactiveOidcIdTokenDecoderFactory.java | 49 +++++---- .../OidcReactiveOAuth2UserService.java | 6 +- .../client/oidc/userinfo/OidcUserRequest.java | 2 - .../oidc/userinfo/OidcUserRequestUtils.java | 4 - .../client/oidc/userinfo/OidcUserService.java | 51 ++++------ ...dcClientInitiatedLogoutSuccessHandler.java | 20 ++-- ...ntInitiatedServerLogoutSuccessHandler.java | 10 +- .../registration/ClientRegistration.java | 37 ++++--- .../registration/ClientRegistrations.java | 26 ++--- .../InMemoryClientRegistrationRepository.java | 5 +- .../CustomUserTypesOAuth2UserService.java | 14 +-- .../userinfo/DefaultOAuth2UserService.java | 26 +++-- .../DefaultReactiveOAuth2UserService.java | 56 +++++------ .../OAuth2UserRequestEntityConverter.java | 15 +-- ...cipalOAuth2AuthorizedClientRepository.java | 5 +- ...ultOAuth2AuthorizationRequestResolver.java | 60 ++++++----- .../DefaultOAuth2AuthorizedClientManager.java | 6 -- ...ReactiveOAuth2AuthorizedClientManager.java | 3 - .../OAuth2AuthorizationCodeGrantFilter.java | 15 --- ...th2AuthorizationRequestRedirectFilter.java | 22 ++--- .../web/OAuth2AuthorizationResponseUtils.java | 11 +-- .../web/OAuth2LoginAuthenticationFilter.java | 8 -- ...Auth2AuthorizedClientArgumentResolver.java | 21 ++-- ...uthorizedClientExchangeFilterFunction.java | 94 ++++++++---------- ...uthorizedClientExchangeFilterFunction.java | 75 ++++++-------- ...Auth2AuthorizedClientArgumentResolver.java | 10 +- ...erverOAuth2AuthorizedClientRepository.java | 15 +-- ...verOAuth2AuthorizationRequestResolver.java | 40 ++++---- ...OAuth2AuthorizationCodeGrantWebFilter.java | 5 +- ...AuthorizationRequestRedirectWebFilter.java | 5 +- .../OAuth2AuthorizationResponseUtils.java | 11 +-- ...erverOAuth2AuthorizedClientRepository.java | 1 - ...2ServerAuthorizationRequestRepository.java | 3 - 71 files changed, 603 insertions(+), 914 deletions(-) diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProvider.java index d99b1a7384..2911fd90de 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProvider.java @@ -46,7 +46,6 @@ public final class AuthorizationCodeOAuth2AuthorizedClientProvider implements OA @Nullable public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { Assert.notNull(context, "context cannot be null"); - if (AuthorizationGrantType.AUTHORIZATION_CODE.equals( context.getClientRegistration().getAuthorizationGrantType()) && context.getAuthorizedClient() == null) { // ClientAuthorizationRequiredException is caught by diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizationCodeReactiveOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizationCodeReactiveOAuth2AuthorizedClientProvider.java index acff2900ba..ab15fe304a 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizationCodeReactiveOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizationCodeReactiveOAuth2AuthorizedClientProvider.java @@ -47,7 +47,6 @@ public final class AuthorizationCodeReactiveOAuth2AuthorizedClientProvider @Override public Mono authorize(OAuth2AuthorizationContext context) { Assert.notNull(context, "context cannot be null"); - if (AuthorizationGrantType.AUTHORIZATION_CODE.equals( context.getClientRegistration().getAuthorizationGrantType()) && context.getAuthorizedClient() == null) { // ClientAuthorizationRequiredException is caught by diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizedClientServiceOAuth2AuthorizedClientManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizedClientServiceOAuth2AuthorizedClientManager.java index e2a4d6a753..2d8490a2da 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizedClientServiceOAuth2AuthorizedClientManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizedClientServiceOAuth2AuthorizedClientManager.java @@ -115,11 +115,9 @@ public final class AuthorizedClientServiceOAuth2AuthorizedClientManager implemen @Override public OAuth2AuthorizedClient authorize(OAuth2AuthorizeRequest authorizeRequest) { Assert.notNull(authorizeRequest, "authorizeRequest cannot be null"); - String clientRegistrationId = authorizeRequest.getClientRegistrationId(); OAuth2AuthorizedClient authorizedClient = authorizeRequest.getAuthorizedClient(); Authentication principal = authorizeRequest.getPrincipal(); - OAuth2AuthorizationContext.Builder contextBuilder; if (authorizedClient != null) { contextBuilder = OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient); @@ -138,14 +136,8 @@ public final class AuthorizedClientServiceOAuth2AuthorizedClientManager implemen contextBuilder = OAuth2AuthorizationContext.withClientRegistration(clientRegistration); } } - OAuth2AuthorizationContext authorizationContext = contextBuilder.principal(principal) - .attributes((attributes) -> { - Map contextAttributes = this.contextAttributesMapper.apply(authorizeRequest); - if (!CollectionUtils.isEmpty(contextAttributes)) { - attributes.putAll(contextAttributes); - } - }).build(); - + OAuth2AuthorizationContext authorizationContext = buildAuthorizationContext(authorizeRequest, principal, + contextBuilder); try { authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); } @@ -153,7 +145,6 @@ public final class AuthorizedClientServiceOAuth2AuthorizedClientManager implemen this.authorizationFailureHandler.onAuthorizationFailure(ex, principal, Collections.emptyMap()); throw ex; } - if (authorizedClient != null) { this.authorizationSuccessHandler.onAuthorizationSuccess(authorizedClient, principal, Collections.emptyMap()); @@ -167,10 +158,21 @@ public final class AuthorizedClientServiceOAuth2AuthorizedClientManager implemen return authorizationContext.getAuthorizedClient(); } } - return authorizedClient; } + private OAuth2AuthorizationContext buildAuthorizationContext(OAuth2AuthorizeRequest authorizeRequest, + Authentication principal, OAuth2AuthorizationContext.Builder contextBuilder) { + OAuth2AuthorizationContext authorizationContext = contextBuilder.principal(principal) + .attributes((attributes) -> { + Map contextAttributes = this.contextAttributesMapper.apply(authorizeRequest); + if (!CollectionUtils.isEmpty(contextAttributes)) { + attributes.putAll(contextAttributes); + } + }).build(); + return authorizationContext; + } + /** * Sets the {@link OAuth2AuthorizedClientProvider} used for authorizing (or * re-authorizing) an OAuth 2.0 Client. diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager.java index c1a6193e6f..2724d5b5e7 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager.java @@ -120,7 +120,6 @@ public final class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager @Override public Mono authorize(OAuth2AuthorizeRequest authorizeRequest) { Assert.notNull(authorizeRequest, "authorizeRequest cannot be null"); - return createAuthorizationContext(authorizeRequest) .flatMap((authorizationContext) -> authorize(authorizationContext, authorizeRequest.getPrincipal())); } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProvider.java index c6186fb34e..527b7bfd9f 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProvider.java @@ -64,39 +64,37 @@ public final class ClientCredentialsOAuth2AuthorizedClientProvider implements OA @Nullable public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { Assert.notNull(context, "context cannot be null"); - ClientRegistration clientRegistration = context.getClientRegistration(); if (!AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) { return null; } - OAuth2AuthorizedClient authorizedClient = context.getAuthorizedClient(); if (authorizedClient != null && !hasTokenExpired(authorizedClient.getAccessToken())) { // If client is already authorized but access token is NOT expired than no // need for re-authorization return null; } - // As per spec, in section 4.4.3 Access Token Response // https://tools.ietf.org/html/rfc6749#section-4.4.3 // A refresh token SHOULD NOT be included. // // Therefore, renewing an expired access token (re-authorization) // is the same as acquiring a new access token (authorization). - OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( clientRegistration); + OAuth2AccessTokenResponse tokenResponse = getTokenResponse(clientRegistration, clientCredentialsGrantRequest); + return new OAuth2AuthorizedClient(clientRegistration, context.getPrincipal().getName(), + tokenResponse.getAccessToken()); + } - OAuth2AccessTokenResponse tokenResponse; + private OAuth2AccessTokenResponse getTokenResponse(ClientRegistration clientRegistration, + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest) { try { - tokenResponse = this.accessTokenResponseClient.getTokenResponse(clientCredentialsGrantRequest); + return this.accessTokenResponseClient.getTokenResponse(clientCredentialsGrantRequest); } catch (OAuth2AuthorizationException ex) { throw new ClientAuthorizationException(ex.getError(), clientRegistration.getRegistrationId(), ex); } - - return new OAuth2AuthorizedClient(clientRegistration, context.getPrincipal().getName(), - tokenResponse.getAccessToken()); } private boolean hasTokenExpired(AbstractOAuth2Token token) { diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsReactiveOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsReactiveOAuth2AuthorizedClientProvider.java index 9a22665efa..e8ec38f7cd 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsReactiveOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsReactiveOAuth2AuthorizedClientProvider.java @@ -64,31 +64,27 @@ public final class ClientCredentialsReactiveOAuth2AuthorizedClientProvider @Override public Mono authorize(OAuth2AuthorizationContext context) { Assert.notNull(context, "context cannot be null"); - ClientRegistration clientRegistration = context.getClientRegistration(); if (!AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) { return Mono.empty(); } - OAuth2AuthorizedClient authorizedClient = context.getAuthorizedClient(); if (authorizedClient != null && !hasTokenExpired(authorizedClient.getAccessToken())) { // If client is already authorized but access token is NOT expired than no // need for re-authorization return Mono.empty(); } - // As per spec, in section 4.4.3 Access Token Response // https://tools.ietf.org/html/rfc6749#section-4.4.3 // A refresh token SHOULD NOT be included. // // Therefore, renewing an expired access token (re-authorization) // is the same as acquiring a new access token (authorization). - return Mono.just(new OAuth2ClientCredentialsGrantRequest(clientRegistration)) .flatMap(this.accessTokenResponseClient::getTokenResponse) .onErrorMap(OAuth2AuthorizationException.class, - (e) -> new ClientAuthorizationException(e.getError(), clientRegistration.getRegistrationId(), - e)) + (ex) -> new ClientAuthorizationException(ex.getError(), clientRegistration.getRegistrationId(), + ex)) .map((tokenResponse) -> new OAuth2AuthorizedClient(clientRegistration, context.getPrincipal().getName(), tokenResponse.getAccessToken())); } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/JdbcOAuth2AuthorizedClientService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/JdbcOAuth2AuthorizedClientService.java index 0927195e79..a31c655c4c 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/JdbcOAuth2AuthorizedClientService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/JdbcOAuth2AuthorizedClientService.java @@ -99,7 +99,6 @@ public class JdbcOAuth2AuthorizedClientService implements OAuth2AuthorizedClient */ public JdbcOAuth2AuthorizedClientService(JdbcOperations jdbcOperations, ClientRegistrationRepository clientRegistrationRepository) { - Assert.notNull(jdbcOperations, "jdbcOperations cannot be null"); Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); this.jdbcOperations = jdbcOperations; @@ -113,15 +112,12 @@ public class JdbcOAuth2AuthorizedClientService implements OAuth2AuthorizedClient String principalName) { Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); Assert.hasText(principalName, "principalName cannot be empty"); - SqlParameterValue[] parameters = new SqlParameterValue[] { new SqlParameterValue(Types.VARCHAR, clientRegistrationId), new SqlParameterValue(Types.VARCHAR, principalName) }; PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters); - List result = this.jdbcOperations.query(LOAD_AUTHORIZED_CLIENT_SQL, pss, this.authorizedClientRowMapper); - return !result.isEmpty() ? (T) result.get(0) : null; } @@ -129,10 +125,8 @@ public class JdbcOAuth2AuthorizedClientService implements OAuth2AuthorizedClient public void saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal) { Assert.notNull(authorizedClient, "authorizedClient cannot be null"); Assert.notNull(principal, "principal cannot be null"); - boolean existsAuthorizedClient = null != this.loadAuthorizedClient( authorizedClient.getClientRegistration().getRegistrationId(), principal.getName()); - if (existsAuthorizedClient) { updateAuthorizedClient(authorizedClient, principal); } @@ -149,14 +143,11 @@ public class JdbcOAuth2AuthorizedClientService implements OAuth2AuthorizedClient private void updateAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal) { List parameters = this.authorizedClientParametersMapper .apply(new OAuth2AuthorizedClientHolder(authorizedClient, principal)); - SqlParameterValue clientRegistrationIdParameter = parameters.remove(0); SqlParameterValue principalNameParameter = parameters.remove(0); parameters.add(clientRegistrationIdParameter); parameters.add(principalNameParameter); - PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray()); - this.jdbcOperations.update(UPDATE_AUTHORIZED_CLIENT_SQL, pss); } @@ -164,7 +155,6 @@ public class JdbcOAuth2AuthorizedClientService implements OAuth2AuthorizedClient List parameters = this.authorizedClientParametersMapper .apply(new OAuth2AuthorizedClientHolder(authorizedClient, principal)); PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray()); - this.jdbcOperations.update(SAVE_AUTHORIZED_CLIENT_SQL, pss); } @@ -172,12 +162,10 @@ public class JdbcOAuth2AuthorizedClientService implements OAuth2AuthorizedClient public void removeAuthorizedClient(String clientRegistrationId, String principalName) { Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); Assert.hasText(principalName, "principalName cannot be empty"); - SqlParameterValue[] parameters = new SqlParameterValue[] { new SqlParameterValue(Types.VARCHAR, clientRegistrationId), new SqlParameterValue(Types.VARCHAR, principalName) }; PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters); - this.jdbcOperations.update(REMOVE_AUTHORIZED_CLIENT_SQL, pss); } @@ -229,7 +217,6 @@ public class JdbcOAuth2AuthorizedClientService implements OAuth2AuthorizedClient "The ClientRegistration with id '" + clientRegistrationId + "' exists in the data source, " + "however, it was not found in the ClientRegistrationRepository."); } - OAuth2AccessToken.TokenType tokenType = null; if (OAuth2AccessToken.TokenType.BEARER.getValue().equalsIgnoreCase(rs.getString("access_token_type"))) { tokenType = OAuth2AccessToken.TokenType.BEARER; @@ -243,7 +230,6 @@ public class JdbcOAuth2AuthorizedClientService implements OAuth2AuthorizedClient scopes = StringUtils.commaDelimitedListToSet(accessTokenScopes); } OAuth2AccessToken accessToken = new OAuth2AccessToken(tokenType, tokenValue, issuedAt, expiresAt, scopes); - OAuth2RefreshToken refreshToken = null; byte[] refreshTokenValue = rs.getBytes("refresh_token_value"); if (refreshTokenValue != null) { @@ -255,9 +241,7 @@ public class JdbcOAuth2AuthorizedClientService implements OAuth2AuthorizedClient } refreshToken = new OAuth2RefreshToken(tokenValue, issuedAt); } - String principalName = rs.getString("principal_name"); - return new OAuth2AuthorizedClient(clientRegistration, principalName, accessToken, refreshToken); } @@ -277,7 +261,6 @@ public class JdbcOAuth2AuthorizedClientService implements OAuth2AuthorizedClient ClientRegistration clientRegistration = authorizedClient.getClientRegistration(); OAuth2AccessToken accessToken = authorizedClient.getAccessToken(); OAuth2RefreshToken refreshToken = authorizedClient.getRefreshToken(); - List parameters = new ArrayList<>(); parameters.add(new SqlParameterValue(Types.VARCHAR, clientRegistration.getRegistrationId())); parameters.add(new SqlParameterValue(Types.VARCHAR, principal.getName())); @@ -301,7 +284,6 @@ public class JdbcOAuth2AuthorizedClientService implements OAuth2AuthorizedClient } parameters.add(new SqlParameterValue(Types.BLOB, refreshTokenValue)); parameters.add(new SqlParameterValue(Types.TIMESTAMP, refreshTokenIssuedAt)); - return parameters; } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizeRequest.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizeRequest.java index b4ef828a13..58a4ad3531 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizeRequest.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizeRequest.java @@ -157,8 +157,8 @@ public final class OAuth2AuthorizeRequest { private static Authentication createAuthentication(final String principalName) { Assert.hasText(principalName, "principalName cannot be empty"); - return new AbstractAuthenticationToken(null) { + @Override public Object getCredentials() { return ""; @@ -168,6 +168,7 @@ public final class OAuth2AuthorizeRequest { public Object getPrincipal() { return principalName; } + }; } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/PasswordOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/PasswordOAuth2AuthorizedClientProvider.java index a45b697dbf..931e862b79 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/PasswordOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/PasswordOAuth2AuthorizedClientProvider.java @@ -77,48 +77,43 @@ public final class PasswordOAuth2AuthorizedClientProvider implements OAuth2Autho @Nullable public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { Assert.notNull(context, "context cannot be null"); - ClientRegistration clientRegistration = context.getClientRegistration(); OAuth2AuthorizedClient authorizedClient = context.getAuthorizedClient(); - if (!AuthorizationGrantType.PASSWORD.equals(clientRegistration.getAuthorizationGrantType())) { return null; } - String username = context.getAttribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME); String password = context.getAttribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME); if (!StringUtils.hasText(username) || !StringUtils.hasText(password)) { return null; } - if (authorizedClient != null && !hasTokenExpired(authorizedClient.getAccessToken())) { // If client is already authorized and access token is NOT expired than no // need for re-authorization return null; } - if (authorizedClient != null && hasTokenExpired(authorizedClient.getAccessToken()) && authorizedClient.getRefreshToken() != null) { // If client is already authorized and access token is expired and a refresh - // token is available, - // than return and allow RefreshTokenOAuth2AuthorizedClientProvider to handle - // the refresh + // token is available, than return and allow + // RefreshTokenOAuth2AuthorizedClientProvider to handle the refresh return null; } - OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(clientRegistration, username, password); + OAuth2AccessTokenResponse tokenResponse = getTokenResponse(clientRegistration, passwordGrantRequest); + return new OAuth2AuthorizedClient(clientRegistration, context.getPrincipal().getName(), + tokenResponse.getAccessToken(), tokenResponse.getRefreshToken()); + } - OAuth2AccessTokenResponse tokenResponse; + private OAuth2AccessTokenResponse getTokenResponse(ClientRegistration clientRegistration, + OAuth2PasswordGrantRequest passwordGrantRequest) { try { - tokenResponse = this.accessTokenResponseClient.getTokenResponse(passwordGrantRequest); + return this.accessTokenResponseClient.getTokenResponse(passwordGrantRequest); } catch (OAuth2AuthorizationException ex) { throw new ClientAuthorizationException(ex.getError(), clientRegistration.getRegistrationId(), ex); } - - return new OAuth2AuthorizedClient(clientRegistration, context.getPrincipal().getName(), - tokenResponse.getAccessToken(), tokenResponse.getRefreshToken()); } private boolean hasTokenExpired(AbstractOAuth2Token token) { diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/PasswordReactiveOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/PasswordReactiveOAuth2AuthorizedClientProvider.java index 3b34cf85d9..7240fef0cc 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/PasswordReactiveOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/PasswordReactiveOAuth2AuthorizedClientProvider.java @@ -77,26 +77,21 @@ public final class PasswordReactiveOAuth2AuthorizedClientProvider implements Rea @Override public Mono authorize(OAuth2AuthorizationContext context) { Assert.notNull(context, "context cannot be null"); - ClientRegistration clientRegistration = context.getClientRegistration(); OAuth2AuthorizedClient authorizedClient = context.getAuthorizedClient(); - if (!AuthorizationGrantType.PASSWORD.equals(clientRegistration.getAuthorizationGrantType())) { return Mono.empty(); } - String username = context.getAttribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME); String password = context.getAttribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME); if (!StringUtils.hasText(username) || !StringUtils.hasText(password)) { return Mono.empty(); } - if (authorizedClient != null && !hasTokenExpired(authorizedClient.getAccessToken())) { // If client is already authorized and access token is NOT expired than no // need for re-authorization return Mono.empty(); } - if (authorizedClient != null && hasTokenExpired(authorizedClient.getAccessToken()) && authorizedClient.getRefreshToken() != null) { // If client is already authorized and access token is expired and a refresh @@ -105,10 +100,8 @@ public final class PasswordReactiveOAuth2AuthorizedClientProvider implements Rea // handle the refresh return Mono.empty(); } - OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(clientRegistration, username, password); - return Mono.just(passwordGrantRequest).flatMap(this.accessTokenResponseClient::getTokenResponse) .onErrorMap(OAuth2AuthorizationException.class, (e) -> new ClientAuthorizationException(e.getError(), clientRegistration.getRegistrationId(), diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java index 794ce315df..04962922d9 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java @@ -75,13 +75,11 @@ public final class RefreshTokenOAuth2AuthorizedClientProvider implements OAuth2A @Nullable public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { Assert.notNull(context, "context cannot be null"); - OAuth2AuthorizedClient authorizedClient = context.getAuthorizedClient(); if (authorizedClient == null || authorizedClient.getRefreshToken() == null || !hasTokenExpired(authorizedClient.getAccessToken())) { return null; } - Object requestScope = context.getAttribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME); Set scopes = Collections.emptySet(); if (requestScope != null) { @@ -89,22 +87,23 @@ public final class RefreshTokenOAuth2AuthorizedClientProvider implements OAuth2A + OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME + "'"); scopes = new HashSet<>(Arrays.asList((String[]) requestScope)); } - OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( authorizedClient.getClientRegistration(), authorizedClient.getAccessToken(), authorizedClient.getRefreshToken(), scopes); + OAuth2AccessTokenResponse tokenResponse = getTokenResponse(authorizedClient, refreshTokenGrantRequest); + return new OAuth2AuthorizedClient(context.getAuthorizedClient().getClientRegistration(), + context.getPrincipal().getName(), tokenResponse.getAccessToken(), tokenResponse.getRefreshToken()); + } - OAuth2AccessTokenResponse tokenResponse; + private OAuth2AccessTokenResponse getTokenResponse(OAuth2AuthorizedClient authorizedClient, + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest) { try { - tokenResponse = this.accessTokenResponseClient.getTokenResponse(refreshTokenGrantRequest); + return this.accessTokenResponseClient.getTokenResponse(refreshTokenGrantRequest); } catch (OAuth2AuthorizationException ex) { throw new ClientAuthorizationException(ex.getError(), authorizedClient.getClientRegistration().getRegistrationId(), ex); } - - return new OAuth2AuthorizedClient(context.getAuthorizedClient().getClientRegistration(), - context.getPrincipal().getName(), tokenResponse.getAccessToken(), tokenResponse.getRefreshToken()); } private boolean hasTokenExpired(AbstractOAuth2Token token) { diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizedClientProvider.java index d93f36e485..5f6e16369d 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenReactiveOAuth2AuthorizedClientProvider.java @@ -77,13 +77,11 @@ public final class RefreshTokenReactiveOAuth2AuthorizedClientProvider @Override public Mono authorize(OAuth2AuthorizationContext context) { Assert.notNull(context, "context cannot be null"); - OAuth2AuthorizedClient authorizedClient = context.getAuthorizedClient(); if (authorizedClient == null || authorizedClient.getRefreshToken() == null || !hasTokenExpired(authorizedClient.getAccessToken())) { return Mono.empty(); } - Object requestScope = context.getAttribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME); Set scopes = Collections.emptySet(); if (requestScope != null) { @@ -92,10 +90,8 @@ public final class RefreshTokenReactiveOAuth2AuthorizedClientProvider scopes = new HashSet<>(Arrays.asList((String[]) requestScope)); } ClientRegistration clientRegistration = context.getClientRegistration(); - OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration, authorizedClient.getAccessToken(), authorizedClient.getRefreshToken(), scopes); - return Mono.just(refreshTokenGrantRequest).flatMap(this.accessTokenResponseClient::getTokenResponse) .onErrorMap(OAuth2AuthorizationException.class, (e) -> new ClientAuthorizationException(e.getError(), clientRegistration.getRegistrationId(), diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RemoveAuthorizedClientOAuth2AuthorizationFailureHandler.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RemoveAuthorizedClientOAuth2AuthorizationFailureHandler.java index aafaaa422f..3701e8457c 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RemoveAuthorizedClientOAuth2AuthorizationFailureHandler.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RemoveAuthorizedClientOAuth2AuthorizationFailureHandler.java @@ -16,9 +16,9 @@ package org.springframework.security.oauth2.client; -import java.util.Arrays; import java.util.Collections; import java.util.HashSet; +import java.util.LinkedHashSet; import java.util.Map; import java.util.Set; @@ -47,26 +47,20 @@ public class RemoveAuthorizedClientOAuth2AuthorizationFailureHandler implements * {@link OAuth2AuthorizedClient}. * @see OAuth2ErrorCodes */ - public static final Set DEFAULT_REMOVE_AUTHORIZED_CLIENT_ERROR_CODES = Collections - .unmodifiableSet(new HashSet<>(Arrays.asList( - /* - * Returned from Resource Servers when an access token provided is - * expired, revoked, malformed, or invalid for other reasons. - * - * Note that this is needed because - * ServletOAuth2AuthorizedClientExchangeFilterFunction delegates this - * type of failure received from a Resource Server to this failure - * handler. - */ - OAuth2ErrorCodes.INVALID_TOKEN, - - /* - * Returned from Authorization Servers when the authorization grant or - * refresh token is invalid, expired, revoked, does not match the - * redirection URI used in the authorization request, or was issued to - * another client. - */ - OAuth2ErrorCodes.INVALID_GRANT))); + public static final Set DEFAULT_REMOVE_AUTHORIZED_CLIENT_ERROR_CODES; + static { + Set codes = new LinkedHashSet<>(); + // Returned from Resource Servers when an access token provided is expired, + // revoked, malformed, or invalid for other reasons. Note that this is needed + // because ServletOAuth2AuthorizedClientExchangeFilterFunction delegates this type + // of failure received from a Resource Server to this failure handler. + codes.add(OAuth2ErrorCodes.INVALID_TOKEN); + // Returned from Authorization Servers when the authorization grant or refresh + // token is invalid, expired, revoked, does not match the redirection URI used in + // the authorization request, or was issued to another client. + codes.add(OAuth2ErrorCodes.INVALID_GRANT); + DEFAULT_REMOVE_AUTHORIZED_CLIENT_ERROR_CODES = Collections.unmodifiableSet(codes); + } /** * The OAuth 2.0 error codes which will trigger removal of an @@ -116,10 +110,8 @@ public class RemoveAuthorizedClientOAuth2AuthorizationFailureHandler implements @Override public void onAuthorizationFailure(OAuth2AuthorizationException authorizationException, Authentication principal, Map attributes) { - if (authorizationException instanceof ClientAuthorizationException && hasRemovalErrorCode(authorizationException)) { - ClientAuthorizationException clientAuthorizationException = (ClientAuthorizationException) authorizationException; this.delegate.removeAuthorizedClient(clientAuthorizationException.getClientRegistrationId(), principal, attributes); diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler.java index 0c7b295fb2..0e7edd67d7 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler.java @@ -16,9 +16,9 @@ package org.springframework.security.oauth2.client; -import java.util.Arrays; import java.util.Collections; import java.util.HashSet; +import java.util.LinkedHashSet; import java.util.Map; import java.util.Set; @@ -47,24 +47,20 @@ public class RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler * client. * @see OAuth2ErrorCodes */ - public static final Set DEFAULT_REMOVE_AUTHORIZED_CLIENT_ERROR_CODES = Collections - .unmodifiableSet(new HashSet<>(Arrays.asList( - /* - * Returned from resource servers when an access token provided is - * expired, revoked, malformed, or invalid for other reasons. - * - * Note that this is needed because the - * ServerOAuth2AuthorizedClientExchangeFilterFunction delegates this - * type of failure received from a resource server to this failure - * handler. - */ - OAuth2ErrorCodes.INVALID_TOKEN, - /* - * Returned from authorization servers when a refresh token is - * invalid, expired, revoked, does not match the redirection URI used - * in the authorization request, or was issued to another client. - */ - OAuth2ErrorCodes.INVALID_GRANT))); + public static final Set DEFAULT_REMOVE_AUTHORIZED_CLIENT_ERROR_CODES; + static { + Set codes = new LinkedHashSet<>(); + // Returned from resource servers when an access token provided is expired, + // revoked, malformed, or invalid for other reasons. Note that this is needed + // because the ServerOAuth2AuthorizedClientExchangeFilterFunction delegates this + // type of failure received from a resource server to this failure handler. + codes.add(OAuth2ErrorCodes.INVALID_TOKEN); + // Returned from authorization servers when a refresh token is invalid, expired, + // revoked, does not match the redirection URI used in the authorization request, + // or was issued to another client. + codes.add(OAuth2ErrorCodes.INVALID_GRANT); + DEFAULT_REMOVE_AUTHORIZED_CLIENT_ERROR_CODES = Collections.unmodifiableSet(codes); + } /** * A delegate that removes an {@link OAuth2AuthorizedClient} from a @@ -116,17 +112,13 @@ public class RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler @Override public Mono onAuthorizationFailure(OAuth2AuthorizationException authorizationException, Authentication principal, Map attributes) { - if (authorizationException instanceof ClientAuthorizationException && hasRemovalErrorCode(authorizationException)) { - ClientAuthorizationException clientAuthorizationException = (ClientAuthorizationException) authorizationException; return this.delegate.removeAuthorizedClient(clientAuthorizationException.getClientRegistrationId(), principal, attributes); } - else { - return Mono.empty(); - } + return Mono.empty(); } /** diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java index 64d870d790..efcf42a19a 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java @@ -64,7 +64,6 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica */ public OAuth2AuthorizationCodeAuthenticationProvider( OAuth2AccessTokenResponseClient accessTokenResponseClient) { - Assert.notNull(accessTokenResponseClient, "accessTokenResponseClient cannot be null"); this.accessTokenResponseClient = accessTokenResponseClient; } @@ -72,30 +71,25 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica @Override public Authentication authenticate(Authentication authentication) throws AuthenticationException { OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication = (OAuth2AuthorizationCodeAuthenticationToken) authentication; - OAuth2AuthorizationResponse authorizationResponse = authorizationCodeAuthentication.getAuthorizationExchange() .getAuthorizationResponse(); if (authorizationResponse.statusError()) { throw new OAuth2AuthorizationException(authorizationResponse.getError()); } - OAuth2AuthorizationRequest authorizationRequest = authorizationCodeAuthentication.getAuthorizationExchange() .getAuthorizationRequest(); if (!authorizationResponse.getState().equals(authorizationRequest.getState())) { OAuth2Error oauth2Error = new OAuth2Error(INVALID_STATE_PARAMETER_ERROR_CODE); throw new OAuth2AuthorizationException(oauth2Error); } - OAuth2AccessTokenResponse accessTokenResponse = this.accessTokenResponseClient.getTokenResponse( new OAuth2AuthorizationCodeGrantRequest(authorizationCodeAuthentication.getClientRegistration(), authorizationCodeAuthentication.getAuthorizationExchange())); - OAuth2AuthorizationCodeAuthenticationToken authenticationResult = new OAuth2AuthorizationCodeAuthenticationToken( authorizationCodeAuthentication.getClientRegistration(), authorizationCodeAuthentication.getAuthorizationExchange(), accessTokenResponse.getAccessToken(), accessTokenResponse.getRefreshToken(), accessTokenResponse.getAdditionalParameters()); authenticationResult.setDetails(authorizationCodeAuthentication.getDetails()); - return authenticationResult; } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeReactiveAuthenticationManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeReactiveAuthenticationManager.java index 034ed3efb7..a2497f5978 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeReactiveAuthenticationManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeReactiveAuthenticationManager.java @@ -84,23 +84,19 @@ public class OAuth2AuthorizationCodeReactiveAuthenticationManager implements Rea public Mono authenticate(Authentication authentication) { return Mono.defer(() -> { OAuth2AuthorizationCodeAuthenticationToken token = (OAuth2AuthorizationCodeAuthenticationToken) authentication; - OAuth2AuthorizationResponse authorizationResponse = token.getAuthorizationExchange() .getAuthorizationResponse(); if (authorizationResponse.statusError()) { return Mono.error(new OAuth2AuthorizationException(authorizationResponse.getError())); } - OAuth2AuthorizationRequest authorizationRequest = token.getAuthorizationExchange() .getAuthorizationRequest(); if (!authorizationResponse.getState().equals(authorizationRequest.getState())) { OAuth2Error oauth2Error = new OAuth2Error(INVALID_STATE_PARAMETER_ERROR_CODE); return Mono.error(new OAuth2AuthorizationException(oauth2Error)); } - OAuth2AuthorizationCodeGrantRequest authzRequest = new OAuth2AuthorizationCodeGrantRequest( token.getClientRegistration(), token.getAuthorizationExchange()); - return this.accessTokenResponseClient.getTokenResponse(authzRequest).map(onSuccess(token)); }); } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProvider.java index e20e0006e9..a82d320e53 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProvider.java @@ -83,7 +83,6 @@ public class OAuth2LoginAuthenticationProvider implements AuthenticationProvider public OAuth2LoginAuthenticationProvider( OAuth2AccessTokenResponseClient accessTokenResponseClient, OAuth2UserService userService) { - Assert.notNull(userService, "userService cannot be null"); this.authorizationCodeAuthenticationProvider = new OAuth2AuthorizationCodeAuthenticationProvider( accessTokenResponseClient); @@ -93,10 +92,8 @@ public class OAuth2LoginAuthenticationProvider implements AuthenticationProvider @Override public Authentication authenticate(Authentication authentication) throws AuthenticationException { OAuth2LoginAuthenticationToken loginAuthenticationToken = (OAuth2LoginAuthenticationToken) authentication; - // Section 3.1.2.1 Authentication Request - - // https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest - // scope + // https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest scope // REQUIRED. OpenID Connect requests MUST contain the "openid" scope value. if (loginAuthenticationToken.getAuthorizationExchange().getAuthorizationRequest().getScopes() .contains("openid")) { @@ -104,7 +101,6 @@ public class OAuth2LoginAuthenticationProvider implements AuthenticationProvider // and let OidcAuthorizationCodeAuthenticationProvider handle it instead return null; } - OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthenticationToken; try { authorizationCodeAuthenticationToken = (OAuth2AuthorizationCodeAuthenticationToken) this.authorizationCodeAuthenticationProvider @@ -116,21 +112,16 @@ public class OAuth2LoginAuthenticationProvider implements AuthenticationProvider OAuth2Error oauth2Error = ex.getError(); throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); } - OAuth2AccessToken accessToken = authorizationCodeAuthenticationToken.getAccessToken(); Map additionalParameters = authorizationCodeAuthenticationToken.getAdditionalParameters(); - OAuth2User oauth2User = this.userService.loadUser(new OAuth2UserRequest( loginAuthenticationToken.getClientRegistration(), accessToken, additionalParameters)); - Collection mappedAuthorities = this.authoritiesMapper .mapAuthorities(oauth2User.getAuthorities()); - OAuth2LoginAuthenticationToken authenticationResult = new OAuth2LoginAuthenticationToken( loginAuthenticationToken.getClientRegistration(), loginAuthenticationToken.getAuthorizationExchange(), oauth2User, mappedAuthorities, accessToken, authorizationCodeAuthenticationToken.getRefreshToken()); authenticationResult.setDetails(loginAuthenticationToken.getDetails()); - return authenticationResult; } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationToken.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationToken.java index cbfe8b709f..afbe15784f 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationToken.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationToken.java @@ -66,7 +66,6 @@ public class OAuth2LoginAuthenticationToken extends AbstractAuthenticationToken */ public OAuth2LoginAuthenticationToken(ClientRegistration clientRegistration, OAuth2AuthorizationExchange authorizationExchange) { - super(Collections.emptyList()); Assert.notNull(clientRegistration, "clientRegistration cannot be null"); Assert.notNull(authorizationExchange, "authorizationExchange cannot be null"); diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginReactiveAuthenticationManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginReactiveAuthenticationManager.java index 1d3dae2fbe..e4b72951b5 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginReactiveAuthenticationManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginReactiveAuthenticationManager.java @@ -88,18 +88,15 @@ public class OAuth2LoginReactiveAuthenticationManager implements ReactiveAuthent public Mono authenticate(Authentication authentication) { return Mono.defer(() -> { OAuth2AuthorizationCodeAuthenticationToken token = (OAuth2AuthorizationCodeAuthenticationToken) authentication; - // Section 3.1.2.1 Authentication Request - - // https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest - // scope REQUIRED. OpenID Connect requests MUST contain the "openid" scope - // value. + // https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest scope + // REQUIRED. OpenID Connect requests MUST contain the "openid" scope value. if (token.getAuthorizationExchange().getAuthorizationRequest().getScopes().contains("openid")) { // This is an OpenID Connect Authentication Request so return null // and let OidcAuthorizationCodeReactiveAuthenticationManager handle it // instead once one is created return Mono.empty(); } - return this.authorizationCodeManager.authenticate(token) .onErrorMap(OAuth2AuthorizationException.class, (e) -> new OAuth2AuthenticationException(e.getError(), e.getError().toString())) @@ -128,7 +125,6 @@ public class OAuth2LoginReactiveAuthenticationManager implements ReactiveAuthent return this.userService.loadUser(userRequest).map((oauth2User) -> { Collection mappedAuthorities = this.authoritiesMapper .mapAuthorities(oauth2User.getAuthorities()); - OAuth2LoginAuthenticationToken authenticationResult = new OAuth2LoginAuthenticationToken( authentication.getClientRegistration(), authentication.getAuthorizationExchange(), oauth2User, mappedAuthorities, accessToken, authentication.getRefreshToken()); diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultAuthorizationCodeTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultAuthorizationCodeTokenResponseClient.java index 46ea462465..064c7362d6 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultAuthorizationCodeTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultAuthorizationCodeTokenResponseClient.java @@ -74,23 +74,9 @@ public final class DefaultAuthorizationCodeTokenResponseClient public OAuth2AccessTokenResponse getTokenResponse( OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest) { Assert.notNull(authorizationCodeGrantRequest, "authorizationCodeGrantRequest cannot be null"); - RequestEntity request = this.requestEntityConverter.convert(authorizationCodeGrantRequest); - - ResponseEntity response; - try { - response = this.restOperations.exchange(request, OAuth2AccessTokenResponse.class); - } - catch (RestClientException ex) { - OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE, - "An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " - + ex.getMessage(), - null); - throw new OAuth2AuthorizationException(oauth2Error, ex); - } - + ResponseEntity response = getResponse(request); OAuth2AccessTokenResponse tokenResponse = response.getBody(); - if (CollectionUtils.isEmpty(tokenResponse.getAccessToken().getScopes())) { // As per spec, in Section 5.1 Successful Access Token Response // https://tools.ietf.org/html/rfc6749#section-5.1 @@ -99,10 +85,22 @@ public final class DefaultAuthorizationCodeTokenResponseClient tokenResponse = OAuth2AccessTokenResponse.withResponse(tokenResponse) .scopes(authorizationCodeGrantRequest.getClientRegistration().getScopes()).build(); } - return tokenResponse; } + private ResponseEntity getResponse(RequestEntity request) { + try { + return this.restOperations.exchange(request, OAuth2AccessTokenResponse.class); + } + catch (RestClientException ex) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE, + "An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + + ex.getMessage(), + null); + throw new OAuth2AuthorizationException(oauth2Error, ex); + } + } + /** * Sets the {@link Converter} used for converting the * {@link OAuth2AuthorizationCodeGrantRequest} to a {@link RequestEntity} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClient.java index b74fcad2fa..a216d75b09 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClient.java @@ -74,23 +74,9 @@ public final class DefaultClientCredentialsTokenResponseClient public OAuth2AccessTokenResponse getTokenResponse( OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest) { Assert.notNull(clientCredentialsGrantRequest, "clientCredentialsGrantRequest cannot be null"); - RequestEntity request = this.requestEntityConverter.convert(clientCredentialsGrantRequest); - - ResponseEntity response; - try { - response = this.restOperations.exchange(request, OAuth2AccessTokenResponse.class); - } - catch (RestClientException ex) { - OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE, - "An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " - + ex.getMessage(), - null); - throw new OAuth2AuthorizationException(oauth2Error, ex); - } - + ResponseEntity response = getResponse(request); OAuth2AccessTokenResponse tokenResponse = response.getBody(); - if (CollectionUtils.isEmpty(tokenResponse.getAccessToken().getScopes())) { // As per spec, in Section 5.1 Successful Access Token Response // https://tools.ietf.org/html/rfc6749#section-5.1 @@ -99,10 +85,22 @@ public final class DefaultClientCredentialsTokenResponseClient tokenResponse = OAuth2AccessTokenResponse.withResponse(tokenResponse) .scopes(clientCredentialsGrantRequest.getClientRegistration().getScopes()).build(); } - return tokenResponse; } + private ResponseEntity getResponse(RequestEntity request) { + try { + return this.restOperations.exchange(request, OAuth2AccessTokenResponse.class); + } + catch (RestClientException ex) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE, + "An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + + ex.getMessage(), + null); + throw new OAuth2AuthorizationException(oauth2Error, ex); + } + } + /** * Sets the {@link Converter} used for converting the * {@link OAuth2ClientCredentialsGrantRequest} to a {@link RequestEntity} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultPasswordTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultPasswordTokenResponseClient.java index cea3eee2f9..047e787885 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultPasswordTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultPasswordTokenResponseClient.java @@ -73,23 +73,9 @@ public final class DefaultPasswordTokenResponseClient @Override public OAuth2AccessTokenResponse getTokenResponse(OAuth2PasswordGrantRequest passwordGrantRequest) { Assert.notNull(passwordGrantRequest, "passwordGrantRequest cannot be null"); - RequestEntity request = this.requestEntityConverter.convert(passwordGrantRequest); - - ResponseEntity response; - try { - response = this.restOperations.exchange(request, OAuth2AccessTokenResponse.class); - } - catch (RestClientException ex) { - OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE, - "An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " - + ex.getMessage(), - null); - throw new OAuth2AuthorizationException(oauth2Error, ex); - } - + ResponseEntity response = getResponse(request); OAuth2AccessTokenResponse tokenResponse = response.getBody(); - if (CollectionUtils.isEmpty(tokenResponse.getAccessToken().getScopes())) { // As per spec, in Section 5.1 Successful Access Token Response // https://tools.ietf.org/html/rfc6749#section-5.1 @@ -98,10 +84,22 @@ public final class DefaultPasswordTokenResponseClient tokenResponse = OAuth2AccessTokenResponse.withResponse(tokenResponse) .scopes(passwordGrantRequest.getClientRegistration().getScopes()).build(); } - return tokenResponse; } + private ResponseEntity getResponse(RequestEntity request) { + try { + return this.restOperations.exchange(request, OAuth2AccessTokenResponse.class); + } + catch (RestClientException ex) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE, + "An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + + ex.getMessage(), + null); + throw new OAuth2AuthorizationException(oauth2Error, ex); + } + } + /** * Sets the {@link Converter} used for converting the * {@link OAuth2PasswordGrantRequest} to a {@link RequestEntity} representation of the diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClient.java index 657ab72c7d..8550d077c0 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClient.java @@ -69,12 +69,32 @@ public final class DefaultRefreshTokenTokenResponseClient @Override public OAuth2AccessTokenResponse getTokenResponse(OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest) { Assert.notNull(refreshTokenGrantRequest, "refreshTokenGrantRequest cannot be null"); - RequestEntity request = this.requestEntityConverter.convert(refreshTokenGrantRequest); + ResponseEntity response = getResponse(request); + OAuth2AccessTokenResponse tokenResponse = response.getBody(); + if (CollectionUtils.isEmpty(tokenResponse.getAccessToken().getScopes()) + || tokenResponse.getRefreshToken() == null) { + OAuth2AccessTokenResponse.Builder tokenResponseBuilder = OAuth2AccessTokenResponse + .withResponse(tokenResponse); + if (CollectionUtils.isEmpty(tokenResponse.getAccessToken().getScopes())) { + // As per spec, in Section 5.1 Successful Access Token Response + // https://tools.ietf.org/html/rfc6749#section-5.1 + // If AccessTokenResponse.scope is empty, then default to the scope + // originally requested by the client in the Token Request + tokenResponseBuilder.scopes(refreshTokenGrantRequest.getAccessToken().getScopes()); + } + if (tokenResponse.getRefreshToken() == null) { + // Reuse existing refresh token + tokenResponseBuilder.refreshToken(refreshTokenGrantRequest.getRefreshToken().getTokenValue()); + } + tokenResponse = tokenResponseBuilder.build(); + } + return tokenResponse; + } - ResponseEntity response; + private ResponseEntity getResponse(RequestEntity request) { try { - response = this.restOperations.exchange(request, OAuth2AccessTokenResponse.class); + return this.restOperations.exchange(request, OAuth2AccessTokenResponse.class); } catch (RestClientException ex) { OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE, @@ -83,31 +103,6 @@ public final class DefaultRefreshTokenTokenResponseClient null); throw new OAuth2AuthorizationException(oauth2Error, ex); } - - OAuth2AccessTokenResponse tokenResponse = response.getBody(); - - if (CollectionUtils.isEmpty(tokenResponse.getAccessToken().getScopes()) - || tokenResponse.getRefreshToken() == null) { - OAuth2AccessTokenResponse.Builder tokenResponseBuilder = OAuth2AccessTokenResponse - .withResponse(tokenResponse); - - if (CollectionUtils.isEmpty(tokenResponse.getAccessToken().getScopes())) { - // As per spec, in Section 5.1 Successful Access Token Response - // https://tools.ietf.org/html/rfc6749#section-5.1 - // If AccessTokenResponse.scope is empty, then default to the scope - // originally requested by the client in the Token Request - tokenResponseBuilder.scopes(refreshTokenGrantRequest.getAccessToken().getScopes()); - } - - if (tokenResponse.getRefreshToken() == null) { - // Reuse existing refresh token - tokenResponseBuilder.refreshToken(refreshTokenGrantRequest.getRefreshToken().getTokenValue()); - } - - tokenResponse = tokenResponseBuilder.build(); - } - - return tokenResponse; } /** diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusAuthorizationCodeTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusAuthorizationCodeTokenResponseClient.java index 83c38dbe43..0f8b6a54ef 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusAuthorizationCodeTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusAuthorizationCodeTokenResponseClient.java @@ -81,7 +81,6 @@ public class NimbusAuthorizationCodeTokenResponseClient @Override public OAuth2AccessTokenResponse getTokenResponse(OAuth2AuthorizationCodeGrantRequest authorizationGrantRequest) { ClientRegistration clientRegistration = authorizationGrantRequest.getClientRegistration(); - // Build the authorization code grant request for the token endpoint AuthorizationCode authorizationCode = new AuthorizationCode( authorizationGrantRequest.getAuthorizationExchange().getAuthorizationResponse().getCode()); @@ -89,19 +88,43 @@ public class NimbusAuthorizationCodeTokenResponseClient authorizationGrantRequest.getAuthorizationExchange().getAuthorizationRequest().getRedirectUri()); AuthorizationGrant authorizationCodeGrant = new AuthorizationCodeGrant(authorizationCode, redirectUri); URI tokenUri = toURI(clientRegistration.getProviderDetails().getTokenUri()); - // Set the credentials to authenticate the client at the token endpoint ClientID clientId = new ClientID(clientRegistration.getClientId()); Secret clientSecret = new Secret(clientRegistration.getClientSecret()); - ClientAuthentication clientAuthentication; - if (ClientAuthenticationMethod.POST.equals(clientRegistration.getClientAuthenticationMethod())) { - clientAuthentication = new ClientSecretPost(clientId, clientSecret); + boolean isPost = ClientAuthenticationMethod.POST.equals(clientRegistration.getClientAuthenticationMethod()); + ClientAuthentication clientAuthentication = isPost ? new ClientSecretPost(clientId, clientSecret) + : new ClientSecretBasic(clientId, clientSecret); + com.nimbusds.oauth2.sdk.TokenResponse tokenResponse = getTokenResponse(authorizationCodeGrant, tokenUri, + clientAuthentication); + if (!tokenResponse.indicatesSuccess()) { + TokenErrorResponse tokenErrorResponse = (TokenErrorResponse) tokenResponse; + ErrorObject errorObject = tokenErrorResponse.getErrorObject(); + throw new OAuth2AuthorizationException(getOAuthError(errorObject)); } - else { - clientAuthentication = new ClientSecretBasic(clientId, clientSecret); + AccessTokenResponse accessTokenResponse = (AccessTokenResponse) tokenResponse; + String accessToken = accessTokenResponse.getTokens().getAccessToken().getValue(); + OAuth2AccessToken.TokenType accessTokenType = null; + if (OAuth2AccessToken.TokenType.BEARER.getValue() + .equalsIgnoreCase(accessTokenResponse.getTokens().getAccessToken().getType().getValue())) { + accessTokenType = OAuth2AccessToken.TokenType.BEARER; } + long expiresIn = accessTokenResponse.getTokens().getAccessToken().getLifetime(); + // As per spec, in section 5.1 Successful Access Token Response + // https://tools.ietf.org/html/rfc6749#section-5.1 + // If AccessTokenResponse.scope is empty, then default to the scope + // originally requested by the client in the Authorization Request + Set scopes = getScopes(authorizationGrantRequest, accessTokenResponse); + String refreshToken = null; + if (accessTokenResponse.getTokens().getRefreshToken() != null) { + refreshToken = accessTokenResponse.getTokens().getRefreshToken().getValue(); + } + Map additionalParameters = new LinkedHashMap<>(accessTokenResponse.getCustomParameters()); + return OAuth2AccessTokenResponse.withToken(accessToken).tokenType(accessTokenType).expiresIn(expiresIn) + .scopes(scopes).refreshToken(refreshToken).additionalParameters(additionalParameters).build(); + } - com.nimbusds.oauth2.sdk.TokenResponse tokenResponse; + private com.nimbusds.oauth2.sdk.TokenResponse getTokenResponse(AuthorizationGrant authorizationCodeGrant, + URI tokenUri, ClientAuthentication clientAuthentication) { try { // Send the Access Token request TokenRequest tokenRequest = new TokenRequest(tokenUri, clientAuthentication, authorizationCodeGrant); @@ -109,7 +132,7 @@ public class NimbusAuthorizationCodeTokenResponseClient httpRequest.setAccept(MediaType.APPLICATION_JSON_VALUE); httpRequest.setConnectTimeout(30000); httpRequest.setReadTimeout(30000); - tokenResponse = com.nimbusds.oauth2.sdk.TokenResponse.parse(httpRequest.send()); + return com.nimbusds.oauth2.sdk.TokenResponse.parse(httpRequest.send()); } catch (ParseException | IOException ex) { OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE, @@ -118,55 +141,25 @@ public class NimbusAuthorizationCodeTokenResponseClient null); throw new OAuth2AuthorizationException(oauth2Error, ex); } + } - if (!tokenResponse.indicatesSuccess()) { - TokenErrorResponse tokenErrorResponse = (TokenErrorResponse) tokenResponse; - ErrorObject errorObject = tokenErrorResponse.getErrorObject(); - OAuth2Error oauth2Error; - if (errorObject == null) { - oauth2Error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR); - } - else { - oauth2Error = new OAuth2Error( - (errorObject.getCode() != null) ? errorObject.getCode() : OAuth2ErrorCodes.SERVER_ERROR, - errorObject.getDescription(), - (errorObject.getURI() != null) ? errorObject.getURI().toString() : null); - } - throw new OAuth2AuthorizationException(oauth2Error); - } - - AccessTokenResponse accessTokenResponse = (AccessTokenResponse) tokenResponse; - - String accessToken = accessTokenResponse.getTokens().getAccessToken().getValue(); - OAuth2AccessToken.TokenType accessTokenType = null; - if (OAuth2AccessToken.TokenType.BEARER.getValue() - .equalsIgnoreCase(accessTokenResponse.getTokens().getAccessToken().getType().getValue())) { - accessTokenType = OAuth2AccessToken.TokenType.BEARER; - } - long expiresIn = accessTokenResponse.getTokens().getAccessToken().getLifetime(); - - // As per spec, in section 5.1 Successful Access Token Response - // https://tools.ietf.org/html/rfc6749#section-5.1 - // If AccessTokenResponse.scope is empty, then default to the scope - // originally requested by the client in the Authorization Request - Set scopes; + private Set getScopes(OAuth2AuthorizationCodeGrantRequest authorizationGrantRequest, + AccessTokenResponse accessTokenResponse) { if (CollectionUtils.isEmpty(accessTokenResponse.getTokens().getAccessToken().getScope())) { - scopes = new LinkedHashSet<>( + return new LinkedHashSet<>( authorizationGrantRequest.getAuthorizationExchange().getAuthorizationRequest().getScopes()); } - else { - scopes = new LinkedHashSet<>(accessTokenResponse.getTokens().getAccessToken().getScope().toStringList()); + return new LinkedHashSet<>(accessTokenResponse.getTokens().getAccessToken().getScope().toStringList()); + } + + private OAuth2Error getOAuthError(ErrorObject errorObject) { + if (errorObject == null) { + return new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR); } - - String refreshToken = null; - if (accessTokenResponse.getTokens().getRefreshToken() != null) { - refreshToken = accessTokenResponse.getTokens().getRefreshToken().getValue(); - } - - Map additionalParameters = new LinkedHashMap<>(accessTokenResponse.getCustomParameters()); - - return OAuth2AccessTokenResponse.withToken(accessToken).tokenType(accessTokenType).expiresIn(expiresIn) - .scopes(scopes).refreshToken(refreshToken).additionalParameters(additionalParameters).build(); + String errorCode = (errorObject.getCode() != null) ? errorObject.getCode() : OAuth2ErrorCodes.SERVER_ERROR; + String description = errorObject.getDescription(); + String uri = (errorObject.getURI() != null) ? errorObject.getURI().toString() : null; + return new OAuth2Error(errorCode, description, uri); } private static URI toURI(String uriStr) { diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestEntityConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestEntityConverter.java index 15eae1f504..a77470c0de 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestEntityConverter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestEntityConverter.java @@ -53,12 +53,10 @@ public class OAuth2AuthorizationCodeGrantRequestEntityConverter @Override public RequestEntity convert(OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest) { ClientRegistration clientRegistration = authorizationCodeGrantRequest.getClientRegistration(); - HttpHeaders headers = OAuth2AuthorizationGrantRequestEntityUtils.getTokenRequestHeaders(clientRegistration); MultiValueMap formParameters = this.buildFormParameters(authorizationCodeGrantRequest); URI uri = UriComponentsBuilder.fromUriString(clientRegistration.getProviderDetails().getTokenUri()).build() .toUri(); - return new RequestEntity<>(formParameters, headers, HttpMethod.POST, uri); } @@ -73,7 +71,6 @@ public class OAuth2AuthorizationCodeGrantRequestEntityConverter OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest) { ClientRegistration clientRegistration = authorizationCodeGrantRequest.getClientRegistration(); OAuth2AuthorizationExchange authorizationExchange = authorizationCodeGrantRequest.getAuthorizationExchange(); - MultiValueMap formParameters = new LinkedMultiValueMap<>(); formParameters.add(OAuth2ParameterNames.GRANT_TYPE, authorizationCodeGrantRequest.getGrantType().getValue()); formParameters.add(OAuth2ParameterNames.CODE, authorizationExchange.getAuthorizationResponse().getCode()); @@ -92,7 +89,6 @@ public class OAuth2AuthorizationCodeGrantRequestEntityConverter if (codeVerifier != null) { formParameters.add(PkceParameterNames.CODE_VERIFIER, codeVerifier); } - return formParameters; } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestEntityConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestEntityConverter.java index 81467775da..b555aabd1b 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestEntityConverter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestEntityConverter.java @@ -53,12 +53,10 @@ public class OAuth2ClientCredentialsGrantRequestEntityConverter @Override public RequestEntity convert(OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest) { ClientRegistration clientRegistration = clientCredentialsGrantRequest.getClientRegistration(); - HttpHeaders headers = OAuth2AuthorizationGrantRequestEntityUtils.getTokenRequestHeaders(clientRegistration); MultiValueMap formParameters = this.buildFormParameters(clientCredentialsGrantRequest); URI uri = UriComponentsBuilder.fromUriString(clientRegistration.getProviderDetails().getTokenUri()).build() .toUri(); - return new RequestEntity<>(formParameters, headers, HttpMethod.POST, uri); } @@ -72,7 +70,6 @@ public class OAuth2ClientCredentialsGrantRequestEntityConverter private MultiValueMap buildFormParameters( OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest) { ClientRegistration clientRegistration = clientCredentialsGrantRequest.getClientRegistration(); - MultiValueMap formParameters = new LinkedMultiValueMap<>(); formParameters.add(OAuth2ParameterNames.GRANT_TYPE, clientCredentialsGrantRequest.getGrantType().getValue()); if (!CollectionUtils.isEmpty(clientRegistration.getScopes())) { @@ -83,7 +80,6 @@ public class OAuth2ClientCredentialsGrantRequestEntityConverter formParameters.add(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId()); formParameters.add(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret()); } - return formParameters; } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequestEntityConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequestEntityConverter.java index 0a9caee61f..1bef0f8404 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequestEntityConverter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2PasswordGrantRequestEntityConverter.java @@ -53,12 +53,10 @@ public class OAuth2PasswordGrantRequestEntityConverter @Override public RequestEntity convert(OAuth2PasswordGrantRequest passwordGrantRequest) { ClientRegistration clientRegistration = passwordGrantRequest.getClientRegistration(); - HttpHeaders headers = OAuth2AuthorizationGrantRequestEntityUtils.getTokenRequestHeaders(clientRegistration); MultiValueMap formParameters = buildFormParameters(passwordGrantRequest); URI uri = UriComponentsBuilder.fromUriString(clientRegistration.getProviderDetails().getTokenUri()).build() .toUri(); - return new RequestEntity<>(formParameters, headers, HttpMethod.POST, uri); } @@ -71,7 +69,6 @@ public class OAuth2PasswordGrantRequestEntityConverter */ private MultiValueMap buildFormParameters(OAuth2PasswordGrantRequest passwordGrantRequest) { ClientRegistration clientRegistration = passwordGrantRequest.getClientRegistration(); - MultiValueMap formParameters = new LinkedMultiValueMap<>(); formParameters.add(OAuth2ParameterNames.GRANT_TYPE, passwordGrantRequest.getGrantType().getValue()); formParameters.add(OAuth2ParameterNames.USERNAME, passwordGrantRequest.getUsername()); @@ -84,7 +81,6 @@ public class OAuth2PasswordGrantRequestEntityConverter formParameters.add(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId()); formParameters.add(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret()); } - return formParameters; } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverter.java index 3012797171..bd22022bf8 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverter.java @@ -53,12 +53,10 @@ public class OAuth2RefreshTokenGrantRequestEntityConverter @Override public RequestEntity convert(OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest) { ClientRegistration clientRegistration = refreshTokenGrantRequest.getClientRegistration(); - HttpHeaders headers = OAuth2AuthorizationGrantRequestEntityUtils.getTokenRequestHeaders(clientRegistration); MultiValueMap formParameters = buildFormParameters(refreshTokenGrantRequest); URI uri = UriComponentsBuilder.fromUriString(clientRegistration.getProviderDetails().getTokenUri()).build() .toUri(); - return new RequestEntity<>(formParameters, headers, HttpMethod.POST, uri); } @@ -71,7 +69,6 @@ public class OAuth2RefreshTokenGrantRequestEntityConverter */ private MultiValueMap buildFormParameters(OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest) { ClientRegistration clientRegistration = refreshTokenGrantRequest.getClientRegistration(); - MultiValueMap formParameters = new LinkedMultiValueMap<>(); formParameters.add(OAuth2ParameterNames.GRANT_TYPE, refreshTokenGrantRequest.getGrantType().getValue()); formParameters.add(OAuth2ParameterNames.REFRESH_TOKEN, @@ -84,7 +81,6 @@ public class OAuth2RefreshTokenGrantRequestEntityConverter formParameters.add(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId()); formParameters.add(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret()); } - return formParameters; } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClient.java index 6fb61850ff..ee09608f20 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClient.java @@ -68,12 +68,10 @@ public final class WebClientReactiveRefreshTokenTokenResponseClient @Override OAuth2AccessTokenResponse populateTokenResponse(OAuth2RefreshTokenGrantRequest grantRequest, OAuth2AccessTokenResponse accessTokenResponse) { - if (!CollectionUtils.isEmpty(accessTokenResponse.getAccessToken().getScopes()) && accessTokenResponse.getRefreshToken() != null) { return accessTokenResponse; } - OAuth2AccessTokenResponse.Builder tokenResponseBuilder = OAuth2AccessTokenResponse .withResponse(accessTokenResponse); if (CollectionUtils.isEmpty(accessTokenResponse.getAccessToken().getScopes())) { diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/http/OAuth2ErrorResponseErrorHandler.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/http/OAuth2ErrorResponseErrorHandler.java index 95464eeb74..2b50b967ac 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/http/OAuth2ErrorResponseErrorHandler.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/http/OAuth2ErrorResponseErrorHandler.java @@ -55,14 +55,12 @@ public class OAuth2ErrorResponseErrorHandler implements ResponseErrorHandler { if (!HttpStatus.BAD_REQUEST.equals(response.getStatusCode())) { this.defaultErrorHandler.handleError(response); } - // A Bearer Token Error may be in the WWW-Authenticate response header // See https://tools.ietf.org/html/rfc6750#section-3 OAuth2Error oauth2Error = this.readErrorFromWwwAuthenticate(response.getHeaders()); if (oauth2Error == null) { oauth2Error = this.oauth2ErrorConverter.read(OAuth2Error.class, response); } - throw new OAuth2AuthorizationException(oauth2Error); } @@ -71,21 +69,21 @@ public class OAuth2ErrorResponseErrorHandler implements ResponseErrorHandler { if (!StringUtils.hasText(wwwAuthenticateHeader)) { return null; } - - BearerTokenError bearerTokenError; - try { - bearerTokenError = BearerTokenError.parse(wwwAuthenticateHeader); - } - catch (Exception ex) { - return null; - } - + BearerTokenError bearerTokenError = getBearerToken(wwwAuthenticateHeader); String errorCode = (bearerTokenError.getCode() != null) ? bearerTokenError.getCode() : OAuth2ErrorCodes.SERVER_ERROR; String errorDescription = bearerTokenError.getDescription(); String errorUri = (bearerTokenError.getURI() != null) ? bearerTokenError.getURI().toString() : null; - return new OAuth2Error(errorCode, errorDescription, errorUri); } + private BearerTokenError getBearerToken(String wwwAuthenticateHeader) { + try { + return BearerTokenError.parse(wwwAuthenticateHeader); + } + catch (Exception ex) { + return null; + } + } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/ClientRegistrationDeserializer.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/ClientRegistrationDeserializer.java index 9b8faa2ec5..d8cfde2efc 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/ClientRegistrationDeserializer.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/ClientRegistrationDeserializer.java @@ -52,7 +52,6 @@ final class ClientRegistrationDeserializer extends JsonDeserializer> SET_TYPE_REFERENCE = new TypeReference>() { + static final TypeReference> STRING_SET = new TypeReference>() { }; - static final TypeReference> MAP_TYPE_REFERENCE = new TypeReference>() { + + static final TypeReference> STRING_OBJECT_MAP = new TypeReference>() { }; static String findStringValue(JsonNode jsonNode, String fieldName) { if (jsonNode == null) { return null; } - JsonNode nodeValue = jsonNode.findValue(fieldName); - if (nodeValue != null && nodeValue.isTextual()) { - return nodeValue.asText(); - } - return null; + JsonNode value = jsonNode.findValue(fieldName); + return (value != null && value.isTextual()) ? value.asText() : null; } static T findValue(JsonNode jsonNode, String fieldName, TypeReference valueTypeReference, @@ -52,22 +50,16 @@ abstract class JsonNodeUtils { if (jsonNode == null) { return null; } - JsonNode nodeValue = jsonNode.findValue(fieldName); - if (nodeValue != null && nodeValue.isContainerNode()) { - return mapper.convertValue(nodeValue, valueTypeReference); - } - return null; + JsonNode value = jsonNode.findValue(fieldName); + return (value != null && value.isContainerNode()) ? mapper.convertValue(value, valueTypeReference) : null; } static JsonNode findObjectNode(JsonNode jsonNode, String fieldName) { if (jsonNode == null) { return null; } - JsonNode nodeValue = jsonNode.findValue(fieldName); - if (nodeValue != null && nodeValue.isObject()) { - return nodeValue; - } - return null; + JsonNode value = jsonNode.findValue(fieldName); + return (value != null && value.isObject()) ? value : null; } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthorizationRequestDeserializer.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthorizationRequestDeserializer.java index 01a31eb502..00e717bb8f 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthorizationRequestDeserializer.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthorizationRequestDeserializer.java @@ -28,6 +28,7 @@ import com.fasterxml.jackson.databind.util.StdConverter; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest.Builder; /** * A {@code JsonDeserializer} for {@link OAuth2AuthorizationRequest}. @@ -45,35 +46,36 @@ final class OAuth2AuthorizationRequestDeserializer extends JsonDeserializer accessTokenResponseClient, OAuth2UserService userService) { - Assert.notNull(accessTokenResponseClient, "accessTokenResponseClient cannot be null"); Assert.notNull(userService, "userService cannot be null"); this.accessTokenResponseClient = accessTokenResponseClient; @@ -120,7 +119,6 @@ public class OidcAuthorizationCodeAuthenticationProvider implements Authenticati @Override public Authentication authenticate(Authentication authentication) throws AuthenticationException { OAuth2LoginAuthenticationToken authorizationCodeAuthentication = (OAuth2LoginAuthenticationToken) authentication; - // Section 3.1.2.1 Authentication Request - // https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest // scope @@ -131,35 +129,20 @@ public class OidcAuthorizationCodeAuthenticationProvider implements Authenticati // and let OAuth2LoginAuthenticationProvider handle it instead return null; } - OAuth2AuthorizationRequest authorizationRequest = authorizationCodeAuthentication.getAuthorizationExchange() .getAuthorizationRequest(); OAuth2AuthorizationResponse authorizationResponse = authorizationCodeAuthentication.getAuthorizationExchange() .getAuthorizationResponse(); - if (authorizationResponse.statusError()) { throw new OAuth2AuthenticationException(authorizationResponse.getError(), authorizationResponse.getError().toString()); } - if (!authorizationResponse.getState().equals(authorizationRequest.getState())) { OAuth2Error oauth2Error = new OAuth2Error(INVALID_STATE_PARAMETER_ERROR_CODE); throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); } - - OAuth2AccessTokenResponse accessTokenResponse; - try { - accessTokenResponse = this.accessTokenResponseClient.getTokenResponse( - new OAuth2AuthorizationCodeGrantRequest(authorizationCodeAuthentication.getClientRegistration(), - authorizationCodeAuthentication.getAuthorizationExchange())); - } - catch (OAuth2AuthorizationException ex) { - OAuth2Error oauth2Error = ex.getError(); - throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); - } - + OAuth2AccessTokenResponse accessTokenResponse = getResponse(authorizationCodeAuthentication); ClientRegistration clientRegistration = authorizationCodeAuthentication.getClientRegistration(); - Map additionalParameters = accessTokenResponse.getAdditionalParameters(); if (!additionalParameters.containsKey(OidcParameterNames.ID_TOKEN)) { OAuth2Error invalidIdTokenError = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, @@ -169,39 +152,54 @@ public class OidcAuthorizationCodeAuthenticationProvider implements Authenticati throw new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString()); } OidcIdToken idToken = createOidcToken(clientRegistration, accessTokenResponse); - - // Validate nonce - String requestNonce = authorizationRequest.getAttribute(OidcParameterNames.NONCE); - if (requestNonce != null) { - String nonceHash; - try { - nonceHash = createHash(requestNonce); - } - catch (NoSuchAlgorithmException ex) { - OAuth2Error oauth2Error = new OAuth2Error(INVALID_NONCE_ERROR_CODE); - throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); - } - String nonceHashClaim = idToken.getNonce(); - if (nonceHashClaim == null || !nonceHashClaim.equals(nonceHash)) { - OAuth2Error oauth2Error = new OAuth2Error(INVALID_NONCE_ERROR_CODE); - throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); - } - } - + validateNonce(authorizationRequest, idToken); OidcUser oidcUser = this.userService.loadUser(new OidcUserRequest(clientRegistration, accessTokenResponse.getAccessToken(), idToken, additionalParameters)); Collection mappedAuthorities = this.authoritiesMapper .mapAuthorities(oidcUser.getAuthorities()); - OAuth2LoginAuthenticationToken authenticationResult = new OAuth2LoginAuthenticationToken( authorizationCodeAuthentication.getClientRegistration(), authorizationCodeAuthentication.getAuthorizationExchange(), oidcUser, mappedAuthorities, accessTokenResponse.getAccessToken(), accessTokenResponse.getRefreshToken()); authenticationResult.setDetails(authorizationCodeAuthentication.getDetails()); - return authenticationResult; } + private OAuth2AccessTokenResponse getResponse(OAuth2LoginAuthenticationToken authorizationCodeAuthentication) { + try { + return this.accessTokenResponseClient.getTokenResponse( + new OAuth2AuthorizationCodeGrantRequest(authorizationCodeAuthentication.getClientRegistration(), + authorizationCodeAuthentication.getAuthorizationExchange())); + } + catch (OAuth2AuthorizationException ex) { + OAuth2Error oauth2Error = ex.getError(); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); + } + } + + private void validateNonce(OAuth2AuthorizationRequest authorizationRequest, OidcIdToken idToken) { + String requestNonce = authorizationRequest.getAttribute(OidcParameterNames.NONCE); + if (requestNonce == null) { + return; + } + String nonceHash = getNonceHash(requestNonce); + String nonceHashClaim = idToken.getNonce(); + if (nonceHashClaim == null || !nonceHashClaim.equals(nonceHash)) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_NONCE_ERROR_CODE); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); + } + } + + private String getNonceHash(String requestNonce) { + try { + return createHash(requestNonce); + } + catch (NoSuchAlgorithmException ex) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_NONCE_ERROR_CODE); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); + } + } + /** * Sets the {@link JwtDecoderFactory} used for {@link OidcIdToken} signature * verification. The factory returns a {@link JwtDecoder} associated to the provided @@ -235,18 +233,21 @@ public class OidcAuthorizationCodeAuthenticationProvider implements Authenticati private OidcIdToken createOidcToken(ClientRegistration clientRegistration, OAuth2AccessTokenResponse accessTokenResponse) { JwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(clientRegistration); - Jwt jwt; + Jwt jwt = getJwt(accessTokenResponse, jwtDecoder); + OidcIdToken idToken = new OidcIdToken(jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), + jwt.getClaims()); + return idToken; + } + + private Jwt getJwt(OAuth2AccessTokenResponse accessTokenResponse, JwtDecoder jwtDecoder) { try { - jwt = jwtDecoder - .decode((String) accessTokenResponse.getAdditionalParameters().get(OidcParameterNames.ID_TOKEN)); + Map parameters = accessTokenResponse.getAdditionalParameters(); + return jwtDecoder.decode((String) parameters.get(OidcParameterNames.ID_TOKEN)); } catch (JwtException ex) { OAuth2Error invalidIdTokenError = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, ex.getMessage(), null); throw new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString(), ex); } - OidcIdToken idToken = new OidcIdToken(jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), - jwt.getClaims()); - return idToken; } static String createHash(String nonce) throws NoSuchAlgorithmException { diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManager.java index 02f5fdfe97..a328b08b13 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManager.java @@ -114,7 +114,6 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements React public Mono authenticate(Authentication authentication) { return Mono.defer(() -> { OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication = (OAuth2AuthorizationCodeAuthenticationToken) authentication; - // Section 3.1.2.1 Authentication Request - // https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest // scope REQUIRED. OpenID Connect requests MUST contain the "openid" scope @@ -125,26 +124,21 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements React // and let OAuth2LoginReactiveAuthenticationManager handle it instead return Mono.empty(); } - OAuth2AuthorizationRequest authorizationRequest = authorizationCodeAuthentication.getAuthorizationExchange() .getAuthorizationRequest(); OAuth2AuthorizationResponse authorizationResponse = authorizationCodeAuthentication .getAuthorizationExchange().getAuthorizationResponse(); - if (authorizationResponse.statusError()) { return Mono.error(new OAuth2AuthenticationException(authorizationResponse.getError(), authorizationResponse.getError().toString())); } - if (!authorizationResponse.getState().equals(authorizationRequest.getState())) { OAuth2Error oauth2Error = new OAuth2Error(INVALID_STATE_PARAMETER_ERROR_CODE); return Mono.error(new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString())); } - OAuth2AuthorizationCodeGrantRequest authzRequest = new OAuth2AuthorizationCodeGrantRequest( authorizationCodeAuthentication.getClientRegistration(), authorizationCodeAuthentication.getAuthorizationExchange()); - return this.accessTokenResponseClient.getTokenResponse(authzRequest).flatMap( (accessTokenResponse) -> authenticationResult(authorizationCodeAuthentication, accessTokenResponse)) .onErrorMap(OAuth2AuthorizationException.class, @@ -190,7 +184,6 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements React OAuth2AccessToken accessToken = accessTokenResponse.getAccessToken(); ClientRegistration clientRegistration = authorizationCodeAuthentication.getClientRegistration(); Map additionalParameters = accessTokenResponse.getAdditionalParameters(); - if (!additionalParameters.containsKey(OidcParameterNames.ID_TOKEN)) { OAuth2Error invalidIdTokenError = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, "Missing (required) ID Token in Token Response for Client Registration: " @@ -198,14 +191,12 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements React null); return Mono.error(new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString())); } - return createOidcToken(clientRegistration, accessTokenResponse) .doOnNext((idToken) -> validateNonce(authorizationCodeAuthentication, idToken)) .map((idToken) -> new OidcUserRequest(clientRegistration, accessToken, idToken, additionalParameters)) .flatMap(this.userService::loadUser).map((oauth2User) -> { Collection mappedAuthorities = this.authoritiesMapper .mapAuthorities(oauth2User.getAuthorities()); - return new OAuth2LoginAuthenticationToken(authorizationCodeAuthentication.getClientRegistration(), authorizationCodeAuthentication.getAuthorizationExchange(), oauth2User, mappedAuthorities, accessToken, accessTokenResponse.getRefreshToken()); @@ -225,14 +216,7 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements React String requestNonce = authorizationCodeAuthentication.getAuthorizationExchange().getAuthorizationRequest() .getAttribute(OidcParameterNames.NONCE); if (requestNonce != null) { - String nonceHash; - try { - nonceHash = createHash(requestNonce); - } - catch (NoSuchAlgorithmException ex) { - OAuth2Error oauth2Error = new OAuth2Error(INVALID_NONCE_ERROR_CODE); - throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); - } + String nonceHash = getNonceHash(requestNonce); String nonceHashClaim = idToken.getNonce(); if (nonceHashClaim == null || !nonceHashClaim.equals(nonceHash)) { OAuth2Error oauth2Error = new OAuth2Error(INVALID_NONCE_ERROR_CODE); @@ -243,6 +227,16 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements React return Mono.just(idToken); } + private static String getNonceHash(String requestNonce) { + try { + return createHash(requestNonce); + } + catch (NoSuchAlgorithmException ex) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_NONCE_ERROR_CODE); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); + } + } + static String createHash(String nonce) throws NoSuchAlgorithmException { MessageDigest md = MessageDigest.getInstance("SHA-256"); byte[] digest = md.digest(nonce.getBytes(StandardCharsets.US_ASCII)); diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactory.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactory.java index 5eedbdc2b9..a349ddbe0e 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactory.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactory.java @@ -20,6 +20,7 @@ import java.net.URL; import java.nio.charset.StandardCharsets; import java.time.Instant; import java.util.Collection; +import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -66,15 +67,16 @@ public final class OidcIdTokenDecoderFactory implements JwtDecoderFactory jcaAlgorithmMappings = new HashMap() { - { - put(MacAlgorithm.HS256, "HmacSHA256"); - put(MacAlgorithm.HS384, "HmacSHA384"); - put(MacAlgorithm.HS512, "HmacSHA512"); - } + private static final Map JCA_ALGORITHM_MAPPINGS; + static { + Map mappings = new HashMap<>(); + mappings.put(MacAlgorithm.HS256, "HmacSHA256"); + mappings.put(MacAlgorithm.HS384, "HmacSHA384"); + mappings.put(MacAlgorithm.HS512, "HmacSHA512"); + JCA_ALGORITHM_MAPPINGS = Collections.unmodifiableMap(mappings); }; - private static final Converter, Map> DEFAULT_CLAIM_TYPE_CONVERTER = new ClaimTypeConverter( + private static final ClaimTypeConverter DEFAULT_CLAIM_TYPE_CONVERTER = new ClaimTypeConverter( createDefaultClaimTypeConverters()); private final Map jwtDecoders = new ConcurrentHashMap<>(); @@ -100,23 +102,22 @@ public final class OidcIdTokenDecoderFactory implements JwtDecoderFactory stringConverter = getConverter(TypeDescriptor.valueOf(String.class)); Converter collectionStringConverter = getConverter( TypeDescriptor.collection(Collection.class, TypeDescriptor.valueOf(String.class))); - - Map> claimTypeConverters = new HashMap<>(); - claimTypeConverters.put(IdTokenClaimNames.ISS, urlConverter); - claimTypeConverters.put(IdTokenClaimNames.AUD, collectionStringConverter); - claimTypeConverters.put(IdTokenClaimNames.NONCE, stringConverter); - claimTypeConverters.put(IdTokenClaimNames.EXP, instantConverter); - claimTypeConverters.put(IdTokenClaimNames.IAT, instantConverter); - claimTypeConverters.put(IdTokenClaimNames.AUTH_TIME, instantConverter); - claimTypeConverters.put(IdTokenClaimNames.AMR, collectionStringConverter); - claimTypeConverters.put(StandardClaimNames.EMAIL_VERIFIED, booleanConverter); - claimTypeConverters.put(StandardClaimNames.PHONE_NUMBER_VERIFIED, booleanConverter); - claimTypeConverters.put(StandardClaimNames.UPDATED_AT, instantConverter); - return claimTypeConverters; + Map> converters = new HashMap<>(); + converters.put(IdTokenClaimNames.ISS, urlConverter); + converters.put(IdTokenClaimNames.AUD, collectionStringConverter); + converters.put(IdTokenClaimNames.NONCE, stringConverter); + converters.put(IdTokenClaimNames.EXP, instantConverter); + converters.put(IdTokenClaimNames.IAT, instantConverter); + converters.put(IdTokenClaimNames.AUTH_TIME, instantConverter); + converters.put(IdTokenClaimNames.AMR, collectionStringConverter); + converters.put(StandardClaimNames.EMAIL_VERIFIED, booleanConverter); + converters.put(StandardClaimNames.PHONE_NUMBER_VERIFIED, booleanConverter); + converters.put(StandardClaimNames.UPDATED_AT, instantConverter); + return converters; } private static Converter getConverter(TypeDescriptor targetDescriptor) { - final TypeDescriptor sourceDescriptor = TypeDescriptor.valueOf(Object.class); + TypeDescriptor sourceDescriptor = TypeDescriptor.valueOf(Object.class); return (source) -> ClaimConversionService.getSharedInstance().convert(source, sourceDescriptor, targetDescriptor); } @@ -165,7 +166,7 @@ public final class OidcIdTokenDecoderFactory implements JwtDecoderFactory { public OAuth2TokenValidatorResult validate(Jwt idToken) { // 3.1.3.7 ID Token Validation // https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation - Map invalidClaims = validateRequiredClaims(idToken); if (!invalidClaims.isEmpty()) { return OAuth2TokenValidatorResult.failure(invalidIdToken(invalidClaims)); } - // 2. The Issuer Identifier for the OpenID Provider (which is typically obtained // during Discovery) // MUST exactly match the value of the iss (issuer) Claim. String metadataIssuer = this.clientRegistration.getProviderDetails().getIssuerUri(); - if (metadataIssuer != null && !Objects.equals(metadataIssuer, idToken.getIssuer().toExternalForm())) { invalidClaims.put(IdTokenClaimNames.ISS, idToken.getIssuer()); } - // 3. The Client MUST validate that the aud (audience) Claim contains its // client_id value // registered at the Issuer identified by the iss (issuer) Claim as an audience. @@ -93,31 +89,26 @@ public final class OidcIdTokenValidator implements OAuth2TokenValidator { if (!idToken.getAudience().contains(this.clientRegistration.getClientId())) { invalidClaims.put(IdTokenClaimNames.AUD, idToken.getAudience()); } - // 4. If the ID Token contains multiple audiences, // the Client SHOULD verify that an azp Claim is present. String authorizedParty = idToken.getClaimAsString(IdTokenClaimNames.AZP); if (idToken.getAudience().size() > 1 && authorizedParty == null) { invalidClaims.put(IdTokenClaimNames.AZP, authorizedParty); } - // 5. If an azp (authorized party) Claim is present, // the Client SHOULD verify that its client_id is the Claim Value. if (authorizedParty != null && !authorizedParty.equals(this.clientRegistration.getClientId())) { invalidClaims.put(IdTokenClaimNames.AZP, authorizedParty); } - // 7. The alg value SHOULD be the default of RS256 or the algorithm sent by the // Client // in the id_token_signed_response_alg parameter during Registration. // TODO Depends on gh-4413 - // 9. The current time MUST be before the time represented by the exp Claim. Instant now = Instant.now(this.clock); if (now.minus(this.clockSkew).isAfter(idToken.getExpiresAt())) { invalidClaims.put(IdTokenClaimNames.EXP, idToken.getExpiresAt()); } - // 10. The iat Claim can be used to reject tokens that were issued too far away // from the current time, // limiting the amount of time that nonces need to be stored to prevent attacks. @@ -125,11 +116,9 @@ public final class OidcIdTokenValidator implements OAuth2TokenValidator { if (now.plus(this.clockSkew).isBefore(idToken.getIssuedAt())) { invalidClaims.put(IdTokenClaimNames.IAT, idToken.getIssuedAt()); } - if (!invalidClaims.isEmpty()) { return OAuth2TokenValidatorResult.failure(invalidIdToken(invalidClaims)); } - return OAuth2TokenValidatorResult.success(); } @@ -164,7 +153,6 @@ public final class OidcIdTokenValidator implements OAuth2TokenValidator { private static Map validateRequiredClaims(Jwt idToken) { Map requiredClaims = new HashMap<>(); - URL issuer = idToken.getIssuer(); if (issuer == null) { requiredClaims.put(IdTokenClaimNames.ISS, issuer); @@ -185,7 +173,6 @@ public final class OidcIdTokenValidator implements OAuth2TokenValidator { if (issuedAt == null) { requiredClaims.put(IdTokenClaimNames.IAT, issuedAt); } - return requiredClaims; } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactory.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactory.java index a231bb0d57..ec4b0bcbfa 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactory.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactory.java @@ -20,6 +20,7 @@ import java.net.URL; import java.nio.charset.StandardCharsets; import java.time.Instant; import java.util.Collection; +import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -66,15 +67,16 @@ public final class ReactiveOidcIdTokenDecoderFactory implements ReactiveJwtDecod private static final String MISSING_SIGNATURE_VERIFIER_ERROR_CODE = "missing_signature_verifier"; - private static Map jcaAlgorithmMappings = new HashMap() { - { - put(MacAlgorithm.HS256, "HmacSHA256"); - put(MacAlgorithm.HS384, "HmacSHA384"); - put(MacAlgorithm.HS512, "HmacSHA512"); - } - }; + private static final Map JCA_ALGORITHM_MAPPINGS; + static { + Map mappings = new HashMap(); + mappings.put(MacAlgorithm.HS256, "HmacSHA256"); + mappings.put(MacAlgorithm.HS384, "HmacSHA384"); + mappings.put(MacAlgorithm.HS512, "HmacSHA512"); + JCA_ALGORITHM_MAPPINGS = Collections.unmodifiableMap(mappings); + } - private static final Converter, Map> DEFAULT_CLAIM_TYPE_CONVERTER = new ClaimTypeConverter( + private static final ClaimTypeConverter DEFAULT_CLAIM_TYPE_CONVERTER = new ClaimTypeConverter( createDefaultClaimTypeConverters()); private final Map jwtDecoders = new ConcurrentHashMap<>(); @@ -100,19 +102,18 @@ public final class ReactiveOidcIdTokenDecoderFactory implements ReactiveJwtDecod Converter stringConverter = getConverter(TypeDescriptor.valueOf(String.class)); Converter collectionStringConverter = getConverter( TypeDescriptor.collection(Collection.class, TypeDescriptor.valueOf(String.class))); - - Map> claimTypeConverters = new HashMap<>(); - claimTypeConverters.put(IdTokenClaimNames.ISS, urlConverter); - claimTypeConverters.put(IdTokenClaimNames.AUD, collectionStringConverter); - claimTypeConverters.put(IdTokenClaimNames.NONCE, stringConverter); - claimTypeConverters.put(IdTokenClaimNames.EXP, instantConverter); - claimTypeConverters.put(IdTokenClaimNames.IAT, instantConverter); - claimTypeConverters.put(IdTokenClaimNames.AUTH_TIME, instantConverter); - claimTypeConverters.put(IdTokenClaimNames.AMR, collectionStringConverter); - claimTypeConverters.put(StandardClaimNames.EMAIL_VERIFIED, booleanConverter); - claimTypeConverters.put(StandardClaimNames.PHONE_NUMBER_VERIFIED, booleanConverter); - claimTypeConverters.put(StandardClaimNames.UPDATED_AT, instantConverter); - return claimTypeConverters; + Map> converters = new HashMap<>(); + converters.put(IdTokenClaimNames.ISS, urlConverter); + converters.put(IdTokenClaimNames.AUD, collectionStringConverter); + converters.put(IdTokenClaimNames.NONCE, stringConverter); + converters.put(IdTokenClaimNames.EXP, instantConverter); + converters.put(IdTokenClaimNames.IAT, instantConverter); + converters.put(IdTokenClaimNames.AUTH_TIME, instantConverter); + converters.put(IdTokenClaimNames.AMR, collectionStringConverter); + converters.put(StandardClaimNames.EMAIL_VERIFIED, booleanConverter); + converters.put(StandardClaimNames.PHONE_NUMBER_VERIFIED, booleanConverter); + converters.put(StandardClaimNames.UPDATED_AT, instantConverter); + return converters; } private static Converter getConverter(TypeDescriptor targetDescriptor) { @@ -153,7 +154,6 @@ public final class ReactiveOidcIdTokenDecoderFactory implements ReactiveJwtDecod // 7. The alg value SHOULD be the default of RS256 or the algorithm sent by // the Client // in the id_token_signed_response_alg parameter during Registration. - String jwkSetUri = clientRegistration.getProviderDetails().getJwkSetUri(); if (!StringUtils.hasText(jwkSetUri)) { OAuth2Error oauth2Error = new OAuth2Error(MISSING_SIGNATURE_VERIFIER_ERROR_CODE, @@ -166,7 +166,7 @@ public final class ReactiveOidcIdTokenDecoderFactory implements ReactiveJwtDecod return NimbusReactiveJwtDecoder.withJwkSetUri(jwkSetUri).jwsAlgorithm((SignatureAlgorithm) jwsAlgorithm) .build(); } - else if (jwsAlgorithm != null && MacAlgorithm.class.isAssignableFrom(jwsAlgorithm.getClass())) { + if (jwsAlgorithm != null && MacAlgorithm.class.isAssignableFrom(jwsAlgorithm.getClass())) { // https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation // // 8. If the JWT alg Header Parameter uses a MAC based algorithm such as @@ -188,11 +188,10 @@ public final class ReactiveOidcIdTokenDecoderFactory implements ReactiveJwtDecod throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); } SecretKeySpec secretKeySpec = new SecretKeySpec(clientSecret.getBytes(StandardCharsets.UTF_8), - jcaAlgorithmMappings.get(jwsAlgorithm)); + JCA_ALGORITHM_MAPPINGS.get(jwsAlgorithm)); return NimbusReactiveJwtDecoder.withSecretKey(secretKeySpec).macAlgorithm((MacAlgorithm) jwsAlgorithm) .build(); } - OAuth2Error oauth2Error = new OAuth2Error(MISSING_SIGNATURE_VERIFIER_ERROR_CODE, "Failed to find a Signature Verifier for Client Registration: '" + clientRegistration.getRegistrationId() diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserService.java index fced8acd0a..28a7f8c92c 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserService.java @@ -81,7 +81,6 @@ public class OidcReactiveOAuth2UserService implements ReactiveOAuth2UserService< public static Map> createDefaultClaimTypeConverters() { Converter booleanConverter = getConverter(TypeDescriptor.valueOf(Boolean.class)); Converter instantConverter = getConverter(TypeDescriptor.valueOf(Instant.class)); - Map> claimTypeConverters = new HashMap<>(); claimTypeConverters.put(StandardClaimNames.EMAIL_VERIFIED, booleanConverter); claimTypeConverters.put(StandardClaimNames.PHONE_NUMBER_VERIFIED, booleanConverter); @@ -113,9 +112,7 @@ public class OidcReactiveOAuth2UserService implements ReactiveOAuth2UserService< return new DefaultOidcUser(authorities, userRequest.getIdToken(), userInfo, userNameAttributeName); } - else { - return new DefaultOidcUser(authorities, userRequest.getIdToken(), userInfo); - } + return new DefaultOidcUser(authorities, userRequest.getIdToken(), userInfo); }); } @@ -123,7 +120,6 @@ public class OidcReactiveOAuth2UserService implements ReactiveOAuth2UserService< if (!OidcUserRequestUtils.shouldRetrieveUserInfo(userRequest)) { return Mono.empty(); } - return this.oauth2UserService.loadUser(userRequest).map(OAuth2User::getAttributes) .map((claims) -> convertClaims(claims, userRequest.getClientRegistration())).map(OidcUserInfo::new) .doOnNext((userInfo) -> { diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequest.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequest.java index d7ef47081b..19f78658d6 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequest.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequest.java @@ -47,7 +47,6 @@ public class OidcUserRequest extends OAuth2UserRequest { * @param idToken the ID Token */ public OidcUserRequest(ClientRegistration clientRegistration, OAuth2AccessToken accessToken, OidcIdToken idToken) { - this(clientRegistration, accessToken, idToken, Collections.emptyMap()); } @@ -61,7 +60,6 @@ public class OidcUserRequest extends OAuth2UserRequest { */ public OidcUserRequest(ClientRegistration clientRegistration, OAuth2AccessToken accessToken, OidcIdToken idToken, Map additionalParameters) { - super(clientRegistration, accessToken, additionalParameters); Assert.notNull(idToken, "idToken cannot be null"); this.idToken = idToken; diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequestUtils.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequestUtils.java index 6045a44fa9..e8e6362479 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequestUtils.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequestUtils.java @@ -46,10 +46,8 @@ final class OidcUserRequestUtils { // Auto-disabled if UserInfo Endpoint URI is not provided ClientRegistration clientRegistration = userRequest.getClientRegistration(); if (StringUtils.isEmpty(clientRegistration.getProviderDetails().getUserInfoEndpoint().getUri())) { - return false; } - // The Claims requested by the profile, email, address, and phone scope values // are returned from the UserInfo Endpoint (as described in Section 5.3.2), // when a response_type value is used that results in an Access Token being @@ -60,13 +58,11 @@ final class OidcUserRequestUtils { // The Authorization Code Grant Flow, which is response_type=code, results in an // Access Token being issued. if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(clientRegistration.getAuthorizationGrantType())) { - // Return true if there is at least one match between the authorized scope(s) // and UserInfo scope(s) return CollectionUtils.containsAny(userRequest.getAccessToken().getScopes(), userRequest.getClientRegistration().getScopes()); } - return false; } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java index 6f83c525fa..31f181213a 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java @@ -30,6 +30,7 @@ import org.springframework.core.convert.converter.Converter; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ClientRegistration.ProviderDetails; import org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService; import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest; import org.springframework.security.oauth2.client.userinfo.OAuth2UserService; @@ -87,7 +88,6 @@ public class OidcUserService implements OAuth2UserService> createDefaultClaimTypeConverters() { Converter booleanConverter = getConverter(TypeDescriptor.valueOf(Boolean.class)); Converter instantConverter = getConverter(TypeDescriptor.valueOf(Instant.class)); - Map> claimTypeConverters = new HashMap<>(); claimTypeConverters.put(StandardClaimNames.EMAIL_VERIFIED, booleanConverter); claimTypeConverters.put(StandardClaimNames.PHONE_NUMBER_VERIFIED, booleanConverter); @@ -96,7 +96,7 @@ public class OidcUserService implements OAuth2UserService getConverter(TypeDescriptor targetDescriptor) { - final TypeDescriptor sourceDescriptor = TypeDescriptor.valueOf(Object.class); + TypeDescriptor sourceDescriptor = TypeDescriptor.valueOf(Object.class); return (source) -> ClaimConversionService.getSharedInstance().convert(source, sourceDescriptor, targetDescriptor); } @@ -107,26 +107,14 @@ public class OidcUserService implements OAuth2UserService claims; - Converter, Map> claimTypeConverter = this.claimTypeConverterFactory - .apply(userRequest.getClientRegistration()); - if (claimTypeConverter != null) { - claims = claimTypeConverter.convert(oauth2User.getAttributes()); - } - else { - claims = DEFAULT_CLAIM_TYPE_CONVERTER.convert(oauth2User.getAttributes()); - } + Map claims = getClaims(userRequest, oauth2User); userInfo = new OidcUserInfo(claims); - // https://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse - // 1) The sub (subject) Claim MUST always be returned in the UserInfo Response if (userInfo.getSubject() == null) { OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE); throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); } - // 2) Due to the possibility of token substitution attacks (see Section // 16.11), // the UserInfo Response is not guaranteed to be about the End-User @@ -139,36 +127,39 @@ public class OidcUserService implements OAuth2UserService authorities = new LinkedHashSet<>(); authorities.add(new OidcUserAuthority(userRequest.getIdToken(), userInfo)); OAuth2AccessToken token = userRequest.getAccessToken(); for (String authority : token.getScopes()) { authorities.add(new SimpleGrantedAuthority("SCOPE_" + authority)); } + return getUser(userRequest, userInfo, authorities); + } - OidcUser user; + private Map getClaims(OidcUserRequest userRequest, OAuth2User oauth2User) { + Converter, Map> converter = this.claimTypeConverterFactory + .apply(userRequest.getClientRegistration()); + if (converter != null) { + return converter.convert(oauth2User.getAttributes()); + } + return DEFAULT_CLAIM_TYPE_CONVERTER.convert(oauth2User.getAttributes()); + } - String userNameAttributeName = userRequest.getClientRegistration().getProviderDetails().getUserInfoEndpoint() - .getUserNameAttributeName(); + private OidcUser getUser(OidcUserRequest userRequest, OidcUserInfo userInfo, Set authorities) { + ProviderDetails providerDetails = userRequest.getClientRegistration().getProviderDetails(); + String userNameAttributeName = providerDetails.getUserInfoEndpoint().getUserNameAttributeName(); if (StringUtils.hasText(userNameAttributeName)) { - user = new DefaultOidcUser(authorities, userRequest.getIdToken(), userInfo, userNameAttributeName); + return new DefaultOidcUser(authorities, userRequest.getIdToken(), userInfo, userNameAttributeName); } - else { - user = new DefaultOidcUser(authorities, userRequest.getIdToken(), userInfo); - } - - return user; + return new DefaultOidcUser(authorities, userRequest.getIdToken(), userInfo); } private boolean shouldRetrieveUserInfo(OidcUserRequest userRequest) { // Auto-disabled if UserInfo Endpoint URI is not provided - if (StringUtils - .isEmpty(userRequest.getClientRegistration().getProviderDetails().getUserInfoEndpoint().getUri())) { - + ProviderDetails providerDetails = userRequest.getClientRegistration().getProviderDetails(); + if (StringUtils.isEmpty(providerDetails.getUserInfoEndpoint().getUri())) { return false; } - // The Claims requested by the profile, email, address, and phone scope values // are returned from the UserInfo Endpoint (as described in Section 5.3.2), // when a response_type value is used that results in an Access Token being @@ -180,13 +171,11 @@ public class OidcUserService implements OAuth2UserService onLogoutSuccess(WebFilterExchange exchange, Authentication authentication) { return Mono.just(authentication).filter(OAuth2AuthenticationToken.class::isInstance) @@ -95,16 +91,14 @@ public class OidcClientInitiatedServerLogoutSuccessHandler implements ServerLogo } private URI endSessionEndpoint(ClientRegistration clientRegistration) { - URI result = null; if (clientRegistration != null) { Object endSessionEndpoint = clientRegistration.getProviderDetails().getConfigurationMetadata() .get("end_session_endpoint"); if (endSessionEndpoint != null) { - result = URI.create(endSessionEndpoint.toString()); + return URI.create(endSessionEndpoint.toString()); } } - - return result; + return null; } private URI endpointUri(URI endSessionEndpoint, String idToken, URI postLogoutRedirectUri) { diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistration.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistration.java index 353049076a..10815a03f4 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistration.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistration.java @@ -620,26 +620,29 @@ public final class ClientRegistration implements Serializable { private ClientRegistration create() { ClientRegistration clientRegistration = new ClientRegistration(); - clientRegistration.registrationId = this.registrationId; clientRegistration.clientId = this.clientId; clientRegistration.clientSecret = StringUtils.hasText(this.clientSecret) ? this.clientSecret : ""; - if (this.clientAuthenticationMethod != null) { - clientRegistration.clientAuthenticationMethod = this.clientAuthenticationMethod; - } - else { - if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(this.authorizationGrantType) - && !StringUtils.hasText(this.clientSecret)) { - clientRegistration.clientAuthenticationMethod = ClientAuthenticationMethod.NONE; - } - else { - clientRegistration.clientAuthenticationMethod = ClientAuthenticationMethod.BASIC; - } - } + clientRegistration.clientAuthenticationMethod = (this.clientAuthenticationMethod != null) + ? this.clientAuthenticationMethod : deduceClientAuthenticationMethod(clientRegistration); clientRegistration.authorizationGrantType = this.authorizationGrantType; clientRegistration.redirectUri = this.redirectUri; clientRegistration.scopes = this.scopes; + clientRegistration.providerDetails = createProviderDetails(clientRegistration); + clientRegistration.clientName = StringUtils.hasText(this.clientName) ? this.clientName + : this.registrationId; + return clientRegistration; + } + private ClientAuthenticationMethod deduceClientAuthenticationMethod(ClientRegistration clientRegistration) { + if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(this.authorizationGrantType) + && !StringUtils.hasText(this.clientSecret)) { + return ClientAuthenticationMethod.NONE; + } + return ClientAuthenticationMethod.BASIC; + } + + private ProviderDetails createProviderDetails(ClientRegistration clientRegistration) { ProviderDetails providerDetails = clientRegistration.new ProviderDetails(); providerDetails.authorizationUri = this.authorizationUri; providerDetails.tokenUri = this.tokenUri; @@ -649,12 +652,7 @@ public final class ClientRegistration implements Serializable { providerDetails.jwkSetUri = this.jwkSetUri; providerDetails.issuerUri = this.issuerUri; providerDetails.configurationMetadata = Collections.unmodifiableMap(this.configurationMetadata); - clientRegistration.providerDetails = providerDetails; - - clientRegistration.clientName = StringUtils.hasText(this.clientName) ? this.clientName - : this.registrationId; - - return clientRegistration; + return providerDetails; } private void validateAuthorizationCodeGrantType() { @@ -696,7 +694,6 @@ public final class ClientRegistration implements Serializable { if (this.scopes == null) { return; } - for (String scope : this.scopes) { Assert.isTrue(validateScope(scope), "scope \"" + scope + "\" contains invalid characters"); } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistrations.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistrations.java index 875c7ca257..95e817f0d3 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistrations.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistrations.java @@ -150,7 +150,6 @@ public final class ClientRegistrations { private static Supplier oidc(URI issuer) { URI uri = UriComponentsBuilder.fromUri(issuer).replacePath(issuer.getPath() + OIDC_METADATA_PATH) .build(Collections.emptyMap()); - return () -> { RequestEntity request = RequestEntity.get(uri).build(); Map configuration = rest.exchange(request, typeReference).getBody(); @@ -182,12 +181,10 @@ public final class ClientRegistrations { Map configuration = rest.exchange(request, typeReference).getBody(); AuthorizationServerMetadata metadata = parse(configuration, AuthorizationServerMetadata::parse); ClientRegistration.Builder builder = withProviderConfiguration(metadata, issuer.toASCIIString()); - URI jwkSetUri = metadata.getJWKSetURI(); if (jwkSetUri != null) { builder.jwkSetUri(jwkSetUri.toASCIIString()); } - String userinfoEndpoint = (String) configuration.get("userinfo_endpoint"); if (userinfoEndpoint != null) { builder.userInfoUri(userinfoEndpoint); @@ -221,7 +218,6 @@ public final class ClientRegistrations { } private static T parse(Map body, ThrowingFunction parser) { - try { return parser.apply(new JSONObject(body)); } @@ -233,25 +229,19 @@ public final class ClientRegistrations { private static ClientRegistration.Builder withProviderConfiguration(AuthorizationServerMetadata metadata, String issuer) { String metadataIssuer = metadata.getIssuer().getValue(); - if (!issuer.equals(metadataIssuer)) { - throw new IllegalStateException( - "The Issuer \"" + metadataIssuer + "\" provided in the configuration metadata did " - + "not match the requested issuer \"" + issuer + "\""); - } - + Assert.state(issuer.equals(metadataIssuer), + () -> "The Issuer \"" + metadataIssuer + "\" provided in the configuration metadata did " + + "not match the requested issuer \"" + issuer + "\""); String name = URI.create(issuer).getHost(); ClientAuthenticationMethod method = getClientAuthenticationMethod(issuer, metadata.getTokenEndpointAuthMethods()); List grantTypes = metadata.getGrantTypes(); // If null, the default includes authorization_code - if (grantTypes != null && !grantTypes.contains(GrantType.AUTHORIZATION_CODE)) { - throw new IllegalArgumentException( - "Only AuthorizationGrantType.AUTHORIZATION_CODE is supported. The issuer \"" + issuer - + "\" returned a configuration of " + grantTypes); - } + Assert.isTrue(grantTypes == null || grantTypes.contains(GrantType.AUTHORIZATION_CODE), + "Only AuthorizationGrantType.AUTHORIZATION_CODE is supported. The issuer \"" + issuer + + "\" returned a configuration of " + grantTypes); List scopes = getScopes(metadata); Map configurationMetadata = new LinkedHashMap<>(metadata.toJSONObject()); - return ClientRegistration.withRegistrationId(name).userNameAttributeName(IdTokenClaimNames.SUB).scope(scopes) .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE).clientAuthenticationMethod(method) .redirectUri("{baseUrl}/{action}/oauth2/code/{registrationId}") @@ -284,9 +274,7 @@ public final class ClientRegistrations { // If null, default to "openid" which must be supported return Collections.singletonList(OidcScopes.OPENID); } - else { - return scope.toStringList(); - } + return scope.toStringList(); } private interface ThrowingFunction { diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/InMemoryClientRegistrationRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/InMemoryClientRegistrationRepository.java index 67f656e963..f0092368f3 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/InMemoryClientRegistrationRepository.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/InMemoryClientRegistrationRepository.java @@ -67,9 +67,8 @@ public final class InMemoryClientRegistrationRepository private static Map toUnmodifiableConcurrentMap(List registrations) { ConcurrentHashMap result = new ConcurrentHashMap<>(); for (ClientRegistration registration : registrations) { - if (result.containsKey(registration.getRegistrationId())) { - throw new IllegalStateException(String.format("Duplicate key %s", registration.getRegistrationId())); - } + Assert.state(!result.containsKey(registration.getRegistrationId()), + () -> String.format("Duplicate key %s", registration.getRegistrationId())); result.put(registration.getRegistrationId(), registration); } return Collections.unmodifiableMap(result); diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/CustomUserTypesOAuth2UserService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/CustomUserTypesOAuth2UserService.java index 5416809838..d5b8d0ab85 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/CustomUserTypesOAuth2UserService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/CustomUserTypesOAuth2UserService.java @@ -88,22 +88,22 @@ public class CustomUserTypesOAuth2UserService implements OAuth2UserService request = this.requestEntityConverter.convert(userRequest); + ResponseEntity response = getResponse(customUserType, request); + OAuth2User oauth2User = response.getBody(); + return oauth2User; + } - ResponseEntity response; + private ResponseEntity getResponse(Class customUserType, + RequestEntity request) { try { - response = this.restOperations.exchange(request, customUserType); + return this.restOperations.exchange(request, customUserType); } catch (RestClientException ex) { OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE, "An error occurred while attempting to retrieve the UserInfo Resource: " + ex.getMessage(), null); throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex); } - - OAuth2User oauth2User = response.getBody(); - - return oauth2User; } /** diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserService.java index 09dd2f4045..78d797d38f 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserService.java @@ -87,7 +87,6 @@ public class DefaultOAuth2UserService implements OAuth2UserService request = this.requestEntityConverter.convert(userRequest); + ResponseEntity> response = getResponse(userRequest, request); + Map userAttributes = response.getBody(); + Set authorities = new LinkedHashSet<>(); + authorities.add(new OAuth2UserAuthority(userAttributes)); + OAuth2AccessToken token = userRequest.getAccessToken(); + for (String authority : token.getScopes()) { + authorities.add(new SimpleGrantedAuthority("SCOPE_" + authority)); + } + return new DefaultOAuth2User(authorities, userAttributes, userNameAttributeName); + } - ResponseEntity> response; + private ResponseEntity> getResponse(OAuth2UserRequest userRequest, RequestEntity request) { try { - response = this.restOperations.exchange(request, PARAMETERIZED_RESPONSE_TYPE); + return this.restOperations.exchange(request, PARAMETERIZED_RESPONSE_TYPE); } catch (OAuth2AuthorizationException ex) { OAuth2Error oauth2Error = ex.getError(); @@ -145,16 +153,6 @@ public class DefaultOAuth2UserService implements OAuth2UserService userAttributes = response.getBody(); - Set authorities = new LinkedHashSet<>(); - authorities.add(new OAuth2UserAuthority(userAttributes)); - OAuth2AccessToken token = userRequest.getAccessToken(); - for (String authority : token.getScopes()) { - authorities.add(new SimpleGrantedAuthority("SCOPE_" + authority)); - } - - return new DefaultOAuth2User(authorities, userAttributes, userNameAttributeName); } /** diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserService.java index 9d7212c9e7..cd195f29b6 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserService.java @@ -74,13 +74,18 @@ public class DefaultReactiveOAuth2UserService implements ReactiveOAuth2UserServi private static final String MISSING_USER_NAME_ATTRIBUTE_ERROR_CODE = "missing_user_name_attribute"; + private static final ParameterizedTypeReference> STRING_OBJECT_MAP = new ParameterizedTypeReference>() { + }; + + private static final ParameterizedTypeReference> STRING_STRING_MAP = new ParameterizedTypeReference>() { + }; + private WebClient webClient = WebClient.create(); @Override public Mono loadUser(OAuth2UserRequest userRequest) throws OAuth2AuthenticationException { return Mono.defer(() -> { Assert.notNull(userRequest, "userRequest cannot be null"); - String userInfoUri = userRequest.getClientRegistration().getProviderDetails().getUserInfoEndpoint() .getUri(); if (!StringUtils.hasText(userInfoUri)) { @@ -99,32 +104,17 @@ public class DefaultReactiveOAuth2UserService implements ReactiveOAuth2UserServi null); throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); } - - ParameterizedTypeReference> typeReference = new ParameterizedTypeReference>() { - }; - AuthenticationMethod authenticationMethod = userRequest.getClientRegistration().getProviderDetails() .getUserInfoEndpoint().getAuthenticationMethod(); - WebClient.RequestHeadersSpec requestHeadersSpec; - if (AuthenticationMethod.FORM.equals(authenticationMethod)) { - requestHeadersSpec = this.webClient.post().uri(userInfoUri) - .header(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) - .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED_VALUE) - .syncBody("access_token=" + userRequest.getAccessToken().getTokenValue()); - } - else { - requestHeadersSpec = this.webClient.get().uri(userInfoUri) - .header(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) - .headers((headers) -> headers.setBearerAuth(userRequest.getAccessToken().getTokenValue())); - } + WebClient.RequestHeadersSpec requestHeadersSpec = getRequestHeaderSpec(userRequest, userInfoUri, + authenticationMethod); Mono> userAttributes = requestHeadersSpec.retrieve() .onStatus((s) -> s != HttpStatus.OK, (response) -> parse(response).map((userInfoErrorResponse) -> { String description = userInfoErrorResponse.getErrorObject().getDescription(); OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE, description, null); throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); - })).bodyToMono(typeReference); - + })).bodyToMono(DefaultReactiveOAuth2UserService.STRING_OBJECT_MAP); return userAttributes.map((attrs) -> { GrantedAuthority authority = new OAuth2UserAuthority(attrs); Set authorities = new HashSet<>(); @@ -136,13 +126,13 @@ public class DefaultReactiveOAuth2UserService implements ReactiveOAuth2UserServi return new DefaultOAuth2User(authorities, attrs, userNameAttributeName); }).onErrorMap(IOException.class, - (e) -> new AuthenticationServiceException("Unable to access the userInfoEndpoint " + userInfoUri, - e)) - .onErrorMap(UnsupportedMediaTypeException.class, (e) -> { + (ex) -> new AuthenticationServiceException("Unable to access the userInfoEndpoint " + userInfoUri, + ex)) + .onErrorMap(UnsupportedMediaTypeException.class, (ex) -> { String errorMessage = "An error occurred while attempting to retrieve the UserInfo Resource from '" + userRequest.getClientRegistration().getProviderDetails().getUserInfoEndpoint() .getUri() - + "': response contains invalid content type '" + e.getContentType().toString() + "'. " + + "': response contains invalid content type '" + ex.getContentType().toString() + "'. " + "The UserInfo Response should return a JSON object (content type 'application/json') " + "that contains a collection of name and value pairs of the claims about the authenticated End-User. " + "Please ensure the UserInfo Uri in UserInfoEndpoint for Client Registration '" @@ -151,7 +141,7 @@ public class DefaultReactiveOAuth2UserService implements ReactiveOAuth2UserServi + "as defined in OpenID Connect 1.0: 'https://openid.net/specs/openid-connect-core-1_0.html#UserInfo'"; OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE, errorMessage, null); - throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), e); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex); }).onErrorMap((t) -> !(t instanceof AuthenticationServiceException), (t) -> { OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE, "An error occurred reading the UserInfo Success response: " + t.getMessage(), null); @@ -160,6 +150,17 @@ public class DefaultReactiveOAuth2UserService implements ReactiveOAuth2UserServi }); } + private WebClient.RequestHeadersSpec getRequestHeaderSpec(OAuth2UserRequest userRequest, String userInfoUri, + AuthenticationMethod authenticationMethod) { + if (AuthenticationMethod.FORM.equals(authenticationMethod)) { + return this.webClient.post().uri(userInfoUri).header(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED_VALUE) + .syncBody("access_token=" + userRequest.getAccessToken().getTokenValue()); + } + return this.webClient.get().uri(userInfoUri).header(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) + .headers((headers) -> headers.setBearerAuth(userRequest.getAccessToken().getTokenValue())); + } + /** * Sets the {@link WebClient} used for retrieving the user endpoint * @param webClient the client to use @@ -170,18 +171,13 @@ public class DefaultReactiveOAuth2UserService implements ReactiveOAuth2UserServi } private static Mono parse(ClientResponse httpResponse) { - String wwwAuth = httpResponse.headers().asHttpHeaders().getFirst(HttpHeaders.WWW_AUTHENTICATE); - if (!StringUtils.isEmpty(wwwAuth)) { // Bearer token error? return Mono.fromCallable(() -> UserInfoErrorResponse.parse(wwwAuth)); } - - ParameterizedTypeReference> typeReference = new ParameterizedTypeReference>() { - }; // Other error? - return httpResponse.bodyToMono(typeReference) + return httpResponse.bodyToMono(STRING_STRING_MAP) .map((body) -> new UserInfoErrorResponse(ErrorObject.parse(new JSONObject(body)))); } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequestEntityConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequestEntityConverter.java index a27dc7b37e..9a7a3c8dd8 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequestEntityConverter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequestEntityConverter.java @@ -54,12 +54,7 @@ public class OAuth2UserRequestEntityConverter implements Converter convert(OAuth2UserRequest userRequest) { ClientRegistration clientRegistration = userRequest.getClientRegistration(); - - HttpMethod httpMethod = HttpMethod.GET; - if (AuthenticationMethod.FORM - .equals(clientRegistration.getProviderDetails().getUserInfoEndpoint().getAuthenticationMethod())) { - httpMethod = HttpMethod.POST; - } + HttpMethod httpMethod = getHttpMethod(clientRegistration); HttpHeaders headers = new HttpHeaders(); headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON)); URI uri = UriComponentsBuilder @@ -80,4 +75,12 @@ public class OAuth2UserRequestEntityConverter implements Converter attributes = new HashMap<>(); attributes.put(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId()); - - OAuth2AuthorizationRequest.Builder builder; - if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(clientRegistration.getAuthorizationGrantType())) { - builder = OAuth2AuthorizationRequest.authorizationCode(); - Map additionalParameters = new HashMap<>(); - if (!CollectionUtils.isEmpty(clientRegistration.getScopes()) - && clientRegistration.getScopes().contains(OidcScopes.OPENID)) { - // Section 3.1.2.1 Authentication Request - - // https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest - // scope - // REQUIRED. OpenID Connect requests MUST contain the "openid" scope - // value. - addNonceParameters(attributes, additionalParameters); - } - if (ClientAuthenticationMethod.NONE.equals(clientRegistration.getClientAuthenticationMethod())) { - addPkceParameters(attributes, additionalParameters); - } - builder.additionalParameters(additionalParameters); - } - else if (AuthorizationGrantType.IMPLICIT.equals(clientRegistration.getAuthorizationGrantType())) { - builder = OAuth2AuthorizationRequest.implicit(); - } - else { - throw new IllegalArgumentException( - "Invalid Authorization Grant Type (" + clientRegistration.getAuthorizationGrantType().getValue() - + ") for Client Registration with Id: " + clientRegistration.getRegistrationId()); - } + OAuth2AuthorizationRequest.Builder builder = getBuilder(clientRegistration, attributes); String redirectUriStr = expandRedirectUri(request, clientRegistration, redirectUriAction); @@ -191,6 +163,33 @@ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2Au return builder.build(); } + private OAuth2AuthorizationRequest.Builder getBuilder(ClientRegistration clientRegistration, + Map attributes) { + if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(clientRegistration.getAuthorizationGrantType())) { + OAuth2AuthorizationRequest.Builder builder = OAuth2AuthorizationRequest.authorizationCode(); + Map additionalParameters = new HashMap<>(); + if (!CollectionUtils.isEmpty(clientRegistration.getScopes()) + && clientRegistration.getScopes().contains(OidcScopes.OPENID)) { + // Section 3.1.2.1 Authentication Request - + // https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest scope + // REQUIRED. OpenID Connect requests MUST contain the "openid" scope + // value. + addNonceParameters(attributes, additionalParameters); + } + if (ClientAuthenticationMethod.NONE.equals(clientRegistration.getClientAuthenticationMethod())) { + addPkceParameters(attributes, additionalParameters); + } + builder.additionalParameters(additionalParameters); + return builder; + } + if (AuthorizationGrantType.IMPLICIT.equals(clientRegistration.getAuthorizationGrantType())) { + return OAuth2AuthorizationRequest.implicit(); + } + throw new IllegalArgumentException( + "Invalid Authorization Grant Type (" + clientRegistration.getAuthorizationGrantType().getValue() + + ") for Client Registration with Id: " + clientRegistration.getRegistrationId()); + } + private String resolveRegistrationId(HttpServletRequest request) { if (this.authorizationRequestMatcher.matches(request)) { return this.authorizationRequestMatcher.matcher(request).getVariables() @@ -220,7 +219,6 @@ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2Au String action) { Map uriVariables = new HashMap<>(); uriVariables.put("registrationId", clientRegistration.getRegistrationId()); - UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request)) .replacePath(request.getContextPath()).replaceQuery(null).fragment(null).build(); String scheme = uriComponents.getScheme(); @@ -238,9 +236,7 @@ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2Au } uriVariables.put("basePath", (path != null) ? path : ""); uriVariables.put("baseUrl", uriComponents.toUriString()); - uriVariables.put("action", (action != null) ? action : ""); - return UriComponentsBuilder.fromUriString(clientRegistration.getRedirectUri()).buildAndExpand(uriVariables) .toUriString(); } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java index 6d21a1a243..5f4b3111fe 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java @@ -131,16 +131,13 @@ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2Authori @Override public OAuth2AuthorizedClient authorize(OAuth2AuthorizeRequest authorizeRequest) { Assert.notNull(authorizeRequest, "authorizeRequest cannot be null"); - String clientRegistrationId = authorizeRequest.getClientRegistrationId(); OAuth2AuthorizedClient authorizedClient = authorizeRequest.getAuthorizedClient(); Authentication principal = authorizeRequest.getPrincipal(); - HttpServletRequest servletRequest = getHttpServletRequestOrDefault(authorizeRequest.getAttributes()); Assert.notNull(servletRequest, "servletRequest cannot be null"); HttpServletResponse servletResponse = getHttpServletResponseOrDefault(authorizeRequest.getAttributes()); Assert.notNull(servletResponse, "servletResponse cannot be null"); - OAuth2AuthorizationContext.Builder contextBuilder; if (authorizedClient != null) { contextBuilder = OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient); @@ -166,7 +163,6 @@ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2Authori attributes.putAll(contextAttributes); } }).build(); - try { authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); } @@ -175,7 +171,6 @@ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2Authori createAttributes(servletRequest, servletResponse)); throw ex; } - if (authorizedClient != null) { this.authorizationSuccessHandler.onAuthorizationSuccess(authorizedClient, principal, createAttributes(servletRequest, servletResponse)); @@ -189,7 +184,6 @@ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2Authori return authorizationContext.getAuthorizedClient(); } } - return authorizedClient; } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManager.java index 7ee459ad6a..69dbf4e9bc 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManager.java @@ -136,10 +136,8 @@ public final class DefaultReactiveOAuth2AuthorizedClientManager implements React @Override public Mono authorize(OAuth2AuthorizeRequest authorizeRequest) { Assert.notNull(authorizeRequest, "authorizeRequest cannot be null"); - String clientRegistrationId = authorizeRequest.getClientRegistrationId(); Authentication principal = authorizeRequest.getPrincipal(); - return Mono.justOrEmpty(authorizeRequest.getAttribute(ServerWebExchange.class.getName())) .switchIfEmpty(currentServerWebExchangeMono) .switchIfEmpty(Mono.error(() -> new IllegalArgumentException("serverWebExchange cannot be null"))) @@ -183,7 +181,6 @@ public final class DefaultReactiveOAuth2AuthorizedClientManager implements React */ private Mono authorize(OAuth2AuthorizationContext authorizationContext, Authentication principal, ServerWebExchange serverWebExchange) { - return this.authorizedClientProvider.authorize(authorizationContext) // Delegate to the authorizationSuccessHandler of the successful // authorization diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilter.java index f86c761178..f8bc3b13b2 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilter.java @@ -161,12 +161,10 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter { @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { - if (matchesAuthorizationResponse(request)) { processAuthorizationResponse(request, response); return; } - filterChain.doFilter(request, response); } @@ -180,7 +178,6 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter { if (authorizationRequest == null) { return false; } - // Compare redirect_uri UriComponents requestUri = UriComponentsBuilder.fromUriString(UrlUtils.buildFullRequestUrl(request)).build(); UriComponents redirectUri = UriComponentsBuilder.fromUriString(authorizationRequest.getRedirectUri()).build(); @@ -193,7 +190,6 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter { // before doing an exact comparison with the authorizationRequest.getRedirectUri() // parameters (if any) requestUriParameters.retainAll(redirectUriParameters); - if (Objects.equals(requestUri.getScheme(), redirectUri.getScheme()) && Objects.equals(requestUri.getUserInfo(), redirectUri.getUserInfo()) && Objects.equals(requestUri.getHost(), redirectUri.getHost()) @@ -207,24 +203,18 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter { private void processAuthorizationResponse(HttpServletRequest request, HttpServletResponse response) throws IOException { - OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestRepository .removeAuthorizationRequest(request, response); - String registrationId = authorizationRequest.getAttribute(OAuth2ParameterNames.REGISTRATION_ID); ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(registrationId); - MultiValueMap params = OAuth2AuthorizationResponseUtils.toMultiMap(request.getParameterMap()); String redirectUri = UrlUtils.buildFullRequestUrl(request); OAuth2AuthorizationResponse authorizationResponse = OAuth2AuthorizationResponseUtils.convert(params, redirectUri); - OAuth2AuthorizationCodeAuthenticationToken authenticationRequest = new OAuth2AuthorizationCodeAuthenticationToken( clientRegistration, new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse)); authenticationRequest.setDetails(this.authenticationDetailsSource.buildDetails(request)); - OAuth2AuthorizationCodeAuthenticationToken authenticationResult; - try { authenticationResult = (OAuth2AuthorizationCodeAuthenticationToken) this.authenticationManager .authenticate(authenticationRequest); @@ -242,24 +232,19 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter { this.redirectStrategy.sendRedirect(request, response, uriBuilder.build().encode().toString()); return; } - Authentication currentAuthentication = SecurityContextHolder.getContext().getAuthentication(); String principalName = (currentAuthentication != null) ? currentAuthentication.getName() : "anonymousUser"; - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( authenticationResult.getClientRegistration(), principalName, authenticationResult.getAccessToken(), authenticationResult.getRefreshToken()); - this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, currentAuthentication, request, response); - String redirectUrl = authorizationRequest.getRedirectUri(); SavedRequest savedRequest = this.requestCache.getRequest(request, response); if (savedRequest != null) { redirectUrl = savedRequest.getRedirectUrl(); this.requestCache.removeRequest(request, response); } - this.redirectStrategy.sendRedirect(request, response, redirectUrl); } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java index 670ca8006d..2e9c62c993 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java @@ -23,6 +23,7 @@ import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import org.springframework.core.log.LogMessage; import org.springframework.http.HttpStatus; import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException; import org.springframework.security.oauth2.client.registration.ClientRegistration; @@ -162,7 +163,6 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { - try { OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestResolver.resolve(request); if (authorizationRequest != null) { @@ -170,11 +170,10 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt return; } } - catch (Exception failed) { - this.unsuccessfulRedirectForAuthorization(request, response, failed); + catch (Exception ex) { + this.unsuccessfulRedirectForAuthorization(request, response, ex); return; } - try { filterChain.doFilter(request, response); } @@ -201,22 +200,18 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt } return; } - if (ex instanceof ServletException) { throw (ServletException) ex; } - else if (ex instanceof RuntimeException) { + if (ex instanceof RuntimeException) { throw (RuntimeException) ex; } - else { - throw new RuntimeException(ex); - } + throw new RuntimeException(ex); } } private void sendRedirectForAuthorization(HttpServletRequest request, HttpServletResponse response, OAuth2AuthorizationRequest authorizationRequest) throws IOException { - if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(authorizationRequest.getGrantType())) { this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response); } @@ -225,11 +220,8 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt } private void unsuccessfulRedirectForAuthorization(HttpServletRequest request, HttpServletResponse response, - Exception failed) throws IOException { - - if (this.logger.isErrorEnabled()) { - this.logger.error("Authorization Request failed: " + failed.toString(), failed); - } + Exception ex) throws IOException { + this.logger.error(LogMessage.format("Authorization Request failed: %s", ex, ex)); response.sendError(HttpStatus.INTERNAL_SERVER_ERROR.value(), HttpStatus.INTERNAL_SERVER_ERROR.getReasonPhrase()); } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationResponseUtils.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationResponseUtils.java index b4020ee9e3..8fdbf17b0f 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationResponseUtils.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationResponseUtils.java @@ -66,16 +66,13 @@ final class OAuth2AuthorizationResponseUtils { String code = request.getFirst(OAuth2ParameterNames.CODE); String errorCode = request.getFirst(OAuth2ParameterNames.ERROR); String state = request.getFirst(OAuth2ParameterNames.STATE); - if (StringUtils.hasText(code)) { return OAuth2AuthorizationResponse.success(code).redirectUri(redirectUri).state(state).build(); } - else { - String errorDescription = request.getFirst(OAuth2ParameterNames.ERROR_DESCRIPTION); - String errorUri = request.getFirst(OAuth2ParameterNames.ERROR_URI); - return OAuth2AuthorizationResponse.error(errorCode).redirectUri(redirectUri) - .errorDescription(errorDescription).errorUri(errorUri).state(state).build(); - } + String errorDescription = request.getFirst(OAuth2ParameterNames.ERROR_DESCRIPTION); + String errorUri = request.getFirst(OAuth2ParameterNames.ERROR_URI); + return OAuth2AuthorizationResponse.error(errorCode).redirectUri(redirectUri).errorDescription(errorDescription) + .errorUri(errorUri).state(state).build(); } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilter.java index 24103d8097..2583b1a9a2 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilter.java @@ -158,20 +158,17 @@ public class OAuth2LoginAuthenticationFilter extends AbstractAuthenticationProce @Override public Authentication attemptAuthentication(HttpServletRequest request, HttpServletResponse response) throws AuthenticationException { - MultiValueMap params = OAuth2AuthorizationResponseUtils.toMultiMap(request.getParameterMap()); if (!OAuth2AuthorizationResponseUtils.isAuthorizationResponse(params)) { OAuth2Error oauth2Error = new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST); throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); } - OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestRepository .removeAuthorizationRequest(request, response); if (authorizationRequest == null) { OAuth2Error oauth2Error = new OAuth2Error(AUTHORIZATION_REQUEST_NOT_FOUND_ERROR_CODE); throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); } - String registrationId = authorizationRequest.getAttribute(OAuth2ParameterNames.REGISTRATION_ID); ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(registrationId); if (clientRegistration == null) { @@ -183,26 +180,21 @@ public class OAuth2LoginAuthenticationFilter extends AbstractAuthenticationProce .build().toUriString(); OAuth2AuthorizationResponse authorizationResponse = OAuth2AuthorizationResponseUtils.convert(params, redirectUri); - Object authenticationDetails = this.authenticationDetailsSource.buildDetails(request); OAuth2LoginAuthenticationToken authenticationRequest = new OAuth2LoginAuthenticationToken(clientRegistration, new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse)); authenticationRequest.setDetails(authenticationDetails); - OAuth2LoginAuthenticationToken authenticationResult = (OAuth2LoginAuthenticationToken) this .getAuthenticationManager().authenticate(authenticationRequest); - OAuth2AuthenticationToken oauth2Authentication = new OAuth2AuthenticationToken( authenticationResult.getPrincipal(), authenticationResult.getAuthorities(), authenticationResult.getClientRegistration().getRegistrationId()); oauth2Authentication.setDetails(authenticationDetails); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( authenticationResult.getClientRegistration(), oauth2Authentication.getName(), authenticationResult.getAccessToken(), authenticationResult.getRefreshToken()); this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, oauth2Authentication, request, response); - return oauth2Authentication; } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java index 4a5680a752..66f81abc45 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java @@ -114,46 +114,38 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth @Override public Object resolveArgument(MethodParameter parameter, @Nullable ModelAndViewContainer mavContainer, NativeWebRequest webRequest, @Nullable WebDataBinderFactory binderFactory) { - String clientRegistrationId = this.resolveClientRegistrationId(parameter); if (StringUtils.isEmpty(clientRegistrationId)) { throw new IllegalArgumentException("Unable to resolve the Client Registration Identifier. " + "It must be provided via @RegisteredOAuth2AuthorizedClient(\"client1\") or " + "@RegisteredOAuth2AuthorizedClient(registrationId = \"client1\")."); } - Authentication principal = SecurityContextHolder.getContext().getAuthentication(); if (principal == null) { principal = ANONYMOUS_AUTHENTICATION; } HttpServletRequest servletRequest = webRequest.getNativeRequest(HttpServletRequest.class); HttpServletResponse servletResponse = webRequest.getNativeResponse(HttpServletResponse.class); - OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(clientRegistrationId) .principal(principal).attribute(HttpServletRequest.class.getName(), servletRequest) .attribute(HttpServletResponse.class.getName(), servletResponse).build(); - return this.authorizedClientManager.authorize(authorizeRequest); } private String resolveClientRegistrationId(MethodParameter parameter) { RegisteredOAuth2AuthorizedClient authorizedClientAnnotation = AnnotatedElementUtils .findMergedAnnotation(parameter.getParameter(), RegisteredOAuth2AuthorizedClient.class); - Authentication principal = SecurityContextHolder.getContext().getAuthentication(); - - String clientRegistrationId = null; if (!StringUtils.isEmpty(authorizedClientAnnotation.registrationId())) { - clientRegistrationId = authorizedClientAnnotation.registrationId(); + return authorizedClientAnnotation.registrationId(); } - else if (!StringUtils.isEmpty(authorizedClientAnnotation.value())) { - clientRegistrationId = authorizedClientAnnotation.value(); + if (!StringUtils.isEmpty(authorizedClientAnnotation.value())) { + return authorizedClientAnnotation.value(); } - else if (principal != null && OAuth2AuthenticationToken.class.isAssignableFrom(principal.getClass())) { - clientRegistrationId = ((OAuth2AuthenticationToken) principal).getAuthorizedClientRegistrationId(); + if (principal != null && OAuth2AuthenticationToken.class.isAssignableFrom(principal.getClass())) { + return ((OAuth2AuthenticationToken) principal).getAuthorizedClientRegistrationId(); } - - return clientRegistrationId; + return null; } /** @@ -184,7 +176,6 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth private void updateDefaultAuthorizedClientManager( OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient) { - OAuth2AuthorizedClientProvider authorizedClientProvider = OAuth2AuthorizedClientProviderBuilder.builder() .authorizationCode().refreshToken() .clientCredentials( diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java index 6e9b5f09f4..8cf8ae34b9 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java @@ -206,12 +206,10 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements public ServerOAuth2AuthorizedClientExchangeFilterFunction( ReactiveClientRegistrationRepository clientRegistrationRepository, ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { - ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler = new RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler( (clientRegistrationId, principal, attributes) -> authorizedClientRepository.removeAuthorizedClient( clientRegistrationId, principal, (ServerWebExchange) attributes.get(ServerWebExchange.class.getName()))); - this.authorizedClientManager = createDefaultAuthorizedClientManager(clientRegistrationRepository, authorizedClientRepository, authorizationFailureHandler); this.clientResponseHandler = new AuthorizationFailureForwarder(authorizationFailureHandler); @@ -222,7 +220,6 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements ReactiveClientRegistrationRepository clientRegistrationRepository, ServerOAuth2AuthorizedClientRepository authorizedClientRepository, ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler) { - // gh-7544 if (authorizedClientRepository instanceof UnAuthenticatedServerOAuth2AuthorizedClientRepository) { UnAuthenticatedReactiveOAuth2AuthorizedClientManager unauthenticatedAuthorizedClientManager = new UnAuthenticatedReactiveOAuth2AuthorizedClientManager( @@ -234,11 +231,9 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements .authorizationCode().refreshToken().clientCredentials().password().build()); return unauthenticatedAuthorizedClientManager; } - DefaultReactiveOAuth2AuthorizedClientManager authorizedClientManager = new DefaultReactiveOAuth2AuthorizedClientManager( clientRegistrationRepository, authorizedClientRepository); authorizedClientManager.setAuthorizationFailureHandler(authorizationFailureHandler); - return authorizedClientManager; } @@ -444,9 +439,7 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements private Mono authorizeRequest(ClientRequest request) { Mono clientRegistrationId = effectiveClientRegistrationId(request); - Mono> serverWebExchange = effectiveServerWebExchange(request); - return Mono.zip(clientRegistrationId, this.currentAuthenticationMono, serverWebExchange).map((t3) -> { OAuth2AuthorizeRequest.Builder builder = OAuth2AuthorizeRequest.withClientRegistrationId(t3.getT1()) .principal(t3.getT2()); @@ -488,7 +481,6 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements private Mono reauthorizeRequest(ClientRequest request, OAuth2AuthorizedClient authorizedClient) { Mono> serverWebExchange = effectiveServerWebExchange(request); - return Mono.zip(this.currentAuthenticationMono, serverWebExchange).map((t2) -> { OAuth2AuthorizeRequest.Builder builder = OAuth2AuthorizeRequest.withAuthorizedClient(authorizedClient) .principal(t2.getT1()); @@ -561,27 +553,39 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements @Override public Mono authorize(OAuth2AuthorizeRequest authorizeRequest) { Assert.notNull(authorizeRequest, "authorizeRequest cannot be null"); - String clientRegistrationId = authorizeRequest.getClientRegistrationId(); Authentication principal = authorizeRequest.getPrincipal(); + return Mono.justOrEmpty(authorizeRequest.getAuthorizedClient()) + .switchIfEmpty(loadAuthorizedClient(clientRegistrationId, principal)) + .flatMap((authorizedClient) -> reauthorize(authorizedClient, authorizeRequest, principal)) + .switchIfEmpty(findAndAuthorize(clientRegistrationId, principal)); + } - return Mono.justOrEmpty(authorizeRequest.getAuthorizedClient()).switchIfEmpty(Mono.defer( - () -> this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, null))) - .flatMap((authorizedClient) -> // Re-authorize - Mono.just(OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient).principal(principal) - .build()).flatMap((authorizationContext) -> authorize(authorizationContext, principal)) - // Default to the existing authorizedClient if the client - // was not re-authorized - .defaultIfEmpty((authorizeRequest.getAuthorizedClient() != null) - ? authorizeRequest.getAuthorizedClient() : authorizedClient)) - .switchIfEmpty(Mono.defer(() -> - // Authorize - this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId) - .switchIfEmpty(Mono.error(() -> new IllegalArgumentException( - "Could not find ClientRegistration with id '" + clientRegistrationId + "'"))) - .flatMap((clientRegistration) -> Mono.just(OAuth2AuthorizationContext - .withClientRegistration(clientRegistration).principal(principal).build())) - .flatMap((authorizationContext) -> authorize(authorizationContext, principal)))); + private Mono loadAuthorizedClient(String clientRegistrationId, + Authentication principal) { + return Mono.defer( + () -> this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, null)); + } + + private Mono reauthorize(OAuth2AuthorizedClient authorizedClient, + OAuth2AuthorizeRequest authorizeRequest, Authentication principal) { + return Mono + .just(OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient).principal(principal) + .build()) + .flatMap((authorizationContext) -> authorize(authorizationContext, principal)) + // Default to the existing authorizedClient if the client was not + // re-authorized + .defaultIfEmpty((authorizeRequest.getAuthorizedClient() != null) + ? authorizeRequest.getAuthorizedClient() : authorizedClient); + } + + private Mono findAndAuthorize(String clientRegistrationId, Authentication principal) { + return Mono.defer(() -> this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId) + .switchIfEmpty(Mono.error(() -> new IllegalArgumentException( + "Could not find ClientRegistration with id '" + clientRegistrationId + "'"))) + .flatMap((clientRegistration) -> Mono.just(OAuth2AuthorizationContext + .withClientRegistration(clientRegistration).principal(principal).build())) + .flatMap((authorizationContext) -> authorize(authorizationContext, principal))); } /** @@ -597,7 +601,6 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements */ private Mono authorize(OAuth2AuthorizationContext authorizationContext, Authentication principal) { - return this.authorizedClientProvider.authorize(authorizationContext) // Delegates to the authorizationSuccessHandler of the successful // authorization @@ -642,7 +645,6 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements private AuthorizationFailureForwarder(ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler) { Assert.notNull(authorizationFailureHandler, "authorizationFailureHandler cannot be null"); this.authorizationFailureHandler = authorizationFailureHandler; - Map httpStatusToOAuth2Error = new HashMap<>(); httpStatusToOAuth2Error.put(HttpStatus.UNAUTHORIZED.value(), OAuth2ErrorCodes.INVALID_TOKEN); httpStatusToOAuth2Error.put(HttpStatus.FORBIDDEN.value(), OAuth2ErrorCodes.INSUFFICIENT_SCOPE); @@ -661,17 +663,12 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements private Mono handleResponse(ClientRequest request, ClientResponse response) { return Mono.justOrEmpty(resolveErrorIfPossible(response)).flatMap((oauth2Error) -> { Mono> serverWebExchange = effectiveServerWebExchange(request); - Mono clientRegistrationId = effectiveClientRegistrationId(request); - return Mono .zip(ServerOAuth2AuthorizedClientExchangeFilterFunction.this.currentAuthenticationMono, serverWebExchange, clientRegistrationId) - .flatMap((tuple3) -> handleAuthorizationFailure(tuple3.getT1(), // Authentication - // principal - tuple3.getT2().orElse(null), // ServerWebExchange exchange - new ClientAuthorizationException(oauth2Error, tuple3.getT3()))); // String - // clientRegistrationId + .flatMap((zipped) -> handleAuthorizationFailure(zipped.getT1(), zipped.getT2(), + new ClientAuthorizationException(oauth2Error, zipped.getT3()))); }); } @@ -720,18 +717,12 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements WebClientResponseException exception) { return Mono.justOrEmpty(resolveErrorIfPossible(exception.getRawStatusCode())).flatMap((oauth2Error) -> { Mono> serverWebExchange = effectiveServerWebExchange(request); - Mono clientRegistrationId = effectiveClientRegistrationId(request); - return Mono .zip(ServerOAuth2AuthorizedClientExchangeFilterFunction.this.currentAuthenticationMono, serverWebExchange, clientRegistrationId) - .flatMap((tuple3) -> handleAuthorizationFailure(tuple3.getT1(), // Authentication - // principal - tuple3.getT2().orElse(null), // ServerWebExchange exchange - new ClientAuthorizationException(oauth2Error, tuple3.getT3(), // String - // clientRegistrationId - exception))); + .flatMap((zipped) -> handleAuthorizationFailure(zipped.getT1(), zipped.getT2(), + new ClientAuthorizationException(oauth2Error, zipped.getT3(), exception))); }); } @@ -745,14 +736,10 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements */ private Mono handleAuthorizationException(ClientRequest request, OAuth2AuthorizationException exception) { Mono> serverWebExchange = effectiveServerWebExchange(request); - - return Mono.zip(ServerOAuth2AuthorizedClientExchangeFilterFunction.this.currentAuthenticationMono, - serverWebExchange).flatMap( - (tuple2) -> handleAuthorizationFailure(tuple2.getT1(), // Authentication - // principal - tuple2.getT2().orElse(null), // ServerWebExchange - // exchange - exception)); + return Mono + .zip(ServerOAuth2AuthorizedClientExchangeFilterFunction.this.currentAuthenticationMono, + serverWebExchange) + .flatMap((zipped) -> handleAuthorizationFailure(zipped.getT1(), zipped.getT2(), exception)); } /** @@ -763,11 +750,10 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements * @return a {@link Mono} that completes empty after the authorization failure * handler completes. */ - private Mono handleAuthorizationFailure(Authentication principal, ServerWebExchange exchange, + private Mono handleAuthorizationFailure(Authentication principal, Optional exchange, OAuth2AuthorizationException exception) { - return this.authorizationFailureHandler.onAuthorizationFailure(exception, principal, - createAttributes(exchange)); + createAttributes(exchange.orElse(null))); } private Map createAttributes(ServerWebExchange exchange) { diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java index aa6a237fdb..4fe3bf5fa3 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java @@ -218,12 +218,9 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement public ServletOAuth2AuthorizedClientExchangeFilterFunction( ClientRegistrationRepository clientRegistrationRepository, OAuth2AuthorizedClientRepository authorizedClientRepository) { - OAuth2AuthorizationFailureHandler authorizationFailureHandler = new RemoveAuthorizedClientOAuth2AuthorizationFailureHandler( - (clientRegistrationId, principal, attributes) -> authorizedClientRepository.removeAuthorizedClient( - clientRegistrationId, principal, - (HttpServletRequest) attributes.get(HttpServletRequest.class.getName()), - (HttpServletResponse) attributes.get(HttpServletResponse.class.getName()))); + (clientRegistrationId, principal, attributes) -> removeAuthorizedClient(authorizedClientRepository, + clientRegistrationId, principal, attributes)); DefaultOAuth2AuthorizedClientManager defaultAuthorizedClientManager = new DefaultOAuth2AuthorizedClientManager( clientRegistrationRepository, authorizedClientRepository); defaultAuthorizedClientManager.setAuthorizationFailureHandler(authorizationFailureHandler); @@ -232,6 +229,13 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement this.clientResponseHandler = new AuthorizationFailureForwarder(authorizationFailureHandler); } + private void removeAuthorizedClient(OAuth2AuthorizedClientRepository authorizedClientRepository, + String clientRegistrationId, Authentication principal, Map attributes) { + HttpServletRequest request = getRequest(attributes); + HttpServletResponse response = getResponse(attributes); + authorizedClientRepository.removeAuthorizedClient(clientRegistrationId, principal, request, response); + } + /** * Sets the {@link OAuth2AccessTokenResponseClient} used for getting an * {@link OAuth2AuthorizedClient} for the client_credentials grant. @@ -453,9 +457,7 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement || !request.attribute(AUTHENTICATION_ATTR_NAME).isPresent()) { return mergeRequestAttributesFromContext(request); } - else { - return Mono.just(request); - } + return Mono.just(request); } private Mono mergeRequestAttributesFromContext(ClientRequest request) { @@ -530,23 +532,13 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement } HttpServletRequest servletRequest = getRequest(attrs); HttpServletResponse servletResponse = getResponse(attrs); - OAuth2AuthorizeRequest.Builder builder = OAuth2AuthorizeRequest.withClientRegistrationId(clientRegistrationId) .principal(authentication); - builder.attributes((attributes) -> { - if (servletRequest != null) { - attributes.put(HttpServletRequest.class.getName(), servletRequest); - } - if (servletResponse != null) { - attributes.put(HttpServletResponse.class.getName(), servletResponse); - } - }); + builder.attributes((attributes) -> addToAttributes(attributes, servletRequest, servletResponse)); OAuth2AuthorizeRequest authorizeRequest = builder.build(); - - // NOTE: - // 'authorizedClientManager.authorize()' needs to be executed - // on a dedicated thread via subscribeOn(Schedulers.boundedElastic()) - // since it performs a blocking I/O operation using RestTemplate internally + // NOTE: 'authorizedClientManager.authorize()' needs to be executed on a dedicated + // thread via subscribeOn(Schedulers.boundedElastic()) since it performs a + // blocking I/O operation using RestTemplate internally return Mono.fromSupplier(() -> this.authorizedClientManager.authorize(authorizeRequest)) .subscribeOn(Schedulers.boundedElastic()); } @@ -563,27 +555,27 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement } HttpServletRequest servletRequest = getRequest(attrs); HttpServletResponse servletResponse = getResponse(attrs); - OAuth2AuthorizeRequest.Builder builder = OAuth2AuthorizeRequest.withAuthorizedClient(authorizedClient) .principal(authentication); - builder.attributes((attributes) -> { - if (servletRequest != null) { - attributes.put(HttpServletRequest.class.getName(), servletRequest); - } - if (servletResponse != null) { - attributes.put(HttpServletResponse.class.getName(), servletResponse); - } - }); + builder.attributes((attributes) -> addToAttributes(attributes, servletRequest, servletResponse)); OAuth2AuthorizeRequest reauthorizeRequest = builder.build(); - - // NOTE: - // 'authorizedClientManager.authorize()' needs to be executed - // on a dedicated thread via subscribeOn(Schedulers.boundedElastic()) - // since it performs a blocking I/O operation using RestTemplate internally + // NOTE: 'authorizedClientManager.authorize()' needs to be executed on a dedicated + // thread via subscribeOn(Schedulers.boundedElastic()) since it performs a + // blocking I/O operation using RestTemplate internally return Mono.fromSupplier(() -> this.authorizedClientManager.authorize(reauthorizeRequest)) .subscribeOn(Schedulers.boundedElastic()); } + private void addToAttributes(Map attributes, HttpServletRequest servletRequest, + HttpServletResponse servletResponse) { + if (servletRequest != null) { + attributes.put(HTTP_SERVLET_REQUEST_ATTR_NAME, servletRequest); + } + if (servletResponse != null) { + attributes.put(HTTP_SERVLET_RESPONSE_ATTR_NAME, servletResponse); + } + } + private ClientRequest bearer(ClientRequest request, OAuth2AuthorizedClient authorizedClient) { return ClientRequest.from(request) .headers((headers) -> headers.setBearerAuth(authorizedClient.getAccessToken().getTokenValue())) @@ -612,8 +604,8 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement private static Authentication createAuthentication(final String principalName) { Assert.hasText(principalName, "principalName cannot be empty"); - return new AbstractAuthenticationToken(null) { + @Override public Object getCredentials() { return ""; @@ -656,7 +648,6 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement private AuthorizationFailureForwarder(OAuth2AuthorizationFailureHandler authorizationFailureHandler) { Assert.notNull(authorizationFailureHandler, "authorizationFailureHandler cannot be null"); this.authorizationFailureHandler = authorizationFailureHandler; - Map httpStatusToOAuth2Error = new HashMap<>(); httpStatusToOAuth2Error.put(HttpStatus.UNAUTHORIZED.value(), OAuth2ErrorCodes.INVALID_TOKEN); httpStatusToOAuth2Error.put(HttpStatus.FORBIDDEN.value(), OAuth2ErrorCodes.INSUFFICIENT_SCOPE); @@ -679,14 +670,11 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement if (authorizedClient == null) { return Mono.empty(); } - ClientAuthorizationException authorizationException = new ClientAuthorizationException(oauth2Error, authorizedClient.getClientRegistration().getRegistrationId()); - Authentication principal = createAuthentication(authorizedClient.getPrincipalName()); HttpServletRequest servletRequest = getRequest(attrs); HttpServletResponse servletResponse = getResponse(attrs); - return handleAuthorizationFailure(authorizationException, principal, servletRequest, servletResponse); }); } @@ -740,14 +728,11 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement if (authorizedClient == null) { return Mono.empty(); } - ClientAuthorizationException authorizationException = new ClientAuthorizationException(oauth2Error, authorizedClient.getClientRegistration().getRegistrationId(), exception); - Authentication principal = createAuthentication(authorizedClient.getPrincipalName()); HttpServletRequest servletRequest = getRequest(attrs); HttpServletResponse servletResponse = getResponse(attrs); - return handleAuthorizationFailure(authorizationException, principal, servletRequest, servletResponse); }); } @@ -769,11 +754,9 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement if (authorizedClient == null) { return Mono.empty(); } - Authentication principal = createAuthentication(authorizedClient.getPrincipalName()); HttpServletRequest servletRequest = getRequest(attrs); HttpServletResponse servletResponse = getResponse(attrs); - return handleAuthorizationFailure(authorizationException, principal, servletRequest, servletResponse); }); } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolver.java index b014eb2dd6..eed58f7918 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolver.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolver.java @@ -105,28 +105,24 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth return Mono.defer(() -> { RegisteredOAuth2AuthorizedClient authorizedClientAnnotation = AnnotatedElementUtils .findMergedAnnotation(parameter.getParameter(), RegisteredOAuth2AuthorizedClient.class); - String clientRegistrationId = StringUtils.hasLength(authorizedClientAnnotation.registrationId()) ? authorizedClientAnnotation.registrationId() : null; - return authorizeRequest(clientRegistrationId, exchange).flatMap(this.authorizedClientManager::authorize); }); } private Mono authorizeRequest(String registrationId, ServerWebExchange exchange) { Mono defaultedAuthentication = currentAuthentication(); - Mono defaultedRegistrationId = Mono.justOrEmpty(registrationId) .switchIfEmpty(clientRegistrationId(defaultedAuthentication)) .switchIfEmpty(Mono.error(() -> new IllegalArgumentException( "The clientRegistrationId could not be resolved. Please provide one"))); - Mono defaultedExchange = Mono.justOrEmpty(exchange) .switchIfEmpty(currentServerWebExchange()); - return Mono.zip(defaultedRegistrationId, defaultedAuthentication, defaultedExchange) - .map((t3) -> OAuth2AuthorizeRequest.withClientRegistrationId(t3.getT1()).principal(t3.getT2()) - .attribute(ServerWebExchange.class.getName(), t3.getT3()).build()); + .map((zipped) -> OAuth2AuthorizeRequest.withClientRegistrationId(zipped.getT1()) + .principal(zipped.getT2()).attribute(ServerWebExchange.class.getName(), zipped.getT3()) + .build()); } private Mono currentAuthentication() { diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository.java index 0c6e56a17c..56c180e0f0 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository.java @@ -83,10 +83,7 @@ public final class AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository if (this.isPrincipalAuthenticated(principal)) { return this.authorizedClientService.loadAuthorizedClient(clientRegistrationId, principal.getName()); } - else { - return this.anonymousAuthorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, - exchange); - } + return this.anonymousAuthorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, exchange); } @Override @@ -95,9 +92,7 @@ public final class AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository if (this.isPrincipalAuthenticated(principal)) { return this.authorizedClientService.saveAuthorizedClient(authorizedClient, principal); } - else { - return this.anonymousAuthorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, exchange); - } + return this.anonymousAuthorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, exchange); } @Override @@ -106,10 +101,8 @@ public final class AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository if (this.isPrincipalAuthenticated(principal)) { return this.authorizedClientService.removeAuthorizedClient(clientRegistrationId, principal.getName()); } - else { - return this.anonymousAuthorizedClientRepository.removeAuthorizedClient(clientRegistrationId, principal, - exchange); - } + return this.anonymousAuthorizedClientRepository.removeAuthorizedClient(clientRegistrationId, principal, + exchange); } private boolean isPrincipalAuthenticated(Authentication authentication) { diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolver.java index af0df46d5c..e5a9e60362 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolver.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolver.java @@ -153,13 +153,23 @@ public class DefaultServerOAuth2AuthorizationRequestResolver implements ServerOA private OAuth2AuthorizationRequest authorizationRequest(ServerWebExchange exchange, ClientRegistration clientRegistration) { String redirectUriStr = expandRedirectUri(exchange.getRequest(), clientRegistration); - Map attributes = new HashMap<>(); attributes.put(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId()); + OAuth2AuthorizationRequest.Builder builder = getBuilder(clientRegistration, attributes); + builder.clientId(clientRegistration.getClientId()) + .authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri()) + .redirectUri(redirectUriStr).scopes(clientRegistration.getScopes()) + .state(this.stateGenerator.generateKey()).attributes(attributes); - OAuth2AuthorizationRequest.Builder builder; + this.authorizationRequestCustomizer.accept(builder); + + return builder.build(); + } + + private OAuth2AuthorizationRequest.Builder getBuilder(ClientRegistration clientRegistration, + Map attributes) { if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(clientRegistration.getAuthorizationGrantType())) { - builder = OAuth2AuthorizationRequest.authorizationCode(); + OAuth2AuthorizationRequest.Builder builder = OAuth2AuthorizationRequest.authorizationCode(); Map additionalParameters = new HashMap<>(); if (!CollectionUtils.isEmpty(clientRegistration.getScopes()) && clientRegistration.getScopes().contains(OidcScopes.OPENID)) { @@ -174,23 +184,14 @@ public class DefaultServerOAuth2AuthorizationRequestResolver implements ServerOA addPkceParameters(attributes, additionalParameters); } builder.additionalParameters(additionalParameters); + return builder; } - else if (AuthorizationGrantType.IMPLICIT.equals(clientRegistration.getAuthorizationGrantType())) { - builder = OAuth2AuthorizationRequest.implicit(); + if (AuthorizationGrantType.IMPLICIT.equals(clientRegistration.getAuthorizationGrantType())) { + return OAuth2AuthorizationRequest.implicit(); } - else { - throw new IllegalArgumentException( - "Invalid Authorization Grant Type (" + clientRegistration.getAuthorizationGrantType().getValue() - + ") for Client Registration with Id: " + clientRegistration.getRegistrationId()); - } - builder.clientId(clientRegistration.getClientId()) - .authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri()) - .redirectUri(redirectUriStr).scopes(clientRegistration.getScopes()) - .state(this.stateGenerator.generateKey()).attributes(attributes); - - this.authorizationRequestCustomizer.accept(builder); - - return builder.build(); + throw new IllegalArgumentException( + "Invalid Authorization Grant Type (" + clientRegistration.getAuthorizationGrantType().getValue() + + ") for Client Registration with Id: " + clientRegistration.getRegistrationId()); } /** @@ -213,7 +214,6 @@ public class DefaultServerOAuth2AuthorizationRequestResolver implements ServerOA private static String expandRedirectUri(ServerHttpRequest request, ClientRegistration clientRegistration) { Map uriVariables = new HashMap<>(); uriVariables.put("registrationId", clientRegistration.getRegistrationId()); - UriComponents uriComponents = UriComponentsBuilder.fromUri(request.getURI()) .replacePath(request.getPath().contextPath().value()).replaceQuery(null).fragment(null).build(); String scheme = uriComponents.getScheme(); @@ -231,13 +231,11 @@ public class DefaultServerOAuth2AuthorizationRequestResolver implements ServerOA } uriVariables.put("basePath", (path != null) ? path : ""); uriVariables.put("baseUrl", uriComponents.toUriString()); - String action = ""; if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(clientRegistration.getAuthorizationGrantType())) { action = "login"; } uriVariables.put("action", action); - return UriComponentsBuilder.fromUriString(clientRegistration.getRedirectUri()).buildAndExpand(uriVariables) .toUriString(); } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilter.java index b7e68fbe8a..7406bb76fe 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilter.java @@ -206,7 +206,7 @@ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter { .filter(ServerWebExchangeMatcher.MatchResult::isMatch) .flatMap((matchResult) -> this.authenticationConverter.convert(exchange).onErrorMap( OAuth2AuthorizationException.class, - (e) -> new OAuth2AuthenticationException(e.getError(), e.getError().toString()))) + (ex) -> new OAuth2AuthenticationException(ex.getError(), ex.getError().toString()))) .switchIfEmpty(chain.filter(exchange).then(Mono.empty())) .flatMap((token) -> authenticate(exchange, chain, token)) .onErrorResume(AuthenticationException.class, (e) -> this.authenticationFailureHandler @@ -217,7 +217,7 @@ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter { WebFilterExchange webFilterExchange = new WebFilterExchange(exchange, chain); return this.authenticationManager.authenticate(token) .onErrorMap(OAuth2AuthorizationException.class, - (e) -> new OAuth2AuthenticationException(e.getError(), e.getError().toString())) + (ex) -> new OAuth2AuthenticationException(ex.getError(), ex.getError().toString())) .switchIfEmpty(Mono.defer( () -> Mono.error(new IllegalStateException("No provider found for " + token.getClass())))) .flatMap((authentication) -> onAuthenticationSuccess(authentication, webFilterExchange)) @@ -258,7 +258,6 @@ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter { // before doing an exact comparison with the authorizationRequest.getRedirectUri() // parameters (if any) requestUriParameters.retainAll(redirectUriParameters); - if (Objects.equals(requestUri.getScheme(), redirectUri.getScheme()) && Objects.equals(requestUri.getUserInfo(), redirectUri.getUserInfo()) && Objects.equals(requestUri.getHost(), redirectUri.getHost()) diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationRequestRedirectWebFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationRequestRedirectWebFilter.java index b1ae82caaf..f9f539d608 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationRequestRedirectWebFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationRequestRedirectWebFilter.java @@ -130,8 +130,8 @@ public class OAuth2AuthorizationRequestRedirectWebFilter implements WebFilter { return this.authorizationRequestResolver.resolve(exchange) .switchIfEmpty(chain.filter(exchange).then(Mono.empty())) .onErrorResume(ClientAuthorizationRequiredException.class, - (e) -> this.requestCache.saveRequest(exchange) - .then(this.authorizationRequestResolver.resolve(exchange, e.getClientRegistrationId()))) + (ex) -> this.requestCache.saveRequest(exchange).then( + this.authorizationRequestResolver.resolve(exchange, ex.getClientRegistrationId()))) .flatMap((clientRegistration) -> sendRedirectForAuthorization(exchange, clientRegistration)); } @@ -143,7 +143,6 @@ public class OAuth2AuthorizationRequestRedirectWebFilter implements WebFilter { saveAuthorizationRequest = this.authorizationRequestRepository .saveAuthorizationRequest(authorizationRequest, exchange); } - URI redirectUri = UriComponentsBuilder.fromUriString(authorizationRequest.getAuthorizationRequestUri()) .build(true).toUri(); return saveAuthorizationRequest diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationResponseUtils.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationResponseUtils.java index 1d35f6b131..1c3d3f9810 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationResponseUtils.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationResponseUtils.java @@ -66,16 +66,13 @@ final class OAuth2AuthorizationResponseUtils { String code = request.getFirst(OAuth2ParameterNames.CODE); String errorCode = request.getFirst(OAuth2ParameterNames.ERROR); String state = request.getFirst(OAuth2ParameterNames.STATE); - if (StringUtils.hasText(code)) { return OAuth2AuthorizationResponse.success(code).redirectUri(redirectUri).state(state).build(); } - else { - String errorDescription = request.getFirst(OAuth2ParameterNames.ERROR_DESCRIPTION); - String errorUri = request.getFirst(OAuth2ParameterNames.ERROR_URI); - return OAuth2AuthorizationResponse.error(errorCode).redirectUri(redirectUri) - .errorDescription(errorDescription).errorUri(errorUri).state(state).build(); - } + String errorDescription = request.getFirst(OAuth2ParameterNames.ERROR_DESCRIPTION); + String errorUri = request.getFirst(OAuth2ParameterNames.ERROR_URI); + return OAuth2AuthorizationResponse.error(errorCode).redirectUri(redirectUri).errorDescription(errorDescription) + .errorUri(errorUri).state(state).build(); } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/UnAuthenticatedServerOAuth2AuthorizedClientRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/UnAuthenticatedServerOAuth2AuthorizedClientRepository.java index fe9c8e0c77..1073e32fa7 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/UnAuthenticatedServerOAuth2AuthorizedClientRepository.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/UnAuthenticatedServerOAuth2AuthorizedClientRepository.java @@ -53,7 +53,6 @@ public class UnAuthenticatedServerOAuth2AuthorizedClientRepository implements Se Assert.notNull(clientRegistrationId, "clientRegistrationId cannot be null"); Assert.isNull(serverWebExchange, "serverWebExchange must be null"); Assert.isTrue(isUnauthenticated(authentication), "The user " + authentication + " should not be authenticated"); - return Mono.fromSupplier(() -> (T) this.clientRegistrationIdToAuthorizedClient.get(clientRegistrationId)); } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepository.java index b987ef0b2f..0ed7126a85 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepository.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepository.java @@ -121,14 +121,11 @@ public final class WebSessionOAuth2ServerAuthorizationRequestRepository private Mono> saveStateToAuthorizationRequest(ServerWebExchange exchange) { Assert.notNull(exchange, "exchange cannot be null"); - return getSessionAttributes(exchange).doOnNext((sessionAttrs) -> { Object stateToAuthzRequest = sessionAttrs.get(this.sessionAttributeName); - if (stateToAuthzRequest == null) { stateToAuthzRequest = new HashMap(); } - // No matter stateToAuthzRequest was in session or not, we should always put // it into session again // in case of redis or hazelcast session. #6215