Polish gh-17507

This commit is contained in:
Joe Grandja 2025-10-02 16:37:46 -04:00
parent 8c65dc93f2
commit 4dfef1483d
4 changed files with 22 additions and 106 deletions

View File

@ -651,10 +651,6 @@ public final class ClientRegistration implements Serializable {
clientRegistration.clientName = StringUtils.hasText(this.clientName) ? this.clientName clientRegistration.clientName = StringUtils.hasText(this.clientName) ? this.clientName
: this.registrationId; : this.registrationId;
clientRegistration.clientSettings = this.clientSettings; clientRegistration.clientSettings = this.clientSettings;
if (clientRegistration.clientSettings.requireProofKey) {
clientRegistration.clientSettings.requireProofKey = AuthorizationGrantType.AUTHORIZATION_CODE
.equals(this.authorizationGrantType);
}
return clientRegistration; return clientRegistration;
} }
@ -706,6 +702,13 @@ public final class ClientRegistration implements Serializable {
this.authorizationGrantType, authorizationGrantType)); this.authorizationGrantType, authorizationGrantType));
} }
} }
if (!AuthorizationGrantType.AUTHORIZATION_CODE.equals(this.authorizationGrantType)
&& this.clientSettings.isRequireProofKey()) {
this.clientSettings = ClientSettings.builder().requireProofKey(false).build();
logger.warn(LogMessage.format(
"clientSettings.isRequireProofKey=true is only valid with authorizationGrantType=%s. Got authorizationGrantType=%s. Resetting to clientSettings.isRequireProofKey=false",
AuthorizationGrantType.AUTHORIZATION_CODE, this.authorizationGrantType));
}
} }
private void validateScopes() { private void validateScopes() {

View File

@ -35,7 +35,8 @@ import org.springframework.security.oauth2.core.AuthenticationMethod;
import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import static org.assertj.core.api.Assertions.*; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
/** /**
* Tests for {@link ClientRegistration}. * Tests for {@link ClientRegistration}.
@ -680,7 +681,6 @@ public class ClientRegistrationTests {
// should not be null // should not be null
assertThat(clientRegistration.getClientSettings()).isNotNull(); assertThat(clientRegistration.getClientSettings()).isNotNull();
// proof key should be false for passivity
assertThat(clientRegistration.getClientSettings().isRequireProofKey()).isTrue(); assertThat(clientRegistration.getClientSettings().isRequireProofKey()).isTrue();
} }
@ -719,37 +719,9 @@ public class ClientRegistrationTests {
assertThat(clientRegistration.getClientSettings().isRequireProofKey()).isFalse(); assertThat(clientRegistration.getClientSettings().isRequireProofKey()).isFalse();
} }
@Test
void buildWhenNewAuthorizationCodeAndPrivateClientThenPkceEnabledAndExceptionThrown() {
List<ClientAuthenticationMethod> clientAuthenticationMethods = Arrays
.stream(ClientAuthenticationMethod.class.getFields())
.filter((field) -> Modifier.isFinal(field.getModifiers())
&& field.getType() == ClientAuthenticationMethod.class)
.map((field) -> getStaticValue(field, ClientAuthenticationMethod.class))
.filter((authenticationMethod) -> authenticationMethod != ClientAuthenticationMethod.NONE)
.map((authenticationMethod) -> new ClientAuthenticationMethod(authenticationMethod.getValue()))
.toList();
for (ClientAuthenticationMethod clientAuthenticationMethod : clientAuthenticationMethods) {
ClientRegistration.ClientSettings pkceEnabled = ClientRegistration.ClientSettings.builder()
.requireProofKey(true)
.build();
ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(REGISTRATION_ID)
.clientId(CLIENT_ID)
.clientSettings(pkceEnabled)
.authorizationGrantType(
new AuthorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()))
.clientAuthenticationMethod(clientAuthenticationMethod)
.redirectUri(REDIRECT_URI)
.authorizationUri(AUTHORIZATION_URI)
.tokenUri(TOKEN_URI)
.build();
assertThat(clientRegistration.getClientSettings().isRequireProofKey()).isTrue();
}
}
@ParameterizedTest @ParameterizedTest
@MethodSource("invalidPkceGrantTypes") @MethodSource("invalidPkceGrantTypes")
void buildWhenInvalidGrantTypeForPkceThenException(AuthorizationGrantType invalidGrantType) { void buildWhenInvalidGrantTypeForPkceThenPkceDisabled(AuthorizationGrantType invalidGrantType) {
ClientRegistration.ClientSettings pkceEnabled = ClientRegistration.ClientSettings.builder() ClientRegistration.ClientSettings pkceEnabled = ClientRegistration.ClientSettings.builder()
.requireProofKey(true) .requireProofKey(true)
.build(); .build();

View File

@ -60,8 +60,6 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
private ClientRegistration pkceClientRegistration; private ClientRegistration pkceClientRegistration;
private ClientRegistration nonProofKeyPublicClientRegistration;
private ClientRegistration fineRedirectUriTemplateRegistration; private ClientRegistration fineRedirectUriTemplateRegistration;
private ClientRegistration publicClientRegistration; private ClientRegistration publicClientRegistration;
@ -80,11 +78,6 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
this.registration2 = TestClientRegistrations.clientRegistration2().build(); this.registration2 = TestClientRegistrations.clientRegistration2().build();
this.pkceClientRegistration = pkceClientRegistration().build(); this.pkceClientRegistration = pkceClientRegistration().build();
this.nonProofKeyPublicClientRegistration = TestClientRegistrations.clientRegistration()
.registrationId("invalid-public-client-registration-id")
.clientAuthenticationMethod(ClientAuthenticationMethod.NONE)
.clientSettings(ClientRegistration.ClientSettings.builder().requireProofKey(false).build())
.build();
this.fineRedirectUriTemplateRegistration = fineRedirectUriTemplateClientRegistration().build(); this.fineRedirectUriTemplateRegistration = fineRedirectUriTemplateClientRegistration().build();
// @formatter:off // @formatter:off
this.publicClientRegistration = TestClientRegistrations.clientRegistration() this.publicClientRegistration = TestClientRegistrations.clientRegistration()
@ -100,7 +93,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
// @formatter:on // @formatter:on
this.clientRegistrationRepository = new InMemoryClientRegistrationRepository(this.registration1, this.clientRegistrationRepository = new InMemoryClientRegistrationRepository(this.registration1,
this.registration2, this.pkceClientRegistration, this.fineRedirectUriTemplateRegistration, this.registration2, this.pkceClientRegistration, this.fineRedirectUriTemplateRegistration,
this.publicClientRegistration, this.oidcRegistration, this.nonProofKeyPublicClientRegistration); this.publicClientRegistration, this.oidcRegistration);
this.resolver = new DefaultOAuth2AuthorizationRequestResolver(this.clientRegistrationRepository, this.resolver = new DefaultOAuth2AuthorizationRequestResolver(this.clientRegistrationRepository,
this.authorizationRequestBaseUri); this.authorizationRequestBaseUri);
} }
@ -396,33 +389,6 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
// gh-6548 // gh-6548
@Test @Test
public void resolveWhenAuthorizationRequestApplyPkceToConfidentialClientsThenApplied() { public void resolveWhenAuthorizationRequestApplyPkceToConfidentialClientsThenApplied() {
this.resolver.setAuthorizationRequestCustomizer(OAuth2AuthorizationRequestCustomizers.withPkce());
ClientRegistration clientRegistration = this.registration1;
String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
MockHttpServletRequest request = get(requestUri).build();
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
assertPkceApplied(authorizationRequest, clientRegistration);
clientRegistration = this.registration2;
requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
request = get(requestUri).build();
authorizationRequest = this.resolver.resolve(request);
assertPkceApplied(authorizationRequest, clientRegistration);
}
// gh-6548
@Test
public void resolveWhenAuthorizationRequestApplyPkceToSpecificConfidentialClientThenApplied() {
this.resolver.setAuthorizationRequestCustomizer((builder) -> {
builder.attributes((attrs) -> {
String registrationId = (String) attrs.get(OAuth2ParameterNames.REGISTRATION_ID);
if (this.registration1.getRegistrationId().equals(registrationId)) {
OAuth2AuthorizationRequestCustomizers.withPkce().accept(builder);
}
});
});
ClientRegistration clientRegistration = this.registration1; ClientRegistration clientRegistration = this.registration1;
String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
MockHttpServletRequest request = get(requestUri).build(); MockHttpServletRequest request = get(requestUri).build();
@ -549,6 +515,17 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
+ "&code_challenge=([a-zA-Z0-9\\-\\.\\_\\~]){43}&code_challenge_method=S256&appid=client-id"); + "&code_challenge=([a-zA-Z0-9\\-\\.\\_\\~]){43}&code_challenge_method=S256&appid=client-id");
} }
@Test
public void resolveWhenAuthorizationRequestNoProvideAuthorizationRequestBaseUri() {
OAuth2AuthorizationRequestResolver resolver = new DefaultOAuth2AuthorizationRequestResolver(
this.clientRegistrationRepository);
String requestUri = this.authorizationRequestBaseUri + "/" + this.registration2.getRegistrationId();
MockHttpServletRequest request = get(requestUri).build();
OAuth2AuthorizationRequest authorizationRequest = resolver.resolve(request);
assertThat(authorizationRequest.getRedirectUri())
.isEqualTo("http://localhost/login/oauth2/code/" + this.registration2.getRegistrationId());
}
@Test @Test
public void resolveWhenAuthorizationRequestProvideCodeChallengeMethod() { public void resolveWhenAuthorizationRequestProvideCodeChallengeMethod() {
ClientRegistration clientRegistration = this.pkceClientRegistration; ClientRegistration clientRegistration = this.pkceClientRegistration;

View File

@ -29,7 +29,6 @@ import org.springframework.mock.web.server.MockServerWebExchange;
import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestCustomizers;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
@ -59,18 +58,11 @@ public class DefaultServerOAuth2AuthorizationRequestResolverTests {
private DefaultServerOAuth2AuthorizationRequestResolver resolver; private DefaultServerOAuth2AuthorizationRequestResolver resolver;
private ClientRegistration nonProofKeyPublicClientRegistration;
private ClientRegistration registration = TestClientRegistrations.clientRegistration().build(); private ClientRegistration registration = TestClientRegistrations.clientRegistration().build();
@BeforeEach @BeforeEach
public void setup() { public void setup() {
this.resolver = new DefaultServerOAuth2AuthorizationRequestResolver(this.clientRegistrationRepository); this.resolver = new DefaultServerOAuth2AuthorizationRequestResolver(this.clientRegistrationRepository);
this.nonProofKeyPublicClientRegistration = TestClientRegistrations.clientRegistration()
.registrationId("invalid-public-client-registration-id")
.clientAuthenticationMethod(ClientAuthenticationMethod.NONE)
.clientSettings(ClientRegistration.ClientSettings.builder().requireProofKey(false).build())
.build();
} }
@Test @Test
@ -143,8 +135,6 @@ public class DefaultServerOAuth2AuthorizationRequestResolverTests {
given(this.clientRegistrationRepository.findByRegistrationId(eq(registration2.getRegistrationId()))) given(this.clientRegistrationRepository.findByRegistrationId(eq(registration2.getRegistrationId())))
.willReturn(Mono.just(registration2)); .willReturn(Mono.just(registration2));
this.resolver.setAuthorizationRequestCustomizer(OAuth2AuthorizationRequestCustomizers.withPkce());
OAuth2AuthorizationRequest request = resolve("/oauth2/authorization/" + registration1.getRegistrationId()); OAuth2AuthorizationRequest request = resolve("/oauth2/authorization/" + registration1.getRegistrationId());
assertPkceApplied(request, registration1); assertPkceApplied(request, registration1);
@ -152,32 +142,6 @@ public class DefaultServerOAuth2AuthorizationRequestResolverTests {
assertPkceApplied(request, registration2); assertPkceApplied(request, registration2);
} }
// gh-6548
@Test
public void resolveWhenAuthorizationRequestApplyPkceToSpecificConfidentialClientThenApplied() {
ClientRegistration registration1 = TestClientRegistrations.clientRegistration().build();
given(this.clientRegistrationRepository.findByRegistrationId(eq(registration1.getRegistrationId())))
.willReturn(Mono.just(registration1));
given(this.clientRegistrationRepository
.findByRegistrationId(eq(this.nonProofKeyPublicClientRegistration.getRegistrationId())))
.willReturn(Mono.just(this.nonProofKeyPublicClientRegistration));
this.resolver.setAuthorizationRequestCustomizer((builder) -> {
builder.attributes((attrs) -> {
String registrationId = (String) attrs.get(OAuth2ParameterNames.REGISTRATION_ID);
if (registration1.getRegistrationId().equals(registrationId)) {
OAuth2AuthorizationRequestCustomizers.withPkce().accept(builder);
}
});
});
OAuth2AuthorizationRequest request = resolve("/oauth2/authorization/" + registration1.getRegistrationId());
assertPkceApplied(request, registration1);
request = resolve("/oauth2/authorization/" + this.nonProofKeyPublicClientRegistration.getRegistrationId());
assertPkceApplied(request, this.nonProofKeyPublicClientRegistration);
}
@Test @Test
void resolveWhenRequireProofKeyTrueThenPkceEnabled() { void resolveWhenRequireProofKeyTrueThenPkceEnabled() {
ClientRegistration.ClientSettings pkceEnabled = ClientRegistration.ClientSettings.builder() ClientRegistration.ClientSettings pkceEnabled = ClientRegistration.ClientSettings.builder()