Add additional parameters to OAuth2UserRequest
Fixes gh-5368
This commit is contained in:
		
							parent
							
								
									950a314c9f
								
							
						
					
					
						commit
						8a0c6868cd
					
				| 
						 | 
				
			
			@ -30,6 +30,7 @@ import org.springframework.security.oauth2.core.user.OAuth2User;
 | 
			
		|||
import org.springframework.util.Assert;
 | 
			
		||||
 | 
			
		||||
import java.util.Collection;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * An implementation of an {@link AuthenticationProvider} for OAuth 2.0 Login,
 | 
			
		||||
| 
						 | 
				
			
			@ -101,9 +102,10 @@ public class OAuth2LoginAuthenticationProvider implements AuthenticationProvider
 | 
			
		|||
					authorizationCodeAuthentication.getAuthorizationExchange()));
 | 
			
		||||
 | 
			
		||||
		OAuth2AccessToken accessToken = accessTokenResponse.getAccessToken();
 | 
			
		||||
		Map<String, Object> additionalParameters = accessTokenResponse.getAdditionalParameters();
 | 
			
		||||
 | 
			
		||||
		OAuth2User oauth2User = this.userService.loadUser(
 | 
			
		||||
			new OAuth2UserRequest(authorizationCodeAuthentication.getClientRegistration(), accessToken));
 | 
			
		||||
		OAuth2User oauth2User = this.userService.loadUser(new OAuth2UserRequest(
 | 
			
		||||
				authorizationCodeAuthentication.getClientRegistration(), accessToken, additionalParameters));
 | 
			
		||||
 | 
			
		||||
		Collection<? extends GrantedAuthority> mappedAuthorities =
 | 
			
		||||
			this.authoritiesMapper.mapAuthorities(oauth2User.getAuthorities());
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -16,6 +16,7 @@
 | 
			
		|||
package org.springframework.security.oauth2.client.authentication;
 | 
			
		||||
 | 
			
		||||
import java.util.Collection;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
 | 
			
		||||
import org.springframework.security.authentication.ReactiveAuthenticationManager;
 | 
			
		||||
import org.springframework.security.core.Authentication;
 | 
			
		||||
| 
						 | 
				
			
			@ -109,7 +110,9 @@ public class OAuth2LoginReactiveAuthenticationManager implements
 | 
			
		|||
 | 
			
		||||
	private Mono<OAuth2AuthenticationToken> authenticationResult(OAuth2LoginAuthenticationToken authorizationCodeAuthentication, OAuth2AccessTokenResponse accessTokenResponse) {
 | 
			
		||||
		OAuth2AccessToken accessToken = accessTokenResponse.getAccessToken();
 | 
			
		||||
		OAuth2UserRequest userRequest = new OAuth2UserRequest(authorizationCodeAuthentication.getClientRegistration(), accessToken);
 | 
			
		||||
		Map<String, Object> additionalParameters = accessTokenResponse.getAdditionalParameters();
 | 
			
		||||
		OAuth2UserRequest userRequest = new OAuth2UserRequest(
 | 
			
		||||
				authorizationCodeAuthentication.getClientRegistration(), accessToken, additionalParameters);
 | 
			
		||||
		return this.userService.loadUser(userRequest)
 | 
			
		||||
				.flatMap(oauth2User -> {
 | 
			
		||||
					Collection<? extends GrantedAuthority> mappedAuthorities =
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -139,19 +139,18 @@ public class OidcAuthorizationCodeAuthenticationProvider implements Authenticati
 | 
			
		|||
 | 
			
		||||
		ClientRegistration clientRegistration = authorizationCodeAuthentication.getClientRegistration();
 | 
			
		||||
 | 
			
		||||
		if (!accessTokenResponse.getAdditionalParameters().containsKey(OidcParameterNames.ID_TOKEN)) {
 | 
			
		||||
		Map<String, Object> additionalParameters = accessTokenResponse.getAdditionalParameters();
 | 
			
		||||
		if (!additionalParameters.containsKey(OidcParameterNames.ID_TOKEN)) {
 | 
			
		||||
			OAuth2Error invalidIdTokenError = new OAuth2Error(
 | 
			
		||||
				INVALID_ID_TOKEN_ERROR_CODE,
 | 
			
		||||
				"Missing (required) ID Token in Token Response for Client Registration: " + clientRegistration.getRegistrationId(),
 | 
			
		||||
				null);
 | 
			
		||||
			throw new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString());
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		OidcIdToken idToken = createOidcToken(clientRegistration, accessTokenResponse);
 | 
			
		||||
 | 
			
		||||
		OidcUser oidcUser = this.userService.loadUser(
 | 
			
		||||
			new OidcUserRequest(clientRegistration, accessTokenResponse.getAccessToken(), idToken));
 | 
			
		||||
 | 
			
		||||
		OidcUser oidcUser = this.userService.loadUser(new OidcUserRequest(
 | 
			
		||||
				clientRegistration, accessTokenResponse.getAccessToken(), idToken, additionalParameters));
 | 
			
		||||
		Collection<? extends GrantedAuthority> mappedAuthorities =
 | 
			
		||||
			this.authoritiesMapper.mapAuthorities(oidcUser.getAuthorities());
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -159,10 +159,10 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements
 | 
			
		|||
 | 
			
		||||
	private Mono<OAuth2AuthenticationToken> authenticationResult(OAuth2LoginAuthenticationToken authorizationCodeAuthentication, OAuth2AccessTokenResponse accessTokenResponse) {
 | 
			
		||||
		OAuth2AccessToken accessToken = accessTokenResponse.getAccessToken();
 | 
			
		||||
 | 
			
		||||
		ClientRegistration clientRegistration = authorizationCodeAuthentication.getClientRegistration();
 | 
			
		||||
		Map<String, Object> additionalParameters = accessTokenResponse.getAdditionalParameters();
 | 
			
		||||
 | 
			
		||||
		if (!accessTokenResponse.getAdditionalParameters().containsKey(OidcParameterNames.ID_TOKEN)) {
 | 
			
		||||
		if (!additionalParameters.containsKey(OidcParameterNames.ID_TOKEN)) {
 | 
			
		||||
			OAuth2Error invalidIdTokenError = new OAuth2Error(
 | 
			
		||||
					INVALID_ID_TOKEN_ERROR_CODE,
 | 
			
		||||
					"Missing (required) ID Token in Token Response for Client Registration: " + clientRegistration.getRegistrationId(),
 | 
			
		||||
| 
						 | 
				
			
			@ -171,7 +171,7 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements
 | 
			
		|||
		}
 | 
			
		||||
 | 
			
		||||
		return createOidcToken(clientRegistration, accessTokenResponse)
 | 
			
		||||
				.map(idToken ->  new OidcUserRequest(clientRegistration, accessToken, idToken))
 | 
			
		||||
				.map(idToken ->  new OidcUserRequest(clientRegistration, accessToken, idToken, additionalParameters))
 | 
			
		||||
				.flatMap(this.userService::loadUser)
 | 
			
		||||
				.flatMap(oauth2User -> {
 | 
			
		||||
					Collection<? extends GrantedAuthority> mappedAuthorities =
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,5 +1,5 @@
 | 
			
		|||
/*
 | 
			
		||||
 * Copyright 2002-2017 the original author or authors.
 | 
			
		||||
 * Copyright 2002-2018 the original author or authors.
 | 
			
		||||
 *
 | 
			
		||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
 * you may not use this file except in compliance with the License.
 | 
			
		||||
| 
						 | 
				
			
			@ -21,6 +21,9 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken;
 | 
			
		|||
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
 | 
			
		||||
import org.springframework.util.Assert;
 | 
			
		||||
 | 
			
		||||
import java.util.Collections;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Represents a request the {@link OidcUserService} uses
 | 
			
		||||
 * when initiating a request to the UserInfo Endpoint.
 | 
			
		||||
| 
						 | 
				
			
			@ -45,7 +48,22 @@ public class OidcUserRequest extends OAuth2UserRequest {
 | 
			
		|||
	public OidcUserRequest(ClientRegistration clientRegistration,
 | 
			
		||||
							OAuth2AccessToken accessToken, OidcIdToken idToken) {
 | 
			
		||||
 | 
			
		||||
		super(clientRegistration, accessToken);
 | 
			
		||||
		this(clientRegistration, accessToken, idToken, Collections.emptyMap());
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	/**
 | 
			
		||||
	 * Constructs an {@code OidcUserRequest} using the provided parameters.
 | 
			
		||||
	 *
 | 
			
		||||
	 * @since 5.1
 | 
			
		||||
	 * @param clientRegistration the client registration
 | 
			
		||||
	 * @param accessToken the access token credential
 | 
			
		||||
	 * @param idToken the ID Token
 | 
			
		||||
	 * @param additionalParameters the additional parameters, may be empty
 | 
			
		||||
	 */
 | 
			
		||||
	public OidcUserRequest(ClientRegistration clientRegistration, OAuth2AccessToken accessToken,
 | 
			
		||||
							OidcIdToken idToken, Map<String, Object> additionalParameters) {
 | 
			
		||||
 | 
			
		||||
		super(clientRegistration, accessToken, additionalParameters);
 | 
			
		||||
		Assert.notNull(idToken, "idToken cannot be null");
 | 
			
		||||
		this.idToken = idToken;
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,5 +1,5 @@
 | 
			
		|||
/*
 | 
			
		||||
 * Copyright 2002-2017 the original author or authors.
 | 
			
		||||
 * Copyright 2002-2018 the original author or authors.
 | 
			
		||||
 *
 | 
			
		||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
 * you may not use this file except in compliance with the License.
 | 
			
		||||
| 
						 | 
				
			
			@ -18,6 +18,11 @@ package org.springframework.security.oauth2.client.userinfo;
 | 
			
		|||
import org.springframework.security.oauth2.client.registration.ClientRegistration;
 | 
			
		||||
import org.springframework.security.oauth2.core.OAuth2AccessToken;
 | 
			
		||||
import org.springframework.util.Assert;
 | 
			
		||||
import org.springframework.util.CollectionUtils;
 | 
			
		||||
 | 
			
		||||
import java.util.Collections;
 | 
			
		||||
import java.util.LinkedHashMap;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Represents a request the {@link OAuth2UserService} uses
 | 
			
		||||
| 
						 | 
				
			
			@ -32,6 +37,7 @@ import org.springframework.util.Assert;
 | 
			
		|||
public class OAuth2UserRequest {
 | 
			
		||||
	private final ClientRegistration clientRegistration;
 | 
			
		||||
	private final OAuth2AccessToken accessToken;
 | 
			
		||||
	private final Map<String, Object> additionalParameters;
 | 
			
		||||
 | 
			
		||||
	/**
 | 
			
		||||
	 * Constructs an {@code OAuth2UserRequest} using the provided parameters.
 | 
			
		||||
| 
						 | 
				
			
			@ -40,10 +46,26 @@ public class OAuth2UserRequest {
 | 
			
		|||
	 * @param accessToken the access token
 | 
			
		||||
	 */
 | 
			
		||||
	public OAuth2UserRequest(ClientRegistration clientRegistration, OAuth2AccessToken accessToken) {
 | 
			
		||||
		this(clientRegistration, accessToken, Collections.emptyMap());
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	/**
 | 
			
		||||
	 * Constructs an {@code OAuth2UserRequest} using the provided parameters.
 | 
			
		||||
	 *
 | 
			
		||||
	 * @since 5.1
 | 
			
		||||
	 * @param clientRegistration the client registration
 | 
			
		||||
	 * @param accessToken the access token
 | 
			
		||||
	 * @param additionalParameters the additional parameters, may be empty
 | 
			
		||||
	 */
 | 
			
		||||
	public OAuth2UserRequest(ClientRegistration clientRegistration, OAuth2AccessToken accessToken,
 | 
			
		||||
								Map<String, Object> additionalParameters) {
 | 
			
		||||
		Assert.notNull(clientRegistration, "clientRegistration cannot be null");
 | 
			
		||||
		Assert.notNull(accessToken, "accessToken cannot be null");
 | 
			
		||||
		this.clientRegistration = clientRegistration;
 | 
			
		||||
		this.accessToken = accessToken;
 | 
			
		||||
		this.additionalParameters = Collections.unmodifiableMap(
 | 
			
		||||
				CollectionUtils.isEmpty(additionalParameters) ?
 | 
			
		||||
				Collections.emptyMap() : new LinkedHashMap<>(additionalParameters));
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	/**
 | 
			
		||||
| 
						 | 
				
			
			@ -63,4 +85,14 @@ public class OAuth2UserRequest {
 | 
			
		|||
	public OAuth2AccessToken getAccessToken() {
 | 
			
		||||
		return this.accessToken;
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	/**
 | 
			
		||||
	 * Returns the additional parameters that may be used in the request.
 | 
			
		||||
	 *
 | 
			
		||||
	 * @since 5.1
 | 
			
		||||
	 * @return a {@code Map} of the additional parameters, may be empty.
 | 
			
		||||
	 */
 | 
			
		||||
	public Map<String, Object> getAdditionalParameters() {
 | 
			
		||||
		return this.additionalParameters;
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -20,6 +20,7 @@ import org.junit.Rule;
 | 
			
		|||
import org.junit.Test;
 | 
			
		||||
import org.junit.rules.ExpectedException;
 | 
			
		||||
import org.junit.runner.RunWith;
 | 
			
		||||
import org.mockito.ArgumentCaptor;
 | 
			
		||||
import org.mockito.stubbing.Answer;
 | 
			
		||||
import org.powermock.core.classloader.annotations.PrepareForTest;
 | 
			
		||||
import org.powermock.modules.junit4.PowerMockRunner;
 | 
			
		||||
| 
						 | 
				
			
			@ -35,17 +36,20 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken;
 | 
			
		|||
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 | 
			
		||||
import org.springframework.security.oauth2.core.OAuth2Error;
 | 
			
		||||
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 | 
			
		||||
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 | 
			
		||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 | 
			
		||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
 | 
			
		||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
 | 
			
		||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
 | 
			
		||||
import org.springframework.security.oauth2.core.user.OAuth2User;
 | 
			
		||||
 | 
			
		||||
import java.time.Instant;
 | 
			
		||||
import java.util.Arrays;
 | 
			
		||||
import java.util.Collections;
 | 
			
		||||
import java.util.HashMap;
 | 
			
		||||
import java.util.LinkedHashSet;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
import java.util.Set;
 | 
			
		||||
 | 
			
		||||
import static org.assertj.core.api.Assertions.assertThat;
 | 
			
		||||
import static org.hamcrest.CoreMatchers.containsString;
 | 
			
		||||
| 
						 | 
				
			
			@ -164,11 +168,7 @@ public class OAuth2LoginAuthenticationProviderTests {
 | 
			
		|||
 | 
			
		||||
	@Test
 | 
			
		||||
	public void authenticateWhenLoginSuccessThenReturnAuthentication() {
 | 
			
		||||
		OAuth2AccessToken accessToken = mock(OAuth2AccessToken.class);
 | 
			
		||||
		OAuth2RefreshToken refreshToken = mock(OAuth2RefreshToken.class);
 | 
			
		||||
		OAuth2AccessTokenResponse accessTokenResponse = mock(OAuth2AccessTokenResponse.class);
 | 
			
		||||
		when(accessTokenResponse.getAccessToken()).thenReturn(accessToken);
 | 
			
		||||
		when(accessTokenResponse.getRefreshToken()).thenReturn(refreshToken);
 | 
			
		||||
		OAuth2AccessTokenResponse accessTokenResponse = this.accessTokenSuccessResponse();
 | 
			
		||||
		when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse);
 | 
			
		||||
 | 
			
		||||
		OAuth2User principal = mock(OAuth2User.class);
 | 
			
		||||
| 
						 | 
				
			
			@ -187,15 +187,13 @@ public class OAuth2LoginAuthenticationProviderTests {
 | 
			
		|||
		assertThat(authentication.getAuthorities()).isEqualTo(authorities);
 | 
			
		||||
		assertThat(authentication.getClientRegistration()).isEqualTo(this.clientRegistration);
 | 
			
		||||
		assertThat(authentication.getAuthorizationExchange()).isEqualTo(this.authorizationExchange);
 | 
			
		||||
		assertThat(authentication.getAccessToken()).isEqualTo(accessToken);
 | 
			
		||||
		assertThat(authentication.getRefreshToken()).isEqualTo(refreshToken);
 | 
			
		||||
		assertThat(authentication.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
 | 
			
		||||
		assertThat(authentication.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken());
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	@Test
 | 
			
		||||
	public void authenticateWhenAuthoritiesMapperSetThenReturnMappedAuthorities() {
 | 
			
		||||
		OAuth2AccessToken accessToken = mock(OAuth2AccessToken.class);
 | 
			
		||||
		OAuth2AccessTokenResponse accessTokenResponse = mock(OAuth2AccessTokenResponse.class);
 | 
			
		||||
		when(accessTokenResponse.getAccessToken()).thenReturn(accessToken);
 | 
			
		||||
		OAuth2AccessTokenResponse accessTokenResponse = this.accessTokenSuccessResponse();
 | 
			
		||||
		when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse);
 | 
			
		||||
 | 
			
		||||
		OAuth2User principal = mock(OAuth2User.class);
 | 
			
		||||
| 
						 | 
				
			
			@ -216,4 +214,42 @@ public class OAuth2LoginAuthenticationProviderTests {
 | 
			
		|||
 | 
			
		||||
		assertThat(authentication.getAuthorities()).isEqualTo(mappedAuthorities);
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// gh-5368
 | 
			
		||||
	@Test
 | 
			
		||||
	public void authenticateWhenTokenSuccessResponseThenAdditionalParametersAddedToUserRequest() {
 | 
			
		||||
		OAuth2AccessTokenResponse accessTokenResponse = this.accessTokenSuccessResponse();
 | 
			
		||||
		when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse);
 | 
			
		||||
 | 
			
		||||
		OAuth2User principal = mock(OAuth2User.class);
 | 
			
		||||
		List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
 | 
			
		||||
		when(principal.getAuthorities()).thenAnswer(
 | 
			
		||||
				(Answer<List<GrantedAuthority>>) invocation -> authorities);
 | 
			
		||||
		ArgumentCaptor<OAuth2UserRequest> userRequestArgCaptor = ArgumentCaptor.forClass(OAuth2UserRequest.class);
 | 
			
		||||
		when(this.userService.loadUser(userRequestArgCaptor.capture())).thenReturn(principal);
 | 
			
		||||
 | 
			
		||||
		this.authenticationProvider.authenticate(
 | 
			
		||||
				new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
 | 
			
		||||
 | 
			
		||||
		assertThat(userRequestArgCaptor.getValue().getAdditionalParameters()).containsAllEntriesOf(
 | 
			
		||||
				accessTokenResponse.getAdditionalParameters());
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	private OAuth2AccessTokenResponse accessTokenSuccessResponse() {
 | 
			
		||||
		Instant expiresAt = Instant.now().plusSeconds(5);
 | 
			
		||||
		Set<String> scopes = new LinkedHashSet<>(Arrays.asList("scope1", "scope2"));
 | 
			
		||||
		Map<String, Object> additionalParameters = new HashMap<>();
 | 
			
		||||
		additionalParameters.put("param1", "value1");
 | 
			
		||||
		additionalParameters.put("param2", "value2");
 | 
			
		||||
 | 
			
		||||
		return OAuth2AccessTokenResponse
 | 
			
		||||
				.withToken("access-token-1234")
 | 
			
		||||
				.tokenType(OAuth2AccessToken.TokenType.BEARER)
 | 
			
		||||
				.expiresIn(expiresAt.getEpochSecond())
 | 
			
		||||
				.scopes(scopes)
 | 
			
		||||
				.refreshToken("refresh-token-1234")
 | 
			
		||||
				.additionalParameters(additionalParameters)
 | 
			
		||||
				.build();
 | 
			
		||||
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -23,11 +23,14 @@ import static org.mockito.ArgumentMatchers.any;
 | 
			
		|||
import static org.mockito.Mockito.when;
 | 
			
		||||
 | 
			
		||||
import java.util.Collections;
 | 
			
		||||
import java.util.HashMap;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
 | 
			
		||||
import org.junit.Before;
 | 
			
		||||
import org.junit.Ignore;
 | 
			
		||||
import org.junit.Test;
 | 
			
		||||
import org.junit.runner.RunWith;
 | 
			
		||||
import org.mockito.ArgumentCaptor;
 | 
			
		||||
import org.mockito.Mock;
 | 
			
		||||
import org.mockito.junit.MockitoJUnitRunner;
 | 
			
		||||
import org.springframework.security.authentication.TestingAuthenticationToken;
 | 
			
		||||
| 
						 | 
				
			
			@ -164,7 +167,7 @@ public class OAuth2LoginReactiveAuthenticationManagerTests {
 | 
			
		|||
	}
 | 
			
		||||
 | 
			
		||||
	@Test
 | 
			
		||||
	public void authenticationWhenOAuth2UserNotFoundThenSuccess() {
 | 
			
		||||
	public void authenticationWhenOAuth2UserFoundThenSuccess() {
 | 
			
		||||
		OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo")
 | 
			
		||||
				.tokenType(OAuth2AccessToken.TokenType.BEARER)
 | 
			
		||||
				.build();
 | 
			
		||||
| 
						 | 
				
			
			@ -179,6 +182,27 @@ public class OAuth2LoginReactiveAuthenticationManagerTests {
 | 
			
		|||
		assertThat(result.isAuthenticated()).isTrue();
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// gh-5368
 | 
			
		||||
	@Test
 | 
			
		||||
	public void authenticateWhenTokenSuccessResponseThenAdditionalParametersAddedToUserRequest() {
 | 
			
		||||
		Map<String, Object> additionalParameters = new HashMap<>();
 | 
			
		||||
		additionalParameters.put("param1", "value1");
 | 
			
		||||
		additionalParameters.put("param2", "value2");
 | 
			
		||||
		OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo")
 | 
			
		||||
				.tokenType(OAuth2AccessToken.TokenType.BEARER)
 | 
			
		||||
				.additionalParameters(additionalParameters)
 | 
			
		||||
				.build();
 | 
			
		||||
		when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse));
 | 
			
		||||
		DefaultOAuth2User user = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), Collections.singletonMap("user", "rob"), "user");
 | 
			
		||||
		ArgumentCaptor<OAuth2UserRequest> userRequestArgCaptor = ArgumentCaptor.forClass(OAuth2UserRequest.class);
 | 
			
		||||
		when(this.userService.loadUser(userRequestArgCaptor.capture())).thenReturn(Mono.just(user));
 | 
			
		||||
 | 
			
		||||
		this.manager.authenticate(loginToken()).block();
 | 
			
		||||
 | 
			
		||||
		assertThat(userRequestArgCaptor.getValue().getAdditionalParameters())
 | 
			
		||||
				.containsAllEntriesOf(accessTokenResponse.getAdditionalParameters());
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	private OAuth2LoginAuthenticationToken loginToken() {
 | 
			
		||||
		ClientRegistration clientRegistration = this.registration.build();
 | 
			
		||||
		OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -20,6 +20,7 @@ import org.junit.Rule;
 | 
			
		|||
import org.junit.Test;
 | 
			
		||||
import org.junit.rules.ExpectedException;
 | 
			
		||||
import org.junit.runner.RunWith;
 | 
			
		||||
import org.mockito.ArgumentCaptor;
 | 
			
		||||
import org.mockito.stubbing.Answer;
 | 
			
		||||
import org.powermock.api.mockito.PowerMockito;
 | 
			
		||||
import org.powermock.core.classloader.annotations.PrepareForTest;
 | 
			
		||||
| 
						 | 
				
			
			@ -37,7 +38,6 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken;
 | 
			
		|||
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 | 
			
		||||
import org.springframework.security.oauth2.core.OAuth2Error;
 | 
			
		||||
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 | 
			
		||||
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 | 
			
		||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 | 
			
		||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
 | 
			
		||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
 | 
			
		||||
| 
						 | 
				
			
			@ -55,6 +55,7 @@ import java.util.HashMap;
 | 
			
		|||
import java.util.LinkedHashSet;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
import java.util.Set;
 | 
			
		||||
 | 
			
		||||
import static org.assertj.core.api.Assertions.assertThat;
 | 
			
		||||
import static org.hamcrest.CoreMatchers.containsString;
 | 
			
		||||
| 
						 | 
				
			
			@ -78,8 +79,6 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
 | 
			
		|||
	private OAuth2AuthorizationExchange authorizationExchange;
 | 
			
		||||
	private OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient;
 | 
			
		||||
	private OAuth2AccessTokenResponse accessTokenResponse;
 | 
			
		||||
	private OAuth2AccessToken accessToken;
 | 
			
		||||
	private OAuth2RefreshToken refreshToken;
 | 
			
		||||
	private OAuth2UserService<OidcUserRequest, OidcUser> userService;
 | 
			
		||||
	private OidcAuthorizationCodeAuthenticationProvider authenticationProvider;
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -95,9 +94,7 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
 | 
			
		|||
		this.authorizationResponse = mock(OAuth2AuthorizationResponse.class);
 | 
			
		||||
		this.authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, this.authorizationResponse);
 | 
			
		||||
		this.accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class);
 | 
			
		||||
		this.accessTokenResponse = mock(OAuth2AccessTokenResponse.class);
 | 
			
		||||
		this.accessToken = mock(OAuth2AccessToken.class);
 | 
			
		||||
		this.refreshToken = mock(OAuth2RefreshToken.class);
 | 
			
		||||
		this.accessTokenResponse = this.accessTokenSuccessResponse();
 | 
			
		||||
		this.userService = mock(OAuth2UserService.class);
 | 
			
		||||
		this.authenticationProvider = PowerMockito.spy(
 | 
			
		||||
			new OidcAuthorizationCodeAuthenticationProvider(this.accessTokenResponseClient, this.userService));
 | 
			
		||||
| 
						 | 
				
			
			@ -111,11 +108,6 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
 | 
			
		|||
		when(this.authorizationResponse.getState()).thenReturn("12345");
 | 
			
		||||
		when(this.authorizationRequest.getRedirectUri()).thenReturn("http://example.com");
 | 
			
		||||
		when(this.authorizationResponse.getRedirectUri()).thenReturn("http://example.com");
 | 
			
		||||
		when(this.accessTokenResponse.getAccessToken()).thenReturn(this.accessToken);
 | 
			
		||||
		when(this.accessTokenResponse.getRefreshToken()).thenReturn(this.refreshToken);
 | 
			
		||||
		Map<String, Object> additionalParameters = new HashMap<>();
 | 
			
		||||
		additionalParameters.put(OidcParameterNames.ID_TOKEN, "id-token");
 | 
			
		||||
		when(this.accessTokenResponse.getAdditionalParameters()).thenReturn(additionalParameters);
 | 
			
		||||
		when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(this.accessTokenResponse);
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -194,7 +186,11 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
 | 
			
		|||
		this.exception.expect(OAuth2AuthenticationException.class);
 | 
			
		||||
		this.exception.expectMessage(containsString("invalid_id_token"));
 | 
			
		||||
 | 
			
		||||
		when(this.accessTokenResponse.getAdditionalParameters()).thenReturn(Collections.emptyMap());
 | 
			
		||||
		OAuth2AccessTokenResponse accessTokenResponse =
 | 
			
		||||
				OAuth2AccessTokenResponse.withResponse(this.accessTokenSuccessResponse())
 | 
			
		||||
						.additionalParameters(Collections.emptyMap())
 | 
			
		||||
						.build();
 | 
			
		||||
		when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse);
 | 
			
		||||
 | 
			
		||||
		this.authenticationProvider.authenticate(
 | 
			
		||||
			new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
 | 
			
		||||
| 
						 | 
				
			
			@ -368,8 +364,8 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
 | 
			
		|||
		assertThat(authentication.getAuthorities()).isEqualTo(authorities);
 | 
			
		||||
		assertThat(authentication.getClientRegistration()).isEqualTo(this.clientRegistration);
 | 
			
		||||
		assertThat(authentication.getAuthorizationExchange()).isEqualTo(this.authorizationExchange);
 | 
			
		||||
		assertThat(authentication.getAccessToken()).isEqualTo(this.accessToken);
 | 
			
		||||
		assertThat(authentication.getRefreshToken()).isEqualTo(this.refreshToken);
 | 
			
		||||
		assertThat(authentication.getAccessToken()).isEqualTo(this.accessTokenResponse.getAccessToken());
 | 
			
		||||
		assertThat(authentication.getRefreshToken()).isEqualTo(this.accessTokenResponse.getRefreshToken());
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	@Test
 | 
			
		||||
| 
						 | 
				
			
			@ -400,6 +396,30 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
 | 
			
		|||
		assertThat(authentication.getAuthorities()).isEqualTo(mappedAuthorities);
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// gh-5368
 | 
			
		||||
	@Test
 | 
			
		||||
	public void authenticateWhenTokenSuccessResponseThenAdditionalParametersAddedToUserRequest() throws Exception {
 | 
			
		||||
		Map<String, Object> claims = new HashMap<>();
 | 
			
		||||
		claims.put(IdTokenClaimNames.ISS, "https://provider.com");
 | 
			
		||||
		claims.put(IdTokenClaimNames.SUB, "subject1");
 | 
			
		||||
		claims.put(IdTokenClaimNames.AUD, Arrays.asList("client1", "client2"));
 | 
			
		||||
		claims.put(IdTokenClaimNames.AZP, "client1");
 | 
			
		||||
		this.setUpIdToken(claims);
 | 
			
		||||
 | 
			
		||||
		OidcUser principal = mock(OidcUser.class);
 | 
			
		||||
		List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
 | 
			
		||||
		when(principal.getAuthorities()).thenAnswer(
 | 
			
		||||
				(Answer<List<GrantedAuthority>>) invocation -> authorities);
 | 
			
		||||
		ArgumentCaptor<OidcUserRequest> userRequestArgCaptor = ArgumentCaptor.forClass(OidcUserRequest.class);
 | 
			
		||||
		when(this.userService.loadUser(userRequestArgCaptor.capture())).thenReturn(principal);
 | 
			
		||||
 | 
			
		||||
		this.authenticationProvider.authenticate(new OAuth2LoginAuthenticationToken(
 | 
			
		||||
				this.clientRegistration, this.authorizationExchange));
 | 
			
		||||
 | 
			
		||||
		assertThat(userRequestArgCaptor.getValue().getAdditionalParameters()).containsAllEntriesOf(
 | 
			
		||||
				this.accessTokenResponse.getAdditionalParameters());
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	private void setUpIdToken(Map<String, Object> claims) throws Exception {
 | 
			
		||||
		Instant issuedAt = Instant.now();
 | 
			
		||||
		Instant expiresAt = Instant.from(issuedAt).plusSeconds(3600);
 | 
			
		||||
| 
						 | 
				
			
			@ -416,4 +436,23 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
 | 
			
		|||
		when(jwtDecoder.decode(anyString())).thenReturn(idToken);
 | 
			
		||||
		PowerMockito.doReturn(jwtDecoder).when(this.authenticationProvider, "getJwtDecoder", any(ClientRegistration.class));
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	private OAuth2AccessTokenResponse accessTokenSuccessResponse() {
 | 
			
		||||
		Instant expiresAt = Instant.now().plusSeconds(5);
 | 
			
		||||
		Set<String> scopes = new LinkedHashSet<>(Arrays.asList("openid", "profile", "email"));
 | 
			
		||||
		Map<String, Object> additionalParameters = new HashMap<>();
 | 
			
		||||
		additionalParameters.put("param1", "value1");
 | 
			
		||||
		additionalParameters.put("param2", "value2");
 | 
			
		||||
		additionalParameters.put(OidcParameterNames.ID_TOKEN, "id-token");
 | 
			
		||||
 | 
			
		||||
		return OAuth2AccessTokenResponse
 | 
			
		||||
				.withToken("access-token-1234")
 | 
			
		||||
				.tokenType(OAuth2AccessToken.TokenType.BEARER)
 | 
			
		||||
				.expiresIn(expiresAt.getEpochSecond())
 | 
			
		||||
				.scopes(scopes)
 | 
			
		||||
				.refreshToken("refresh-token-1234")
 | 
			
		||||
				.additionalParameters(additionalParameters)
 | 
			
		||||
				.build();
 | 
			
		||||
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -19,6 +19,7 @@ package org.springframework.security.oauth2.client.oidc.authentication;
 | 
			
		|||
import org.junit.Before;
 | 
			
		||||
import org.junit.Test;
 | 
			
		||||
import org.junit.runner.RunWith;
 | 
			
		||||
import org.mockito.ArgumentCaptor;
 | 
			
		||||
import org.mockito.Mock;
 | 
			
		||||
import org.mockito.junit.MockitoJUnitRunner;
 | 
			
		||||
import org.springframework.security.authentication.TestingAuthenticationToken;
 | 
			
		||||
| 
						 | 
				
			
			@ -217,6 +218,39 @@ public class OidcAuthorizationCodeReactiveAuthenticationManagerTests {
 | 
			
		|||
		assertThat(result.isAuthenticated()).isTrue();
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// gh-5368
 | 
			
		||||
	@Test
 | 
			
		||||
	public void authenticateWhenTokenSuccessResponseThenAdditionalParametersAddedToUserRequest() {
 | 
			
		||||
		Map<String, Object> additionalParameters = new HashMap<>();
 | 
			
		||||
		additionalParameters.put(OidcParameterNames.ID_TOKEN, this.idToken.getTokenValue());
 | 
			
		||||
		additionalParameters.put("param1", "value1");
 | 
			
		||||
		additionalParameters.put("param2", "value2");
 | 
			
		||||
		OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo")
 | 
			
		||||
				.tokenType(OAuth2AccessToken.TokenType.BEARER)
 | 
			
		||||
				.additionalParameters(additionalParameters)
 | 
			
		||||
				.build();
 | 
			
		||||
 | 
			
		||||
		Map<String, Object> claims = new HashMap<>();
 | 
			
		||||
		claims.put(IdTokenClaimNames.ISS, "https://issuer.example.com");
 | 
			
		||||
		claims.put(IdTokenClaimNames.SUB, "rob");
 | 
			
		||||
		claims.put(IdTokenClaimNames.AUD, Arrays.asList("clientId"));
 | 
			
		||||
		Instant issuedAt = Instant.now();
 | 
			
		||||
		Instant expiresAt = Instant.from(issuedAt).plusSeconds(3600);
 | 
			
		||||
		Jwt idToken = new Jwt("id-token", issuedAt, expiresAt, claims, claims);
 | 
			
		||||
 | 
			
		||||
		when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse));
 | 
			
		||||
		DefaultOidcUser user = new DefaultOidcUser(AuthorityUtils.createAuthorityList("ROLE_USER"), this.idToken);
 | 
			
		||||
		ArgumentCaptor<OidcUserRequest> userRequestArgCaptor = ArgumentCaptor.forClass(OidcUserRequest.class);
 | 
			
		||||
		when(this.userService.loadUser(userRequestArgCaptor.capture())).thenReturn(Mono.just(user));
 | 
			
		||||
		when(this.jwtDecoder.decode(any())).thenReturn(Mono.just(idToken));
 | 
			
		||||
		this.manager.setDecoderFactory(c -> this.jwtDecoder);
 | 
			
		||||
 | 
			
		||||
		this.manager.authenticate(loginToken()).block();
 | 
			
		||||
 | 
			
		||||
		assertThat(userRequestArgCaptor.getValue().getAdditionalParameters())
 | 
			
		||||
				.containsAllEntriesOf(accessTokenResponse.getAdditionalParameters());
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	private OAuth2LoginAuthenticationToken loginToken() {
 | 
			
		||||
		ClientRegistration clientRegistration = this.registration.build();
 | 
			
		||||
		OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,5 +1,5 @@
 | 
			
		|||
/*
 | 
			
		||||
 * Copyright 2002-2017 the original author or authors.
 | 
			
		||||
 * Copyright 2002-2018 the original author or authors.
 | 
			
		||||
 *
 | 
			
		||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
 * you may not use this file except in compliance with the License.
 | 
			
		||||
| 
						 | 
				
			
			@ -17,57 +17,87 @@ package org.springframework.security.oauth2.client.oidc.userinfo;
 | 
			
		|||
 | 
			
		||||
import org.junit.Before;
 | 
			
		||||
import org.junit.Test;
 | 
			
		||||
import org.junit.runner.RunWith;
 | 
			
		||||
import org.powermock.core.classloader.annotations.PrepareForTest;
 | 
			
		||||
import org.powermock.modules.junit4.PowerMockRunner;
 | 
			
		||||
import org.springframework.security.oauth2.client.registration.ClientRegistration;
 | 
			
		||||
import org.springframework.security.oauth2.core.AuthorizationGrantType;
 | 
			
		||||
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 | 
			
		||||
import org.springframework.security.oauth2.core.OAuth2AccessToken;
 | 
			
		||||
import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames;
 | 
			
		||||
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
 | 
			
		||||
 | 
			
		||||
import java.time.Instant;
 | 
			
		||||
import java.util.Arrays;
 | 
			
		||||
import java.util.HashMap;
 | 
			
		||||
import java.util.LinkedHashSet;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
 | 
			
		||||
import static org.assertj.core.api.Assertions.assertThat;
 | 
			
		||||
import static org.mockito.Mockito.mock;
 | 
			
		||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Tests for {@link OidcUserRequest}.
 | 
			
		||||
 *
 | 
			
		||||
 * @author Joe Grandja
 | 
			
		||||
 */
 | 
			
		||||
@RunWith(PowerMockRunner.class)
 | 
			
		||||
@PrepareForTest(ClientRegistration.class)
 | 
			
		||||
public class OidcUserRequestTests {
 | 
			
		||||
	private ClientRegistration clientRegistration;
 | 
			
		||||
	private OAuth2AccessToken accessToken;
 | 
			
		||||
	private OidcIdToken idToken;
 | 
			
		||||
	private Map<String, Object> additionalParameters;
 | 
			
		||||
 | 
			
		||||
	@Before
 | 
			
		||||
	public void setUp() {
 | 
			
		||||
		this.clientRegistration = mock(ClientRegistration.class);
 | 
			
		||||
		this.accessToken = mock(OAuth2AccessToken.class);
 | 
			
		||||
		this.idToken = mock(OidcIdToken.class);
 | 
			
		||||
		this.clientRegistration = ClientRegistration.withRegistrationId("registration-1")
 | 
			
		||||
				.clientId("client-1")
 | 
			
		||||
				.clientSecret("secret")
 | 
			
		||||
				.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
 | 
			
		||||
				.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
 | 
			
		||||
				.redirectUriTemplate("https://client.com")
 | 
			
		||||
				.scope(new LinkedHashSet<>(Arrays.asList("openid", "profile")))
 | 
			
		||||
				.authorizationUri("https://provider.com/oauth2/authorization")
 | 
			
		||||
				.tokenUri("https://provider.com/oauth2/token")
 | 
			
		||||
				.jwkSetUri("https://provider.com/keys")
 | 
			
		||||
				.clientName("Client 1")
 | 
			
		||||
				.build();
 | 
			
		||||
		this.accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
 | 
			
		||||
				"access-token-1234", Instant.now(), Instant.now().plusSeconds(60),
 | 
			
		||||
				new LinkedHashSet<>(Arrays.asList("scope1", "scope2")));
 | 
			
		||||
		Map<String, Object> claims = new HashMap<>();
 | 
			
		||||
		claims.put(IdTokenClaimNames.ISS, "https://provider.com");
 | 
			
		||||
		claims.put(IdTokenClaimNames.SUB, "subject1");
 | 
			
		||||
		claims.put(IdTokenClaimNames.AZP, "client-1");
 | 
			
		||||
		this.idToken = new OidcIdToken("id-token-1234", Instant.now(),
 | 
			
		||||
				Instant.now().plusSeconds(3600), claims);
 | 
			
		||||
		this.additionalParameters = new HashMap<>();
 | 
			
		||||
		this.additionalParameters.put("param1", "value1");
 | 
			
		||||
		this.additionalParameters.put("param2", "value2");
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	@Test(expected = IllegalArgumentException.class)
 | 
			
		||||
	@Test
 | 
			
		||||
	public void constructorWhenClientRegistrationIsNullThenThrowIllegalArgumentException() {
 | 
			
		||||
		new OidcUserRequest(null, this.accessToken, this.idToken);
 | 
			
		||||
		assertThatThrownBy(() -> new OidcUserRequest(null, this.accessToken, this.idToken))
 | 
			
		||||
				.isInstanceOf(IllegalArgumentException.class);
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	@Test(expected = IllegalArgumentException.class)
 | 
			
		||||
	@Test
 | 
			
		||||
	public void constructorWhenAccessTokenIsNullThenThrowIllegalArgumentException() {
 | 
			
		||||
		new OidcUserRequest(this.clientRegistration, null, this.idToken);
 | 
			
		||||
		assertThatThrownBy(() -> new OidcUserRequest(this.clientRegistration, null, this.idToken))
 | 
			
		||||
				.isInstanceOf(IllegalArgumentException.class);
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	@Test(expected = IllegalArgumentException.class)
 | 
			
		||||
	@Test
 | 
			
		||||
	public void constructorWhenIdTokenIsNullThenThrowIllegalArgumentException() {
 | 
			
		||||
		new OidcUserRequest(this.clientRegistration, this.accessToken, null);
 | 
			
		||||
		assertThatThrownBy(() -> new OidcUserRequest(this.clientRegistration, this.accessToken, null))
 | 
			
		||||
				.isInstanceOf(IllegalArgumentException.class);
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	@Test
 | 
			
		||||
	public void constructorWhenAllParametersProvidedAndValidThenCreated() {
 | 
			
		||||
		OidcUserRequest userRequest = new OidcUserRequest(
 | 
			
		||||
			this.clientRegistration, this.accessToken, this.idToken);
 | 
			
		||||
			this.clientRegistration, this.accessToken, this.idToken, this.additionalParameters);
 | 
			
		||||
 | 
			
		||||
		assertThat(userRequest.getClientRegistration()).isEqualTo(this.clientRegistration);
 | 
			
		||||
		assertThat(userRequest.getAccessToken()).isEqualTo(this.accessToken);
 | 
			
		||||
		assertThat(userRequest.getIdToken()).isEqualTo(this.idToken);
 | 
			
		||||
		assertThat(userRequest.getAdditionalParameters()).containsAllEntriesOf(this.additionalParameters);
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,5 +1,5 @@
 | 
			
		|||
/*
 | 
			
		||||
 * Copyright 2002-2017 the original author or authors.
 | 
			
		||||
 * Copyright 2002-2018 the original author or authors.
 | 
			
		||||
 *
 | 
			
		||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
 * you may not use this file except in compliance with the License.
 | 
			
		||||
| 
						 | 
				
			
			@ -17,47 +17,70 @@ package org.springframework.security.oauth2.client.userinfo;
 | 
			
		|||
 | 
			
		||||
import org.junit.Before;
 | 
			
		||||
import org.junit.Test;
 | 
			
		||||
import org.junit.runner.RunWith;
 | 
			
		||||
import org.powermock.core.classloader.annotations.PrepareForTest;
 | 
			
		||||
import org.powermock.modules.junit4.PowerMockRunner;
 | 
			
		||||
import org.springframework.security.oauth2.client.registration.ClientRegistration;
 | 
			
		||||
import org.springframework.security.oauth2.core.AuthorizationGrantType;
 | 
			
		||||
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 | 
			
		||||
import org.springframework.security.oauth2.core.OAuth2AccessToken;
 | 
			
		||||
 | 
			
		||||
import java.time.Instant;
 | 
			
		||||
import java.util.Arrays;
 | 
			
		||||
import java.util.HashMap;
 | 
			
		||||
import java.util.LinkedHashSet;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
 | 
			
		||||
import static org.assertj.core.api.Assertions.assertThat;
 | 
			
		||||
import static org.mockito.Mockito.mock;
 | 
			
		||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Tests for {@link OAuth2UserRequest}.
 | 
			
		||||
 *
 | 
			
		||||
 * @author Joe Grandja
 | 
			
		||||
 */
 | 
			
		||||
@RunWith(PowerMockRunner.class)
 | 
			
		||||
@PrepareForTest(ClientRegistration.class)
 | 
			
		||||
public class OAuth2UserRequestTests {
 | 
			
		||||
	private ClientRegistration clientRegistration;
 | 
			
		||||
	private OAuth2AccessToken accessToken;
 | 
			
		||||
	private Map<String, Object> additionalParameters;
 | 
			
		||||
 | 
			
		||||
	@Before
 | 
			
		||||
	public void setUp() {
 | 
			
		||||
		this.clientRegistration = mock(ClientRegistration.class);
 | 
			
		||||
		this.accessToken = mock(OAuth2AccessToken.class);
 | 
			
		||||
		this.clientRegistration = ClientRegistration.withRegistrationId("registration-1")
 | 
			
		||||
				.clientId("client-1")
 | 
			
		||||
				.clientSecret("secret")
 | 
			
		||||
				.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
 | 
			
		||||
				.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
 | 
			
		||||
				.redirectUriTemplate("https://client.com")
 | 
			
		||||
				.scope(new LinkedHashSet<>(Arrays.asList("scope1", "scope2")))
 | 
			
		||||
				.authorizationUri("https://provider.com/oauth2/authorization")
 | 
			
		||||
				.tokenUri("https://provider.com/oauth2/token")
 | 
			
		||||
				.clientName("Client 1")
 | 
			
		||||
				.build();
 | 
			
		||||
		this.accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
 | 
			
		||||
				"access-token-1234", Instant.now(), Instant.now().plusSeconds(60),
 | 
			
		||||
				new LinkedHashSet<>(Arrays.asList("scope1", "scope2")));
 | 
			
		||||
		this.additionalParameters = new HashMap<>();
 | 
			
		||||
		this.additionalParameters.put("param1", "value1");
 | 
			
		||||
		this.additionalParameters.put("param2", "value2");
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	@Test(expected = IllegalArgumentException.class)
 | 
			
		||||
	@Test
 | 
			
		||||
	public void constructorWhenClientRegistrationIsNullThenThrowIllegalArgumentException() {
 | 
			
		||||
		new OAuth2UserRequest(null, this.accessToken);
 | 
			
		||||
		assertThatThrownBy(() -> new OAuth2UserRequest(null, this.accessToken))
 | 
			
		||||
				.isInstanceOf(IllegalArgumentException.class);
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	@Test(expected = IllegalArgumentException.class)
 | 
			
		||||
	@Test
 | 
			
		||||
	public void constructorWhenAccessTokenIsNullThenThrowIllegalArgumentException() {
 | 
			
		||||
		new OAuth2UserRequest(this.clientRegistration, null);
 | 
			
		||||
		assertThatThrownBy(() -> new OAuth2UserRequest(this.clientRegistration, null))
 | 
			
		||||
				.isInstanceOf(IllegalArgumentException.class);
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	@Test
 | 
			
		||||
	public void constructorWhenAllParametersProvidedAndValidThenCreated() {
 | 
			
		||||
		OAuth2UserRequest userRequest = new OAuth2UserRequest(this.clientRegistration, this.accessToken);
 | 
			
		||||
		OAuth2UserRequest userRequest = new OAuth2UserRequest(
 | 
			
		||||
				this.clientRegistration, this.accessToken, this.additionalParameters);
 | 
			
		||||
 | 
			
		||||
		assertThat(userRequest.getClientRegistration()).isEqualTo(this.clientRegistration);
 | 
			
		||||
		assertThat(userRequest.getAccessToken()).isEqualTo(this.accessToken);
 | 
			
		||||
		assertThat(userRequest.getAdditionalParameters()).containsAllEntriesOf(this.additionalParameters);
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue