Look up ReactiveOAuth2AccessTokenResponseClient as a bean

Closes gh-11097
This commit is contained in:
Steve Riesenberg 2024-09-23 11:06:12 -05:00
parent 2763bbed33
commit cd7f6e09b0
No known key found for this signature in database
GPG Key ID: 3D0169B18AB8F0A9
2 changed files with 118 additions and 2 deletions

View File

@ -4813,11 +4813,22 @@ public class ServerHttpSecurity {
private ReactiveAuthenticationManager getAuthenticationManager() {
if (this.authenticationManager == null) {
this.authenticationManager = new OAuth2AuthorizationCodeReactiveAuthenticationManager(
new WebClientReactiveAuthorizationCodeTokenResponseClient());
getAuthorizationCodeTokenResponseClient());
}
return this.authenticationManager;
}
private ReactiveOAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> getAuthorizationCodeTokenResponseClient() {
ResolvableType resolvableType = ResolvableType.forClassWithGenerics(
ReactiveOAuth2AccessTokenResponseClient.class, OAuth2AuthorizationCodeGrantRequest.class);
ReactiveOAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient = getBeanOrNull(
resolvableType);
if (accessTokenResponseClient == null) {
accessTokenResponseClient = new WebClientReactiveAuthorizationCodeTokenResponseClient();
}
return accessTokenResponseClient;
}
/**
* Configures the {@link ReactiveClientRegistrationRepository}. Default is to look
* the value up as a Bean.

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -17,9 +17,11 @@
package org.springframework.security.config.web.server;
import java.net.URI;
import java.util.Set;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import reactor.core.publisher.Mono;
import org.springframework.beans.factory.annotation.Autowired;
@ -31,9 +33,12 @@ import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity;
import org.springframework.security.config.test.SpringTestContext;
import org.springframework.security.config.test.SpringTestContextExtension;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken;
import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.InMemoryReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
@ -41,8 +46,10 @@ import org.springframework.security.oauth2.client.registration.TestClientRegistr
import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationRequestResolver;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
@ -59,7 +66,9 @@ import org.springframework.test.web.reactive.server.WebTestClient;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.reactive.config.EnableWebFlux;
import org.springframework.web.server.ServerWebExchange;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
@ -215,6 +224,62 @@ public class OAuth2ClientSpecTests {
verify(requestCache).getRedirectUri(any());
}
@Test
@SuppressWarnings("unchecked")
public void oauth2ClientWhenCustomAccessTokenResponseClientThenUsed() {
this.spring.register(OAuth2ClientBeanConfig.class, AuthorizedClientController.class).autowire();
ReactiveClientRegistrationRepository clientRegistrationRepository = this.spring.getContext()
.getBean(ReactiveClientRegistrationRepository.class);
given(clientRegistrationRepository.findByRegistrationId(any())).willReturn(Mono.just(this.registration));
ServerOAuth2AuthorizedClientRepository authorizedClientRepository = this.spring.getContext()
.getBean(ServerOAuth2AuthorizedClientRepository.class);
given(authorizedClientRepository.saveAuthorizedClient(any(OAuth2AuthorizedClient.class),
any(Authentication.class), any(ServerWebExchange.class)))
.willReturn(Mono.empty());
ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository = this.spring
.getContext()
.getBean(ServerAuthorizationRequestRepository.class);
OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request()
.redirectUri("/authorize/oauth2/code/registration-id")
.build();
given(authorizationRequestRepository.loadAuthorizationRequest(any(ServerWebExchange.class)))
.willReturn(Mono.just(authorizationRequest));
given(authorizationRequestRepository.removeAuthorizationRequest(any(ServerWebExchange.class)))
.willReturn(Mono.just(authorizationRequest));
ReactiveOAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient = this.spring
.getContext()
.getBean(ReactiveOAuth2AccessTokenResponseClient.class);
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("token")
.tokenType(OAuth2AccessToken.TokenType.BEARER)
.scopes(Set.of())
.expiresIn(300)
.build();
given(accessTokenResponseClient.getTokenResponse(any(OAuth2AuthorizationCodeGrantRequest.class)))
.willReturn(Mono.just(accessTokenResponse));
// @formatter:off
this.client.get()
.uri((uriBuilder) -> uriBuilder
.path("/authorize/oauth2/code/registration-id")
.queryParam(OAuth2ParameterNames.CODE, "code")
.queryParam(OAuth2ParameterNames.STATE, "state")
.build()
)
.exchange()
.expectStatus().is3xxRedirection();
// @formatter:on
ArgumentCaptor<OAuth2AuthorizationCodeGrantRequest> grantRequestArgumentCaptor = ArgumentCaptor
.forClass(OAuth2AuthorizationCodeGrantRequest.class);
verify(accessTokenResponseClient).getTokenResponse(grantRequestArgumentCaptor.capture());
OAuth2AuthorizationCodeGrantRequest grantRequest = grantRequestArgumentCaptor.getValue();
assertThat(grantRequest.getClientRegistration()).isEqualTo(this.registration);
assertThat(grantRequest.getGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);
assertThat(grantRequest.getAuthorizationExchange().getAuthorizationRequest()).isEqualTo(authorizationRequest);
assertThat(grantRequest.getAuthorizationExchange().getAuthorizationResponse().getCode()).isEqualTo("code");
assertThat(grantRequest.getAuthorizationExchange().getAuthorizationResponse().getState()).isEqualTo("state");
assertThat(grantRequest.getAuthorizationExchange().getAuthorizationResponse().getRedirectUri())
.startsWith("/authorize/oauth2/code/registration-id");
}
@Configuration
@EnableWebFlux
@EnableWebFluxSecurity
@ -324,4 +389,44 @@ public class OAuth2ClientSpecTests {
}
@Configuration
@EnableWebFlux
@EnableWebFluxSecurity
static class OAuth2ClientBeanConfig {
@Bean
SecurityWebFilterChain securityWebFilterChain(ServerHttpSecurity http) {
// @formatter:off
http
.oauth2Client((oauth2Client) -> oauth2Client
.authorizationRequestRepository(authorizationRequestRepository())
);
// @formatter:on
return http.build();
}
@Bean
@SuppressWarnings("unchecked")
ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository() {
return mock(ServerAuthorizationRequestRepository.class);
}
@Bean
@SuppressWarnings("unchecked")
ReactiveOAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> authorizationCodeAccessTokenResponseClient() {
return mock(ReactiveOAuth2AccessTokenResponseClient.class);
}
@Bean
ReactiveClientRegistrationRepository clientRegistrationRepository() {
return mock(ReactiveClientRegistrationRepository.class);
}
@Bean
ServerOAuth2AuthorizedClientRepository authorizedClientRepository() {
return mock(ServerOAuth2AuthorizedClientRepository.class);
}
}
}