diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java new file mode 100644 index 0000000000..7001ecd891 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java @@ -0,0 +1,408 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.web.reactive.function.client; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.context.ReactiveSecurityContextHolder; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.util.Assert; +import org.springframework.web.context.request.RequestContextHolder; +import org.springframework.web.context.request.ServletRequestAttributes; +import org.springframework.web.reactive.function.BodyInserters; +import org.springframework.web.reactive.function.client.ClientRequest; +import org.springframework.web.reactive.function.client.ClientResponse; +import org.springframework.web.reactive.function.client.ExchangeFilterFunction; +import org.springframework.web.reactive.function.client.ExchangeFunction; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.Base64; +import java.util.Collection; +import java.util.Map; +import java.util.Optional; +import java.util.function.Consumer; + +import static org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors.oauth2AccessTokenResponse; +import static org.springframework.security.web.http.SecurityHeaders.bearerToken; + +/** + * Provides an easy mechanism for using an {@link OAuth2AuthorizedClient} to make OAuth2 requests by including the + * token as a Bearer Token. It also provides mechanisms for looking up the {@link OAuth2AuthorizedClient}. This class is + * intended to be used in a servlet environment. + * + * Example usage: + * + *
+ * OAuth2AuthorizedClientExchangeFilterFunction oauth2 = new OAuth2AuthorizedClientExchangeFilterFunction(authorizedClientService);
+ * WebClient webClient = WebClient.builder()
+ *    .apply(oauth2.oauth2Configuration())
+ *    .build();
+ * Mono response = webClient
+ *    .get()
+ *    .uri(uri)
+ *    .attributes(oauth2AuthorizedClient(authorizedClient))
+ *    // ...
+ *    .retrieve()
+ *    .bodyToMono(String.class);
+ * 
+ * + * An attempt to automatically refresh the token will be made if all of the following + * are true: + * + * + * + * @author Rob Winch + * @since 5.1 + */ +public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implements ExchangeFilterFunction { + /** + * The request attribute name used to locate the {@link OAuth2AuthorizedClient}. + */ + private static final String OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME = OAuth2AuthorizedClient.class.getName(); + private static final String CLIENT_REGISTRATION_ID_ATTR_NAME = OAuth2AuthorizedClient.class.getName().concat(".CLIENT_REGISTRATION_ID"); + private static final String AUTHENTICATION_ATTR_NAME = Authentication.class.getName(); + private static final String HTTP_SERVLET_REQUEST_ATTR_NAME = HttpServletRequest.class.getName(); + private static final String HTTP_SERVLET_RESPONSE_ATTR_NAME = HttpServletResponse.class.getName(); + + private Clock clock = Clock.systemUTC(); + + private Duration accessTokenExpiresSkew = Duration.ofMinutes(1); + + private OAuth2AuthorizedClientRepository authorizedClientRepository; + + public ServletOAuth2AuthorizedClientExchangeFilterFunction() {} + + public ServletOAuth2AuthorizedClientExchangeFilterFunction(OAuth2AuthorizedClientRepository authorizedClientRepository) { + this.authorizedClientRepository = authorizedClientRepository; + } + + /** + * Configures the builder with {@link #defaultRequest()} and adds this as a {@link ExchangeFilterFunction} + * @return the {@link Consumer} to configure the builder + */ + public Consumer oauth2Configuration() { + return builder -> builder.defaultRequest(defaultRequest()).filter(this); + } + + /** + * Provides defaults for the {@link HttpServletRequest} and the {@link HttpServletResponse} using + * {@link RequestContextHolder}. It also provides defaults for the {@link Authentication} using + * {@link SecurityContextHolder}. It also can default the {@link OAuth2AuthorizedClient} using the + * {@link #clientRegistrationId(String)} or the {@link #authentication(Authentication)}. + * @return the {@link Consumer} to populate the attributes + */ + public Consumer> defaultRequest() { + return spec -> { + spec.attributes(attrs -> { + populateDefaultRequestResponse(attrs); + populateDefaultAuthentication(attrs); + populateDefaultOAuth2AuthorizedClient(attrs); + }); + }; + } + + /** + * Modifies the {@link ClientRequest#attributes()} to include the {@link OAuth2AuthorizedClient} to be used for + * providing the Bearer Token. + * + * @param authorizedClient the {@link OAuth2AuthorizedClient} to use. + * @return the {@link Consumer} to populate the attributes + */ + public static Consumer> oauth2AuthorizedClient(OAuth2AuthorizedClient authorizedClient) { + return attributes -> attributes.put(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME, authorizedClient); + } + + /** + * Modifies the {@link ClientRequest#attributes()} to include the {@link ClientRegistration#getRegistrationId()} to + * be used to look up the {@link OAuth2AuthorizedClient}. + * + * @param clientRegistrationId the {@link ClientRegistration#getRegistrationId()} to + * be used to look up the {@link OAuth2AuthorizedClient}. + * @return the {@link Consumer} to populate the attributes + */ + public static Consumer> clientRegistrationId(String clientRegistrationId) { + return attributes -> attributes.put(CLIENT_REGISTRATION_ID_ATTR_NAME, clientRegistrationId); + } + + /** + * Modifies the {@link ClientRequest#attributes()} to include the {@link Authentication} used to + * look up and save the {@link OAuth2AuthorizedClient}. The value is defaulted in + * {@link ServletOAuth2AuthorizedClientExchangeFilterFunction#defaultRequest()} + * + * @param authentication the {@link Authentication} to use. + * @return the {@link Consumer} to populate the attributes + */ + public static Consumer> authentication(Authentication authentication) { + return attributes -> attributes.put(AUTHENTICATION_ATTR_NAME, authentication); + } + + /** + * Modifies the {@link ClientRequest#attributes()} to include the {@link HttpServletRequest} used to + * look up and save the {@link OAuth2AuthorizedClient}. The value is defaulted in + * {@link ServletOAuth2AuthorizedClientExchangeFilterFunction#defaultRequest()} + * + * @param request the {@link HttpServletRequest} to use. + * @return the {@link Consumer} to populate the attributes + */ + public static Consumer> httpServletRequest(HttpServletRequest request) { + return attributes -> attributes.put(HTTP_SERVLET_REQUEST_ATTR_NAME, request); + } + + /** + * Modifies the {@link ClientRequest#attributes()} to include the {@link HttpServletResponse} used to + * save the {@link OAuth2AuthorizedClient}. The value is defaulted in + * {@link ServletOAuth2AuthorizedClientExchangeFilterFunction#defaultRequest()} + * + * @param response the {@link HttpServletResponse} to use. + * @return the {@link Consumer} to populate the attributes + */ + public static Consumer> httpServletResponse(HttpServletResponse response) { + return attributes -> attributes.put(HTTP_SERVLET_RESPONSE_ATTR_NAME, response); + } + + /** + * An access token will be considered expired by comparing its expiration to now + + * this skewed Duration. The default is 1 minute. + * @param accessTokenExpiresSkew the Duration to use. + */ + public void setAccessTokenExpiresSkew(Duration accessTokenExpiresSkew) { + Assert.notNull(accessTokenExpiresSkew, "accessTokenExpiresSkew cannot be null"); + this.accessTokenExpiresSkew = accessTokenExpiresSkew; + } + + @Override + public Mono filter(ClientRequest request, ExchangeFunction next) { + Optional attribute = request.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME) + .map(OAuth2AuthorizedClient.class::cast); + return Mono.justOrEmpty(attribute) + .flatMap(authorizedClient -> authorizedClient(request, next, authorizedClient)) + .map(authorizedClient -> bearer(request, authorizedClient)) + .flatMap(next::exchange) + .switchIfEmpty(next.exchange(request)); + } + + private void populateDefaultRequestResponse(Map attrs) { + if (attrs.containsKey(HTTP_SERVLET_REQUEST_ATTR_NAME) && attrs.containsKey( + HTTP_SERVLET_RESPONSE_ATTR_NAME)) { + return; + } + ServletRequestAttributes context = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes(); + HttpServletRequest request = null; + HttpServletResponse response = null; + if (context != null) { + request = context.getRequest(); + response = context.getResponse(); + } + attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, request); + attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, response); + } + + private void populateDefaultAuthentication(Map attrs) { + if (attrs.containsKey(AUTHENTICATION_ATTR_NAME)) { + return; + } + Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); + attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, authentication); + } + + private void populateDefaultOAuth2AuthorizedClient(Map attrs) { + if (this.authorizedClientRepository == null || attrs.containsKey(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)) { + return; + } + + Authentication authentication = getAuthentication(attrs); + String clientRegistrationId = getClientRegistrationId(attrs); + if (clientRegistrationId == null && authentication instanceof OAuth2AuthenticationToken) { + clientRegistrationId = ((OAuth2AuthenticationToken) authentication).getAuthorizedClientRegistrationId(); + } + if (clientRegistrationId != null) { + HttpServletRequest request = (HttpServletRequest) attrs.get( + HTTP_SERVLET_REQUEST_ATTR_NAME); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository + .loadAuthorizedClient(clientRegistrationId, authentication, + request); + oauth2AuthorizedClient(authorizedClient).accept(attrs); + } + } + + private Mono authorizedClient(ClientRequest request, ExchangeFunction next, OAuth2AuthorizedClient authorizedClient) { + if (shouldRefresh(authorizedClient)) { + return refreshAuthorizedClient(request, next, authorizedClient); + } + return Mono.just(authorizedClient); + } + + private Mono refreshAuthorizedClient(ClientRequest request, ExchangeFunction next, + OAuth2AuthorizedClient authorizedClient) { + ClientRegistration clientRegistration = authorizedClient + .getClientRegistration(); + String tokenUri = clientRegistration + .getProviderDetails().getTokenUri(); + ClientRequest refreshRequest = ClientRequest.create(HttpMethod.POST, URI.create(tokenUri)) + .header(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) + .headers(httpBasic(clientRegistration.getClientId(), clientRegistration.getClientSecret())) + .body(refreshTokenBody(authorizedClient.getRefreshToken().getTokenValue())) + .build(); + return next.exchange(refreshRequest) + .flatMap(response -> response.body(oauth2AccessTokenResponse())) + .map(accessTokenResponse -> new OAuth2AuthorizedClient(authorizedClient.getClientRegistration(), authorizedClient.getPrincipalName(), accessTokenResponse.getAccessToken(), accessTokenResponse.getRefreshToken())) + .map(result -> { + Authentication principal = (Authentication) request.attribute( + AUTHENTICATION_ATTR_NAME).orElse(new PrincipalNameAuthentication(authorizedClient.getPrincipalName())); + HttpServletRequest httpRequest = (HttpServletRequest) request.attributes().get( + HTTP_SERVLET_REQUEST_ATTR_NAME); + HttpServletResponse httpResponse = (HttpServletResponse) request.attributes().get( + HTTP_SERVLET_RESPONSE_ATTR_NAME); + this.authorizedClientRepository.saveAuthorizedClient(result, principal, httpRequest, httpResponse); + return result; + }) + .publishOn(Schedulers.elastic()); + } + + private static Consumer httpBasic(String username, String password) { + return httpHeaders -> { + String credentialsString = username + ":" + password; + byte[] credentialBytes = credentialsString.getBytes(StandardCharsets.ISO_8859_1); + byte[] encodedBytes = Base64.getEncoder().encode(credentialBytes); + String encodedCredentials = new String(encodedBytes, StandardCharsets.ISO_8859_1); + httpHeaders.set(HttpHeaders.AUTHORIZATION, "Basic " + encodedCredentials); + }; + } + + private boolean shouldRefresh(OAuth2AuthorizedClient authorizedClient) { + if (this.authorizedClientRepository == null) { + return false; + } + OAuth2RefreshToken refreshToken = authorizedClient.getRefreshToken(); + if (refreshToken == null) { + return false; + } + Instant now = this.clock.instant(); + Instant expiresAt = authorizedClient.getAccessToken().getExpiresAt(); + if (now.isAfter(expiresAt.minus(this.accessTokenExpiresSkew))) { + return true; + } + return false; + } + + private ClientRequest bearer(ClientRequest request, OAuth2AuthorizedClient authorizedClient) { + return ClientRequest.from(request) + .headers(bearerToken(authorizedClient.getAccessToken().getTokenValue())) + .build(); + } + + private static BodyInserters.FormInserter refreshTokenBody(String refreshToken) { + return BodyInserters + .fromFormData("grant_type", AuthorizationGrantType.REFRESH_TOKEN.getValue()) + .with("refresh_token", refreshToken); + } + + static OAuth2AuthorizedClient getOAuth2AuthorizedClient(Map attrs) { + return (OAuth2AuthorizedClient) attrs.get(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME); + } + + static String getClientRegistrationId(Map attrs) { + return (String) attrs.get(CLIENT_REGISTRATION_ID_ATTR_NAME); + } + + static Authentication getAuthentication(Map attrs) { + return (Authentication) attrs.get(AUTHENTICATION_ATTR_NAME); + } + + static HttpServletRequest getRequest(Map attrs) { + return (HttpServletRequest) attrs.get(HTTP_SERVLET_REQUEST_ATTR_NAME); + } + + static HttpServletResponse getResponse(Map attrs) { + return (HttpServletResponse) attrs.get(HTTP_SERVLET_RESPONSE_ATTR_NAME); + } + + private static class PrincipalNameAuthentication implements Authentication { + private final String username; + + private PrincipalNameAuthentication(String username) { + this.username = username; + } + + @Override + public Collection getAuthorities() { + throw unsupported(); + } + + @Override + public Object getCredentials() { + throw unsupported(); + } + + @Override + public Object getDetails() { + throw unsupported(); + } + + @Override + public Object getPrincipal() { + throw unsupported(); + } + + @Override + public boolean isAuthenticated() { + throw unsupported(); + } + + @Override + public void setAuthenticated(boolean isAuthenticated) + throws IllegalArgumentException { + throw unsupported(); + } + + @Override + public String getName() { + return this.username; + } + + private UnsupportedOperationException unsupported() { + return new UnsupportedOperationException("Not Supported"); + } + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java new file mode 100644 index 0000000000..a8e2b93dfc --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java @@ -0,0 +1,478 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.web.reactive.function.client; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +import org.springframework.core.codec.ByteBufferEncoder; +import org.springframework.core.codec.CharSequenceEncoder; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.codec.EncoderHttpMessageWriter; +import org.springframework.http.codec.FormHttpMessageWriter; +import org.springframework.http.codec.HttpMessageWriter; +import org.springframework.http.codec.ResourceHttpMessageWriter; +import org.springframework.http.codec.ServerSentEventHttpMessageWriter; +import org.springframework.http.codec.json.Jackson2JsonEncoder; +import org.springframework.http.codec.multipart.MultipartHttpMessageWriter; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.mock.http.client.reactive.MockClientHttpRequest; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.user.OAuth2User; +import org.springframework.web.context.request.RequestContextHolder; +import org.springframework.web.context.request.ServletRequestAttributes; +import org.springframework.web.reactive.function.BodyInserter; +import org.springframework.web.reactive.function.client.ClientRequest; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Mono; + +import java.net.URI; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Consumer; + +import static org.assertj.core.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyZeroInteractions; +import static org.mockito.Mockito.when; +import static org.springframework.http.HttpMethod.GET; +import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.*; + +/** + * @author Rob Winch + * @since 5.1 + */ +@RunWith(MockitoJUnitRunner.class) +public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { + @Mock + private OAuth2AuthorizedClientRepository authorizedClientRepository; + @Mock + private WebClient.RequestHeadersSpec spec; + @Captor + private ArgumentCaptor>> attrs; + + /** + * Used for get the attributes from defaultRequest. + */ + private Map result = new HashMap<>(); + + private ServletOAuth2AuthorizedClientExchangeFilterFunction function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(); + + private MockExchangeFunction exchange = new MockExchangeFunction(); + + private Authentication authentication; + + private ClientRegistration github = ClientRegistration.withRegistrationId("github") + .redirectUriTemplate("{baseUrl}/{action}/oauth2/code/{registrationId}") + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .scope("read:user") + .authorizationUri("https://github.com/login/oauth/authorize") + .tokenUri("https://github.com/login/oauth/access_token") + .userInfoUri("https://api.github.com/user") + .userNameAttributeName("id") + .clientName("GitHub") + .clientId("clientId") + .clientSecret("clientSecret") + .build(); + + private OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + "token-0", + Instant.now(), + Instant.now().plus(Duration.ofDays(1))); + + @Before + public void setup() { + this.authentication = new TestingAuthenticationToken("test", "this"); + } + + @After + public void cleanup() { + SecurityContextHolder.clearContext(); + RequestContextHolder.resetRequestAttributes(); + } + + @Test + public void defaultRequestRequestResponseWhenNullRequestContextThenRequestAndResponseNull() { + Map attrs = getDefaultRequestAttributes(); + assertThat(getRequest(attrs)).isNull(); + assertThat(getResponse(attrs)).isNull(); + } + + @Test + public void defaultRequestRequestResponseWhenRequestContextThenRequestAndResponseSet() { + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(request, response)); + Map attrs = getDefaultRequestAttributes(); + assertThat(getRequest(attrs)).isEqualTo(request); + assertThat(getResponse(attrs)).isEqualTo(response); + } + + @Test + public void defaultRequestAuthenticationWhenSecurityContextEmptyThenAuthenticationNull() { + Map attrs = getDefaultRequestAttributes(); + assertThat(getAuthentication(attrs)).isNull(); + } + + @Test + public void defaultRequestAuthenticationWhenAuthenticationSetThenAuthenticationSet() { + this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository); + SecurityContextHolder.getContext().setAuthentication(this.authentication); + Map attrs = getDefaultRequestAttributes(); + assertThat(getAuthentication(attrs)).isEqualTo(this.authentication); + verifyZeroInteractions(this.authorizedClientRepository); + } + + @Test + public void defaultRequestOAuth2AuthorizedClientWhenOAuth2AuthorizationClientAndClientIdThenNotOverride() { + this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.github, + "principalName", this.accessToken); + oauth2AuthorizedClient(authorizedClient).accept(this.result); + Map attrs = getDefaultRequestAttributes(); + assertThat(getOAuth2AuthorizedClient(attrs)).isEqualTo(authorizedClient); + verifyZeroInteractions(this.authorizedClientRepository); + } + + @Test + public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationNullAndClientRegistrationIdNullThenOAuth2AuthorizedClientNull() { + this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository); + Map attrs = getDefaultRequestAttributes(); + assertThat(getOAuth2AuthorizedClient(attrs)).isNull(); + verifyZeroInteractions(this.authorizedClientRepository); + } + + @Test + public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationWrongTypeAndClientRegistrationIdNullThenOAuth2AuthorizedClientNull() { + this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository); + Map attrs = getDefaultRequestAttributes(); + assertThat(getOAuth2AuthorizedClient(attrs)).isNull(); + verifyZeroInteractions(this.authorizedClientRepository); + } + + @Test + public void defaultRequestOAuth2AuthorizedClientWhenRepositoryNullThenOAuth2AuthorizedClient() { + OAuth2User user = mock(OAuth2User.class); + List authorities = AuthorityUtils.createAuthorityList("ROLE_USER"); + OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(user, authorities, "id"); + authentication(token).accept(this.result); + + Map attrs = getDefaultRequestAttributes(); + + assertThat(getOAuth2AuthorizedClient(attrs)).isNull(); + } + + @Test + public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationAndClientRegistrationIdNullThenOAuth2AuthorizedClient() { + this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository); + OAuth2User user = mock(OAuth2User.class); + List authorities = AuthorityUtils.createAuthorityList("ROLE_USER"); + OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(user, authorities, "id"); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.github, + "principalName", this.accessToken); + when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(authorizedClient); + authentication(token).accept(this.result); + + Map attrs = getDefaultRequestAttributes(); + + assertThat(getOAuth2AuthorizedClient(attrs)).isEqualTo(authorizedClient); + verify(this.authorizedClientRepository).loadAuthorizedClient(eq(token.getAuthorizedClientRegistrationId()), any(), any()); + } + + @Test + public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationAndClientRegistrationIdThenIdIsExplicit() { + this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository); + OAuth2User user = mock(OAuth2User.class); + List authorities = AuthorityUtils.createAuthorityList("ROLE_USER"); + OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(user, authorities, "id"); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.github, + "principalName", this.accessToken); + when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(authorizedClient); + authentication(token).accept(this.result); + clientRegistrationId("explicit").accept(this.result); + + Map attrs = getDefaultRequestAttributes(); + + assertThat(getOAuth2AuthorizedClient(attrs)).isEqualTo(authorizedClient); + verify(this.authorizedClientRepository).loadAuthorizedClient(eq("explicit"), any(), any()); + } + + @Test + public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationNullAndClientRegistrationIdThenOAuth2AuthorizedClient() { + this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository); + OAuth2User user = mock(OAuth2User.class); + List authorities = AuthorityUtils.createAuthorityList("ROLE_USER"); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.github, + "principalName", this.accessToken); + when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(authorizedClient); + clientRegistrationId("id").accept(this.result); + + Map attrs = getDefaultRequestAttributes(); + + assertThat(getOAuth2AuthorizedClient(attrs)).isEqualTo(authorizedClient); + verify(this.authorizedClientRepository).loadAuthorizedClient(eq("id"), any(), any()); + } + + private Map getDefaultRequestAttributes() { + this.function.defaultRequest().accept(this.spec); + verify(this.spec).attributes(this.attrs.capture()); + + this.attrs.getValue().accept(this.result); + + return this.result; + } + + @Test + public void filterWhenAuthorizedClientNullThenAuthorizationHeaderNull() { + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) + .build(); + + this.function.filter(request, this.exchange).block(); + + assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)).isNull(); + } + + @Test + public void filterWhenAuthorizedClientThenAuthorizationHeader() { + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.github, + "principalName", this.accessToken); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) + .attributes(oauth2AuthorizedClient(authorizedClient)) + .build(); + + this.function.filter(request, this.exchange).block(); + + assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer " + this.accessToken.getTokenValue()); + } + + @Test + public void filterWhenExistingAuthorizationThenSingleAuthorizationHeader() { + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.github, + "principalName", this.accessToken); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) + .header(HttpHeaders.AUTHORIZATION, "Existing") + .attributes(oauth2AuthorizedClient(authorizedClient)) + .build(); + + this.function.filter(request, this.exchange).block(); + + HttpHeaders headers = this.exchange.getRequest().headers(); + assertThat(headers.get(HttpHeaders.AUTHORIZATION)).containsOnly("Bearer " + this.accessToken.getTokenValue()); + } + + @Test + public void filterWhenRefreshRequiredThenRefresh() { + OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1") + .tokenType(OAuth2AccessToken.TokenType.BEARER) + .expiresIn(3600) + .refreshToken("refresh-1") + .build(); + when(this.exchange.getResponse().body(any())).thenReturn(Mono.just(response)); + Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); + Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1)); + Instant refreshTokenExpiresAt = Instant.now().plus(Duration.ofHours(1)); + + this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(), + this.accessToken.getTokenValue(), + issuedAt, + accessTokenExpiresAt); + this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository); + + OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.github, + "principalName", this.accessToken, refreshToken); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) + .attributes(oauth2AuthorizedClient(authorizedClient)) + .attributes(authentication(this.authentication)) + .build(); + + this.function.filter(request, this.exchange).block(); + + verify(this.authorizedClientRepository).saveAuthorizedClient(any(), eq(this.authentication), any(), any()); + + List requests = this.exchange.getRequests(); + assertThat(requests).hasSize(2); + + ClientRequest request0 = requests.get(0); + assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Basic Y2xpZW50SWQ6Y2xpZW50U2VjcmV0"); + assertThat(request0.url().toASCIIString()).isEqualTo("https://github.com/login/oauth/access_token"); + assertThat(request0.method()).isEqualTo(HttpMethod.POST); + assertThat(getBody(request0)).isEqualTo("grant_type=refresh_token&refresh_token=refresh-token"); + + ClientRequest request1 = requests.get(1); + assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-1"); + assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com"); + assertThat(request1.method()).isEqualTo(HttpMethod.GET); + assertThat(getBody(request1)).isEmpty(); + } + + @Test + public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved() { + OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1") + .tokenType(OAuth2AccessToken.TokenType.BEARER) + .expiresIn(3600) + .refreshToken("refresh-1") + .build(); + when(this.exchange.getResponse().body(any())).thenReturn(Mono.just(response)); + Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); + Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1)); + Instant refreshTokenExpiresAt = Instant.now().plus(Duration.ofHours(1)); + + this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(), + this.accessToken.getTokenValue(), + issuedAt, + accessTokenExpiresAt); + this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository); + + OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.github, + "principalName", this.accessToken, refreshToken); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) + .attributes(oauth2AuthorizedClient(authorizedClient)) + .build(); + + this.function.filter(request, this.exchange) + .block(); + + verify(this.authorizedClientRepository).saveAuthorizedClient(any(), any(), any(), any()); + + List requests = this.exchange.getRequests(); + assertThat(requests).hasSize(2); + + ClientRequest request0 = requests.get(0); + assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Basic Y2xpZW50SWQ6Y2xpZW50U2VjcmV0"); + assertThat(request0.url().toASCIIString()).isEqualTo("https://github.com/login/oauth/access_token"); + assertThat(request0.method()).isEqualTo(HttpMethod.POST); + assertThat(getBody(request0)).isEqualTo("grant_type=refresh_token&refresh_token=refresh-token"); + + ClientRequest request1 = requests.get(1); + assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-1"); + assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com"); + assertThat(request1.method()).isEqualTo(HttpMethod.GET); + assertThat(getBody(request1)).isEmpty(); + } + + @Test + public void filterWhenRefreshTokenNullThenShouldRefreshFalse() { + this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository); + + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.github, + "principalName", this.accessToken); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) + .attributes(oauth2AuthorizedClient(authorizedClient)) + .build(); + + this.function.filter(request, this.exchange).block(); + + List requests = this.exchange.getRequests(); + assertThat(requests).hasSize(1); + + ClientRequest request0 = requests.get(0); + assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0"); + assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com"); + assertThat(request0.method()).isEqualTo(HttpMethod.GET); + assertThat(getBody(request0)).isEmpty(); + } + + @Test + public void filterWhenNotExpiredThenShouldRefreshFalse() { + this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository); + + OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt()); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.github, + "principalName", this.accessToken, refreshToken); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) + .attributes(oauth2AuthorizedClient(authorizedClient)) + .build(); + + this.function.filter(request, this.exchange).block(); + + List requests = this.exchange.getRequests(); + assertThat(requests).hasSize(1); + + ClientRequest request0 = requests.get(0); + assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0"); + assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com"); + assertThat(request0.method()).isEqualTo(HttpMethod.GET); + assertThat(getBody(request0)).isEmpty(); + } + + private static String getBody(ClientRequest request) { + final List> messageWriters = new ArrayList<>(); + messageWriters.add(new EncoderHttpMessageWriter<>(new ByteBufferEncoder())); + messageWriters.add(new EncoderHttpMessageWriter<>(CharSequenceEncoder.textPlainOnly())); + messageWriters.add(new ResourceHttpMessageWriter()); + Jackson2JsonEncoder jsonEncoder = new Jackson2JsonEncoder(); + messageWriters.add(new EncoderHttpMessageWriter<>(jsonEncoder)); + messageWriters.add(new ServerSentEventHttpMessageWriter(jsonEncoder)); + messageWriters.add(new FormHttpMessageWriter()); + messageWriters.add(new EncoderHttpMessageWriter<>(CharSequenceEncoder.allMimeTypes())); + messageWriters.add(new MultipartHttpMessageWriter(messageWriters)); + + BodyInserter.Context context = new BodyInserter.Context() { + @Override + public List> messageWriters() { + return messageWriters; + } + + @Override + public Optional serverRequest() { + return Optional.empty(); + } + + @Override + public Map hints() { + return new HashMap<>(); + } + }; + + MockClientHttpRequest body = new MockClientHttpRequest(HttpMethod.GET, "/"); + request.body().insert(body, context).block(); + return body.getBodyAsString().block(); + } + +}