diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2UserAuthenticationProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2UserAuthenticationProvider.java index fa1c4627e7..3330281962 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2UserAuthenticationProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2UserAuthenticationProvider.java @@ -20,6 +20,9 @@ import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ClientRegistrationIdentifierStrategy; import org.springframework.security.oauth2.client.user.DefaultOAuth2UserService; import org.springframework.security.oauth2.client.user.OAuth2UserService; import org.springframework.security.oauth2.core.user.OAuth2User; @@ -55,6 +58,7 @@ import java.util.Collection; * @see OidcUser */ public class OAuth2UserAuthenticationProvider implements AuthenticationProvider { + private final ClientRegistrationIdentifierStrategy providerIdentifierStrategy = new ProviderIdentifierStrategy(); private final OAuth2UserService userService; private GrantedAuthoritiesMapper authoritiesMapper = (authorities -> authorities); @@ -65,11 +69,18 @@ public class OAuth2UserAuthenticationProvider implements AuthenticationProvider @Override public Authentication authenticate(Authentication authentication) throws AuthenticationException { - OAuth2UserAuthenticationToken oauth2UserAuthentication = (OAuth2UserAuthenticationToken) authentication; + OAuth2UserAuthenticationToken userAuthentication = (OAuth2UserAuthenticationToken) authentication; + OAuth2ClientAuthenticationToken clientAuthentication = userAuthentication.getClientAuthentication(); - OAuth2ClientAuthenticationToken oauth2ClientAuthentication = oauth2UserAuthentication.getClientAuthentication(); + if (this.userAuthenticated() && this.userAuthenticatedSameProviderAs(clientAuthentication)) { + // Create a new user authentication (using same principal) + // but with a different client authentication association + return this.createUserAuthentication( + (OAuth2UserAuthenticationToken)SecurityContextHolder.getContext().getAuthentication(), + clientAuthentication); + } - OAuth2User oauth2User = this.userService.loadUser(oauth2ClientAuthentication); + OAuth2User oauth2User = this.userService.loadUser(clientAuthentication); Collection mappedAuthorities = this.authoritiesMapper.mapAuthorities(oauth2User.getAuthorities()); @@ -77,12 +88,12 @@ public class OAuth2UserAuthenticationProvider implements AuthenticationProvider OAuth2UserAuthenticationToken authenticationResult; if (OidcUser.class.isAssignableFrom(oauth2User.getClass())) { authenticationResult = new OidcUserAuthenticationToken( - (OidcUser)oauth2User, mappedAuthorities, (OidcClientAuthenticationToken)oauth2ClientAuthentication); + (OidcUser)oauth2User, mappedAuthorities, (OidcClientAuthenticationToken)clientAuthentication); } else { authenticationResult = new OAuth2UserAuthenticationToken( - oauth2User, mappedAuthorities, oauth2ClientAuthentication); + oauth2User, mappedAuthorities, clientAuthentication); } - authenticationResult.setDetails(oauth2ClientAuthentication.getDetails()); + authenticationResult.setDetails(clientAuthentication.getDetails()); return authenticationResult; } @@ -96,4 +107,52 @@ public class OAuth2UserAuthenticationProvider implements AuthenticationProvider public boolean supports(Class authentication) { return OAuth2UserAuthenticationToken.class.isAssignableFrom(authentication); } + + private boolean userAuthenticated() { + Authentication currentAuthentication = SecurityContextHolder.getContext().getAuthentication(); + return currentAuthentication != null && + currentAuthentication instanceof OAuth2UserAuthenticationToken && + currentAuthentication.isAuthenticated(); + } + + private boolean userAuthenticatedSameProviderAs(OAuth2ClientAuthenticationToken clientAuthentication) { + OAuth2UserAuthenticationToken currentUserAuthentication = + (OAuth2UserAuthenticationToken)SecurityContextHolder.getContext().getAuthentication(); + + String userProviderId = this.providerIdentifierStrategy.getIdentifier( + currentUserAuthentication.getClientAuthentication().getClientRegistration()); + String clientProviderId = this.providerIdentifierStrategy.getIdentifier( + clientAuthentication.getClientRegistration()); + + return userProviderId.equals(clientProviderId); + } + + private OAuth2UserAuthenticationToken createUserAuthentication( + OAuth2UserAuthenticationToken currentUserAuthentication, + OAuth2ClientAuthenticationToken newClientAuthentication) { + + if (OidcUserAuthenticationToken.class.isAssignableFrom(currentUserAuthentication.getClass())) { + return new OidcUserAuthenticationToken( + (OidcUser) currentUserAuthentication.getPrincipal(), + currentUserAuthentication.getAuthorities(), + newClientAuthentication); + } else { + return new OAuth2UserAuthenticationToken( + (OAuth2User)currentUserAuthentication.getPrincipal(), + currentUserAuthentication.getAuthorities(), + newClientAuthentication); + } + } + + private static class ProviderIdentifierStrategy implements ClientRegistrationIdentifierStrategy { + + @Override + public String getIdentifier(ClientRegistration clientRegistration) { + StringBuilder builder = new StringBuilder(); + builder.append("[").append(clientRegistration.getProviderDetails().getAuthorizationUri()).append("]"); + builder.append("[").append(clientRegistration.getProviderDetails().getTokenUri()).append("]"); + builder.append("[").append(clientRegistration.getProviderDetails().getUserInfoEndpoint().getUri()).append("]"); + return builder.toString(); + } + } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/AuthorizationCodeAuthenticationFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/AuthorizationCodeAuthenticationFilter.java index 46a7fe1d10..b3aff79014 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/AuthorizationCodeAuthenticationFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/AuthorizationCodeAuthenticationFilter.java @@ -18,23 +18,17 @@ package org.springframework.security.oauth2.client.web; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; -import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.oauth2.client.authentication.AuthorizationCodeAuthenticationProvider; import org.springframework.security.oauth2.client.authentication.AuthorizationCodeAuthenticationToken; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationException; import org.springframework.security.oauth2.client.authentication.OAuth2ClientAuthenticationToken; import org.springframework.security.oauth2.client.authentication.OAuth2UserAuthenticationToken; import org.springframework.security.oauth2.client.registration.ClientRegistration; -import org.springframework.security.oauth2.client.registration.ClientRegistrationIdentifierStrategy; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.endpoint.AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.AuthorizationResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2Parameter; -import org.springframework.security.oauth2.core.user.OAuth2User; -import org.springframework.security.oauth2.oidc.client.authentication.OidcClientAuthenticationToken; -import org.springframework.security.oauth2.oidc.client.authentication.OidcUserAuthenticationToken; -import org.springframework.security.oauth2.oidc.core.user.OidcUser; import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; @@ -81,7 +75,6 @@ import java.io.IOException; public class AuthorizationCodeAuthenticationFilter extends AbstractAuthenticationProcessingFilter { public static final String DEFAULT_AUTHORIZATION_RESPONSE_BASE_URI = "/oauth2/authorize/code"; private static final String AUTHORIZATION_REQUEST_NOT_FOUND_ERROR_CODE = "authorization_request_not_found"; - private final ClientRegistrationIdentifierStrategy providerIdentifierStrategy = new ProviderIdentifierStrategy(); private AuthorizationResponseMatcher authorizationResponseMatcher; private ClientRegistrationRepository clientRegistrationRepository; private AuthorizationRequestRepository authorizationRequestRepository = new HttpSessionAuthorizationRequestRepository(); @@ -135,20 +128,8 @@ public class AuthorizationCodeAuthenticationFilter extends AbstractAuthenticatio OAuth2ClientAuthenticationToken oauth2ClientAuthentication = (OAuth2ClientAuthenticationToken)this.getAuthenticationManager().authenticate(authorizationCodeAuthentication); - OAuth2UserAuthenticationToken oauth2UserAuthentication; - if (this.authenticated() && this.authenticatedSameProviderAs(oauth2ClientAuthentication)) { - // Create a new user authentication (using same principal) - // but with a different client authentication association - oauth2UserAuthentication = (OAuth2UserAuthenticationToken)SecurityContextHolder.getContext().getAuthentication(); - oauth2UserAuthentication = this.createUserAuthentication(oauth2UserAuthentication, oauth2ClientAuthentication); - } else { - // Authenticate the user... the user needs to be authenticated - // before we can associate the client authentication to the user - oauth2UserAuthentication = (OAuth2UserAuthenticationToken)this.getAuthenticationManager().authenticate( - this.createUserAuthentication(oauth2ClientAuthentication)); - } - - return oauth2UserAuthentication; + return this.getAuthenticationManager().authenticate( + new OAuth2UserAuthenticationToken(oauth2ClientAuthentication)); } public final RequestMatcher getAuthorizationResponseMatcher() { @@ -171,50 +152,6 @@ public class AuthorizationCodeAuthenticationFilter extends AbstractAuthenticatio this.authorizationRequestRepository = authorizationRequestRepository; } - private boolean authenticated() { - Authentication currentAuthentication = SecurityContextHolder.getContext().getAuthentication(); - return currentAuthentication != null && - currentAuthentication instanceof OAuth2UserAuthenticationToken && - currentAuthentication.isAuthenticated(); - } - - private boolean authenticatedSameProviderAs(OAuth2ClientAuthenticationToken oauth2ClientAuthentication) { - OAuth2UserAuthenticationToken userAuthentication = - (OAuth2UserAuthenticationToken)SecurityContextHolder.getContext().getAuthentication(); - - String userProviderId = this.providerIdentifierStrategy.getIdentifier( - userAuthentication.getClientAuthentication().getClientRegistration()); - String clientProviderId = this.providerIdentifierStrategy.getIdentifier( - oauth2ClientAuthentication.getClientRegistration()); - - return userProviderId.equals(clientProviderId); - } - - private OAuth2UserAuthenticationToken createUserAuthentication(OAuth2ClientAuthenticationToken clientAuthentication) { - if (OidcClientAuthenticationToken.class.isAssignableFrom(clientAuthentication.getClass())) { - return new OidcUserAuthenticationToken((OidcClientAuthenticationToken)clientAuthentication); - } else { - return new OAuth2UserAuthenticationToken(clientAuthentication); - } - } - - private OAuth2UserAuthenticationToken createUserAuthentication( - OAuth2UserAuthenticationToken currentUserAuthentication, - OAuth2ClientAuthenticationToken newClientAuthentication) { - - if (OidcUserAuthenticationToken.class.isAssignableFrom(currentUserAuthentication.getClass())) { - return new OidcUserAuthenticationToken( - (OidcUser) currentUserAuthentication.getPrincipal(), - currentUserAuthentication.getAuthorities(), - newClientAuthentication); - } else { - return new OAuth2UserAuthenticationToken( - (OAuth2User)currentUserAuthentication.getPrincipal(), - currentUserAuthentication.getAuthorities(), - newClientAuthentication); - } - } - private static class AuthorizationResponseMatcher implements RequestMatcher { private final String baseUri; @@ -266,16 +203,4 @@ public class AuthorizationCodeAuthenticationFilter extends AbstractAuthenticatio } } } - - private static class ProviderIdentifierStrategy implements ClientRegistrationIdentifierStrategy { - - @Override - public String getIdentifier(ClientRegistration clientRegistration) { - StringBuilder builder = new StringBuilder(); - builder.append("[").append(clientRegistration.getProviderDetails().getAuthorizationUri()).append("]"); - builder.append("[").append(clientRegistration.getProviderDetails().getTokenUri()).append("]"); - builder.append("[").append(clientRegistration.getProviderDetails().getUserInfoEndpoint().getUri()).append("]"); - return builder.toString(); - } - } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/AuthorizationCodeAuthenticationFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/AuthorizationCodeAuthenticationFilterTests.java index 783010cde1..24da2a645f 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/AuthorizationCodeAuthenticationFilterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/AuthorizationCodeAuthenticationFilterTests.java @@ -128,7 +128,7 @@ public class AuthorizationCodeAuthenticationFilterTests { ArgumentCaptor authenticationArgCaptor = ArgumentCaptor.forClass(Authentication.class); Mockito.verify(successHandler).onAuthenticationSuccess(Matchers.any(HttpServletRequest.class), Matchers.any(HttpServletResponse.class), authenticationArgCaptor.capture()); - Assertions.assertThat(authenticationArgCaptor.getValue()).isEqualTo(userAuthentication); + Assertions.assertThat(authenticationArgCaptor.getValue()).isEqualTo(clientAuthentication); } @Test