Jwt client authentication converter detects new key
Closes gh-9814
This commit is contained in:
		
							parent
							
								
									700bda68b7
								
							
						
					
					
						commit
						6fbd038111
					
				| 
						 | 
					@ -80,7 +80,7 @@ public final class NimbusJwtClientAuthenticationParametersConverter<T extends Ab
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	private final Function<ClientRegistration, JWK> jwkResolver;
 | 
						private final Function<ClientRegistration, JWK> jwkResolver;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	private final Map<String, NimbusJwsEncoder> jwsEncoders = new ConcurrentHashMap<>();
 | 
						private final Map<String, JwsEncoderHolder> jwsEncoders = new ConcurrentHashMap<>();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	/**
 | 
						/**
 | 
				
			||||||
	 * Constructs a {@code NimbusJwtClientAuthenticationParametersConverter} using the
 | 
						 * Constructs a {@code NimbusJwtClientAuthenticationParametersConverter} using the
 | 
				
			||||||
| 
						 | 
					@ -140,12 +140,16 @@ public final class NimbusJwtClientAuthenticationParametersConverter<T extends Ab
 | 
				
			||||||
		JoseHeader joseHeader = headersBuilder.build();
 | 
							JoseHeader joseHeader = headersBuilder.build();
 | 
				
			||||||
		JwtClaimsSet jwtClaimsSet = claimsBuilder.build();
 | 
							JwtClaimsSet jwtClaimsSet = claimsBuilder.build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		NimbusJwsEncoder jwsEncoder = this.jwsEncoders.computeIfAbsent(clientRegistration.getRegistrationId(),
 | 
							JwsEncoderHolder jwsEncoderHolder = this.jwsEncoders.compute(clientRegistration.getRegistrationId(),
 | 
				
			||||||
				(clientRegistrationId) -> {
 | 
									(clientRegistrationId, currentJwsEncoderHolder) -> {
 | 
				
			||||||
 | 
										if (currentJwsEncoderHolder != null && currentJwsEncoderHolder.getJwk().equals(jwk)) {
 | 
				
			||||||
 | 
											return currentJwsEncoderHolder;
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
					JWKSource<SecurityContext> jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk));
 | 
										JWKSource<SecurityContext> jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk));
 | 
				
			||||||
					return new NimbusJwsEncoder(jwkSource);
 | 
										return new JwsEncoderHolder(new NimbusJwsEncoder(jwkSource), jwk);
 | 
				
			||||||
				});
 | 
									});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							NimbusJwsEncoder jwsEncoder = jwsEncoderHolder.getJwsEncoder();
 | 
				
			||||||
		Jwt jws = jwsEncoder.encode(joseHeader, jwtClaimsSet);
 | 
							Jwt jws = jwsEncoder.encode(joseHeader, jwtClaimsSet);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
 | 
							MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
 | 
				
			||||||
| 
						 | 
					@ -180,4 +184,25 @@ public final class NimbusJwtClientAuthenticationParametersConverter<T extends Ab
 | 
				
			||||||
		return jwsAlgorithm;
 | 
							return jwsAlgorithm;
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						private static final class JwsEncoderHolder {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							private final NimbusJwsEncoder jwsEncoder;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							private final JWK jwk;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							private JwsEncoderHolder(NimbusJwsEncoder jwsEncoder, JWK jwk) {
 | 
				
			||||||
 | 
								this.jwsEncoder = jwsEncoder;
 | 
				
			||||||
 | 
								this.jwk = jwk;
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							private NimbusJwsEncoder getJwsEncoder() {
 | 
				
			||||||
 | 
								return this.jwsEncoder;
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							private JWK getJwk() {
 | 
				
			||||||
 | 
								return this.jwk;
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -16,7 +16,12 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
package org.springframework.security.oauth2.client.endpoint;
 | 
					package org.springframework.security.oauth2.client.endpoint;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import java.security.KeyPair;
 | 
				
			||||||
 | 
					import java.security.KeyPairGenerator;
 | 
				
			||||||
 | 
					import java.security.interfaces.RSAPrivateKey;
 | 
				
			||||||
 | 
					import java.security.interfaces.RSAPublicKey;
 | 
				
			||||||
import java.util.Collections;
 | 
					import java.util.Collections;
 | 
				
			||||||
 | 
					import java.util.UUID;
 | 
				
			||||||
import java.util.function.Function;
 | 
					import java.util.function.Function;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import com.nimbusds.jose.jwk.JWK;
 | 
					import com.nimbusds.jose.jwk.JWK;
 | 
				
			||||||
| 
						 | 
					@ -42,6 +47,7 @@ import static org.assertj.core.api.Assertions.assertThat;
 | 
				
			||||||
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
 | 
					import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
 | 
				
			||||||
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 | 
					import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 | 
				
			||||||
import static org.mockito.ArgumentMatchers.any;
 | 
					import static org.mockito.ArgumentMatchers.any;
 | 
				
			||||||
 | 
					import static org.mockito.ArgumentMatchers.eq;
 | 
				
			||||||
import static org.mockito.BDDMockito.given;
 | 
					import static org.mockito.BDDMockito.given;
 | 
				
			||||||
import static org.mockito.Mockito.mock;
 | 
					import static org.mockito.Mockito.mock;
 | 
				
			||||||
import static org.mockito.Mockito.verifyNoInteractions;
 | 
					import static org.mockito.Mockito.verifyNoInteractions;
 | 
				
			||||||
| 
						 | 
					@ -172,4 +178,54 @@ public class NimbusJwtClientAuthenticationParametersConverterTests {
 | 
				
			||||||
		assertThat(jws.getExpiresAt()).isNotNull();
 | 
							assertThat(jws.getExpiresAt()).isNotNull();
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// gh-9814
 | 
				
			||||||
 | 
						@Test
 | 
				
			||||||
 | 
						public void convertWhenClientKeyChangesThenNewKeyUsed() throws Exception {
 | 
				
			||||||
 | 
							// @formatter:off
 | 
				
			||||||
 | 
							ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials()
 | 
				
			||||||
 | 
									.clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT)
 | 
				
			||||||
 | 
									.build();
 | 
				
			||||||
 | 
							// @formatter:on
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							RSAKey rsaJwk1 = TestJwks.DEFAULT_RSA_JWK;
 | 
				
			||||||
 | 
							given(this.jwkResolver.apply(eq(clientRegistration))).willReturn(rsaJwk1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest(
 | 
				
			||||||
 | 
									clientRegistration);
 | 
				
			||||||
 | 
							MultiValueMap<String, String> parameters = this.converter.convert(clientCredentialsGrantRequest);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							String encodedJws = parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION);
 | 
				
			||||||
 | 
							NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withPublicKey(rsaJwk1.toRSAPublicKey()).build();
 | 
				
			||||||
 | 
							jwtDecoder.decode(encodedJws);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							RSAKey rsaJwk2 = generateRsaJwk();
 | 
				
			||||||
 | 
							given(this.jwkResolver.apply(eq(clientRegistration))).willReturn(rsaJwk2);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							parameters = this.converter.convert(clientCredentialsGrantRequest);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							encodedJws = parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION);
 | 
				
			||||||
 | 
							jwtDecoder = NimbusJwtDecoder.withPublicKey(rsaJwk2.toRSAPublicKey()).build();
 | 
				
			||||||
 | 
							jwtDecoder.decode(encodedJws);
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						private static RSAKey generateRsaJwk() {
 | 
				
			||||||
 | 
							KeyPair keyPair;
 | 
				
			||||||
 | 
							try {
 | 
				
			||||||
 | 
								KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA");
 | 
				
			||||||
 | 
								keyPairGenerator.initialize(2048);
 | 
				
			||||||
 | 
								keyPair = keyPairGenerator.generateKeyPair();
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							catch (Exception ex) {
 | 
				
			||||||
 | 
								throw new IllegalStateException(ex);
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							RSAPublicKey publicKey = (RSAPublicKey) keyPair.getPublic();
 | 
				
			||||||
 | 
							RSAPrivateKey privateKey = (RSAPrivateKey) keyPair.getPrivate();
 | 
				
			||||||
 | 
							// @formatter:off
 | 
				
			||||||
 | 
							return new RSAKey.Builder(publicKey)
 | 
				
			||||||
 | 
									.privateKey(privateKey)
 | 
				
			||||||
 | 
									.keyID(UUID.randomUUID().toString())
 | 
				
			||||||
 | 
									.build();
 | 
				
			||||||
 | 
							// @formatter:on
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue