From 525f40490ceb4e268b8a3b929e7a059c5ae8ee8b Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Fri, 7 Jan 2022 13:23:02 -0500 Subject: [PATCH] Allow Jwt assertion to be resolved Closes gh-9812 --- .../oauth2/client/authorization-grants.adoc | 6 ++++ .../oauth2/client/authorization-grants.adoc | 6 ++++ ...tBearerOAuth2AuthorizedClientProvider.java | 27 ++++++++++++-- ...eactiveOAuth2AuthorizedClientProvider.java | 36 +++++++++++++++---- ...erOAuth2AuthorizedClientProviderTests.java | 36 +++++++++++++++++-- ...veOAuth2AuthorizedClientProviderTests.java | 35 ++++++++++++++++-- 6 files changed, 131 insertions(+), 15 deletions(-) diff --git a/docs/modules/ROOT/pages/reactive/oauth2/client/authorization-grants.adoc b/docs/modules/ROOT/pages/reactive/oauth2/client/authorization-grants.adoc index 11fe4d541b..3300572114 100644 --- a/docs/modules/ROOT/pages/reactive/oauth2/client/authorization-grants.adoc +++ b/docs/modules/ROOT/pages/reactive/oauth2/client/authorization-grants.adoc @@ -1098,3 +1098,9 @@ class OAuth2ResourceServerController { } ---- ==== + +[NOTE] +`JwtBearerReactiveOAuth2AuthorizedClientProvider` resolves the `Jwt` assertion via `OAuth2AuthorizationContext.getPrincipal().getPrincipal()` by default, hence the use of `JwtAuthenticationToken` in the preceding example. + +[TIP] +If you need to resolve the `Jwt` assertion from a different source, you can provide `JwtBearerReactiveOAuth2AuthorizedClientProvider.setJwtAssertionResolver()` with a custom `Function>`. diff --git a/docs/modules/ROOT/pages/servlet/oauth2/client/authorization-grants.adoc b/docs/modules/ROOT/pages/servlet/oauth2/client/authorization-grants.adoc index 6f42e92e04..d9317ec075 100644 --- a/docs/modules/ROOT/pages/servlet/oauth2/client/authorization-grants.adoc +++ b/docs/modules/ROOT/pages/servlet/oauth2/client/authorization-grants.adoc @@ -1352,3 +1352,9 @@ class OAuth2ResourceServerController { } ---- ==== + +[NOTE] +`JwtBearerOAuth2AuthorizedClientProvider` resolves the `Jwt` assertion via `OAuth2AuthorizationContext.getPrincipal().getPrincipal()` by default, hence the use of `JwtAuthenticationToken` in the preceding example. + +[TIP] +If you need to resolve the `Jwt` assertion from a different source, you can provide `JwtBearerOAuth2AuthorizedClientProvider.setJwtAssertionResolver()` with a custom `Function`. diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/JwtBearerOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/JwtBearerOAuth2AuthorizedClientProvider.java index 87accc63a3..857f38af0b 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/JwtBearerOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/JwtBearerOAuth2AuthorizedClientProvider.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ package org.springframework.security.oauth2.client; import java.time.Clock; import java.time.Duration; import java.time.Instant; +import java.util.function.Function; import org.springframework.lang.Nullable; import org.springframework.security.oauth2.client.endpoint.DefaultJwtBearerTokenResponseClient; @@ -45,6 +46,8 @@ public final class JwtBearerOAuth2AuthorizedClientProvider implements OAuth2Auth private OAuth2AccessTokenResponseClient accessTokenResponseClient = new DefaultJwtBearerTokenResponseClient(); + private Function jwtAssertionResolver = this::resolveJwtAssertion; + private Duration clockSkew = Duration.ofSeconds(60); private Clock clock = Clock.systemUTC(); @@ -75,10 +78,10 @@ public final class JwtBearerOAuth2AuthorizedClientProvider implements OAuth2Auth // need for re-authorization return null; } - if (!(context.getPrincipal().getPrincipal() instanceof Jwt)) { + Jwt jwt = this.jwtAssertionResolver.apply(context); + if (jwt == null) { return null; } - Jwt jwt = (Jwt) context.getPrincipal().getPrincipal(); // As per spec, in section 4.1 Using Assertions as Authorization Grants // https://tools.ietf.org/html/rfc7521#section-4.1 // @@ -97,6 +100,13 @@ public final class JwtBearerOAuth2AuthorizedClientProvider implements OAuth2Auth tokenResponse.getAccessToken()); } + private Jwt resolveJwtAssertion(OAuth2AuthorizationContext context) { + if (!(context.getPrincipal().getPrincipal() instanceof Jwt)) { + return null; + } + return (Jwt) context.getPrincipal().getPrincipal(); + } + private OAuth2AccessTokenResponse getTokenResponse(ClientRegistration clientRegistration, JwtBearerGrantRequest jwtBearerGrantRequest) { try { @@ -123,6 +133,17 @@ public final class JwtBearerOAuth2AuthorizedClientProvider implements OAuth2Auth this.accessTokenResponseClient = accessTokenResponseClient; } + /** + * Sets the resolver used for resolving the {@link Jwt} assertion. + * @param jwtAssertionResolver the resolver used for resolving the {@link Jwt} + * assertion + * @since 5.7 + */ + public void setJwtAssertionResolver(Function jwtAssertionResolver) { + Assert.notNull(jwtAssertionResolver, "jwtAssertionResolver cannot be null"); + this.jwtAssertionResolver = jwtAssertionResolver; + } + /** * Sets the maximum acceptable clock skew, which is used when checking the * {@link OAuth2AuthorizedClient#getAccessToken() access token} expiry. The default is diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/JwtBearerReactiveOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/JwtBearerReactiveOAuth2AuthorizedClientProvider.java index eb60c3c4bb..a15da34b3c 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/JwtBearerReactiveOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/JwtBearerReactiveOAuth2AuthorizedClientProvider.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ package org.springframework.security.oauth2.client; import java.time.Clock; import java.time.Duration; import java.time.Instant; +import java.util.function.Function; import reactor.core.publisher.Mono; @@ -45,6 +46,8 @@ public final class JwtBearerReactiveOAuth2AuthorizedClientProvider implements Re private ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient = new WebClientReactiveJwtBearerTokenResponseClient(); + private Function> jwtAssertionResolver = this::resolveJwtAssertion; + private Duration clockSkew = Duration.ofSeconds(60); private Clock clock = Clock.systemUTC(); @@ -74,10 +77,7 @@ public final class JwtBearerReactiveOAuth2AuthorizedClientProvider implements Re // need for re-authorization return Mono.empty(); } - if (!(context.getPrincipal().getPrincipal() instanceof Jwt)) { - return Mono.empty(); - } - Jwt jwt = (Jwt) context.getPrincipal().getPrincipal(); + // As per spec, in section 4.1 Using Assertions as Authorization Grants // https://tools.ietf.org/html/rfc7521#section-4.1 // @@ -90,13 +90,26 @@ public final class JwtBearerReactiveOAuth2AuthorizedClientProvider implements Re // issued with a reasonably short lifetime. Clients can refresh an // expired access token by requesting a new one using the same // assertion, if it is still valid, or with a new assertion. - return Mono.just(new JwtBearerGrantRequest(clientRegistration, jwt)) + + // @formatter:off + return this.jwtAssertionResolver.apply(context) + .map((jwt) -> new JwtBearerGrantRequest(clientRegistration, jwt)) .flatMap(this.accessTokenResponseClient::getTokenResponse) .onErrorMap(OAuth2AuthorizationException.class, (ex) -> new ClientAuthorizationException(ex.getError(), clientRegistration.getRegistrationId(), ex)) .map((tokenResponse) -> new OAuth2AuthorizedClient(clientRegistration, context.getPrincipal().getName(), tokenResponse.getAccessToken())); + // @formatter:on + } + + private Mono resolveJwtAssertion(OAuth2AuthorizationContext context) { + // @formatter:off + return Mono.just(context) + .map((ctx) -> ctx.getPrincipal().getPrincipal()) + .filter((principal) -> principal instanceof Jwt) + .cast(Jwt.class); + // @formatter:on } private boolean hasTokenExpired(OAuth2Token token) { @@ -115,6 +128,17 @@ public final class JwtBearerReactiveOAuth2AuthorizedClientProvider implements Re this.accessTokenResponseClient = accessTokenResponseClient; } + /** + * Sets the resolver used for resolving the {@link Jwt} assertion. + * @param jwtAssertionResolver the resolver used for resolving the {@link Jwt} + * assertion + * @since 5.7 + */ + public void setJwtAssertionResolver(Function> jwtAssertionResolver) { + Assert.notNull(jwtAssertionResolver, "jwtAssertionResolver cannot be null"); + this.jwtAssertionResolver = jwtAssertionResolver; + } + /** * Sets the maximum acceptable clock skew, which is used when checking the * {@link OAuth2AuthorizedClient#getAccessToken() access token} expiry. The default is diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/JwtBearerOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/JwtBearerOAuth2AuthorizedClientProviderTests.java index 0ea5e25552..49d8dc416e 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/JwtBearerOAuth2AuthorizedClientProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/JwtBearerOAuth2AuthorizedClientProviderTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ package org.springframework.security.oauth2.client; import java.time.Duration; import java.time.Instant; +import java.util.function.Function; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -42,6 +43,7 @@ import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; /** * Tests for {@link JwtBearerOAuth2AuthorizedClientProvider}. @@ -87,6 +89,13 @@ public class JwtBearerOAuth2AuthorizedClientProviderTests { .withMessage("accessTokenResponseClient cannot be null"); } + @Test + public void setJwtAssertionResolverWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setJwtAssertionResolver(null)) + .withMessage("jwtAssertionResolver cannot be null"); + } + @Test public void setClockSkewWhenNullThenThrowIllegalArgumentException() { // @formatter:off @@ -198,7 +207,7 @@ public class JwtBearerOAuth2AuthorizedClientProviderTests { } @Test - public void authorizeWhenJwtBearerAndNotAuthorizedAndPrincipalNotJwtThenUnableToAuthorize() { + public void authorizeWhenJwtBearerAndNotAuthorizedAndJwtDoesNotResolveThenUnableToAuthorize() { // @formatter:off OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext .withClientRegistration(this.clientRegistration) @@ -209,7 +218,7 @@ public class JwtBearerOAuth2AuthorizedClientProviderTests { } @Test - public void authorizeWhenJwtBearerAndNotAuthorizedAndPrincipalJwtThenAuthorize() { + public void authorizeWhenJwtBearerAndNotAuthorizedAndJwtResolvesThenAuthorize() { OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); // @formatter:off @@ -224,4 +233,25 @@ public class JwtBearerOAuth2AuthorizedClientProviderTests { assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); } + @Test + public void authorizeWhenCustomJwtAssertionResolverSetThenUsed() { + Function jwtAssertionResolver = mock(Function.class); + given(jwtAssertionResolver.apply(any())).willReturn(this.jwtAssertion); + this.authorizedClientProvider.setJwtAssertionResolver(jwtAssertionResolver); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); + // @formatter:off + TestingAuthenticationToken principal = new TestingAuthenticationToken("user", "password"); + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(this.clientRegistration) + .principal(principal) + .build(); + // @formatter:on + OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); + verify(jwtAssertionResolver).apply(any()); + assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); + assertThat(authorizedClient.getPrincipalName()).isEqualTo(principal.getName()); + assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/JwtBearerReactiveOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/JwtBearerReactiveOAuth2AuthorizedClientProviderTests.java index 33279c6f94..2ec6e2f4a0 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/JwtBearerReactiveOAuth2AuthorizedClientProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/JwtBearerReactiveOAuth2AuthorizedClientProviderTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ package org.springframework.security.oauth2.client; import java.time.Clock; import java.time.Duration; import java.time.Instant; +import java.util.function.Function; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -93,6 +94,13 @@ public class JwtBearerReactiveOAuth2AuthorizedClientProviderTests { .withMessage("accessTokenResponseClient cannot be null"); } + @Test + public void setJwtAssertionResolverWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setJwtAssertionResolver(null)) + .withMessage("jwtAssertionResolver cannot be null"); + } + @Test public void setClockSkewWhenNullThenThrowIllegalArgumentException() { // @formatter:off @@ -222,7 +230,7 @@ public class JwtBearerReactiveOAuth2AuthorizedClientProviderTests { } @Test - public void authorizeWhenJwtBearerAndNotAuthorizedAndPrincipalNotJwtThenUnableToAuthorize() { + public void authorizeWhenJwtBearerAndNotAuthorizedAndJwtDoesNotResolveThenUnableToAuthorize() { // @formatter:off OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext .withClientRegistration(this.clientRegistration) @@ -251,7 +259,7 @@ public class JwtBearerReactiveOAuth2AuthorizedClientProviderTests { } @Test - public void authorizeWhenJwtBearerAndNotAuthorizedAndPrincipalJwtThenAuthorize() { + public void authorizeWhenJwtBearerAndNotAuthorizedAndJwtResolvesThenAuthorize() { OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse)); // @formatter:off @@ -266,4 +274,25 @@ public class JwtBearerReactiveOAuth2AuthorizedClientProviderTests { assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); } + @Test + public void authorizeWhenCustomJwtAssertionResolverSetThenUsed() { + Function> jwtAssertionResolver = mock(Function.class); + given(jwtAssertionResolver.apply(any())).willReturn(Mono.just(this.jwtAssertion)); + this.authorizedClientProvider.setJwtAssertionResolver(jwtAssertionResolver); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse)); + // @formatter:off + TestingAuthenticationToken principal = new TestingAuthenticationToken("user", "password"); + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(this.clientRegistration) + .principal(principal) + .build(); + // @formatter:on + OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext).block(); + verify(jwtAssertionResolver).apply(any()); + assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); + assertThat(authorizedClient.getPrincipalName()).isEqualTo(principal.getName()); + assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + } + }