From 4ca9e15595a3d1c9480e64af2199bafb0525f519 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Fri, 21 Jun 2019 10:50:38 -0400 Subject: [PATCH] Fix blocking in ServletOAuth2AuthorizedClientExchangeFilterFunction Fixes gh-6589 --- gradle/dependency-management.gradle | 1 + .../spring-security-oauth2-client.gradle | 1 + ...uthorizedClientExchangeFilterFunction.java | 73 +++-- ...zedClientExchangeFilterFunctionITests.java | 268 ++++++++++++++++++ ...izedClientExchangeFilterFunctionTests.java | 161 ----------- 5 files changed, 317 insertions(+), 187 deletions(-) create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionITests.java diff --git a/gradle/dependency-management.gradle b/gradle/dependency-management.gradle index 00bfaaf30e..3f2b203c10 100644 --- a/gradle/dependency-management.gradle +++ b/gradle/dependency-management.gradle @@ -70,6 +70,7 @@ dependencyManagement { dependency 'commons-lang:commons-lang:2.6' dependency 'commons-logging:commons-logging:1.2' dependency 'dom4j:dom4j:1.6.1' + dependency 'io.projectreactor.tools:blockhound:1.0.0.M4' dependency 'javax.activation:activation:1.1.1' dependency 'javax.annotation:jsr250-api:1.0' dependency 'javax.inject:javax.inject:1' diff --git a/oauth2/oauth2-client/spring-security-oauth2-client.gradle b/oauth2/oauth2-client/spring-security-oauth2-client.gradle index 76a3006591..624125ee28 100644 --- a/oauth2/oauth2-client/spring-security-oauth2-client.gradle +++ b/oauth2/oauth2-client/spring-security-oauth2-client.gradle @@ -17,6 +17,7 @@ dependencies { testCompile 'com.fasterxml.jackson.core:jackson-databind' testCompile 'io.projectreactor.netty:reactor-netty' testCompile 'io.projectreactor:reactor-test' + testCompile 'io.projectreactor.tools:blockhound' provided 'javax.servlet:javax.servlet-api' } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java index 9919bf859f..1cd234c043 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java @@ -50,6 +50,7 @@ import reactor.core.CoreSubscriber; import reactor.core.publisher.Hooks; import reactor.core.publisher.Mono; import reactor.core.publisher.Operators; +import reactor.core.scheduler.Schedulers; import reactor.util.context.Context; import javax.servlet.http.HttpServletRequest; @@ -258,7 +259,6 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction spec.attributes(attrs -> { populateDefaultRequestResponse(attrs); populateDefaultAuthentication(attrs); - populateDefaultOAuth2AuthorizedClient(attrs); }); }; } @@ -349,20 +349,33 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction @Override public Mono filter(ClientRequest request, ExchangeFunction next) { - return Mono.just(request) - .filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent()) - .switchIfEmpty(mergeRequestAttributesFromContext(request)) + return mergeRequestAttributesIfNecessary(request) .filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent()) .flatMap(req -> authorizedClient(getOAuth2AuthorizedClient(req.attributes()), req)) + .switchIfEmpty(Mono.defer(() -> + mergeRequestAttributesIfNecessary(request) + .filter(req -> resolveClientRegistrationId(req) != null) + .flatMap(req -> authorizeClient(resolveClientRegistrationId(req), req)) + )) .map(authorizedClient -> bearer(request, authorizedClient)) .flatMap(next::exchange) - .switchIfEmpty(next.exchange(request)); + .switchIfEmpty(Mono.defer(() -> next.exchange(request))); + } + + private Mono mergeRequestAttributesIfNecessary(ClientRequest request) { + if (!request.attribute(HTTP_SERVLET_REQUEST_ATTR_NAME).isPresent() || + !request.attribute(HTTP_SERVLET_RESPONSE_ATTR_NAME).isPresent() || + !request.attribute(AUTHENTICATION_ATTR_NAME).isPresent()) { + return mergeRequestAttributesFromContext(request); + } else { + return Mono.just(request); + } } private Mono mergeRequestAttributesFromContext(ClientRequest request) { - return Mono.just(ClientRequest.from(request)) - .flatMap(builder -> Mono.subscriberContext() - .map(ctx -> builder.attributes(attrs -> populateRequestAttributes(attrs, ctx)))) + ClientRequest.Builder builder = ClientRequest.from(request); + return Mono.subscriberContext() + .map(ctx -> builder.attributes(attrs -> populateRequestAttributes(attrs, ctx))) .map(ClientRequest.Builder::build); } @@ -376,7 +389,6 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction if (ctx.hasKey(AUTHENTICATION_ATTR_NAME)) { attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, ctx.get(AUTHENTICATION_ATTR_NAME)); } - populateDefaultOAuth2AuthorizedClient(attrs); } private void populateDefaultRequestResponse(Map attrs) { @@ -403,32 +415,38 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, authentication); } - private void populateDefaultOAuth2AuthorizedClient(Map attrs) { - if (this.authorizedClientManager == null || - attrs.containsKey(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)) { - return; - } - - Authentication authentication = getAuthentication(attrs); + private String resolveClientRegistrationId(ClientRequest request) { + Map attrs = request.attributes(); String clientRegistrationId = getClientRegistrationId(attrs); if (clientRegistrationId == null) { clientRegistrationId = this.defaultClientRegistrationId; } + Authentication authentication = getAuthentication(attrs); if (clientRegistrationId == null && this.defaultOAuth2AuthorizedClient && authentication instanceof OAuth2AuthenticationToken) { clientRegistrationId = ((OAuth2AuthenticationToken) authentication).getAuthorizedClientRegistrationId(); } - if (clientRegistrationId != null) { - HttpServletRequest request = getRequest(attrs); - if (authentication == null) { - authentication = ANONYMOUS_AUTHENTICATION; - } - OAuth2AuthorizeRequest authorizeRequest = new OAuth2AuthorizeRequest( - clientRegistrationId, authentication, request, getResponse(attrs)); - OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); - oauth2AuthorizedClient(authorizedClient).accept(attrs); + return clientRegistrationId; + } + + private Mono authorizeClient(String clientRegistrationId, ClientRequest request) { + if (this.authorizedClientManager == null) { + return Mono.empty(); } + Map attrs = request.attributes(); + Authentication authentication = getAuthentication(attrs); + if (authentication == null) { + authentication = ANONYMOUS_AUTHENTICATION; + } + HttpServletRequest servletRequest = getRequest(attrs); + HttpServletResponse servletResponse = getResponse(attrs); + OAuth2AuthorizeRequest authorizeRequest = new OAuth2AuthorizeRequest( + clientRegistrationId, authentication, servletRequest, servletResponse); + + // NOTE: 'authorizedClientManager.authorize()' needs to be executed on a dedicated thread via subscribeOn(Schedulers.elastic()) + // since it performs a blocking I/O operation using RestTemplate internally + return Mono.fromSupplier(() -> this.authorizedClientManager.authorize(authorizeRequest)).subscribeOn(Schedulers.elastic()); } private Mono authorizedClient(OAuth2AuthorizedClient authorizedClient, ClientRequest request) { @@ -444,7 +462,10 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction HttpServletResponse servletResponse = getResponse(attrs); OAuth2AuthorizeRequest reauthorizeRequest = new OAuth2AuthorizeRequest( authorizedClient, authentication, servletRequest, servletResponse); - return Mono.fromSupplier(() -> this.authorizedClientManager.authorize(reauthorizeRequest)); + + // NOTE: 'authorizedClientManager.authorize()' needs to be executed on a dedicated thread via subscribeOn(Schedulers.elastic()) + // since it performs a blocking I/O operation using RestTemplate internally + return Mono.fromSupplier(() -> this.authorizedClientManager.authorize(reauthorizeRequest)).subscribeOn(Schedulers.elastic()); } private ClientRequest bearer(ClientRequest request, OAuth2AuthorizedClient authorizedClient) { diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionITests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionITests.java new file mode 100644 index 0000000000..fb8cf3e52d --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionITests.java @@ -0,0 +1,268 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.web.reactive.function.client; + +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; +import org.springframework.web.context.request.RequestContextHolder; +import org.springframework.web.context.request.ServletRequestAttributes; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.blockhound.BlockHound; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; +import java.util.HashSet; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; +import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId; + +/** + * @author Joe Grandja + */ +public class ServletOAuth2AuthorizedClientExchangeFilterFunctionITests { + private ClientRegistrationRepository clientRegistrationRepository; + private OAuth2AuthorizedClientRepository authorizedClientRepository; + private ServletOAuth2AuthorizedClientExchangeFilterFunction authorizedClientFilter; + private MockWebServer server; + private String serverUrl; + private WebClient webClient; + private Authentication authentication; + private MockHttpServletRequest request; + private MockHttpServletResponse response; + + @BeforeClass + public static void setUpBlockingChecks() { + // IMPORTANT: + // Before enabling BlockHound, we need to white-list `java.lang.Class.getPackage()`. + // When the JVM loads `java.lang.Package.getSystemPackage()`, it attempts to + // `java.lang.Package.loadManifest()` which is blocking I/O and triggers BlockHound to error. + // NOTE: This is an issue with JDK 8. It's been tested on JDK 10 and works fine w/o this white-list. + BlockHound.builder() + .allowBlockingCallsInside(Class.class.getName(), "getPackage") + .install(); + } + + @Before + public void setUp() throws Exception { + this.clientRegistrationRepository = mock(ClientRegistrationRepository.class); + final OAuth2AuthorizedClientRepository delegate = new AuthenticatedPrincipalOAuth2AuthorizedClientRepository( + new InMemoryOAuth2AuthorizedClientService(this.clientRegistrationRepository)); + this.authorizedClientRepository = spy(new OAuth2AuthorizedClientRepository() { + @Override + public T loadAuthorizedClient(String clientRegistrationId, Authentication principal, HttpServletRequest request) { + return delegate.loadAuthorizedClient(clientRegistrationId, principal, request); + } + + @Override + public void saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal, HttpServletRequest request, HttpServletResponse response) { + delegate.saveAuthorizedClient(authorizedClient, principal, request, response); + } + + @Override + public void removeAuthorizedClient(String clientRegistrationId, Authentication principal, HttpServletRequest request, HttpServletResponse response) { + delegate.removeAuthorizedClient(clientRegistrationId, principal, request, response); + } + }); + this.authorizedClientFilter = new ServletOAuth2AuthorizedClientExchangeFilterFunction( + this.clientRegistrationRepository, this.authorizedClientRepository); + this.authorizedClientFilter.afterPropertiesSet(); + this.server = new MockWebServer(); + this.server.start(); + this.serverUrl = this.server.url("/").toString(); + this.webClient = WebClient.builder() + .apply(this.authorizedClientFilter.oauth2Configuration()) + .build(); + this.authentication = new TestingAuthenticationToken("principal", "password"); + SecurityContextHolder.getContext().setAuthentication(this.authentication); + this.request = new MockHttpServletRequest(); + this.response = new MockHttpServletResponse(); + RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(this.request, this.response)); + } + + @After + public void cleanup() throws Exception { + this.authorizedClientFilter.destroy(); + this.server.shutdown(); + SecurityContextHolder.clearContext(); + RequestContextHolder.resetRequestAttributes(); + } + + @Test + public void requestWhenNotAuthorizedThenAuthorizeAndSendRequest() { + String accessTokenResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read write\"\n" + + "}\n"; + String clientResponse = "{\n" + + " \"attribute1\": \"value1\",\n" + + " \"attribute2\": \"value2\"\n" + + "}\n"; + + this.server.enqueue(jsonResponse(accessTokenResponse)); + this.server.enqueue(jsonResponse(clientResponse)); + + ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials().tokenUri(this.serverUrl).build(); + when(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration.getRegistrationId()))).thenReturn(clientRegistration); + + this.webClient + .get() + .uri(this.serverUrl) + .attributes(clientRegistrationId(clientRegistration.getRegistrationId())) + .retrieve() + .bodyToMono(String.class) + .block(); + + assertThat(this.server.getRequestCount()).isEqualTo(2); + + ArgumentCaptor authorizedClientCaptor = ArgumentCaptor.forClass(OAuth2AuthorizedClient.class); + verify(this.authorizedClientRepository).saveAuthorizedClient( + authorizedClientCaptor.capture(), eq(this.authentication), eq(this.request), eq(this.response)); + assertThat(authorizedClientCaptor.getValue().getClientRegistration()).isSameAs(clientRegistration); + } + + @Test + public void requestWhenAuthorizedButExpiredThenRefreshAndSendRequest() { + String accessTokenResponse = "{\n" + + " \"access_token\": \"refreshed-access-token\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + String clientResponse = "{\n" + + " \"attribute1\": \"value1\",\n" + + " \"attribute2\": \"value2\"\n" + + "}\n"; + + this.server.enqueue(jsonResponse(accessTokenResponse)); + this.server.enqueue(jsonResponse(clientResponse)); + + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().tokenUri(this.serverUrl).build(); + when(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration.getRegistrationId()))).thenReturn(clientRegistration); + + Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); + Instant expiresAt = issuedAt.plus(Duration.ofHours(1)); + OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + "expired-access-token", issuedAt, expiresAt, new HashSet<>(Arrays.asList("read", "write"))); + OAuth2RefreshToken refreshToken = TestOAuth2RefreshTokens.refreshToken(); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + clientRegistration, this.authentication.getName(), accessToken, refreshToken); + doReturn(authorizedClient).when(this.authorizedClientRepository).loadAuthorizedClient( + eq(clientRegistration.getRegistrationId()), eq(this.authentication), eq(this.request)); + + this.webClient + .get() + .uri(this.serverUrl) + .attributes(clientRegistrationId(clientRegistration.getRegistrationId())) + .retrieve() + .bodyToMono(String.class) + .block(); + + assertThat(this.server.getRequestCount()).isEqualTo(2); + + ArgumentCaptor authorizedClientCaptor = ArgumentCaptor.forClass(OAuth2AuthorizedClient.class); + verify(this.authorizedClientRepository).saveAuthorizedClient( + authorizedClientCaptor.capture(), eq(this.authentication), eq(this.request), eq(this.response)); + OAuth2AuthorizedClient refreshedAuthorizedClient = authorizedClientCaptor.getValue(); + assertThat(refreshedAuthorizedClient.getClientRegistration()).isSameAs(clientRegistration); + assertThat(refreshedAuthorizedClient.getAccessToken().getTokenValue()).isEqualTo("refreshed-access-token"); + } + + @Test + public void requestMultipleWhenNoneAuthorizedThenAuthorizeAndSendRequest() { + String accessTokenResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read write\"\n" + + "}\n"; + String clientResponse = "{\n" + + " \"attribute1\": \"value1\",\n" + + " \"attribute2\": \"value2\"\n" + + "}\n"; + + // Client 1 + this.server.enqueue(jsonResponse(accessTokenResponse)); + this.server.enqueue(jsonResponse(clientResponse)); + + ClientRegistration clientRegistration1 = TestClientRegistrations.clientCredentials() + .registrationId("client-1").tokenUri(this.serverUrl).build(); + when(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration1.getRegistrationId()))).thenReturn(clientRegistration1); + + // Client 2 + this.server.enqueue(jsonResponse(accessTokenResponse)); + this.server.enqueue(jsonResponse(clientResponse)); + + ClientRegistration clientRegistration2 = TestClientRegistrations.clientCredentials() + .registrationId("client-2").tokenUri(this.serverUrl).build(); + when(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration2.getRegistrationId()))).thenReturn(clientRegistration2); + + this.webClient + .get() + .uri(this.serverUrl) + .attributes(clientRegistrationId(clientRegistration1.getRegistrationId())) + .retrieve() + .bodyToMono(String.class) + .flatMap(response -> this.webClient + .get() + .uri(this.serverUrl) + .attributes(clientRegistrationId(clientRegistration2.getRegistrationId())) + .retrieve() + .bodyToMono(String.class)) + .block(); + + assertThat(this.server.getRequestCount()).isEqualTo(4); + + ArgumentCaptor authorizedClientCaptor = ArgumentCaptor.forClass(OAuth2AuthorizedClient.class); + verify(this.authorizedClientRepository, times(2)).saveAuthorizedClient( + authorizedClientCaptor.capture(), eq(this.authentication), eq(this.request), eq(this.response)); + assertThat(authorizedClientCaptor.getAllValues().get(0).getClientRegistration()).isSameAs(clientRegistration1); + assertThat(authorizedClientCaptor.getAllValues().get(1).getClientRegistration()).isSameAs(clientRegistration2); + } + + private MockResponse jsonResponse(String json) { + return new MockResponse() + .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setBody(json); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java index 86df1cec94..ebc15cd8fc 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java @@ -84,7 +84,6 @@ import java.util.Optional; import java.util.function.Consumer; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; @@ -212,166 +211,6 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { verifyZeroInteractions(this.authorizedClientRepository); } - @Test - public void defaultRequestOAuth2AuthorizedClientWhenOAuth2AuthorizationClientAndClientIdThenNotOverride() { - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, - "principalName", this.accessToken); - oauth2AuthorizedClient(authorizedClient).accept(this.result); - Map attrs = getDefaultRequestAttributes(); - assertThat(getOAuth2AuthorizedClient(attrs)).isEqualTo(authorizedClient); - verifyZeroInteractions(this.authorizedClientRepository); - } - - @Test - public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationNullAndClientRegistrationIdNullThenOAuth2AuthorizedClientNull() { - Map attrs = getDefaultRequestAttributes(); - assertThat(getOAuth2AuthorizedClient(attrs)).isNull(); - verifyZeroInteractions(this.authorizedClientRepository); - } - - @Test - public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationWrongTypeAndClientRegistrationIdNullThenOAuth2AuthorizedClientNull() { - Map attrs = getDefaultRequestAttributes(); - assertThat(getOAuth2AuthorizedClient(attrs)).isNull(); - verifyZeroInteractions(this.authorizedClientRepository); - } - - @Test - public void defaultRequestOAuth2AuthorizedClientWhenRepositoryNullThenOAuth2AuthorizedClient() { - OAuth2User user = mock(OAuth2User.class); - List authorities = AuthorityUtils.createAuthorityList("ROLE_USER"); - OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(user, authorities, "id"); - authentication(token).accept(this.result); - - Map attrs = getDefaultRequestAttributes(); - - assertThat(getOAuth2AuthorizedClient(attrs)).isNull(); - } - - @Test - public void defaultRequestOAuth2AuthorizedClientWhenDefaultTrueAndAuthenticationAndClientRegistrationIdNullThenOAuth2AuthorizedClient() { - this.function.setDefaultOAuth2AuthorizedClient(true); - OAuth2User user = mock(OAuth2User.class); - List authorities = AuthorityUtils.createAuthorityList("ROLE_USER"); - OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(user, authorities, "id"); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, - "principalName", this.accessToken); - when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(authorizedClient); - when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(this.registration); - authentication(token).accept(this.result); - httpServletRequest(new MockHttpServletRequest()).accept(this.result); - httpServletResponse(new MockHttpServletResponse()).accept(this.result); - - Map attrs = getDefaultRequestAttributes(); - - assertThat(getOAuth2AuthorizedClient(attrs)).isEqualTo(authorizedClient); - verify(this.authorizedClientRepository).loadAuthorizedClient(eq(token.getAuthorizedClientRegistrationId()), any(), any()); - } - - @Test - public void defaultRequestOAuth2AuthorizedClientWhenDefaultFalseAndAuthenticationAndClientRegistrationIdNullThenOAuth2AuthorizedClient() { - OAuth2User user = mock(OAuth2User.class); - List authorities = AuthorityUtils.createAuthorityList("ROLE_USER"); - OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(user, authorities, "id"); - authentication(token).accept(this.result); - - Map attrs = getDefaultRequestAttributes(); - - assertThat(getOAuth2AuthorizedClient(attrs)).isNull(); - } - - @Test - public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationAndClientRegistrationIdThenIdIsExplicit() { - OAuth2User user = mock(OAuth2User.class); - List authorities = AuthorityUtils.createAuthorityList("ROLE_USER"); - OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(user, authorities, "id"); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, - "principalName", this.accessToken); - when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(authorizedClient); - when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(this.registration); - authentication(token).accept(this.result); - clientRegistrationId("explicit").accept(this.result); - httpServletRequest(new MockHttpServletRequest()).accept(this.result); - httpServletResponse(new MockHttpServletResponse()).accept(this.result); - - Map attrs = getDefaultRequestAttributes(); - - assertThat(getOAuth2AuthorizedClient(attrs)).isEqualTo(authorizedClient); - verify(this.authorizedClientRepository).loadAuthorizedClient(eq("explicit"), any(), any()); - } - - @Test - public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationNullAndClientRegistrationIdThenOAuth2AuthorizedClient() { - when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(this.registration); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, - "principalName", this.accessToken); - when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(authorizedClient); - clientRegistrationId("id").accept(this.result); - httpServletRequest(new MockHttpServletRequest()).accept(this.result); - httpServletResponse(new MockHttpServletResponse()).accept(this.result); - - Map attrs = getDefaultRequestAttributes(); - - assertThat(getOAuth2AuthorizedClient(attrs)).isEqualTo(authorizedClient); - verify(this.authorizedClientRepository).loadAuthorizedClient(eq("id"), any(), any()); - } - - @Test - public void defaultRequestWhenClientCredentialsThenAuthorizedClient() { - this.registration = TestClientRegistrations.clientCredentials().build(); - when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(this.registration); - OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses - .accessTokenResponse().build(); - when(this.clientCredentialsTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); - - clientRegistrationId(this.registration.getRegistrationId()).accept(this.result); - - MockHttpServletRequest request = new MockHttpServletRequest(); - MockHttpServletResponse response = new MockHttpServletResponse(); - RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(request, response)); - SecurityContextHolder.getContext().setAuthentication(this.authentication); - - Map attrs = getDefaultRequestAttributes(); - OAuth2AuthorizedClient authorizedClient = getOAuth2AuthorizedClient(attrs); - - assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); - assertThat(authorizedClient.getClientRegistration()).isEqualTo(this.registration); - assertThat(authorizedClient.getPrincipalName()).isEqualTo("test"); - assertThat(authorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken()); - } - - @Test - public void defaultRequestWhenDefaultClientRegistrationIdThenAuthorizedClient() { - this.registration = TestClientRegistrations.clientCredentials().build(); - this.function.setDefaultClientRegistrationId(this.registration.getRegistrationId()); - when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(this.registration); - OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses - .accessTokenResponse().build(); - when(this.clientCredentialsTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); - - MockHttpServletRequest request = new MockHttpServletRequest(); - MockHttpServletResponse response = new MockHttpServletResponse(); - RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(request, response)); - SecurityContextHolder.getContext().setAuthentication(this.authentication); - - Map attrs = getDefaultRequestAttributes(); - OAuth2AuthorizedClient authorizedClient = getOAuth2AuthorizedClient(attrs); - - assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); - assertThat(authorizedClient.getClientRegistration()).isEqualTo(this.registration); - assertThat(authorizedClient.getPrincipalName()).isEqualTo("test"); - assertThat(authorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken()); - } - - @Test - public void defaultRequestWhenClientIdNotFoundThenIllegalArgumentException() { - this.registration = TestClientRegistrations.clientCredentials().build(); - clientRegistrationId(this.registration.getRegistrationId()).accept(this.result); - - assertThatCode(() -> getDefaultRequestAttributes()) - .isInstanceOf(IllegalArgumentException.class); - } - private Map getDefaultRequestAttributes() { this.function.defaultRequest().accept(this.spec); verify(this.spec).attributes(this.attrs.capture());