Move logic from AuthorizationCodeAuthenticationFilter to OAuth2UserAuthenticationProvider

This commit is contained in:
Joe Grandja 2017-10-11 17:23:59 -04:00
parent 18df9a869e
commit df474e04d8
3 changed files with 68 additions and 84 deletions

View File

@ -20,6 +20,9 @@ import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper; import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationIdentifierStrategy;
import org.springframework.security.oauth2.client.user.DefaultOAuth2UserService; import org.springframework.security.oauth2.client.user.DefaultOAuth2UserService;
import org.springframework.security.oauth2.client.user.OAuth2UserService; import org.springframework.security.oauth2.client.user.OAuth2UserService;
import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.security.oauth2.core.user.OAuth2User;
@ -55,6 +58,7 @@ import java.util.Collection;
* @see OidcUser * @see OidcUser
*/ */
public class OAuth2UserAuthenticationProvider implements AuthenticationProvider { public class OAuth2UserAuthenticationProvider implements AuthenticationProvider {
private final ClientRegistrationIdentifierStrategy<String> providerIdentifierStrategy = new ProviderIdentifierStrategy();
private final OAuth2UserService userService; private final OAuth2UserService userService;
private GrantedAuthoritiesMapper authoritiesMapper = (authorities -> authorities); private GrantedAuthoritiesMapper authoritiesMapper = (authorities -> authorities);
@ -65,11 +69,18 @@ public class OAuth2UserAuthenticationProvider implements AuthenticationProvider
@Override @Override
public Authentication authenticate(Authentication authentication) throws AuthenticationException { public Authentication authenticate(Authentication authentication) throws AuthenticationException {
OAuth2UserAuthenticationToken oauth2UserAuthentication = (OAuth2UserAuthenticationToken) authentication; OAuth2UserAuthenticationToken userAuthentication = (OAuth2UserAuthenticationToken) authentication;
OAuth2ClientAuthenticationToken clientAuthentication = userAuthentication.getClientAuthentication();
OAuth2ClientAuthenticationToken oauth2ClientAuthentication = oauth2UserAuthentication.getClientAuthentication(); if (this.userAuthenticated() && this.userAuthenticatedSameProviderAs(clientAuthentication)) {
// Create a new user authentication (using same principal)
// but with a different client authentication association
return this.createUserAuthentication(
(OAuth2UserAuthenticationToken)SecurityContextHolder.getContext().getAuthentication(),
clientAuthentication);
}
OAuth2User oauth2User = this.userService.loadUser(oauth2ClientAuthentication); OAuth2User oauth2User = this.userService.loadUser(clientAuthentication);
Collection<? extends GrantedAuthority> mappedAuthorities = Collection<? extends GrantedAuthority> mappedAuthorities =
this.authoritiesMapper.mapAuthorities(oauth2User.getAuthorities()); this.authoritiesMapper.mapAuthorities(oauth2User.getAuthorities());
@ -77,12 +88,12 @@ public class OAuth2UserAuthenticationProvider implements AuthenticationProvider
OAuth2UserAuthenticationToken authenticationResult; OAuth2UserAuthenticationToken authenticationResult;
if (OidcUser.class.isAssignableFrom(oauth2User.getClass())) { if (OidcUser.class.isAssignableFrom(oauth2User.getClass())) {
authenticationResult = new OidcUserAuthenticationToken( authenticationResult = new OidcUserAuthenticationToken(
(OidcUser)oauth2User, mappedAuthorities, (OidcClientAuthenticationToken)oauth2ClientAuthentication); (OidcUser)oauth2User, mappedAuthorities, (OidcClientAuthenticationToken)clientAuthentication);
} else { } else {
authenticationResult = new OAuth2UserAuthenticationToken( authenticationResult = new OAuth2UserAuthenticationToken(
oauth2User, mappedAuthorities, oauth2ClientAuthentication); oauth2User, mappedAuthorities, clientAuthentication);
} }
authenticationResult.setDetails(oauth2ClientAuthentication.getDetails()); authenticationResult.setDetails(clientAuthentication.getDetails());
return authenticationResult; return authenticationResult;
} }
@ -96,4 +107,52 @@ public class OAuth2UserAuthenticationProvider implements AuthenticationProvider
public boolean supports(Class<?> authentication) { public boolean supports(Class<?> authentication) {
return OAuth2UserAuthenticationToken.class.isAssignableFrom(authentication); return OAuth2UserAuthenticationToken.class.isAssignableFrom(authentication);
} }
private boolean userAuthenticated() {
Authentication currentAuthentication = SecurityContextHolder.getContext().getAuthentication();
return currentAuthentication != null &&
currentAuthentication instanceof OAuth2UserAuthenticationToken &&
currentAuthentication.isAuthenticated();
}
private boolean userAuthenticatedSameProviderAs(OAuth2ClientAuthenticationToken clientAuthentication) {
OAuth2UserAuthenticationToken currentUserAuthentication =
(OAuth2UserAuthenticationToken)SecurityContextHolder.getContext().getAuthentication();
String userProviderId = this.providerIdentifierStrategy.getIdentifier(
currentUserAuthentication.getClientAuthentication().getClientRegistration());
String clientProviderId = this.providerIdentifierStrategy.getIdentifier(
clientAuthentication.getClientRegistration());
return userProviderId.equals(clientProviderId);
}
private OAuth2UserAuthenticationToken createUserAuthentication(
OAuth2UserAuthenticationToken currentUserAuthentication,
OAuth2ClientAuthenticationToken newClientAuthentication) {
if (OidcUserAuthenticationToken.class.isAssignableFrom(currentUserAuthentication.getClass())) {
return new OidcUserAuthenticationToken(
(OidcUser) currentUserAuthentication.getPrincipal(),
currentUserAuthentication.getAuthorities(),
newClientAuthentication);
} else {
return new OAuth2UserAuthenticationToken(
(OAuth2User)currentUserAuthentication.getPrincipal(),
currentUserAuthentication.getAuthorities(),
newClientAuthentication);
}
}
private static class ProviderIdentifierStrategy implements ClientRegistrationIdentifierStrategy<String> {
@Override
public String getIdentifier(ClientRegistration clientRegistration) {
StringBuilder builder = new StringBuilder();
builder.append("[").append(clientRegistration.getProviderDetails().getAuthorizationUri()).append("]");
builder.append("[").append(clientRegistration.getProviderDetails().getTokenUri()).append("]");
builder.append("[").append(clientRegistration.getProviderDetails().getUserInfoEndpoint().getUri()).append("]");
return builder.toString();
}
}
} }

