diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderJwkSupport.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderJwkSupport.java index 96df899fda..dff7d9b95a 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderJwkSupport.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderJwkSupport.java @@ -15,13 +15,6 @@ */ package org.springframework.security.oauth2.jwt; -import java.net.MalformedURLException; -import java.net.URL; -import java.text.ParseException; -import java.time.Instant; -import java.util.LinkedHashMap; -import java.util.Map; - import com.nimbusds.jose.JWSAlgorithm; import com.nimbusds.jose.RemoteKeySourceException; import com.nimbusds.jose.jwk.source.JWKSource; @@ -29,7 +22,7 @@ import com.nimbusds.jose.jwk.source.RemoteJWKSet; import com.nimbusds.jose.proc.JWSKeySelector; import com.nimbusds.jose.proc.JWSVerificationKeySelector; import com.nimbusds.jose.proc.SecurityContext; -import com.nimbusds.jose.util.DefaultResourceRetriever; +import com.nimbusds.jose.util.Resource; import com.nimbusds.jose.util.ResourceRetriever; import com.nimbusds.jwt.JWT; import com.nimbusds.jwt.JWTClaimsSet; @@ -37,12 +30,27 @@ import com.nimbusds.jwt.JWTParser; import com.nimbusds.jwt.SignedJWT; import com.nimbusds.jwt.proc.ConfigurableJWTProcessor; import com.nimbusds.jwt.proc.DefaultJWTProcessor; - +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.http.RequestEntity; +import org.springframework.http.ResponseEntity; import org.springframework.security.oauth2.jose.jws.JwsAlgorithms; import org.springframework.util.Assert; +import org.springframework.web.client.RestOperations; +import org.springframework.web.client.RestTemplate; + +import java.io.IOException; +import java.net.MalformedURLException; +import java.net.URL; +import java.text.ParseException; +import java.time.Instant; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.Map; /** - * An implementation of a {@link JwtDecoder} that "decodes" a + * An implementation of a {@link JwtDecoder} that "decodes" a * JSON Web Token (JWT) and additionally verifies it's digital signature if the JWT is a * JSON Web Signature (JWS). The public key used for verification is obtained from the * JSON Web Key (JWK) Set {@code URL} supplied via the constructor. @@ -63,9 +71,9 @@ public final class NimbusJwtDecoderJwkSupport implements JwtDecoder { private static final String DECODING_ERROR_MESSAGE_TEMPLATE = "An error occurred while attempting to decode the Jwt: %s"; - private final URL jwkSetUrl; private final JWSAlgorithm jwsAlgorithm; private final ConfigurableJWTProcessor jwtProcessor; + private final RestOperationsResourceRetriever jwkSetRetriever = new RestOperationsResourceRetriever(); /** * Constructs a {@code NimbusJwtDecoderJwkSupport} using the provided parameters. @@ -85,18 +93,15 @@ public final class NimbusJwtDecoderJwkSupport implements JwtDecoder { public NimbusJwtDecoderJwkSupport(String jwkSetUrl, String jwsAlgorithm) { Assert.hasText(jwkSetUrl, "jwkSetUrl cannot be empty"); Assert.hasText(jwsAlgorithm, "jwsAlgorithm cannot be empty"); + JWKSource jwkSource; try { - this.jwkSetUrl = new URL(jwkSetUrl); + jwkSource = new RemoteJWKSet(new URL(jwkSetUrl), this.jwkSetRetriever); } catch (MalformedURLException ex) { - throw new IllegalArgumentException("Invalid JWK Set URL " + jwkSetUrl + " : " + ex.getMessage(), ex); + throw new IllegalArgumentException("Invalid JWK Set URL \"" + jwkSetUrl + "\" : " + ex.getMessage(), ex); } this.jwsAlgorithm = JWSAlgorithm.parse(jwsAlgorithm); - - ResourceRetriever jwkSetRetriever = new DefaultResourceRetriever(30000, 30000); - JWKSource jwkSource = new RemoteJWKSet(this.jwkSetUrl, jwkSetRetriever); JWSKeySelector jwsKeySelector = new JWSVerificationKeySelector<>(this.jwsAlgorithm, jwkSource); - this.jwtProcessor = new DefaultJWTProcessor<>(); this.jwtProcessor.setJWSKeySelector(jwsKeySelector); } @@ -104,10 +109,9 @@ public final class NimbusJwtDecoderJwkSupport implements JwtDecoder { @Override public Jwt decode(String token) throws JwtException { JWT jwt = this.parse(token); - if ( jwt instanceof SignedJWT ) { + if (jwt instanceof SignedJWT) { return this.createJwt(token, jwt); } - throw new JwtException("Unsupported algorithm of " + jwt.getHeader().getAlgorithm()); } @@ -158,4 +162,39 @@ public final class NimbusJwtDecoderJwkSupport implements JwtDecoder { return jwt; } + + /** + * Sets the {@link RestOperations} used when requesting the JSON Web Key (JWK) Set. + * + * @since 5.1 + * @param restOperations the {@link RestOperations} used when requesting the JSON Web Key (JWK) Set + */ + public final void setRestOperations(RestOperations restOperations) { + Assert.notNull(restOperations, "restOperations cannot be null"); + this.jwkSetRetriever.restOperations = restOperations; + } + + private static class RestOperationsResourceRetriever implements ResourceRetriever { + private RestOperations restOperations = new RestTemplate(); + + @Override + public Resource retrieveResource(URL url) throws IOException { + HttpHeaders headers = new HttpHeaders(); + headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON_UTF8)); + + ResponseEntity response; + try { + RequestEntity request = new RequestEntity<>(headers, HttpMethod.GET, url.toURI()); + response = this.restOperations.exchange(request, String.class); + } catch (Exception ex) { + throw new IOException(ex); + } + + if (response.getStatusCodeValue() != 200) { + throw new IOException(response.toString()); + } + + return new Resource(response.getBody(), "UTF-8"); + } + } } diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderJwkSupportTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderJwkSupportTests.java index 675445f634..ecd648f513 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderJwkSupportTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderJwkSupportTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,23 +24,22 @@ import com.nimbusds.jwt.SignedJWT; import com.nimbusds.jwt.proc.DefaultJWTProcessor; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; +import org.assertj.core.api.Assertions; import org.junit.Test; import org.junit.runner.RunWith; import org.powermock.core.classloader.annotations.PowerMockIgnore; import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.modules.junit4.PowerMockRunner; - +import org.springframework.http.RequestEntity; import org.springframework.security.oauth2.jose.jws.JwsAlgorithms; +import org.springframework.web.client.RestTemplate; import static org.assertj.core.api.AssertionsForClassTypes.assertThatCode; import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.*; import static org.mockito.Mockito.mock; -import static org.powermock.api.mockito.PowerMockito.mockStatic; -import static org.powermock.api.mockito.PowerMockito.when; -import static org.powermock.api.mockito.PowerMockito.whenNew; +import static org.mockito.Mockito.verify; +import static org.powermock.api.mockito.PowerMockito.*; /** * Tests for {@link NimbusJwtDecoderJwkSupport}. @@ -62,6 +61,8 @@ public class NimbusJwtDecoderJwkSupportTests { private static final String MALFORMED_JWT = "eyJhbGciOiJSUzI1NiJ9.eyJuYmYiOnt9LCJleHAiOjQ2ODQyMjUwODd9.guoQvujdWvd3xw7FYQEn4D6-gzM_WqFvXdmvAUNSLbxG7fv2_LLCNujPdrBHJoYPbOwS1BGNxIKQWS1tylvqzmr1RohQ-RZ2iAM1HYQzboUlkoMkcd8ENM__ELqho8aNYBfqwkNdUOyBFoy7Syu_w2SoJADw2RTjnesKO6CVVa05bW118pDS4xWxqC4s7fnBjmZoTn4uQ-Kt9YSQZQk8YQxkJSiyanozzgyfgXULA6mPu1pTNU3FVFaK1i1av_xtH_zAPgb647ZeaNe4nahgqC5h8nhOlm8W2dndXbwAt29nd2ZWBsru_QwZz83XSKLhTPFz-mPBByZZDsyBbIHf9A"; private static final String UNSIGNED_JWT = "eyJhbGciOiJub25lIiwidHlwIjoiSldUIn0.eyJleHAiOi0yMDMzMjI0OTcsImp0aSI6IjEyMyIsInR5cCI6IkpXVCJ9."; + private NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(JWK_SET_URL, JWS_ALGORITHM); + @Test public void constructorWhenJwkSetUrlIsNullThenThrowIllegalArgumentException() { assertThatThrownBy(() -> new NimbusJwtDecoderJwkSupport(null)) @@ -80,10 +81,15 @@ public class NimbusJwtDecoderJwkSupportTests { .isInstanceOf(IllegalArgumentException.class); } + @Test + public void setRestOperationsWhenNullThenThrowIllegalArgumentException() { + Assertions.assertThatThrownBy(() -> this.jwtDecoder.setRestOperations(null)) + .isInstanceOf(IllegalArgumentException.class); + } + @Test public void decodeWhenJwtInvalidThenThrowJwtException() { - NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(JWK_SET_URL, JWS_ALGORITHM); - assertThatThrownBy(() -> jwtDecoder.decode("invalid")) + assertThatThrownBy(() -> this.jwtDecoder.decode("invalid")) .isInstanceOf(JwtException.class); } @@ -103,16 +109,14 @@ public class NimbusJwtDecoderJwkSupportTests { JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().audience("resource1").build(); when(jwtProcessor.process(any(JWT.class), eq(null))).thenReturn(jwtClaimsSet); - NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(JWK_SET_URL, JWS_ALGORITHM); + NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(JWK_SET_URL); assertThatCode(() -> jwtDecoder.decode("encoded-jwt")).doesNotThrowAnyException(); } // gh-5457 @Test - public void decodeWhenPlainJwtThenExceptionDoesNotMentionClass() throws Exception { - NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(JWK_SET_URL, JWS_ALGORITHM); - - assertThatCode(() -> jwtDecoder.decode(UNSIGNED_JWT)) + public void decodeWhenPlainJwtThenExceptionDoesNotMentionClass() { + assertThatCode(() -> this.jwtDecoder.decode(UNSIGNED_JWT)) .isInstanceOf(JwtException.class) .hasMessageContaining("Unsupported algorithm of none"); } @@ -122,12 +126,11 @@ public class NimbusJwtDecoderJwkSupportTests { try ( MockWebServer server = new MockWebServer() ) { server.enqueue(new MockResponse().setBody(JWK_SET)); String jwkSetUrl = server.url("/.well-known/jwks.json").toString(); - - NimbusJwtDecoderJwkSupport decoder = new NimbusJwtDecoderJwkSupport(jwkSetUrl); - - assertThatCode(() -> decoder.decode(MALFORMED_JWT)) + NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(jwkSetUrl); + assertThatCode(() -> jwtDecoder.decode(MALFORMED_JWT)) .isInstanceOf(JwtException.class) .hasMessage("An error occurred while attempting to decode the Jwt: Malformed payload"); + server.shutdown(); } } @@ -136,28 +139,39 @@ public class NimbusJwtDecoderJwkSupportTests { try ( MockWebServer server = new MockWebServer() ) { server.enqueue(new MockResponse().setBody(MALFORMED_JWK_SET)); String jwkSetUrl = server.url("/.well-known/jwks.json").toString(); - - NimbusJwtDecoderJwkSupport decoder = new NimbusJwtDecoderJwkSupport(jwkSetUrl); - - assertThatCode(() -> decoder.decode(SIGNED_JWT)) + NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(jwkSetUrl); + assertThatCode(() -> jwtDecoder.decode(SIGNED_JWT)) .isInstanceOf(JwtException.class) .hasMessage("An error occurred while attempting to decode the Jwt: Malformed Jwk set"); + server.shutdown(); } } @Test - public void decodeWhenJwkEndpointIsUnresponsiveThenRetrunsJwtException() throws Exception { + public void decodeWhenJwkEndpointIsUnresponsiveThenReturnsJwtException() throws Exception { try ( MockWebServer server = new MockWebServer() ) { server.enqueue(new MockResponse().setBody(MALFORMED_JWK_SET)); String jwkSetUrl = server.url("/.well-known/jwks.json").toString(); - - NimbusJwtDecoderJwkSupport decoder = new NimbusJwtDecoderJwkSupport(jwkSetUrl); - - server.shutdown(); - - assertThatCode(() -> decoder.decode(SIGNED_JWT)) + NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(jwkSetUrl); + assertThatCode(() -> jwtDecoder.decode(SIGNED_JWT)) .isInstanceOf(JwtException.class) .hasMessageContaining("An error occurred while attempting to decode the Jwt"); + server.shutdown(); + } + } + + // gh-5603 + @Test + public void decodeWhenCustomRestOperationsSetThenUsed() throws Exception { + try ( MockWebServer server = new MockWebServer() ) { + server.enqueue(new MockResponse().setBody(JWK_SET)); + String jwkSetUrl = server.url("/.well-known/jwks.json").toString(); + NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(jwkSetUrl); + RestTemplate restTemplate = spy(new RestTemplate()); + jwtDecoder.setRestOperations(restTemplate); + assertThatCode(() -> jwtDecoder.decode(SIGNED_JWT)).doesNotThrowAnyException(); + verify(restTemplate).exchange(any(RequestEntity.class), eq(String.class)); + server.shutdown(); } } }