Support custom token validators for OAuth2

See gh-35874
This commit is contained in:
Roman Golovin 2023-06-13 20:35:46 +03:00 committed by Andy Wilkinson
parent ce8253ea95
commit 7500dab321
4 changed files with 146 additions and 30 deletions

View File

@ -61,6 +61,7 @@ import org.springframework.util.CollectionUtils;
* @author HaiTao Zhang
* @author Anastasiia Losieva
* @author Mushtaq Ahmed
* @author Roman Golovin
*/
@Configuration(proxyBeanMethods = false)
class ReactiveOAuth2ResourceServerJwkConfiguration {
@ -71,8 +72,12 @@ class ReactiveOAuth2ResourceServerJwkConfiguration {
private final OAuth2ResourceServerProperties.Jwt properties;
JwtConfiguration(OAuth2ResourceServerProperties properties) {
private final List<OAuth2TokenValidator<Jwt>> customOAuth2TokenValidators;
JwtConfiguration(OAuth2ResourceServerProperties properties,
List<OAuth2TokenValidator<Jwt>> customOAuth2TokenValidators) {
this.properties = properties.getJwt();
this.customOAuth2TokenValidators = customOAuth2TokenValidators;
}
@Bean
@ -97,14 +102,17 @@ class ReactiveOAuth2ResourceServerJwkConfiguration {
}
private OAuth2TokenValidator<Jwt> getValidators(OAuth2TokenValidator<Jwt> defaultValidator) {
List<String> audiences = this.properties.getAudiences();
if (CollectionUtils.isEmpty(audiences)) {
return defaultValidator;
}
List<OAuth2TokenValidator<Jwt>> validators = new ArrayList<>();
validators.add(defaultValidator);
validators.add(new JwtClaimValidator<List<String>>(JwtClaimNames.AUD,
(aud) -> aud != null && !Collections.disjoint(aud, audiences)));
validators.addAll(this.customOAuth2TokenValidators);
List<String> audiences = this.properties.getAudiences();
if (!CollectionUtils.isEmpty(audiences)) {
validators.add(new JwtClaimValidator<List<String>>(JwtClaimNames.AUD,
(aud) -> aud != null && !Collections.disjoint(aud, audiences)));
}
if (validators.size() == 1) {
return validators.get(0);
}
return new DelegatingOAuth2TokenValidator<>(validators);
}

View File

@ -62,6 +62,7 @@ import static org.springframework.security.config.Customizer.withDefaults;
* @author Artsiom Yudovin
* @author HaiTao Zhang
* @author Mushtaq Ahmed
* @author Roman Golovin
*/
@Configuration(proxyBeanMethods = false)
class OAuth2ResourceServerJwtConfiguration {
@ -72,8 +73,12 @@ class OAuth2ResourceServerJwtConfiguration {
private final OAuth2ResourceServerProperties.Jwt properties;
JwtDecoderConfiguration(OAuth2ResourceServerProperties properties) {
private final List<OAuth2TokenValidator<Jwt>> customOAuth2TokenValidators;
JwtDecoderConfiguration(OAuth2ResourceServerProperties properties,
List<OAuth2TokenValidator<Jwt>> customOAuth2TokenValidators) {
this.properties = properties.getJwt();
this.customOAuth2TokenValidators = customOAuth2TokenValidators;
}
@Bean
@ -97,14 +102,17 @@ class OAuth2ResourceServerJwtConfiguration {
}
private OAuth2TokenValidator<Jwt> getValidators(OAuth2TokenValidator<Jwt> defaultValidator) {
List<String> audiences = this.properties.getAudiences();
if (CollectionUtils.isEmpty(audiences)) {
return defaultValidator;
}
List<OAuth2TokenValidator<Jwt>> validators = new ArrayList<>();
validators.add(defaultValidator);
validators.add(new JwtClaimValidator<List<String>>(JwtClaimNames.AUD,
(aud) -> aud != null && !Collections.disjoint(aud, audiences)));
validators.addAll(this.customOAuth2TokenValidators);
List<String> audiences = this.properties.getAudiences();
if (!CollectionUtils.isEmpty(audiences)) {
validators.add(new JwtClaimValidator<List<String>>(JwtClaimNames.AUD,
(aud) -> aud != null && !Collections.disjoint(aud, audiences)));
}
if (validators.size() == 1) {
return validators.get(0);
}
return new DelegatingOAuth2TokenValidator<>(validators);
}

View File

@ -24,7 +24,9 @@ import java.time.Instant;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import com.fasterxml.jackson.core.JsonProcessingException;
@ -87,6 +89,7 @@ import static org.springframework.security.config.Customizer.withDefaults;
* @author HaiTao Zhang
* @author Anastasiia Losieva
* @author Mushtaq Ahmed
* @author Roman Golovin
*/
class ReactiveOAuth2ResourceServerAutoConfigurationTests {
@ -502,36 +505,73 @@ class ReactiveOAuth2ResourceServerAutoConfigurationTests {
.run((context) -> {
assertThat(context).hasSingleBean(ReactiveJwtDecoder.class);
ReactiveJwtDecoder reactiveJwtDecoder = context.getBean(ReactiveJwtDecoder.class);
validate(issuerUri, reactiveJwtDecoder);
validate(issuerUri, reactiveJwtDecoder, null);
});
}
@SuppressWarnings("unchecked")
private void validate(String issuerUri, ReactiveJwtDecoder jwtDecoder) throws MalformedURLException {
@Test
void autoConfigurationShouldConfigureAudienceAndCustomValidatorsIfPropertyProvidedAndIssuerUri() throws Exception {
this.server = new MockWebServer();
this.server.start();
String path = "test";
String issuer = this.server.url(path).toString();
String cleanIssuerPath = cleanIssuerPath(issuer);
setupMockResponse(cleanIssuerPath);
String issuerUri = "http://" + this.server.getHostName() + ":" + this.server.getPort() + "/" + path;
this.contextRunner.withPropertyValues(
"spring.security.oauth2.resourceserver.jwt.jwk-set-uri=https://jwk-set-uri.com",
"spring.security.oauth2.resourceserver.jwt.issuer-uri=" + issuerUri,
"spring.security.oauth2.resourceserver.jwt.audiences=https://test-audience.com,https://test-audience1.com")
.withUserConfiguration(CustomTokenValidatorsConfig.class)
.run((context) -> {
assertThat(context).hasSingleBean(ReactiveJwtDecoder.class);
ReactiveJwtDecoder reactiveJwtDecoder = context.getBean(ReactiveJwtDecoder.class);
assertThat(context).hasBean("customJwtClaimValidator");
OAuth2TokenValidator<Jwt> customValidator = (OAuth2TokenValidator<Jwt>) context
.getBean("customJwtClaimValidator");
validate(issuerUri, reactiveJwtDecoder, customValidator);
});
}
@SuppressWarnings("unchecked")
private void validate(String issuerUri, ReactiveJwtDecoder jwtDecoder, OAuth2TokenValidator<Jwt> customValidator)
throws MalformedURLException {
DelegatingOAuth2TokenValidator<Jwt> jwtValidator = (DelegatingOAuth2TokenValidator<Jwt>) ReflectionTestUtils
.getField(jwtDecoder, "jwtValidator");
Jwt.Builder builder = jwt().claim("aud", Collections.singletonList("https://test-audience.com"));
if (issuerUri != null) {
builder.claim("iss", new URL(issuerUri));
}
if (customValidator != null) {
builder.claim("custom_claim", "custom_claim_value");
}
Jwt jwt = builder.build();
assertThat(jwtValidator.validate(jwt).hasErrors()).isFalse();
Collection<OAuth2TokenValidator<Jwt>> delegates = (Collection<OAuth2TokenValidator<Jwt>>) ReflectionTestUtils
.getField(jwtValidator, "tokenValidators");
validateDelegates(issuerUri, delegates);
validateDelegates(issuerUri, delegates, customValidator);
}
@SuppressWarnings("unchecked")
private void validateDelegates(String issuerUri, Collection<OAuth2TokenValidator<Jwt>> delegates) {
private void validateDelegates(String issuerUri, Collection<OAuth2TokenValidator<Jwt>> delegates,
OAuth2TokenValidator<Jwt> customValidator) {
assertThat(delegates).hasAtLeastOneElementOfType(JwtClaimValidator.class);
OAuth2TokenValidator<Jwt> delegatingValidator = delegates.stream()
.filter((v) -> v instanceof DelegatingOAuth2TokenValidator)
.findFirst()
.get();
Collection<OAuth2TokenValidator<Jwt>> nestedDelegates = (Collection<OAuth2TokenValidator<Jwt>>) ReflectionTestUtils
.getField(delegatingValidator, "tokenValidators");
if (issuerUri != null) {
assertThat(nestedDelegates).hasAtLeastOneElementOfType(JwtIssuerValidator.class);
assertThat(delegatingValidator).extracting("tokenValidators")
.asInstanceOf(InstanceOfAssertFactories.collection(OAuth2TokenValidator.class))
.hasAtLeastOneElementOfType(JwtIssuerValidator.class);
}
List<OAuth2TokenValidator<Jwt>> claimValidators = delegates.stream()
.filter((d) -> d instanceof JwtClaimValidator<?>)
.collect(Collectors.toList());
assertThat(claimValidators).anyMatch((v) -> "aud".equals(ReflectionTestUtils.getField(v, "claim")));
if (customValidator != null) {
assertThat(claimValidators)
.anyMatch((v) -> "custom_claim".equals(ReflectionTestUtils.getField(v, "claim")));
}
}
@ -552,7 +592,7 @@ class ReactiveOAuth2ResourceServerAutoConfigurationTests {
Mono<ReactiveJwtDecoder> jwtDecoderSupplier = (Mono<ReactiveJwtDecoder>) ReflectionTestUtils
.getField(supplierJwtDecoderBean, "jwtDecoderMono");
ReactiveJwtDecoder jwtDecoder = jwtDecoderSupplier.block();
validate(issuerUri, jwtDecoder);
validate(issuerUri, jwtDecoder, null);
});
}
@ -570,7 +610,7 @@ class ReactiveOAuth2ResourceServerAutoConfigurationTests {
.run((context) -> {
assertThat(context).hasSingleBean(ReactiveJwtDecoder.class);
ReactiveJwtDecoder jwtDecoder = context.getBean(ReactiveJwtDecoder.class);
validate(null, jwtDecoder);
validate(null, jwtDecoder, null);
});
}
@ -740,4 +780,14 @@ class ReactiveOAuth2ResourceServerAutoConfigurationTests {
}
@Configuration(proxyBeanMethods = false)
static class CustomTokenValidatorsConfig {
@Bean
JwtClaimValidator<String> customJwtClaimValidator() {
return new JwtClaimValidator<>("custom_claim", "custom_claim_value"::equals);
}
}
}

View File

@ -25,6 +25,7 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
@ -80,6 +81,7 @@ import static org.mockito.Mockito.mock;
* @author Artsiom Yudovin
* @author HaiTao Zhang
* @author Mushtaq Ahmed
* @author Roman Golovin
*/
class OAuth2ResourceServerAutoConfigurationTests {
@ -515,7 +517,7 @@ class OAuth2ResourceServerAutoConfigurationTests {
.run((context) -> {
assertThat(context).hasSingleBean(JwtDecoder.class);
JwtDecoder jwtDecoder = context.getBean(JwtDecoder.class);
validate(issuerUri, jwtDecoder);
validate(issuerUri, jwtDecoder, null);
});
}
@ -536,26 +538,56 @@ class OAuth2ResourceServerAutoConfigurationTests {
Supplier<JwtDecoder> jwtDecoderSupplier = (Supplier<JwtDecoder>) ReflectionTestUtils
.getField(supplierJwtDecoderBean, "delegate");
JwtDecoder jwtDecoder = jwtDecoderSupplier.get();
validate(issuerUri, jwtDecoder);
validate(issuerUri, jwtDecoder, null);
});
}
@SuppressWarnings("unchecked")
private void validate(String issuerUri, JwtDecoder jwtDecoder) throws MalformedURLException {
@Test
void autoConfigurationShouldConfigureAudienceAndCustomValidatorsIfPropertyProvidedAndIssuerUri() throws Exception {
this.server = new MockWebServer();
this.server.start();
String path = "test";
String issuer = this.server.url(path).toString();
String cleanIssuerPath = cleanIssuerPath(issuer);
setupMockResponse(cleanIssuerPath);
String issuerUri = "http://" + this.server.getHostName() + ":" + this.server.getPort() + "/" + path;
this.contextRunner.withPropertyValues("spring.security.oauth2.resourceserver.jwt.issuer-uri=" + issuerUri,
"spring.security.oauth2.resourceserver.jwt.audiences=https://test-audience.com,https://test-audience1.com")
.withUserConfiguration(CustomTokenValidatorsConfig.class)
.run((context) -> {
SupplierJwtDecoder supplierJwtDecoderBean = context.getBean(SupplierJwtDecoder.class);
Supplier<JwtDecoder> jwtDecoderSupplier = (Supplier<JwtDecoder>) ReflectionTestUtils
.getField(supplierJwtDecoderBean, "delegate");
JwtDecoder jwtDecoder = jwtDecoderSupplier.get();
assertThat(context).hasBean("customJwtClaimValidator");
OAuth2TokenValidator<Jwt> customValidator = (OAuth2TokenValidator<Jwt>) context
.getBean("customJwtClaimValidator");
validate(issuerUri, jwtDecoder, customValidator);
});
}
@SuppressWarnings("unchecked")
private void validate(String issuerUri, JwtDecoder jwtDecoder, OAuth2TokenValidator<Jwt> customValidator)
throws MalformedURLException {
DelegatingOAuth2TokenValidator<Jwt> jwtValidator = (DelegatingOAuth2TokenValidator<Jwt>) ReflectionTestUtils
.getField(jwtDecoder, "jwtValidator");
Jwt.Builder builder = jwt().claim("aud", Collections.singletonList("https://test-audience.com"));
if (issuerUri != null) {
builder.claim("iss", new URL(issuerUri));
}
if (customValidator != null) {
builder.claim("custom_claim", "custom_claim_value");
}
Jwt jwt = builder.build();
assertThat(jwtValidator.validate(jwt).hasErrors()).isFalse();
Collection<OAuth2TokenValidator<Jwt>> delegates = (Collection<OAuth2TokenValidator<Jwt>>) ReflectionTestUtils
.getField(jwtValidator, "tokenValidators");
validateDelegates(issuerUri, delegates);
validateDelegates(issuerUri, delegates, customValidator);
}
private void validateDelegates(String issuerUri, Collection<OAuth2TokenValidator<Jwt>> delegates) {
private void validateDelegates(String issuerUri, Collection<OAuth2TokenValidator<Jwt>> delegates,
OAuth2TokenValidator<Jwt> customValidator) {
assertThat(delegates).hasAtLeastOneElementOfType(JwtClaimValidator.class);
OAuth2TokenValidator<Jwt> delegatingValidator = delegates.stream()
.filter((v) -> v instanceof DelegatingOAuth2TokenValidator)
@ -566,6 +598,14 @@ class OAuth2ResourceServerAutoConfigurationTests {
.asInstanceOf(InstanceOfAssertFactories.collection(OAuth2TokenValidator.class))
.hasAtLeastOneElementOfType(JwtIssuerValidator.class);
}
List<OAuth2TokenValidator<Jwt>> claimValidators = delegates.stream()
.filter((d) -> d instanceof JwtClaimValidator<?>)
.collect(Collectors.toList());
assertThat(claimValidators).anyMatch((v) -> "aud".equals(ReflectionTestUtils.getField(v, "claim")));
if (customValidator != null) {
assertThat(claimValidators)
.anyMatch((v) -> "custom_claim".equals(ReflectionTestUtils.getField(v, "claim")));
}
}
@Test
@ -582,7 +622,7 @@ class OAuth2ResourceServerAutoConfigurationTests {
.run((context) -> {
assertThat(context).hasSingleBean(JwtDecoder.class);
JwtDecoder jwtDecoder = context.getBean(JwtDecoder.class);
validate(null, jwtDecoder);
validate(null, jwtDecoder, null);
});
}
@ -745,4 +785,14 @@ class OAuth2ResourceServerAutoConfigurationTests {
}
@Configuration(proxyBeanMethods = false)
static class CustomTokenValidatorsConfig {
@Bean
JwtClaimValidator<String> customJwtClaimValidator() {
return new JwtClaimValidator<>("custom_claim", "custom_claim_value"::equals);
}
}
}