View File

@ -18,23 +18,17 @@ package org.springframework.security.oauth2.client.web;
import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.client.authentication.AuthorizationCodeAuthenticationProvider; import org.springframework.security.oauth2.client.authentication.AuthorizationCodeAuthenticationProvider;
import org.springframework.security.oauth2.client.authentication.AuthorizationCodeAuthenticationToken; import org.springframework.security.oauth2.client.authentication.AuthorizationCodeAuthenticationToken;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationException; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationException;
import org.springframework.security.oauth2.client.authentication.OAuth2ClientAuthenticationToken; import org.springframework.security.oauth2.client.authentication.OAuth2ClientAuthenticationToken;
import org.springframework.security.oauth2.client.authentication.OAuth2UserAuthenticationToken; import org.springframework.security.oauth2.client.authentication.OAuth2UserAuthenticationToken;
import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationIdentifierStrategy;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.endpoint.AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.AuthorizationResponse; import org.springframework.security.oauth2.core.endpoint.AuthorizationResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2Parameter; import org.springframework.security.oauth2.core.endpoint.OAuth2Parameter;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.security.oauth2.oidc.client.authentication.OidcClientAuthenticationToken;
import org.springframework.security.oauth2.oidc.client.authentication.OidcUserAuthenticationToken;
import org.springframework.security.oauth2.oidc.core.user.OidcUser;
import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter; import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter;
import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.Assert; import org.springframework.util.Assert;
@ -81,7 +75,6 @@ import java.io.IOException;
public class AuthorizationCodeAuthenticationFilter extends AbstractAuthenticationProcessingFilter { public class AuthorizationCodeAuthenticationFilter extends AbstractAuthenticationProcessingFilter {
public static final String DEFAULT_AUTHORIZATION_RESPONSE_BASE_URI = "/oauth2/authorize/code"; public static final String DEFAULT_AUTHORIZATION_RESPONSE_BASE_URI = "/oauth2/authorize/code";
private static final String AUTHORIZATION_REQUEST_NOT_FOUND_ERROR_CODE = "authorization_request_not_found"; private static final String AUTHORIZATION_REQUEST_NOT_FOUND_ERROR_CODE = "authorization_request_not_found";
private final ClientRegistrationIdentifierStrategy<String> providerIdentifierStrategy = new ProviderIdentifierStrategy();
private AuthorizationResponseMatcher authorizationResponseMatcher; private AuthorizationResponseMatcher authorizationResponseMatcher;
private ClientRegistrationRepository clientRegistrationRepository; private ClientRegistrationRepository clientRegistrationRepository;
private AuthorizationRequestRepository authorizationRequestRepository = new HttpSessionAuthorizationRequestRepository(); private AuthorizationRequestRepository authorizationRequestRepository = new HttpSessionAuthorizationRequestRepository();
@ -135,20 +128,8 @@ public class AuthorizationCodeAuthenticationFilter extends AbstractAuthenticatio
OAuth2ClientAuthenticationToken oauth2ClientAuthentication = OAuth2ClientAuthenticationToken oauth2ClientAuthentication =
(OAuth2ClientAuthenticationToken)this.getAuthenticationManager().authenticate(authorizationCodeAuthentication); (OAuth2ClientAuthenticationToken)this.getAuthenticationManager().authenticate(authorizationCodeAuthentication);
OAuth2UserAuthenticationToken oauth2UserAuthentication; return this.getAuthenticationManager().authenticate(
if (this.authenticated() && this.authenticatedSameProviderAs(oauth2ClientAuthentication)) { new OAuth2UserAuthenticationToken(oauth2ClientAuthentication));
// Create a new user authentication (using same principal)
// but with a different client authentication association
oauth2UserAuthentication = (OAuth2UserAuthenticationToken)SecurityContextHolder.getContext().getAuthentication();
oauth2UserAuthentication = this.createUserAuthentication(oauth2UserAuthentication, oauth2ClientAuthentication);
} else {
// Authenticate the user... the user needs to be authenticated
// before we can associate the client authentication to the user
oauth2UserAuthentication = (OAuth2UserAuthenticationToken)this.getAuthenticationManager().authenticate(
this.createUserAuthentication(oauth2ClientAuthentication));
}
return oauth2UserAuthentication;
} }
public final RequestMatcher getAuthorizationResponseMatcher() { public final RequestMatcher getAuthorizationResponseMatcher() {
@ -171,50 +152,6 @@ public class AuthorizationCodeAuthenticationFilter extends AbstractAuthenticatio
this.authorizationRequestRepository = authorizationRequestRepository; this.authorizationRequestRepository = authorizationRequestRepository;
} }
private boolean authenticated() {
Authentication currentAuthentication = SecurityContextHolder.getContext().getAuthentication();
return currentAuthentication != null &&
currentAuthentication instanceof OAuth2UserAuthenticationToken &&
currentAuthentication.isAuthenticated();
}
private boolean authenticatedSameProviderAs(OAuth2ClientAuthenticationToken oauth2ClientAuthentication) {
OAuth2UserAuthenticationToken userAuthentication =
(OAuth2UserAuthenticationToken)SecurityContextHolder.getContext().getAuthentication();
String userProviderId = this.providerIdentifierStrategy.getIdentifier(
userAuthentication.getClientAuthentication().getClientRegistration());
String clientProviderId = this.providerIdentifierStrategy.getIdentifier(
oauth2ClientAuthentication.getClientRegistration());
return userProviderId.equals(clientProviderId);
}
private OAuth2UserAuthenticationToken createUserAuthentication(OAuth2ClientAuthenticationToken clientAuthentication) {
if (OidcClientAuthenticationToken.class.isAssignableFrom(clientAuthentication.getClass())) {
return new OidcUserAuthenticationToken((OidcClientAuthenticationToken)clientAuthentication);
} else {
return new OAuth2UserAuthenticationToken(clientAuthentication);
}
}
private OAuth2UserAuthenticationToken createUserAuthentication(
OAuth2UserAuthenticationToken currentUserAuthentication,
OAuth2ClientAuthenticationToken newClientAuthentication) {
if (OidcUserAuthenticationToken.class.isAssignableFrom(currentUserAuthentication.getClass())) {
return new OidcUserAuthenticationToken(
(OidcUser) currentUserAuthentication.getPrincipal(),
currentUserAuthentication.getAuthorities(),
newClientAuthentication);
} else {
return new OAuth2UserAuthenticationToken(
(OAuth2User)currentUserAuthentication.getPrincipal(),
currentUserAuthentication.getAuthorities(),
newClientAuthentication);
}
}
private static class AuthorizationResponseMatcher implements RequestMatcher { private static class AuthorizationResponseMatcher implements RequestMatcher {
private final String baseUri; private final String baseUri;
@ -266,16 +203,4 @@ public class AuthorizationCodeAuthenticationFilter extends AbstractAuthenticatio
} }
} }
} }
private static class ProviderIdentifierStrategy implements ClientRegistrationIdentifierStrategy<String> {
@Override
public String getIdentifier(ClientRegistration clientRegistration) {
StringBuilder builder = new StringBuilder();
builder.append("[").append(clientRegistration.getProviderDetails().getAuthorizationUri()).append("]");
builder.append("[").append(clientRegistration.getProviderDetails().getTokenUri()).append("]");
builder.append("[").append(clientRegistration.getProviderDetails().getUserInfoEndpoint().getUri()).append("]");
return builder.toString();
}
}
} }

View File

@ -128,7 +128,7 @@ public class AuthorizationCodeAuthenticationFilterTests {
ArgumentCaptor<Authentication> authenticationArgCaptor = ArgumentCaptor.forClass(Authentication.class); ArgumentCaptor<Authentication> authenticationArgCaptor = ArgumentCaptor.forClass(Authentication.class);
Mockito.verify(successHandler).onAuthenticationSuccess(Matchers.any(HttpServletRequest.class), Matchers.any(HttpServletResponse.class), Mockito.verify(successHandler).onAuthenticationSuccess(Matchers.any(HttpServletRequest.class), Matchers.any(HttpServletResponse.class),
authenticationArgCaptor.capture()); authenticationArgCaptor.capture());
Assertions.assertThat(authenticationArgCaptor.getValue()).isEqualTo(userAuthentication); Assertions.assertThat(authenticationArgCaptor.getValue()).isEqualTo(clientAuthentication);
} }
@Test @Test