diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java index d97532298d..9644f6620c 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java @@ -62,8 +62,8 @@ public class OidcUserService implements OAuth2UserService, Map> DEFAULT_CLAIM_TYPE_CONVERTER = new ClaimTypeConverter(createDefaultClaimTypeConverters()); - private final Set userInfoScopes = new HashSet<>( - Arrays.asList(OidcScopes.PROFILE, OidcScopes.EMAIL, OidcScopes.ADDRESS, OidcScopes.PHONE)); + private Set accessibleScopes = new HashSet<>(Arrays.asList( + OidcScopes.PROFILE, OidcScopes.EMAIL, OidcScopes.ADDRESS, OidcScopes.PHONE)); private OAuth2UserService oauth2UserService = new DefaultOAuth2UserService(); private Function, Map>> claimTypeConverterFactory = clientRegistration -> DEFAULT_CLAIM_TYPE_CONVERTER; @@ -160,8 +160,9 @@ public class OidcUserService implements OAuth2UserService accessibleScopes) { + Assert.notNull(accessibleScopes, "accessibleScopes cannot be null"); + this.accessibleScopes = accessibleScopes; + } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java index b562099f6d..e034076e40 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java @@ -41,11 +41,9 @@ import org.springframework.security.oauth2.core.oidc.user.OidcUser; import org.springframework.security.oauth2.core.oidc.user.OidcUserAuthority; import java.time.Instant; -import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; -import java.util.LinkedHashSet; import java.util.Map; -import java.util.Set; import java.util.concurrent.TimeUnit; import java.util.function.Function; @@ -116,6 +114,17 @@ public class OidcUserServiceTests { .isInstanceOf(IllegalArgumentException.class); } + @Test + public void setAccessibleScopesWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.userService.setAccessibleScopes(null)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void setAccessibleScopesWhenEmptyThenSet() { + this.userService.setAccessibleScopes(Collections.emptySet()); + } + @Test public void loadUserWhenUserRequestIsNullThenThrowIllegalArgumentException() { this.exception.expect(IllegalArgumentException.class); @@ -130,20 +139,91 @@ public class OidcUserServiceTests { } @Test - public void loadUserWhenAuthorizedScopesDoesNotContainUserInfoScopesThenUserInfoEndpointNotRequested() { + public void loadUserWhenNonStandardScopesAuthorizedThenUserInfoEndpointNotRequested() { ClientRegistration clientRegistration = this.clientRegistrationBuilder .userInfoUri("https://provider.com/user").build(); - - Set authorizedScopes = new LinkedHashSet<>(Arrays.asList("scope1", "scope2")); - OAuth2AccessToken accessToken = new OAuth2AccessToken( - OAuth2AccessToken.TokenType.BEARER, "access-token", - Instant.MIN, Instant.MAX, authorizedScopes); + this.accessToken = scopes("scope1", "scope2"); OidcUser user = this.userService.loadUser( - new OidcUserRequest(clientRegistration, accessToken, this.idToken)); + new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); assertThat(user.getUserInfo()).isNull(); } + // gh-6886 + @Test + public void loadUserWhenNonStandardScopesAuthorizedAndAccessibleScopesMatchThenUserInfoEndpointRequested() { + String userInfoResponse = "{\n" + + " \"sub\": \"subject1\",\n" + + " \"name\": \"first last\",\n" + + " \"given_name\": \"first\",\n" + + " \"family_name\": \"last\",\n" + + " \"preferred_username\": \"user1\",\n" + + " \"email\": \"user1@example.com\"\n" + + "}\n"; + this.server.enqueue(jsonResponse(userInfoResponse)); + + String userInfoUri = this.server.url("/user").toString(); + + ClientRegistration clientRegistration = this.clientRegistrationBuilder + .userInfoUri(userInfoUri).build(); + + this.accessToken = scopes("scope1", "scope2"); + this.userService.setAccessibleScopes(Collections.singleton("scope2")); + + OidcUser user = this.userService.loadUser( + new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); + assertThat(user.getUserInfo()).isNotNull(); + } + + // gh-6886 + @Test + public void loadUserWhenNonStandardScopesAuthorizedAndAccessibleScopesEmptyThenUserInfoEndpointRequested() { + String userInfoResponse = "{\n" + + " \"sub\": \"subject1\",\n" + + " \"name\": \"first last\",\n" + + " \"given_name\": \"first\",\n" + + " \"family_name\": \"last\",\n" + + " \"preferred_username\": \"user1\",\n" + + " \"email\": \"user1@example.com\"\n" + + "}\n"; + this.server.enqueue(jsonResponse(userInfoResponse)); + + String userInfoUri = this.server.url("/user").toString(); + + ClientRegistration clientRegistration = this.clientRegistrationBuilder + .userInfoUri(userInfoUri).build(); + + this.accessToken = scopes("scope1", "scope2"); + this.userService.setAccessibleScopes(Collections.emptySet()); + + OidcUser user = this.userService.loadUser( + new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); + assertThat(user.getUserInfo()).isNotNull(); + } + + // gh-6886 + @Test + public void loadUserWhenStandardScopesAuthorizedThenUserInfoEndpointRequested() { + String userInfoResponse = "{\n" + + " \"sub\": \"subject1\",\n" + + " \"name\": \"first last\",\n" + + " \"given_name\": \"first\",\n" + + " \"family_name\": \"last\",\n" + + " \"preferred_username\": \"user1\",\n" + + " \"email\": \"user1@example.com\"\n" + + "}\n"; + this.server.enqueue(jsonResponse(userInfoResponse)); + + String userInfoUri = this.server.url("/user").toString(); + + ClientRegistration clientRegistration = this.clientRegistrationBuilder + .userInfoUri(userInfoUri).build(); + + OidcUser user = this.userService.loadUser( + new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); + assertThat(user.getUserInfo()).isNotNull(); + } + @Test public void loadUserWhenUserInfoSuccessResponseThenReturnUser() { String userInfoResponse = "{\n" +