Reactive Jwt Validation
This allows a user to customize the Jwt validation steps that NimbusReactiveJwtDecoder will take for each Jwt. Fixes: gh-5650
This commit is contained in:
parent
53652584b2
commit
01443e35b4
|
@ -40,6 +40,8 @@ import com.nimbusds.jwt.proc.DefaultJWTProcessor;
|
||||||
import com.nimbusds.jwt.proc.JWTProcessor;
|
import com.nimbusds.jwt.proc.JWTProcessor;
|
||||||
import reactor.core.publisher.Mono;
|
import reactor.core.publisher.Mono;
|
||||||
|
|
||||||
|
import org.springframework.security.oauth2.core.OAuth2TokenValidator;
|
||||||
|
import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult;
|
||||||
import org.springframework.security.oauth2.jose.jws.JwsAlgorithms;
|
import org.springframework.security.oauth2.jose.jws.JwsAlgorithms;
|
||||||
import org.springframework.util.Assert;
|
import org.springframework.util.Assert;
|
||||||
|
|
||||||
|
@ -67,6 +69,8 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
|
||||||
|
|
||||||
private final JWKSelectorFactory jwkSelectorFactory;
|
private final JWKSelectorFactory jwkSelectorFactory;
|
||||||
|
|
||||||
|
private OAuth2TokenValidator<Jwt> jwtValidator = JwtValidators.createDefault();
|
||||||
|
|
||||||
public NimbusReactiveJwtDecoder(RSAPublicKey publicKey) {
|
public NimbusReactiveJwtDecoder(RSAPublicKey publicKey) {
|
||||||
JWSAlgorithm algorithm = JWSAlgorithm.parse(JwsAlgorithms.RS256);
|
JWSAlgorithm algorithm = JWSAlgorithm.parse(JwsAlgorithms.RS256);
|
||||||
|
|
||||||
|
@ -77,6 +81,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
|
||||||
new JWSVerificationKeySelector<>(algorithm, jwkSource);
|
new JWSVerificationKeySelector<>(algorithm, jwkSource);
|
||||||
DefaultJWTProcessor jwtProcessor = new DefaultJWTProcessor<>();
|
DefaultJWTProcessor jwtProcessor = new DefaultJWTProcessor<>();
|
||||||
jwtProcessor.setJWSKeySelector(jwsKeySelector);
|
jwtProcessor.setJWSKeySelector(jwsKeySelector);
|
||||||
|
jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> {});
|
||||||
|
|
||||||
this.jwtProcessor = jwtProcessor;
|
this.jwtProcessor = jwtProcessor;
|
||||||
this.reactiveJwkSource = new ReactiveJWKSourceAdapter(jwkSource);
|
this.reactiveJwkSource = new ReactiveJWKSourceAdapter(jwkSource);
|
||||||
|
@ -98,6 +103,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
|
||||||
|
|
||||||
DefaultJWTProcessor<JWKContext> jwtProcessor = new DefaultJWTProcessor<>();
|
DefaultJWTProcessor<JWKContext> jwtProcessor = new DefaultJWTProcessor<>();
|
||||||
jwtProcessor.setJWSKeySelector(jwsKeySelector);
|
jwtProcessor.setJWSKeySelector(jwsKeySelector);
|
||||||
|
jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> {});
|
||||||
this.jwtProcessor = jwtProcessor;
|
this.jwtProcessor = jwtProcessor;
|
||||||
|
|
||||||
this.reactiveJwkSource = new ReactiveRemoteJWKSource(jwkSetUrl);
|
this.reactiveJwkSource = new ReactiveRemoteJWKSource(jwkSetUrl);
|
||||||
|
@ -106,6 +112,16 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Use the provided {@link OAuth2TokenValidator} to validate incoming {@link Jwt}s.
|
||||||
|
*
|
||||||
|
* @param jwtValidator the {@link OAuth2TokenValidator} to use
|
||||||
|
*/
|
||||||
|
public void setJwtValidator(OAuth2TokenValidator<Jwt> jwtValidator) {
|
||||||
|
Assert.notNull(jwtValidator, "jwtValidator cannot be null");
|
||||||
|
this.jwtValidator = jwtValidator;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Mono<Jwt> decode(String token) throws JwtException {
|
public Mono<Jwt> decode(String token) throws JwtException {
|
||||||
JWT jwt = parse(token);
|
JWT jwt = parse(token);
|
||||||
|
@ -131,7 +147,8 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
|
||||||
.onErrorMap(e -> new IllegalStateException("Could not obtain the keys", e))
|
.onErrorMap(e -> new IllegalStateException("Could not obtain the keys", e))
|
||||||
.map(jwkList -> createClaimsSet(parsedToken, jwkList))
|
.map(jwkList -> createClaimsSet(parsedToken, jwkList))
|
||||||
.map(set -> createJwt(parsedToken, set))
|
.map(set -> createJwt(parsedToken, set))
|
||||||
.onErrorMap(e -> !(e instanceof IllegalStateException), e -> new JwtException("An error occurred while attempting to decode the Jwt: ", e));
|
.map(this::validateJwt)
|
||||||
|
.onErrorMap(e -> !(e instanceof IllegalStateException) && !(e instanceof JwtException), e -> new JwtException("An error occurred while attempting to decode the Jwt: ", e));
|
||||||
} catch (RuntimeException ex) {
|
} catch (RuntimeException ex) {
|
||||||
throw new JwtException("An error occurred while attempting to decode the Jwt: " + ex.getMessage(), ex);
|
throw new JwtException("An error occurred while attempting to decode the Jwt: " + ex.getMessage(), ex);
|
||||||
}
|
}
|
||||||
|
@ -164,6 +181,17 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
|
||||||
return new Jwt(parsedJwt.getParsedString(), issuedAt, expiresAt, headers, jwtClaimsSet.getClaims());
|
return new Jwt(parsedJwt.getParsedString(), issuedAt, expiresAt, headers, jwtClaimsSet.getClaims());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private Jwt validateJwt(Jwt jwt) {
|
||||||
|
OAuth2TokenValidatorResult result = this.jwtValidator.validate(jwt);
|
||||||
|
|
||||||
|
if ( result.hasErrors() ) {
|
||||||
|
String message = result.getErrors().iterator().next().getDescription();
|
||||||
|
throw new JwtValidationException(message, result.getErrors());
|
||||||
|
}
|
||||||
|
|
||||||
|
return jwt;
|
||||||
|
}
|
||||||
|
|
||||||
private static RSAKey rsaKey(RSAPublicKey publicKey) {
|
private static RSAKey rsaKey(RSAPublicKey publicKey) {
|
||||||
return new RSAKey.Builder(publicKey)
|
return new RSAKey.Builder(publicKey)
|
||||||
.build();
|
.build();
|
||||||
|
|
|
@ -16,12 +16,6 @@
|
||||||
|
|
||||||
package org.springframework.security.oauth2.jwt;
|
package org.springframework.security.oauth2.jwt;
|
||||||
|
|
||||||
import okhttp3.mockwebserver.MockResponse;
|
|
||||||
import okhttp3.mockwebserver.MockWebServer;
|
|
||||||
import org.junit.After;
|
|
||||||
import org.junit.Before;
|
|
||||||
import org.junit.Test;
|
|
||||||
|
|
||||||
import java.net.UnknownHostException;
|
import java.net.UnknownHostException;
|
||||||
import java.security.KeyFactory;
|
import java.security.KeyFactory;
|
||||||
import java.security.interfaces.RSAPublicKey;
|
import java.security.interfaces.RSAPublicKey;
|
||||||
|
@ -29,8 +23,21 @@ import java.security.spec.X509EncodedKeySpec;
|
||||||
import java.util.Base64;
|
import java.util.Base64;
|
||||||
import java.util.Date;
|
import java.util.Date;
|
||||||
|
|
||||||
|
import okhttp3.mockwebserver.MockResponse;
|
||||||
|
import okhttp3.mockwebserver.MockWebServer;
|
||||||
|
import org.junit.After;
|
||||||
|
import org.junit.Before;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
import org.springframework.security.oauth2.core.OAuth2Error;
|
||||||
|
import org.springframework.security.oauth2.core.OAuth2TokenValidator;
|
||||||
|
import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult;
|
||||||
|
|
||||||
import static org.assertj.core.api.Assertions.assertThat;
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
import static org.assertj.core.api.Assertions.assertThatCode;
|
import static org.assertj.core.api.Assertions.assertThatCode;
|
||||||
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author Rob Winch
|
* @author Rob Winch
|
||||||
|
@ -114,7 +121,7 @@ public class NimbusReactiveJwtDecoderTests {
|
||||||
@Test
|
@Test
|
||||||
public void decodeWhenExpiredThenFail() {
|
public void decodeWhenExpiredThenFail() {
|
||||||
assertThatCode(() -> this.decoder.decode(this.expired).block())
|
assertThatCode(() -> this.decoder.decode(this.expired).block())
|
||||||
.isInstanceOf(JwtException.class);
|
.isInstanceOf(JwtValidationException.class);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -155,4 +162,24 @@ public class NimbusReactiveJwtDecoderTests {
|
||||||
.isInstanceOf(JwtException.class)
|
.isInstanceOf(JwtException.class)
|
||||||
.hasMessage("Unsupported algorithm of none");
|
.hasMessage("Unsupported algorithm of none");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void decodeWhenUsingCustomValidatorThenValidatorIsInvoked() {
|
||||||
|
OAuth2TokenValidator jwtValidator = mock(OAuth2TokenValidator.class);
|
||||||
|
this.decoder.setJwtValidator(jwtValidator);
|
||||||
|
|
||||||
|
OAuth2Error error = new OAuth2Error("mock-error", "mock-description", "mock-uri");
|
||||||
|
OAuth2TokenValidatorResult result = OAuth2TokenValidatorResult.failure(error);
|
||||||
|
when(jwtValidator.validate(any(Jwt.class))).thenReturn(result);
|
||||||
|
|
||||||
|
assertThatCode(() -> this.decoder.decode(messageReadToken).block())
|
||||||
|
.isInstanceOf(JwtException.class)
|
||||||
|
.hasMessageContaining("mock-description");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void setJwtValidatorWhenGivenNullThrowsIllegalArgumentException() {
|
||||||
|
assertThatCode(() -> this.decoder.setJwtValidator(null))
|
||||||
|
.isInstanceOf(IllegalArgumentException.class);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue