diff --git a/config/src/test/kotlin/org/springframework/security/config/annotation/web/OAuth2ClientDslTests.kt b/config/src/test/kotlin/org/springframework/security/config/annotation/web/OAuth2ClientDslTests.kt index 4d709dbe80..a01ae7dc53 100644 --- a/config/src/test/kotlin/org/springframework/security/config/annotation/web/OAuth2ClientDslTests.kt +++ b/config/src/test/kotlin/org/springframework/security/config/annotation/web/OAuth2ClientDslTests.kt @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,9 +29,9 @@ import org.springframework.security.config.annotation.web.configuration.EnableWe import org.springframework.security.config.oauth2.client.CommonOAuth2Provider import org.springframework.security.config.test.SpringTestContext import org.springframework.security.config.test.SpringTestContextExtension -import org.springframework.security.oauth2.client.endpoint.DefaultAuthorizationCodeTokenResponseClient import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest +import org.springframework.security.oauth2.client.endpoint.RestClientAuthorizationCodeTokenResponseClient import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository @@ -128,7 +128,7 @@ class OAuth2ClientDslTests { val REQUEST_REPOSITORY: AuthorizationRequestRepository = HttpSessionOAuth2AuthorizationRequestRepository() val CLIENT: OAuth2AccessTokenResponseClient = - DefaultAuthorizationCodeTokenResponseClient() + RestClientAuthorizationCodeTokenResponseClient() val CLIENT_REPOSITORY: OAuth2AuthorizedClientRepository = HttpSessionOAuth2AuthorizedClientRepository() } diff --git a/config/src/test/kotlin/org/springframework/security/config/annotation/web/oauth2/client/AuthorizationCodeGrantDslTests.kt b/config/src/test/kotlin/org/springframework/security/config/annotation/web/oauth2/client/AuthorizationCodeGrantDslTests.kt index 0786bab0aa..7983f614e5 100644 --- a/config/src/test/kotlin/org/springframework/security/config/annotation/web/oauth2/client/AuthorizationCodeGrantDslTests.kt +++ b/config/src/test/kotlin/org/springframework/security/config/annotation/web/oauth2/client/AuthorizationCodeGrantDslTests.kt @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,13 +27,13 @@ import org.springframework.context.annotation.Bean import org.springframework.context.annotation.Configuration import org.springframework.security.config.annotation.web.builders.HttpSecurity import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity +import org.springframework.security.config.annotation.web.invoke import org.springframework.security.config.oauth2.client.CommonOAuth2Provider import org.springframework.security.config.test.SpringTestContext import org.springframework.security.config.test.SpringTestContextExtension -import org.springframework.security.config.annotation.web.invoke -import org.springframework.security.oauth2.client.endpoint.DefaultAuthorizationCodeTokenResponseClient import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest +import org.springframework.security.oauth2.client.endpoint.RestClientAuthorizationCodeTokenResponseClient import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository @@ -175,7 +175,7 @@ class AuthorizationCodeGrantDslTests { val REQUEST_REPOSITORY: AuthorizationRequestRepository = HttpSessionOAuth2AuthorizationRequestRepository() val CLIENT: OAuth2AccessTokenResponseClient = - DefaultAuthorizationCodeTokenResponseClient() + RestClientAuthorizationCodeTokenResponseClient() } @Bean diff --git a/config/src/test/kotlin/org/springframework/security/config/annotation/web/oauth2/login/RedirectionEndpointDslTests.kt b/config/src/test/kotlin/org/springframework/security/config/annotation/web/oauth2/login/RedirectionEndpointDslTests.kt index 00fca143b6..6fff212a96 100644 --- a/config/src/test/kotlin/org/springframework/security/config/annotation/web/oauth2/login/RedirectionEndpointDslTests.kt +++ b/config/src/test/kotlin/org/springframework/security/config/annotation/web/oauth2/login/RedirectionEndpointDslTests.kt @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,14 +25,14 @@ import org.springframework.context.annotation.Bean import org.springframework.context.annotation.Configuration import org.springframework.security.config.annotation.web.builders.HttpSecurity import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity +import org.springframework.security.config.annotation.web.invoke import org.springframework.security.config.oauth2.client.CommonOAuth2Provider import org.springframework.security.config.test.SpringTestContext import org.springframework.security.config.test.SpringTestContextExtension -import org.springframework.security.config.annotation.web.invoke import org.springframework.security.core.authority.SimpleGrantedAuthority -import org.springframework.security.oauth2.client.endpoint.DefaultAuthorizationCodeTokenResponseClient import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest +import org.springframework.security.oauth2.client.endpoint.RestClientAuthorizationCodeTokenResponseClient import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository import org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService @@ -110,7 +110,7 @@ class RedirectionEndpointDslTests { val REPOSITORY: AuthorizationRequestRepository = HttpSessionOAuth2AuthorizationRequestRepository() val CLIENT: OAuth2AccessTokenResponseClient = - DefaultAuthorizationCodeTokenResponseClient() + RestClientAuthorizationCodeTokenResponseClient() val USER_SERVICE: OAuth2UserService = DefaultOAuth2UserService() } diff --git a/config/src/test/kotlin/org/springframework/security/config/annotation/web/oauth2/login/TokenEndpointDslTests.kt b/config/src/test/kotlin/org/springframework/security/config/annotation/web/oauth2/login/TokenEndpointDslTests.kt index 2f726e831a..59e08520bd 100644 --- a/config/src/test/kotlin/org/springframework/security/config/annotation/web/oauth2/login/TokenEndpointDslTests.kt +++ b/config/src/test/kotlin/org/springframework/security/config/annotation/web/oauth2/login/TokenEndpointDslTests.kt @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,13 +26,13 @@ import org.springframework.context.annotation.Bean import org.springframework.context.annotation.Configuration import org.springframework.security.config.annotation.web.builders.HttpSecurity import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity +import org.springframework.security.config.annotation.web.invoke import org.springframework.security.config.oauth2.client.CommonOAuth2Provider import org.springframework.security.config.test.SpringTestContext import org.springframework.security.config.test.SpringTestContextExtension -import org.springframework.security.config.annotation.web.invoke -import org.springframework.security.oauth2.client.endpoint.DefaultAuthorizationCodeTokenResponseClient import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest +import org.springframework.security.oauth2.client.endpoint.RestClientAuthorizationCodeTokenResponseClient import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository @@ -101,7 +101,7 @@ class TokenEndpointDslTests { val REPOSITORY: AuthorizationRequestRepository = HttpSessionOAuth2AuthorizationRequestRepository() val CLIENT: OAuth2AccessTokenResponseClient = - DefaultAuthorizationCodeTokenResponseClient() + RestClientAuthorizationCodeTokenResponseClient() } @Bean diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilderTests.java index 4253b3c3d3..16b996d8a4 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilderTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,35 +18,40 @@ package org.springframework.security.oauth2.client; import java.time.Duration; import java.time.Instant; +import java.util.List; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.springframework.http.HttpStatus; -import org.springframework.http.RequestEntity; -import org.springframework.http.ResponseEntity; +import org.springframework.core.io.ClassPathResource; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.http.converter.FormHttpMessageConverter; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; -import org.springframework.security.oauth2.client.endpoint.DefaultClientCredentialsTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.DefaultPasswordTokenResponseClient; -import org.springframework.security.oauth2.client.endpoint.DefaultRefreshTokenTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.RestClientClientCredentialsTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.RestClientRefreshTokenTokenResponseClient; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; -import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; -import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; -import org.springframework.web.client.RestOperations; +import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter; +import org.springframework.test.web.client.ExpectedCount; +import org.springframework.test.web.client.MockRestServiceServer; +import org.springframework.web.client.RestClient; +import org.springframework.web.client.RestTemplate; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.springframework.test.web.client.ExpectedCount.once; +import static org.springframework.test.web.client.ExpectedCount.times; +import static org.springframework.test.web.client.match.MockRestRequestMatchers.requestTo; +import static org.springframework.test.web.client.response.MockRestResponseCreators.withSuccess; /** * Tests for {@link OAuth2AuthorizedClientProviderBuilder}. @@ -55,29 +60,30 @@ import static org.mockito.Mockito.verify; */ public class OAuth2AuthorizedClientProviderBuilderTests { - private RestOperations accessTokenClient; + private RestClientClientCredentialsTokenResponseClient clientCredentialsTokenResponseClient; - private DefaultClientCredentialsTokenResponseClient clientCredentialsTokenResponseClient; - - private DefaultRefreshTokenTokenResponseClient refreshTokenTokenResponseClient; + private RestClientRefreshTokenTokenResponseClient refreshTokenTokenResponseClient; private DefaultPasswordTokenResponseClient passwordTokenResponseClient; private Authentication principal; - @SuppressWarnings("unchecked") + private MockRestServiceServer server; + @BeforeEach public void setup() { - OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); - this.accessTokenClient = mock(RestOperations.class); - given(this.accessTokenClient.exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class))) - .willReturn(new ResponseEntity(accessTokenResponse, HttpStatus.OK)); - this.refreshTokenTokenResponseClient = new DefaultRefreshTokenTokenResponseClient(); - this.refreshTokenTokenResponseClient.setRestOperations(this.accessTokenClient); - this.clientCredentialsTokenResponseClient = new DefaultClientCredentialsTokenResponseClient(); - this.clientCredentialsTokenResponseClient.setRestOperations(this.accessTokenClient); + // TODO: Use of RestTemplate in these tests can be removed when + // DefaultPasswordTokenResponseClient is removed. + RestTemplate accessTokenClient = new RestTemplate( + List.of(new FormHttpMessageConverter(), new OAuth2AccessTokenResponseHttpMessageConverter())); + this.server = MockRestServiceServer.bindTo(accessTokenClient).build(); + RestClient restClient = RestClient.create(accessTokenClient); + this.refreshTokenTokenResponseClient = new RestClientRefreshTokenTokenResponseClient(); + this.refreshTokenTokenResponseClient.setRestClient(restClient); + this.clientCredentialsTokenResponseClient = new RestClientClientCredentialsTokenResponseClient(); + this.clientCredentialsTokenResponseClient.setRestClient(restClient); this.passwordTokenResponseClient = new DefaultPasswordTokenResponseClient(); - this.passwordTokenResponseClient.setRestOperations(this.accessTokenClient); + this.passwordTokenResponseClient.setRestOperations(accessTokenClient); this.principal = new TestingAuthenticationToken("principal", "password"); } @@ -104,6 +110,8 @@ public class OAuth2AuthorizedClientProviderBuilderTests { @Test public void buildWhenRefreshTokenProviderThenProviderReauthorizes() { + mockAccessTokenResponse(once()); + OAuth2AuthorizedClientProvider authorizedClientProvider = OAuth2AuthorizedClientProviderBuilder.builder() .refreshToken((configurer) -> configurer.accessTokenResponseClient(this.refreshTokenTokenResponseClient)) .build(); @@ -118,11 +126,13 @@ public class OAuth2AuthorizedClientProviderBuilderTests { // @formatter:on OAuth2AuthorizedClient reauthorizedClient = authorizedClientProvider.authorize(authorizationContext); assertThat(reauthorizedClient).isNotNull(); - verify(this.accessTokenClient).exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class)); + this.server.verify(); } @Test public void buildWhenClientCredentialsProviderThenProviderAuthorizes() { + mockAccessTokenResponse(once()); + OAuth2AuthorizedClientProvider authorizedClientProvider = OAuth2AuthorizedClientProviderBuilder.builder() .clientCredentials( (configurer) -> configurer.accessTokenResponseClient(this.clientCredentialsTokenResponseClient)) @@ -135,11 +145,13 @@ public class OAuth2AuthorizedClientProviderBuilderTests { // @formatter:on OAuth2AuthorizedClient authorizedClient = authorizedClientProvider.authorize(authorizationContext); assertThat(authorizedClient).isNotNull(); - verify(this.accessTokenClient).exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class)); + this.server.verify(); } @Test public void buildWhenPasswordProviderThenProviderAuthorizes() { + mockAccessTokenResponse(once()); + // @formatter:off OAuth2AuthorizedClientProvider authorizedClientProvider = OAuth2AuthorizedClientProviderBuilder.builder() .password((configurer) -> configurer.accessTokenResponseClient(this.passwordTokenResponseClient)) @@ -153,11 +165,13 @@ public class OAuth2AuthorizedClientProviderBuilderTests { // @formatter:on OAuth2AuthorizedClient authorizedClient = authorizedClientProvider.authorize(authorizationContext); assertThat(authorizedClient).isNotNull(); - verify(this.accessTokenClient).exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class)); + this.server.verify(); } @Test public void buildWhenAllProvidersThenProvidersAuthorize() { + mockAccessTokenResponse(times(3)); + OAuth2AuthorizedClientProvider authorizedClientProvider = OAuth2AuthorizedClientProviderBuilder.builder() .authorizationCode() .refreshToken((configurer) -> configurer.accessTokenResponseClient(this.refreshTokenTokenResponseClient)) @@ -184,8 +198,6 @@ public class OAuth2AuthorizedClientProviderBuilderTests { .build(); OAuth2AuthorizedClient reauthorizedClient = authorizedClientProvider.authorize(refreshTokenContext); assertThat(reauthorizedClient).isNotNull(); - verify(this.accessTokenClient, times(1)).exchange(any(RequestEntity.class), - eq(OAuth2AccessTokenResponse.class)); // client_credentials // @formatter:off OAuth2AuthorizationContext clientCredentialsContext = OAuth2AuthorizationContext @@ -195,8 +207,6 @@ public class OAuth2AuthorizedClientProviderBuilderTests { // @formatter:on authorizedClient = authorizedClientProvider.authorize(clientCredentialsContext); assertThat(authorizedClient).isNotNull(); - verify(this.accessTokenClient, times(2)).exchange(any(RequestEntity.class), - eq(OAuth2AccessTokenResponse.class)); // password // @formatter:off OAuth2AuthorizationContext passwordContext = OAuth2AuthorizationContext @@ -208,8 +218,7 @@ public class OAuth2AuthorizedClientProviderBuilderTests { // @formatter:on authorizedClient = authorizedClientProvider.authorize(passwordContext); assertThat(authorizedClient).isNotNull(); - verify(this.accessTokenClient, times(3)).exchange(any(RequestEntity.class), - eq(OAuth2AccessTokenResponse.class)); + this.server.verify(); } @Test @@ -234,4 +243,10 @@ public class OAuth2AuthorizedClientProviderBuilderTests { return new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "access-token-1234", issuedAt, expiresAt); } + private void mockAccessTokenResponse(ExpectedCount expectedCount) { + this.server.expect(expectedCount, requestTo("https://example.com/login/oauth/access_token")) + .andRespond(withSuccess().header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .body(new ClassPathResource("access-token-response.json"))); + } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java index c963a11897..f455ab236d 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -43,11 +43,11 @@ import reactor.util.context.Context; import org.springframework.core.codec.ByteBufferEncoder; import org.springframework.core.codec.CharSequenceEncoder; +import org.springframework.core.io.ClassPathResource; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.HttpStatus; -import org.springframework.http.RequestEntity; -import org.springframework.http.ResponseEntity; +import org.springframework.http.MediaType; import org.springframework.http.codec.EncoderHttpMessageWriter; import org.springframework.http.codec.FormHttpMessageWriter; import org.springframework.http.codec.HttpMessageWriter; @@ -55,6 +55,7 @@ import org.springframework.http.codec.ResourceHttpMessageWriter; import org.springframework.http.codec.ServerSentEventHttpMessageWriter; import org.springframework.http.codec.json.Jackson2JsonEncoder; import org.springframework.http.codec.multipart.MultipartHttpMessageWriter; +import org.springframework.http.converter.FormHttpMessageConverter; import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.mock.http.client.reactive.MockClientHttpRequest; import org.springframework.mock.web.MockHttpServletRequest; @@ -75,12 +76,12 @@ import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProviderBuilder; import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; -import org.springframework.security.oauth2.client.endpoint.DefaultRefreshTokenTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.JwtBearerGrantRequest; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; import org.springframework.security.oauth2.client.endpoint.OAuth2PasswordGrantRequest; import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; +import org.springframework.security.oauth2.client.endpoint.RestClientRefreshTokenTokenResponseClient; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; @@ -96,11 +97,13 @@ import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; +import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter; import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.TestJwts; +import org.springframework.test.web.client.MockRestServiceServer; import org.springframework.util.StringUtils; -import org.springframework.web.client.RestOperations; +import org.springframework.web.client.RestClient; import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.ServletRequestAttributes; import org.springframework.web.reactive.function.BodyInserter; @@ -121,6 +124,8 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; +import static org.springframework.test.web.client.match.MockRestRequestMatchers.requestTo; +import static org.springframework.test.web.client.response.MockRestResponseCreators.withSuccess; /** * @author Rob Winch @@ -357,11 +362,18 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { .expiresIn(3600) // .refreshToken(xxx) // No refreshToken in response .build(); - RestOperations refreshTokenClient = mock(RestOperations.class); - given(refreshTokenClient.exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class))) - .willReturn(new ResponseEntity(response, HttpStatus.OK)); - DefaultRefreshTokenTokenResponseClient refreshTokenTokenResponseClient = new DefaultRefreshTokenTokenResponseClient(); - refreshTokenTokenResponseClient.setRestOperations(refreshTokenClient); + RestClient.Builder builder = RestClient.builder().messageConverters((messageConverters) -> { + messageConverters.clear(); + messageConverters.add(new FormHttpMessageConverter()); + messageConverters.add(new OAuth2AccessTokenResponseHttpMessageConverter()); + }); + MockRestServiceServer server = MockRestServiceServer.bindTo(builder).build(); + RestClient refreshTokenClient = builder.build(); + server.expect(requestTo("https://example.com/login/oauth/access_token")) + .andRespond(withSuccess().header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .body(new ClassPathResource("access-token-response-1.json"))); + RestClientRefreshTokenTokenResponseClient refreshTokenTokenResponseClient = new RestClientRefreshTokenTokenResponseClient(); + refreshTokenTokenResponseClient.setRestClient(refreshTokenClient); RefreshTokenOAuth2AuthorizedClientProvider authorizedClientProvider = new RefreshTokenOAuth2AuthorizedClientProvider(); authorizedClientProvider.setAccessTokenResponseClient(refreshTokenTokenResponseClient); DefaultOAuth2AuthorizedClientManager authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( @@ -384,11 +396,12 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { .httpServletResponse(new MockHttpServletResponse())) .build(); this.function.filter(request, this.exchange).block(); - verify(refreshTokenClient).exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class)); + server.verify(); verify(this.authorizedClientRepository).saveAuthorizedClient(this.authorizedClientCaptor.capture(), eq(this.authentication), any(), any()); OAuth2AuthorizedClient newAuthorizedClient = this.authorizedClientCaptor.getValue(); - assertThat(newAuthorizedClient.getAccessToken()).isEqualTo(response.getAccessToken()); + assertThat(newAuthorizedClient.getAccessToken().getTokenValue()) + .isEqualTo(response.getAccessToken().getTokenValue()); assertThat(newAuthorizedClient.getRefreshToken().getTokenValue()).isEqualTo(refreshToken.getTokenValue()); List requests = this.exchange.getRequests(); assertThat(requests).hasSize(1); diff --git a/oauth2/oauth2-client/src/test/resources/access-token-response-1.json b/oauth2/oauth2-client/src/test/resources/access-token-response-1.json new file mode 100644 index 0000000000..c9fe6abac0 --- /dev/null +++ b/oauth2/oauth2-client/src/test/resources/access-token-response-1.json @@ -0,0 +1,5 @@ +{ + "access_token": "token-1", + "token_type": "Bearer", + "expires_in": 3600 +} \ No newline at end of file diff --git a/oauth2/oauth2-client/src/test/resources/access-token-response.json b/oauth2/oauth2-client/src/test/resources/access-token-response.json new file mode 100644 index 0000000000..78c6dbd77b --- /dev/null +++ b/oauth2/oauth2-client/src/test/resources/access-token-response.json @@ -0,0 +1,5 @@ +{ + "access_token": "token", + "token_type": "Bearer", + "expires_in": 3600 +} \ No newline at end of file