PKCE cannot be true and AuthorizationGrantType != AUTHORIZATION_CODE

PKCE is only valid for AuthorizationGrantType.AUTHORIZATION_CODE so the
code should validate this.

Issue gh-16382
This commit is contained in:
Rob Winch 2025-01-17 10:59:48 -06:00
parent ab629cc1ca
commit f9498d3885
No known key found for this signature in database
2 changed files with 68 additions and 0 deletions

View File

@ -711,6 +711,12 @@ public final class ClientRegistration implements Serializable {
"AuthorizationGrantType: %s does not match the pre-defined constant %s and won't match a valid OAuth2AuthorizedClientProvider",
this.authorizationGrantType, authorizationGrantType));
}
if (!AuthorizationGrantType.AUTHORIZATION_CODE.equals(this.authorizationGrantType)
&& this.clientSettings.isRequireProofKey()) {
throw new IllegalStateException(
"clientSettings.isRequireProofKey=true is only valid with authorizationGrantType=AUTHORIZATION_CODE. Got authorizationGrantType="
+ this.authorizationGrantType);
}
}
}

View File

@ -16,14 +16,20 @@
package org.springframework.security.oauth2.client.registration;
import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.springframework.security.oauth2.core.AuthenticationMethod;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
@ -31,6 +37,7 @@ import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
/**
* Tests for {@link ClientRegistration}.
@ -776,4 +783,59 @@ public class ClientRegistrationTests {
assertThat(clientRegistration.getClientSettings().isRequireProofKey()).isFalse();
}
// gh-16382
@Test
void buildWhenNewAuthorizationCodeAndPkceThenBuilds() {
ClientSettings pkceEnabled = ClientSettings.builder().requireProofKey(true).build();
ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(REGISTRATION_ID)
.clientId(CLIENT_ID)
.clientSettings(pkceEnabled)
.authorizationGrantType(new AuthorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()))
.redirectUri(REDIRECT_URI)
.authorizationUri(AUTHORIZATION_URI)
.tokenUri(TOKEN_URI)
.build();
// proof key should be false for passivity
assertThat(clientRegistration.getClientSettings().isRequireProofKey()).isTrue();
}
@ParameterizedTest
@MethodSource("invalidPkceGrantTypes")
void buildWhenInvalidGrantTypeForPkceThenException(AuthorizationGrantType invalidGrantType) {
ClientSettings pkceEnabled = ClientSettings.builder().requireProofKey(true).build();
ClientRegistration.Builder builder = ClientRegistration.withRegistrationId(REGISTRATION_ID)
.clientId(CLIENT_ID)
.clientSettings(pkceEnabled)
.authorizationGrantType(invalidGrantType)
.redirectUri(REDIRECT_URI)
.authorizationUri(AUTHORIZATION_URI)
.tokenUri(TOKEN_URI);
assertThatIllegalStateException().describedAs(
"clientSettings.isRequireProofKey=true is only valid with authorizationGrantType=AUTHORIZATION_CODE. Got authorizationGrantType={}",
invalidGrantType)
.isThrownBy(builder::build);
}
static List<AuthorizationGrantType> invalidPkceGrantTypes() {
return Arrays.stream(AuthorizationGrantType.class.getFields())
.filter((field) -> Modifier.isFinal(field.getModifiers())
&& field.getType() == AuthorizationGrantType.class)
.map((field) -> getStaticValue(field, AuthorizationGrantType.class))
.filter((grantType) -> grantType != AuthorizationGrantType.AUTHORIZATION_CODE)
// ensure works with .equals
.map((grantType) -> new AuthorizationGrantType(grantType.getValue()))
.collect(Collectors.toList());
}
private static <T> T getStaticValue(Field field, Class<T> clazz) {
try {
return (T) field.get(null);
}
catch (IllegalAccessException ex) {
throw new RuntimeException(ex);
}
}
}