From 28537fa3b6c9df0fed577e0d1ef41d8c838a60b4 Mon Sep 17 00:00:00 2001 From: Rob Winch Date: Tue, 4 Sep 2018 15:14:27 -0500 Subject: [PATCH] WebClientReactiveClientCredentialsTokenResponseClient Fixes: gh-5607 --- ...eClientCredentialsTokenResponseClient.java | 107 +++++++++++++++ ...ntCredentialsTokenResponseClientTests.java | 126 ++++++++++++++++++ 2 files changed, 233 insertions(+) create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClient.java create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClientTests.java diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClient.java new file mode 100644 index 0000000000..f71926e76a --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClient.java @@ -0,0 +1,107 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.endpoint; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; +import org.springframework.web.reactive.function.BodyInserters; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Mono; + +import java.util.Set; +import java.util.function.Consumer; + +import static org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors.oauth2AccessTokenResponse; + +/** + * An implementation of an {@link ReactiveOAuth2AccessTokenResponseClient} that "exchanges" + * an authorization code credential for an access token credential + * at the Authorization Server's Token Endpoint. + * + * @author Rob Winch + * @since 5.1 + * @see OAuth2AccessTokenResponseClient + * @see OAuth2AuthorizationCodeGrantRequest + * @see OAuth2AccessTokenResponse + * @see Nimbus OAuth 2.0 SDK + * @see Section 4.1.3 Access Token Request (Authorization Code Grant) + * @see Section 4.1.4 Access Token Response (Authorization Code Grant) + */ +public class WebClientReactiveClientCredentialsTokenResponseClient implements ReactiveOAuth2AccessTokenResponseClient { + private WebClient webClient = WebClient.builder() + .build(); + + @Override + public Mono getTokenResponse(OAuth2ClientCredentialsGrantRequest authorizationGrantRequest) + throws OAuth2AuthenticationException { + + return Mono.defer(() -> { + ClientRegistration clientRegistration = authorizationGrantRequest.getClientRegistration(); + + String tokenUri = clientRegistration.getProviderDetails().getTokenUri(); + BodyInserters.FormInserter body = body(authorizationGrantRequest); + + return this.webClient.post() + .uri(tokenUri) + .accept(MediaType.APPLICATION_JSON) + .headers(headers(clientRegistration)) + .body(body) + .exchange() + .flatMap(response -> response.body(oauth2AccessTokenResponse())) + .map(response -> { + if (response.getAccessToken().getScopes().isEmpty()) { + response = OAuth2AccessTokenResponse.withResponse(response) + .scopes(authorizationGrantRequest.getClientRegistration().getScopes()) + .build(); + } + return response; + }); + }); + } + + private Consumer headers(ClientRegistration clientRegistration) { + return headers -> { + headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED); + headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret()); + if (ClientAuthenticationMethod.BASIC.equals(clientRegistration.getClientAuthenticationMethod())) { + headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret()); + } + }; + } + + private static BodyInserters.FormInserter body(OAuth2ClientCredentialsGrantRequest authorizationGrantRequest) { + ClientRegistration clientRegistration = authorizationGrantRequest.getClientRegistration(); + BodyInserters.FormInserter body = BodyInserters + .fromFormData(OAuth2ParameterNames.GRANT_TYPE, authorizationGrantRequest.getGrantType().getValue()); + Set scopes = clientRegistration.getScopes(); + if (!CollectionUtils.isEmpty(scopes)) { + String scope = StringUtils.collectionToDelimitedString(scopes, " "); + body.with(OAuth2ParameterNames.SCOPE, scope); + } + if (ClientAuthenticationMethod.POST.equals(clientRegistration.getClientAuthenticationMethod())) { + body.with(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId()); + body.with(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret()); + } + return body; + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClientTests.java new file mode 100644 index 0000000000..510df757c5 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClientTests.java @@ -0,0 +1,126 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; + +import static org.assertj.core.api.Assertions.*; + +/** + * @author Rob Winch + */ +public class WebClientReactiveClientCredentialsTokenResponseClientTests { + + private MockWebServer server; + + private WebClientReactiveClientCredentialsTokenResponseClient client = new WebClientReactiveClientCredentialsTokenResponseClient(); + + private ClientRegistration.Builder clientRegistration; + + @Before + public void setup() throws Exception { + this.server = new MockWebServer(); + this.server.start(); + + this.clientRegistration = TestClientRegistrations + .clientCredentials() + .tokenUri(this.server.url("/oauth2/token").uri().toASCIIString()); + } + + @After + public void cleanup() throws Exception { + this.server.shutdown(); + } + + @Test + public void getTokenResponseWhenHeaderThenSuccess() throws Exception { + enqueueJson("{\n" + + " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n" + + " \"token_type\":\"bearer\",\n" + + " \"expires_in\":3600,\n" + + " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\",\n" + + " \"scope\":\"create\"\n" + + "}"); + OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(this.clientRegistration + .build()); + + OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block(); + RecordedRequest actualRequest = this.server.takeRequest(); + String body = actualRequest.getUtf8Body(); + + assertThat(response.getAccessToken()).isNotNull(); + assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ="); + assertThat(body).isEqualTo("grant_type=client_credentials&scope=read%3Auser"); + } + + @Test + public void getTokenResponseWhenPostThenSuccess() throws Exception { + ClientRegistration registration = this.clientRegistration + .clientAuthenticationMethod(ClientAuthenticationMethod.POST) + .build(); + enqueueJson("{\n" + + " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n" + + " \"token_type\":\"bearer\",\n" + + " \"expires_in\":3600,\n" + + " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\",\n" + + " \"scope\":\"create\"\n" + + "}"); + + OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(registration); + + OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block(); + String body = this.server.takeRequest().getUtf8Body(); + + assertThat(response.getAccessToken()).isNotNull(); + assertThat(body).isEqualTo("grant_type=client_credentials&scope=read%3Auser&client_id=client-id&client_secret=client-secret"); + } + + @Test + public void getTokenResponseWhenNoScopeThenClientRegistrationScopesDefaulted() { + ClientRegistration registration = this.clientRegistration.build(); + enqueueJson("{\n" + + " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n" + + " \"token_type\":\"bearer\",\n" + + " \"expires_in\":3600,\n" + + " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n" + + "}"); + OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(registration); + + OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block(); + + assertThat(response.getAccessToken().getScopes()).isEqualTo(registration.getScopes()); + } + + + private void enqueueJson(String body) { + MockResponse response = new MockResponse() + .setBody(body) + .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE); + this.server.enqueue(response); + } +}