ServletOAuth2AuthorizedClientExchangeFilterFunction supports chaining
Fixes gh-6483
This commit is contained in:
		
							parent
							
								
									0c2a7e03f7
								
							
						
					
					
						commit
						0c27f64338
					
				| 
						 | 
				
			
			@ -16,6 +16,9 @@
 | 
			
		|||
 | 
			
		||||
package org.springframework.security.oauth2.client.web.reactive.function.client;
 | 
			
		||||
 | 
			
		||||
import org.reactivestreams.Subscription;
 | 
			
		||||
import org.springframework.beans.factory.DisposableBean;
 | 
			
		||||
import org.springframework.beans.factory.InitializingBean;
 | 
			
		||||
import org.springframework.http.HttpHeaders;
 | 
			
		||||
import org.springframework.http.HttpMethod;
 | 
			
		||||
import org.springframework.http.MediaType;
 | 
			
		||||
| 
						 | 
				
			
			@ -44,8 +47,12 @@ import org.springframework.web.reactive.function.client.ClientResponse;
 | 
			
		|||
import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
 | 
			
		||||
import org.springframework.web.reactive.function.client.ExchangeFunction;
 | 
			
		||||
import org.springframework.web.reactive.function.client.WebClient;
 | 
			
		||||
import reactor.core.CoreSubscriber;
 | 
			
		||||
import reactor.core.publisher.Hooks;
 | 
			
		||||
import reactor.core.publisher.Mono;
 | 
			
		||||
import reactor.core.publisher.Operators;
 | 
			
		||||
import reactor.core.scheduler.Schedulers;
 | 
			
		||||
import reactor.util.context.Context;
 | 
			
		||||
 | 
			
		||||
import javax.servlet.http.HttpServletRequest;
 | 
			
		||||
import javax.servlet.http.HttpServletResponse;
 | 
			
		||||
| 
						 | 
				
			
			@ -98,7 +105,9 @@ import static org.springframework.security.oauth2.core.web.reactive.function.OAu
 | 
			
		|||
 * @author Rob Winch
 | 
			
		||||
 * @since 5.1
 | 
			
		||||
 */
 | 
			
		||||
