diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java index ca7d1ec451..2b51e37864 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java @@ -17,6 +17,7 @@ package org.springframework.security.oauth2.jwt; import java.security.interfaces.RSAPublicKey; import java.time.Instant; +import java.util.Collections; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -40,6 +41,7 @@ import com.nimbusds.jwt.proc.DefaultJWTProcessor; import com.nimbusds.jwt.proc.JWTProcessor; import reactor.core.publisher.Mono; +import org.springframework.core.convert.converter.Converter; import org.springframework.security.oauth2.core.OAuth2TokenValidator; import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult; import org.springframework.security.oauth2.jose.jws.JwsAlgorithms; @@ -70,6 +72,8 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { private final JWKSelectorFactory jwkSelectorFactory; private OAuth2TokenValidator jwtValidator = JwtValidators.createDefault(); + private Converter, Map> claimSetConverter = MappedJwtClaimSetConverter + .withDefaults(Collections.emptyMap()); public NimbusReactiveJwtDecoder(RSAPublicKey publicKey) { JWSAlgorithm algorithm = JWSAlgorithm.parse(JwsAlgorithms.RS256); @@ -122,6 +126,16 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { this.jwtValidator = jwtValidator; } + /** + * Use the following {@link Converter} for manipulating the JWT's claim set + * + * @param claimSetConverter the {@link Converter} to use + */ + public void setClaimSetConverter(Converter, Map> claimSetConverter) { + Assert.notNull(claimSetConverter, "claimSetConverter cannot be null"); + this.claimSetConverter = claimSetConverter; + } + @Override public Mono decode(String token) throws JwtException { JWT jwt = parse(token); @@ -164,21 +178,12 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder { } private Jwt createJwt(JWT parsedJwt, JWTClaimsSet jwtClaimsSet) { - Instant expiresAt = null; - if (jwtClaimsSet.getExpirationTime() != null) { - expiresAt = jwtClaimsSet.getExpirationTime().toInstant(); - } - Instant issuedAt = null; - if (jwtClaimsSet.getIssueTime() != null) { - issuedAt = jwtClaimsSet.getIssueTime().toInstant(); - } else if (expiresAt != null) { - // Default to expiresAt - 1 second - issuedAt = Instant.from(expiresAt).minusSeconds(1); - } - Map headers = new LinkedHashMap<>(parsedJwt.getHeader().toJSONObject()); + Map claims = this.claimSetConverter.convert(jwtClaimsSet.getClaims()); - return new Jwt(parsedJwt.getParsedString(), issuedAt, expiresAt, headers, jwtClaimsSet.getClaims()); + Instant expiresAt = (Instant) claims.get(JwtClaimNames.EXP); + Instant issuedAt = (Instant) claims.get(JwtClaimNames.IAT); + return new Jwt(parsedJwt.getParsedString(), issuedAt, expiresAt, headers, claims); } private Jwt validateJwt(Jwt jwt) { diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java index a22d013f5b..14df7c87e8 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java @@ -20,8 +20,10 @@ import java.net.UnknownHostException; import java.security.KeyFactory; import java.security.interfaces.RSAPublicKey; import java.security.spec.X509EncodedKeySpec; +import java.time.Instant; import java.util.Base64; -import java.util.Date; +import java.util.Collections; +import java.util.Map; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; @@ -29,6 +31,7 @@ import org.junit.After; import org.junit.Before; import org.junit.Test; +import org.springframework.core.convert.converter.Converter; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2TokenValidator; import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult; @@ -37,6 +40,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; /** @@ -115,7 +119,7 @@ public class NimbusReactiveJwtDecoderTests { Jwt jwt = this.decoder.decode(withIssuedAt).block(); - assertThat(jwt.getClaims().get(JwtClaimNames.IAT)).isEqualTo(new Date(1529942448000L)); + assertThat(jwt.getClaims().get(JwtClaimNames.IAT)).isEqualTo(Instant.ofEpochSecond(1529942448L)); } @Test @@ -177,9 +181,28 @@ public class NimbusReactiveJwtDecoderTests { .hasMessageContaining("mock-description"); } + @Test + public void decodeWhenUsingSignedJwtThenReturnsClaimsGivenByClaimSetConverter() { + Converter, Map> claimSetConverter = mock(Converter.class); + this.decoder.setClaimSetConverter(claimSetConverter); + + when(claimSetConverter.convert(any(Map.class))).thenReturn(Collections.singletonMap("custom", "value")); + + Jwt jwt = this.decoder.decode(this.messageReadToken).block(); + assertThat(jwt.getClaims().size()).isEqualTo(1); + assertThat(jwt.getClaims().get("custom")).isEqualTo("value"); + verify(claimSetConverter).convert(any(Map.class)); + } + @Test public void setJwtValidatorWhenGivenNullThrowsIllegalArgumentException() { assertThatCode(() -> this.decoder.setJwtValidator(null)) .isInstanceOf(IllegalArgumentException.class); } + + @Test + public void setClaimSetConverterWhenNullThrowsIllegalArgumentException() { + assertThatCode(() -> this.decoder.setClaimSetConverter(null)) + .isInstanceOf(IllegalArgumentException.class); + } }