public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implements ExchangeFilterFunction {
 | 
			
		||||
public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
 | 
			
		||||
		implements ExchangeFilterFunction, InitializingBean, DisposableBean {
 | 
			
		||||
 | 
			
		||||
	/**
 | 
			
		||||
	 * The request attribute name used to locate the {@link OAuth2AuthorizedClient}.
 | 
			
		||||
	 */
 | 
			
		||||
| 
						 | 
				
			
			@ -108,6 +117,8 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
 | 
			
		|||
	private static final String HTTP_SERVLET_REQUEST_ATTR_NAME = HttpServletRequest.class.getName();
 | 
			
		||||
	private static final String HTTP_SERVLET_RESPONSE_ATTR_NAME = HttpServletResponse.class.getName();
 | 
			
		||||
 | 
			
		||||
	private static final String REQUEST_CONTEXT_OPERATOR_KEY = RequestContextSubscriber.class.getName();
 | 
			
		||||
 | 
			
		||||
	private Clock clock = Clock.systemUTC();
 | 
			
		||||
 | 
			
		||||
	private Duration accessTokenExpiresSkew = Duration.ofMinutes(1);
 | 
			
		||||
| 
						 | 
				
			
			@ -123,7 +134,8 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
 | 
			
		|||
 | 
			
		||||
	private String defaultClientRegistrationId;
 | 
			
		||||
 | 
			
		||||
	public ServletOAuth2AuthorizedClientExchangeFilterFunction() {}
 | 
			
		||||
	public ServletOAuth2AuthorizedClientExchangeFilterFunction() {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	public ServletOAuth2AuthorizedClientExchangeFilterFunction(
 | 
			
		||||
			ClientRegistrationRepository clientRegistrationRepository,
 | 
			
		||||
| 
						 | 
				
			
			@ -132,6 +144,16 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
 | 
			
		|||
		this.authorizedClientRepository = authorizedClientRepository;
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	@Override
 | 
			
		||||
	public void afterPropertiesSet() throws Exception {
 | 
			
		||||
		Hooks.onLastOperator(REQUEST_CONTEXT_OPERATOR_KEY, Operators.lift((s, sub) -> createRequestContextSubscriber(sub)));
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	@Override
 | 
			
		||||
	public void destroy() throws Exception {
 | 
			
		||||
		Hooks.resetOnLastOperator(REQUEST_CONTEXT_OPERATOR_KEY);
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	/**
 | 
			
		||||
	 * Sets the {@link OAuth2AccessTokenResponseClient} to be used for getting an {@link OAuth2AuthorizedClient} for
 | 
			
		||||
	 * client_credentials grant.
 | 
			
		||||
| 
						 | 
				
			
			@ -266,15 +288,36 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
 | 
			
		|||
 | 
			
		||||
	@Override
 | 
			
		||||
	public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
 | 
			
		||||
		Optional<OAuth2AuthorizedClient> attribute = request.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)
 | 
			
		||||
				.map(OAuth2AuthorizedClient.class::cast);
 | 
			
		||||
		return Mono.justOrEmpty(attribute)
 | 
			
		||||
				.flatMap(authorizedClient -> authorizedClient(request, next, authorizedClient))
 | 
			
		||||
		return Mono.just(request)
 | 
			
		||||
				.filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent())
 | 
			
		||||
				.switchIfEmpty(mergeRequestAttributesFromContext(request))
 | 
			
		||||
				.filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent())
 | 
			
		||||
				.flatMap(req -> authorizedClient(req, next, getOAuth2AuthorizedClient(req.attributes())))
 | 
			
		||||
				.map(authorizedClient -> bearer(request, authorizedClient))
 | 
			
		||||
				.flatMap(next::exchange)
 | 
			
		||||
				.switchIfEmpty(next.exchange(request));
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	private Mono<ClientRequest> mergeRequestAttributesFromContext(ClientRequest request) {
 | 
			
		||||
		return Mono.just(ClientRequest.from(request))
 | 
			
		||||
				.flatMap(builder -> Mono.subscriberContext()
 | 
			
		||||
						.map(ctx -> builder.attributes(attrs -> populateRequestAttributes(attrs, ctx))))
 | 
			
		||||
				.map(ClientRequest.Builder::build);
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	private void populateRequestAttributes(Map<String, Object> attrs, Context ctx) {
 | 
			
		||||
		if (ctx.hasKey(HTTP_SERVLET_REQUEST_ATTR_NAME)) {
 | 
			
		||||
			attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, ctx.get(HTTP_SERVLET_REQUEST_ATTR_NAME));
 | 
			
		||||
		}
 | 
			
		||||
		if (ctx.hasKey(HTTP_SERVLET_RESPONSE_ATTR_NAME)) {
 | 
			
		||||
			attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, ctx.get(HTTP_SERVLET_RESPONSE_ATTR_NAME));
 | 
			
		||||
		}
 | 
			
		||||
		if (ctx.hasKey(AUTHENTICATION_ATTR_NAME)) {
 | 
			
		||||
			attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, ctx.get(AUTHENTICATION_ATTR_NAME));
 | 
			
		||||
		}
 | 
			
		||||
		populateDefaultOAuth2AuthorizedClient(attrs);
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	private void populateDefaultRequestResponse(Map<String, Object> attrs) {
 | 
			
		||||
		if (attrs.containsKey(HTTP_SERVLET_REQUEST_ATTR_NAME) && attrs.containsKey(
 | 
			
		||||
				HTTP_SERVLET_RESPONSE_ATTR_NAME)) {
 | 
			
		||||
| 
						 | 
				
			
			@ -435,6 +478,19 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
 | 
			
		|||
					.build();
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	private <T> CoreSubscriber<T> createRequestContextSubscriber(CoreSubscriber<T> delegate) {
 | 
			
		||||
		HttpServletRequest request = null;
 | 
			
		||||
		HttpServletResponse response = null;
 | 
			
		||||
		ServletRequestAttributes requestAttributes =
 | 
			
		||||
				(ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
 | 
			
		||||
		if (requestAttributes != null) {
 | 
			
		||||
			request = requestAttributes.getRequest();
 | 
			
		||||
			response = requestAttributes.getResponse();
 | 
			
		||||
		}
 | 
			
		||||
		Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
 | 
			
		||||
		return new RequestContextSubscriber<>(delegate, request, response, authentication);
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	private static BodyInserters.FormInserter<String> refreshTokenBody(String refreshToken) {
 | 
			
		||||
		return BodyInserters
 | 
			
		||||
				.fromFormData("grant_type", AuthorizationGrantType.REFRESH_TOKEN.getValue())
 | 
			
		||||
| 
						 | 
				
			
			@ -508,4 +564,55 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
 | 
			
		|||
			return new UnsupportedOperationException("Not Supported");
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	private static class RequestContextSubscriber<T> implements CoreSubscriber<T> {
 | 
			
		||||
		private static final String CONTEXT_DEFAULTED_ATTR_NAME = RequestContextSubscriber.class.getName().concat(".CONTEXT_DEFAULTED_ATTR_NAME");
 | 
			
		||||
		private final CoreSubscriber<T> delegate;
 | 
			
		||||
		private final HttpServletRequest request;
 | 
			
		||||
		private final HttpServletResponse response;
 | 
			
		||||
		private final Authentication authentication;
 | 
			
		||||
 | 
			
		||||
		private RequestContextSubscriber(CoreSubscriber<T> delegate,
 | 
			
		||||
											HttpServletRequest request,
 | 
			
		||||
											HttpServletResponse response,
 | 
			
		||||
											Authentication authentication) {
 | 
			
		||||
			this.delegate = delegate;
 | 
			
		||||
			this.request = request;
 | 
			
		||||
			this.response = response;
 | 
			
		||||
			this.authentication = authentication;
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		@Override
 | 
			
		||||
		public Context currentContext() {
 | 
			
		||||
			Context context = this.delegate.currentContext();
 | 
			
		||||
			if (context.hasKey(CONTEXT_DEFAULTED_ATTR_NAME)) {
 | 
			
		||||
				return context;
 | 
			
		||||
			}
 | 
			
		||||
			return Context.of(
 | 
			
		||||
					CONTEXT_DEFAULTED_ATTR_NAME, Boolean.TRUE,
 | 
			
		||||
					HTTP_SERVLET_REQUEST_ATTR_NAME, this.request,
 | 
			
		||||
					HTTP_SERVLET_RESPONSE_ATTR_NAME, this.response,
 | 
			
		||||
					AUTHENTICATION_ATTR_NAME, this.authentication);
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		@Override
 | 
			
		||||
		public void onSubscribe(Subscription s) {
 | 
			
		||||
			this.delegate.onSubscribe(s);
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		@Override
 | 
			
		||||
		public void onNext(T t) {
 | 
			
		||||
			this.delegate.onNext(t);
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		@Override
 | 
			
		||||
		public void onError(Throwable t) {
 | 
			
		||||
			this.delegate.onError(t);
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		@Override
 | 
			
		||||
		public void onComplete() {
 | 
			
		||||
			this.delegate.onComplete();
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -74,14 +74,11 @@ import java.util.Map;
 | 
			
		|||
import java.util.Optional;
 | 
			
		||||
import java.util.function.Consumer;
 | 
			
		||||
 | 
			
		||||
import static org.assertj.core.api.Assertions.*;
 | 
			
		||||
import static org.assertj.core.api.Assertions.assertThat;
 | 
			
		||||
import static org.assertj.core.api.Assertions.assertThatCode;
 | 
			
		||||
import static org.mockito.ArgumentMatchers.any;
 | 
			
		||||
import static org.mockito.ArgumentMatchers.eq;
 | 
			
		||||
import static org.mockito.Mockito.mock;
 | 
			
		||||
import static org.mockito.Mockito.never;
 | 
			
		||||
import static org.mockito.Mockito.verify;
 | 
			
		||||
import static org.mockito.Mockito.verifyZeroInteractions;
 | 
			
		||||
import static org.mockito.Mockito.when;
 | 
			
		||||
import static org.mockito.Mockito.*;
 | 
			
		||||
import static org.springframework.http.HttpMethod.GET;
 | 
			
		||||
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.*;
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -647,6 +644,121 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
 | 
			
		|||
		assertThat(getBody(request0)).isEmpty();
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// gh-6483
 | 
			
		||||
	@Test
 | 
			
		||||
	public void filterWhenChainedThenDefaultsStillAvailable() throws Exception {
 | 
			
		||||
		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(
 | 
			
		||||
				this.clientRegistrationRepository, this.authorizedClientRepository);
 | 
			
		||||
		this.function.afterPropertiesSet();			// Hooks.onLastOperator() initialized
 | 
			
		||||
		this.function.setDefaultOAuth2AuthorizedClient(true);
 | 
			
		||||
 | 
			
		||||
		MockHttpServletRequest servletRequest = new MockHttpServletRequest();
 | 
			
		||||
		MockHttpServletResponse servletResponse = new MockHttpServletResponse();
 | 
			
		||||
		RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(servletRequest, servletResponse));
 | 
			
		||||
 | 
			
		||||
		OAuth2User user = mock(OAuth2User.class);
 | 
			
		||||
		List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
 | 
			
		||||
		OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(
 | 
			
		||||
				user, authorities, this.registration.getRegistrationId());
 | 
			
		||||
		SecurityContextHolder.getContext().setAuthentication(authentication);
 | 
			
		||||
 | 
			
		||||
		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
 | 
			
		||||
				this.registration, "principalName", this.accessToken);
 | 
			
		||||
		when(this.authorizedClientRepository.loadAuthorizedClient(eq(authentication.getAuthorizedClientRegistrationId()),
 | 
			
		||||
				eq(authentication), eq(servletRequest))).thenReturn(authorizedClient);
 | 
			
		||||
 | 
			
		||||
		// Default request attributes set
 | 
			
		||||
		final ClientRequest request1 = ClientRequest.create(GET, URI.create("https://example1.com"))
 | 
			
		||||
				.attributes(attrs -> attrs.putAll(getDefaultRequestAttributes())).build();
 | 
			
		||||
 | 
			
		||||
		// Default request attributes NOT set
 | 
			
		||||
		final ClientRequest request2 = ClientRequest.create(GET, URI.create("https://example2.com")).build();
 | 
			
		||||
 | 
			
		||||
		this.function.filter(request1, this.exchange)
 | 
			
		||||
				.flatMap(response -> this.function.filter(request2, this.exchange))
 | 
			
		||||
				.block();
 | 
			
		||||
 | 
			
		||||
		this.function.destroy();		// Hooks.onLastOperator() released
 | 
			
		||||
 | 
			
		||||
		List<ClientRequest> requests = this.exchange.getRequests();
 | 
			
		||||
		assertThat(requests).hasSize(2);
 | 
			
		||||
 | 
			
		||||
		ClientRequest request = requests.get(0);
 | 
			
		||||
		assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0");
 | 
			
		||||
		assertThat(request.url().toASCIIString()).isEqualTo("https://example1.com");
 | 
			
		||||
		assertThat(request.method()).isEqualTo(HttpMethod.GET);
 | 
			
		||||
		assertThat(getBody(request)).isEmpty();
 | 
			
		||||
 | 
			
		||||
		request = requests.get(1);
 | 
			
		||||
		assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0");
 | 
			
		||||
		assertThat(request.url().toASCIIString()).isEqualTo("https://example2.com");
 | 
			
		||||
		assertThat(request.method()).isEqualTo(HttpMethod.GET);
 | 
			
		||||
		assertThat(getBody(request)).isEmpty();
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	@Test
 | 
			
		||||
	public void filterWhenRequestAttributesNotSetAndHooksNotInitThenDefaultsNotAvailable() throws Exception {
 | 
			
		||||
		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(
 | 
			
		||||
				this.clientRegistrationRepository, this.authorizedClientRepository);
 | 
			
		||||
//		this.function.afterPropertiesSet();		// Hooks.onLastOperator() NOT initialized
 | 
			
		||||
		this.function.setDefaultOAuth2AuthorizedClient(true);
 | 
			
		||||
 | 
			
		||||
		MockHttpServletRequest servletRequest = new MockHttpServletRequest();
 | 
			
		||||
		MockHttpServletResponse servletResponse = new MockHttpServletResponse();
 | 
			
		||||
		RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(servletRequest, servletResponse));
 | 
			
		||||
 | 
			
		||||
		OAuth2User user = mock(OAuth2User.class);
 | 
			
		||||
		List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
 | 
			
		||||
		OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(
 | 
			
		||||
				user, authorities, this.registration.getRegistrationId());
 | 
			
		||||
		SecurityContextHolder.getContext().setAuthentication(authentication);
 | 
			
		||||
 | 
			
		||||
		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")).build();
 | 
			
		||||
 | 
			
		||||
		this.function.filter(request, this.exchange).block();
 | 
			
		||||
 | 
			
		||||
		List<ClientRequest> requests = this.exchange.getRequests();
 | 
			
		||||
		assertThat(requests).hasSize(1);
 | 
			
		||||
 | 
			
		||||
		request = requests.get(0);
 | 
			
		||||
		assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isNull();
 | 
			
		||||
		assertThat(request.url().toASCIIString()).isEqualTo("https://example.com");
 | 
			
		||||
		assertThat(request.method()).isEqualTo(HttpMethod.GET);
 | 
			
		||||
		assertThat(getBody(request)).isEmpty();
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	@Test
 | 
			
		||||
	public void filterWhenRequestAttributesNotSetAndHooksInitHooksResetThenDefaultsNotAvailable() throws Exception {
 | 
			
		||||
		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(
 | 
			
		||||
				this.clientRegistrationRepository, this.authorizedClientRepository);
 | 
			
		||||
		this.function.afterPropertiesSet();			// Hooks.onLastOperator() initialized
 | 
			
		||||
		this.function.destroy();					// Hooks.onLastOperator() released
 | 
			
		||||
		this.function.setDefaultOAuth2AuthorizedClient(true);
 | 
			
		||||
 | 
			
		||||
		MockHttpServletRequest servletRequest = new MockHttpServletRequest();
 | 
			
		||||
		MockHttpServletResponse servletResponse = new MockHttpServletResponse();
 | 
			
		||||
		RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(servletRequest, servletResponse));
 | 
			
		||||
 | 
			
		||||
		OAuth2User user = mock(OAuth2User.class);
 | 
			
		||||
		List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
 | 
			
		||||
		OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(
 | 
			
		||||
				user, authorities, this.registration.getRegistrationId());
 | 
			
		||||
		SecurityContextHolder.getContext().setAuthentication(authentication);
 | 
			
		||||
 | 
			
		||||
		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")).build();
 | 
			
		||||
 | 
			
		||||
		this.function.filter(request, this.exchange).block();
 | 
			
		||||
 | 
			
		||||
		List<ClientRequest> requests = this.exchange.getRequests();
 | 
			
		||||
		assertThat(requests).hasSize(1);
 | 
			
		||||
 | 
			
		||||
		request = requests.get(0);
 | 
			
		||||
		assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isNull();
 | 
			
		||||
		assertThat(request.url().toASCIIString()).isEqualTo("https://example.com");
 | 
			
		||||
		assertThat(request.method()).isEqualTo(HttpMethod.GET);
 | 
			
		||||
		assertThat(getBody(request)).isEmpty();
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	private static String getBody(ClientRequest request) {
 | 
			
		||||
		final List<HttpMessageWriter<?>> messageWriters = new ArrayList<>();
 | 
			
		||||
		messageWriters.add(new EncoderHttpMessageWriter<>(new ByteBufferEncoder()));
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue