Add Servlet and ServerBearerExchangeFilterFunction
Fixes gh-5334 Fixes gh-7284
This commit is contained in:
		
							parent
							
								
									dbd1819ea4
								
							
						
					
					
						commit
						f350988285
					
				|  | @ -0,0 +1,248 @@ | ||||||
|  | /* | ||||||
|  |  * Copyright 2002-2019 the original author or authors. | ||||||
|  |  * | ||||||
|  |  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  |  * you may not use this file except in compliance with the License. | ||||||
|  |  * You may obtain a copy of the License at | ||||||
|  |  * | ||||||
|  |  *      https://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  |  * | ||||||
|  |  * Unless required by applicable law or agreed to in writing, software | ||||||
|  |  * distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  |  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  |  * See the License for the specific language governing permissions and | ||||||
|  |  * limitations under the License. | ||||||
|  |  */ | ||||||
|  | 
 | ||||||
|  | package org.springframework.security.oauth2.server.resource.web; | ||||||
|  | 
 | ||||||
|  | import java.util.Map; | ||||||
|  | import java.util.function.Consumer; | ||||||
|  | 
 | ||||||
|  | import org.reactivestreams.Subscription; | ||||||
|  | import reactor.core.CoreSubscriber; | ||||||
|  | import reactor.core.publisher.Hooks; | ||||||
|  | import reactor.core.publisher.Mono; | ||||||
|  | import reactor.core.publisher.Operators; | ||||||
|  | import reactor.util.context.Context; | ||||||
|  | 
 | ||||||
|  | import org.springframework.beans.factory.DisposableBean; | ||||||
|  | import org.springframework.beans.factory.InitializingBean; | ||||||
|  | import org.springframework.lang.Nullable; | ||||||
|  | import org.springframework.security.core.Authentication; | ||||||
|  | import org.springframework.security.core.context.SecurityContextHolder; | ||||||
|  | import org.springframework.security.oauth2.core.AbstractOAuth2Token; | ||||||
|  | import org.springframework.web.reactive.function.client.ClientRequest; | ||||||
|  | import org.springframework.web.reactive.function.client.ClientResponse; | ||||||
|  | import org.springframework.web.reactive.function.client.ExchangeFilterFunction; | ||||||
|  | import org.springframework.web.reactive.function.client.ExchangeFunction; | ||||||
|  | import org.springframework.web.reactive.function.client.WebClient; | ||||||
|  | 
 | ||||||
|  | /** | ||||||
|  |  * An {@link ExchangeFilterFunction} that adds the | ||||||
|  |  * <a href="https://tools.ietf.org/html/rfc6750#section-1.2" target="_blank">Bearer Token</a> | ||||||
|  |  * from an existing {@link AbstractOAuth2Token} tied to the current {@link Authentication}. | ||||||
|  |  * | ||||||
|  |  * Suitable for Servlet applications, applying it to a typical {@link org.springframework.web.reactive.function.client.WebClient} | ||||||
|  |  * configuration: | ||||||
|  |  * | ||||||
|  |  * <pre> | ||||||
|  |  *  @Bean | ||||||
|  |  *  WebClient webClient() { | ||||||
|  |  *      ServletBearerExchangeFilterFunction bearer = new ServletBearerExchangeFilterFunction(); | ||||||
|  |  *      return WebClient.builder() | ||||||
|  |  *              .apply(bearer.oauth2Configuration()) | ||||||
|  |  *              .build(); | ||||||
|  |  *  } | ||||||
|  |  * </pre> | ||||||
|  |  * | ||||||
|  |  * @author Josh Cummings | ||||||
|  |  * @since 5.2 | ||||||
|  |  */ | ||||||
|  | public class ServletBearerExchangeFilterFunction | ||||||
|  | 		implements ExchangeFilterFunction, InitializingBean, DisposableBean { | ||||||
|  | 
 | ||||||
|  | 	private static final String AUTHENTICATION_ATTR_NAME = Authentication.class.getName(); | ||||||
|  | 
 | ||||||
|  | 	private static final String REQUEST_CONTEXT_OPERATOR_KEY = RequestContextSubscriber.class.getName(); | ||||||
|  | 
 | ||||||
|  | 	/** | ||||||
|  | 	 * {@inheritDoc} | ||||||
|  | 	 */ | ||||||
|  | 	@Override | ||||||
|  | 	public void afterPropertiesSet() throws Exception { | ||||||
|  | 		Hooks.onLastOperator(REQUEST_CONTEXT_OPERATOR_KEY, | ||||||
|  | 				Operators.liftPublisher((s, sub) -> createRequestContextSubscriber(sub))); | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	/** | ||||||
|  | 	 * {@inheritDoc} | ||||||
|  | 	 */ | ||||||
|  | 	@Override | ||||||
|  | 	public void destroy() throws Exception { | ||||||
|  | 		Hooks.resetOnLastOperator(REQUEST_CONTEXT_OPERATOR_KEY); | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	/** | ||||||
|  | 	 * Configures the builder with {@link #defaultRequest()} and adds this as a {@link ExchangeFilterFunction} | ||||||
|  | 	 * @return the {@link Consumer} to configure the builder | ||||||
|  | 	 */ | ||||||
|  | 	public Consumer<WebClient.Builder> oauth2Configuration() { | ||||||
|  | 		return builder -> builder.defaultRequest(defaultRequest()).filter(this); | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	/** | ||||||
|  | 	 * Provides defaults for the {@link Authentication} using | ||||||
|  | 	 * {@link SecurityContextHolder}. It also can default the {@link AbstractOAuth2Token} using the | ||||||
|  | 	 * {@link #authentication(Authentication)}. | ||||||
|  | 	 * @return the {@link Consumer} to populate the attributes | ||||||
|  | 	 */ | ||||||
|  | 	public Consumer<WebClient.RequestHeadersSpec<?>> defaultRequest() { | ||||||
|  | 		return spec -> spec.attributes(attrs -> { | ||||||
|  | 			populateDefaultAuthentication(attrs); | ||||||
|  | 		}); | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	/** | ||||||
|  | 	 * Modifies the {@link ClientRequest#attributes()} to include the {@link Authentication} used to | ||||||
|  | 	 * look up and save the {@link AbstractOAuth2Token}. The value is defaulted in | ||||||
|  | 	 * {@link ServletBearerExchangeFilterFunction#defaultRequest()} | ||||||
|  | 	 * | ||||||
|  | 	 * @param authentication the {@link Authentication} to use. | ||||||
|  | 	 * @return the {@link Consumer} to populate the attributes | ||||||
|  | 	 */ | ||||||
|  | 	public static Consumer<Map<String, Object>> authentication(Authentication authentication) { | ||||||
|  | 		return attributes -> attributes.put(AUTHENTICATION_ATTR_NAME, authentication); | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	/** | ||||||
|  | 	 * {@inheritDoc} | ||||||
|  | 	 */ | ||||||
|  | 	@Override | ||||||
|  | 	public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) { | ||||||
|  | 		return mergeRequestAttributesIfNecessary(request) | ||||||
|  | 				.filter(req -> req.attribute(AUTHENTICATION_ATTR_NAME).isPresent()) | ||||||
|  | 				.map(req -> getOAuth2Token(req.attributes())) | ||||||
|  | 				.map(token -> bearer(request, token)) | ||||||
|  | 				.flatMap(next::exchange) | ||||||
|  | 				.switchIfEmpty(Mono.defer(() -> next.exchange(request))); | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	private Mono<ClientRequest> mergeRequestAttributesIfNecessary(ClientRequest request) { | ||||||
|  | 		if (request.attribute(AUTHENTICATION_ATTR_NAME).isPresent()) { | ||||||
|  | 			return Mono.just(request); | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		return mergeRequestAttributesFromContext(request); | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	private Mono<ClientRequest> mergeRequestAttributesFromContext(ClientRequest request) { | ||||||
|  | 		ClientRequest.Builder builder = ClientRequest.from(request); | ||||||
|  | 		return Mono.subscriberContext() | ||||||
|  | 				.map(ctx -> builder.attributes(attrs -> populateRequestAttributes(attrs, ctx))) | ||||||
|  | 				.map(ClientRequest.Builder::build); | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	private void populateRequestAttributes(Map<String, Object> attrs, Context ctx) { | ||||||
|  | 		RequestContextDataHolder holder = RequestContextSubscriber.getRequestContext(ctx); | ||||||
|  | 		if (holder == null) { | ||||||
|  | 			return; | ||||||
|  | 		} | ||||||
|  | 		if (holder.getAuthentication() != null) { | ||||||
|  | 			attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, holder.getAuthentication()); | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	private AbstractOAuth2Token getOAuth2Token(Map<String, Object> attrs) { | ||||||
|  | 		Authentication authentication = (Authentication) attrs.get(AUTHENTICATION_ATTR_NAME); | ||||||
|  | 		if (authentication.getCredentials() instanceof AbstractOAuth2Token) { | ||||||
|  | 			return (AbstractOAuth2Token) authentication.getCredentials(); | ||||||
|  | 		} | ||||||
|  | 		return null; | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	private ClientRequest bearer(ClientRequest request, AbstractOAuth2Token token) { | ||||||
|  | 		return ClientRequest.from(request) | ||||||
|  | 				.headers(headers -> headers.setBearerAuth(token.getTokenValue())) | ||||||
|  | 				.build(); | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	private <T> CoreSubscriber<T> createRequestContextSubscriber(CoreSubscriber<T> delegate) { | ||||||
|  | 		Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); | ||||||
|  | 		return new RequestContextSubscriber<>(delegate, authentication); | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	private void populateDefaultAuthentication(Map<String, Object> attrs) { | ||||||
|  | 		if (attrs.containsKey(AUTHENTICATION_ATTR_NAME)) { | ||||||
|  | 			return; | ||||||
|  | 		} | ||||||
|  | 		Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); | ||||||
|  | 		attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, authentication); | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	private static class RequestContextDataHolder { | ||||||
|  | 		private final Authentication authentication; | ||||||
|  | 
 | ||||||
|  | 		RequestContextDataHolder(Authentication authentication) { | ||||||
|  | 			this.authentication = authentication; | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		public Authentication getAuthentication() { | ||||||
|  | 			return this.authentication; | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	private static class RequestContextSubscriber<T> implements CoreSubscriber<T> { | ||||||
|  | 		private static final String REQUEST_CONTEXT_DATA_HOLDER_ATTR_NAME = | ||||||
|  | 				RequestContextSubscriber.class.getName().concat(".REQUEST_CONTEXT_DATA_HOLDER"); | ||||||
|  | 
 | ||||||
|  | 		private CoreSubscriber<T> delegate; | ||||||
|  | 		private final Context context; | ||||||
|  | 
 | ||||||
|  | 		private RequestContextSubscriber(CoreSubscriber<T> delegate, | ||||||
|  | 				Authentication authentication) { | ||||||
|  | 
 | ||||||
|  | 			this.delegate = delegate; | ||||||
|  | 			Context parentContext = this.delegate.currentContext(); | ||||||
|  | 			Context context; | ||||||
|  | 			if (authentication == null || parentContext.hasKey(REQUEST_CONTEXT_DATA_HOLDER_ATTR_NAME)) { | ||||||
|  | 				context = parentContext; | ||||||
|  | 			} else { | ||||||
|  | 				context = parentContext.put(REQUEST_CONTEXT_DATA_HOLDER_ATTR_NAME, | ||||||
|  | 						new RequestContextDataHolder(authentication)); | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			this.context = context; | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		@Nullable | ||||||
|  | 		static RequestContextDataHolder getRequestContext(Context ctx) { | ||||||
|  | 			return ctx.getOrDefault(REQUEST_CONTEXT_DATA_HOLDER_ATTR_NAME, null); | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		@Override | ||||||
|  | 		public Context currentContext() { | ||||||
|  | 			return this.context; | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		@Override | ||||||
|  | 		public void onSubscribe(Subscription s) { | ||||||
|  | 			this.delegate.onSubscribe(s); | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		@Override | ||||||
|  | 		public void onNext(T t) { | ||||||
|  | 			this.delegate.onNext(t); | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		@Override | ||||||
|  | 		public void onError(Throwable t) { | ||||||
|  | 			this.delegate.onError(t); | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		@Override | ||||||
|  | 		public void onComplete() { | ||||||
|  | 			this.delegate.onComplete(); | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | @ -0,0 +1,117 @@ | ||||||
|  | /* | ||||||
|  |  * Copyright 2002-2019 the original author or authors. | ||||||
|  |  * | ||||||
|  |  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  |  * you may not use this file except in compliance with the License. | ||||||
|  |  * You may obtain a copy of the License at | ||||||
|  |  * | ||||||
|  |  *      https://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  |  * | ||||||
|  |  * Unless required by applicable law or agreed to in writing, software | ||||||
|  |  * distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  |  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  |  * See the License for the specific language governing permissions and | ||||||
|  |  * limitations under the License. | ||||||
|  |  */ | ||||||
|  | 
 | ||||||
|  | package org.springframework.security.oauth2.server.resource.web.server; | ||||||
|  | 
 | ||||||
|  | import java.util.Map; | ||||||
|  | import java.util.function.Consumer; | ||||||
|  | 
 | ||||||
|  | import reactor.core.publisher.Mono; | ||||||
|  | 
 | ||||||
|  | import org.springframework.security.authentication.AnonymousAuthenticationToken; | ||||||
|  | import org.springframework.security.core.Authentication; | ||||||
|  | import org.springframework.security.core.authority.AuthorityUtils; | ||||||
|  | import org.springframework.security.core.context.ReactiveSecurityContextHolder; | ||||||
|  | import org.springframework.security.core.context.SecurityContext; | ||||||
|  | import org.springframework.security.oauth2.core.AbstractOAuth2Token; | ||||||
|  | import org.springframework.web.reactive.function.client.ClientRequest; | ||||||
|  | import org.springframework.web.reactive.function.client.ClientResponse; | ||||||
|  | import org.springframework.web.reactive.function.client.ExchangeFilterFunction; | ||||||
|  | import org.springframework.web.reactive.function.client.ExchangeFunction; | ||||||
|  | 
 | ||||||
|  | /** | ||||||
|  |  * An {@link ExchangeFilterFunction} that adds the | ||||||
|  |  * <a href="https://tools.ietf.org/html/rfc6750#section-1.2" target="_blank">Bearer Token</a> | ||||||
|  |  * from an existing {@link AbstractOAuth2Token} tied to the current {@link Authentication}. | ||||||
|  |  * | ||||||
|  |  * Suitable for Reactive applications, applying it to a typical {@link org.springframework.web.reactive.function.client.WebClient} | ||||||
|  |  * configuration: | ||||||
|  |  * | ||||||
|  |  * <pre> | ||||||
|  |  *  @Bean | ||||||
|  |  *  WebClient webClient() { | ||||||
|  |  *      ServerBearerExchangeFilterFunction bearer = new ServerBearerExchangeFilterFunction(); | ||||||
|  |  *      return WebClient.builder() | ||||||
|  |  *              .filter(bearer).build(); | ||||||
|  |  *  } | ||||||
|  |  * </pre> | ||||||
|  |  * | ||||||
|  |  * @author Josh Cummings | ||||||
|  |  * @since 5.2 | ||||||
|  |  */ | ||||||
|  | public class ServerBearerExchangeFilterFunction | ||||||
|  | 		implements ExchangeFilterFunction { | ||||||
|  | 
 | ||||||
|  | 	private static final String AUTHENTICATION_ATTR_NAME = Authentication.class.getName(); | ||||||
|  | 
 | ||||||
|  | 	private static final AnonymousAuthenticationToken ANONYMOUS_USER_TOKEN = new AnonymousAuthenticationToken("anonymous", "anonymousUser", | ||||||
|  | 			AuthorityUtils.createAuthorityList("ROLE_USER")); | ||||||
|  | 
 | ||||||
|  | 	/** | ||||||
|  | 	 * Modifies the {@link ClientRequest#attributes()} to include the {@link Authentication} to be used for | ||||||
|  | 	 * providing the Bearer Token. Example usage: | ||||||
|  | 	 * | ||||||
|  | 	 * <pre> | ||||||
|  | 	 * WebClient webClient = WebClient.builder() | ||||||
|  | 	 *    .filter(new ServerBearerExchangeFilterFunction()) | ||||||
|  | 	 *    .build(); | ||||||
|  | 	 * Mono<String> response = webClient | ||||||
|  | 	 *    .get() | ||||||
|  | 	 *    .uri(uri) | ||||||
|  | 	 *    .attributes(authentication(authentication)) | ||||||
|  | 	 *    // ... | ||||||
|  | 	 *    .retrieve() | ||||||
|  | 	 *    .bodyToMono(String.class); | ||||||
|  | 	 * </pre> | ||||||
|  | 	 * @param authentication the {@link Authentication} to use | ||||||
|  | 	 * @return the {@link Consumer} to populate the client request attributes | ||||||
|  | 	 */ | ||||||
|  | 	public static Consumer<Map<String, Object>> authentication(Authentication authentication) { | ||||||
|  | 		return attributes -> attributes.put(AUTHENTICATION_ATTR_NAME, authentication); | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	/** | ||||||
|  | 	 * {@inheritDoc} | ||||||
|  | 	 */ | ||||||
|  | 	@Override | ||||||
|  | 	public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) { | ||||||
|  | 		return oauth2Token(request.attributes()) | ||||||
|  | 				.map(oauth2Token -> bearer(request, oauth2Token)) | ||||||
|  | 				.defaultIfEmpty(request) | ||||||
|  | 				.flatMap(next::exchange); | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	private Mono<AbstractOAuth2Token> oauth2Token(Map<String, Object> attrs) { | ||||||
|  | 		return Mono.justOrEmpty(attrs.get(AUTHENTICATION_ATTR_NAME)) | ||||||
|  | 				.cast(Authentication.class) | ||||||
|  | 				.switchIfEmpty(currentAuthentication()) | ||||||
|  | 				.filter(authentication -> authentication.getCredentials() instanceof AbstractOAuth2Token) | ||||||
|  | 				.map(Authentication::getCredentials) | ||||||
|  | 				.cast(AbstractOAuth2Token.class); | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	private Mono<Authentication> currentAuthentication() { | ||||||
|  | 		return ReactiveSecurityContextHolder.getContext() | ||||||
|  | 				.map(SecurityContext::getAuthentication) | ||||||
|  | 				.defaultIfEmpty(ANONYMOUS_USER_TOKEN); | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	private ClientRequest bearer(ClientRequest request, AbstractOAuth2Token token) { | ||||||
|  | 		return ClientRequest.from(request) | ||||||
|  | 				.headers(headers -> headers.setBearerAuth(token.getTokenValue())) | ||||||
|  | 				.build(); | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | @ -0,0 +1,58 @@ | ||||||
|  | /* | ||||||
|  |  * Copyright 2002-2019 the original author or authors. | ||||||
|  |  * | ||||||
|  |  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  |  * you may not use this file except in compliance with the License. | ||||||
|  |  * You may obtain a copy of the License at | ||||||
|  |  * | ||||||
|  |  *      https://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  |  * | ||||||
|  |  * Unless required by applicable law or agreed to in writing, software | ||||||
|  |  * distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  |  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  |  * See the License for the specific language governing permissions and | ||||||
|  |  * limitations under the License. | ||||||
|  |  */ | ||||||
|  | 
 | ||||||
|  | package org.springframework.security.oauth2.server.resource.web; | ||||||
|  | 
 | ||||||
|  | import java.util.ArrayList; | ||||||
|  | import java.util.List; | ||||||
|  | 
 | ||||||
|  | import reactor.core.publisher.Mono; | ||||||
|  | 
 | ||||||
|  | import org.springframework.web.reactive.function.client.ClientRequest; | ||||||
|  | import org.springframework.web.reactive.function.client.ClientResponse; | ||||||
|  | import org.springframework.web.reactive.function.client.ExchangeFunction; | ||||||
|  | 
 | ||||||
|  | import static org.mockito.Mockito.mock; | ||||||
|  | 
 | ||||||
|  | /** | ||||||
|  |  * @author Rob Winch | ||||||
|  |  * @since 5.1 | ||||||
|  |  */ | ||||||
|  | public class MockExchangeFunction implements ExchangeFunction { | ||||||
|  | 	private List<ClientRequest> requests = new ArrayList<>(); | ||||||
|  | 
 | ||||||
|  | 	private ClientResponse response = mock(ClientResponse.class); | ||||||
|  | 
 | ||||||
|  | 	public ClientRequest getRequest() { | ||||||
|  | 		return this.requests.get(this.requests.size() - 1); | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	public List<ClientRequest> getRequests() { | ||||||
|  | 		return this.requests; | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	public ClientResponse getResponse() { | ||||||
|  | 		return this.response; | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	@Override | ||||||
|  | 	public Mono<ClientResponse> exchange(ClientRequest request) { | ||||||
|  | 		return Mono.defer(() -> { | ||||||
|  | 			this.requests.add(request); | ||||||
|  | 			return Mono.just(this.response); | ||||||
|  | 		}); | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | @ -0,0 +1,116 @@ | ||||||
|  | /* | ||||||
|  |  * Copyright 2002-2019 the original author or authors. | ||||||
|  |  * | ||||||
|  |  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  |  * you may not use this file except in compliance with the License. | ||||||
|  |  * You may obtain a copy of the License at | ||||||
|  |  * | ||||||
|  |  *      https://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  |  * | ||||||
|  |  * Unless required by applicable law or agreed to in writing, software | ||||||
|  |  * distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  |  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  |  * See the License for the specific language governing permissions and | ||||||
|  |  * limitations under the License. | ||||||
|  |  */ | ||||||
|  | 
 | ||||||
|  | package org.springframework.security.oauth2.server.resource.web; | ||||||
|  | 
 | ||||||
|  | import java.net.URI; | ||||||
|  | import java.time.Duration; | ||||||
|  | import java.time.Instant; | ||||||
|  | import java.util.Collections; | ||||||
|  | import java.util.Map; | ||||||
|  | 
 | ||||||
|  | import org.junit.After; | ||||||
|  | import org.junit.Test; | ||||||
|  | import org.junit.runner.RunWith; | ||||||
|  | import org.mockito.junit.MockitoJUnitRunner; | ||||||
|  | 
 | ||||||
|  | import org.springframework.http.HttpHeaders; | ||||||
|  | import org.springframework.security.core.Authentication; | ||||||
|  | import org.springframework.security.core.context.SecurityContextHolder; | ||||||
|  | import org.springframework.security.oauth2.core.OAuth2AccessToken; | ||||||
|  | import org.springframework.security.oauth2.server.resource.authentication.AbstractOAuth2TokenAuthenticationToken; | ||||||
|  | import org.springframework.web.reactive.function.client.ClientRequest; | ||||||
|  | 
 | ||||||
|  | import static org.assertj.core.api.Assertions.assertThat; | ||||||
|  | import static org.springframework.http.HttpMethod.GET; | ||||||
|  | import static org.springframework.security.oauth2.server.resource.web.ServletBearerExchangeFilterFunction.authentication; | ||||||
|  | 
 | ||||||
|  | /** | ||||||
|  |  * Tests for {@link ServletBearerExchangeFilterFunction} | ||||||
|  |  * | ||||||
|  |  * @author Josh Cummings | ||||||
|  |  */ | ||||||
|  | @RunWith(MockitoJUnitRunner.class) | ||||||
|  | public class ServletBearerExchangeFilterFunctionTests { | ||||||
|  | 	private ServletBearerExchangeFilterFunction function = new ServletBearerExchangeFilterFunction(); | ||||||
|  | 
 | ||||||
|  | 	private MockExchangeFunction exchange = new MockExchangeFunction(); | ||||||
|  | 
 | ||||||
|  | 	private OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, | ||||||
|  | 			"token-0", | ||||||
|  | 			Instant.now(), | ||||||
|  | 			Instant.now().plus(Duration.ofDays(1))); | ||||||
|  | 	private Authentication authentication = new AbstractOAuth2TokenAuthenticationToken<OAuth2AccessToken>(accessToken) { | ||||||
|  | 		@Override | ||||||
|  | 		public Map<String, Object> getTokenAttributes() { | ||||||
|  | 			return Collections.emptyMap(); | ||||||
|  | 		} | ||||||
|  | 	}; | ||||||
|  | 
 | ||||||
|  | 	@After | ||||||
|  | 	public void cleanup() { | ||||||
|  | 		SecurityContextHolder.clearContext(); | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	@Test | ||||||
|  | 	public void filterWhenUnauthenticatedThenAuthorizationHeaderNull() { | ||||||
|  | 		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) | ||||||
|  | 				.build(); | ||||||
|  | 
 | ||||||
|  | 		this.function.filter(request, this.exchange).block(); | ||||||
|  | 
 | ||||||
|  | 		assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)).isNull(); | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	@Test | ||||||
|  | 	public void filterWhenAuthenticatedThenAuthorizationHeaderNull() throws Exception { | ||||||
|  | 		this.function.afterPropertiesSet(); | ||||||
|  | 		SecurityContextHolder.getContext().setAuthentication(this.authentication); | ||||||
|  | 
 | ||||||
|  | 		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) | ||||||
|  | 				.build(); | ||||||
|  | 
 | ||||||
|  | 		this.function.filter(request, this.exchange).block(); | ||||||
|  | 
 | ||||||
|  | 		assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)) | ||||||
|  | 				.isEqualTo("Bearer " + this.accessToken.getTokenValue()); | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	@Test | ||||||
|  | 	public void filterWhenAuthenticationAttributeThenAuthorizationHeader() { | ||||||
|  | 		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) | ||||||
|  | 				.attributes(authentication(this.authentication)) | ||||||
|  | 				.build(); | ||||||
|  | 
 | ||||||
|  | 		this.function.filter(request, this.exchange).block(); | ||||||
|  | 
 | ||||||
|  | 		assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)) | ||||||
|  | 				.isEqualTo("Bearer " + this.accessToken.getTokenValue()); | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	@Test | ||||||
|  | 	public void filterWhenExistingAuthorizationThenSingleAuthorizationHeader() { | ||||||
|  | 		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) | ||||||
|  | 				.header(HttpHeaders.AUTHORIZATION, "Existing") | ||||||
|  | 				.attributes(authentication(this.authentication)) | ||||||
|  | 				.build(); | ||||||
|  | 
 | ||||||
|  | 		this.function.filter(request, this.exchange).block(); | ||||||
|  | 
 | ||||||
|  | 		HttpHeaders headers = this.exchange.getRequest().headers(); | ||||||
|  | 		assertThat(headers.get(HttpHeaders.AUTHORIZATION)).containsOnly("Bearer " + this.accessToken.getTokenValue()); | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | @ -0,0 +1,107 @@ | ||||||
|  | /* | ||||||
|  |  * Copyright 2002-2019 the original author or authors. | ||||||
|  |  * | ||||||
|  |  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  |  * you may not use this file except in compliance with the License. | ||||||
|  |  * You may obtain a copy of the License at | ||||||
|  |  * | ||||||
|  |  *      https://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  |  * | ||||||
|  |  * Unless required by applicable law or agreed to in writing, software | ||||||
|  |  * distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  |  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  |  * See the License for the specific language governing permissions and | ||||||
|  |  * limitations under the License. | ||||||
|  |  */ | ||||||
|  | 
 | ||||||
|  | package org.springframework.security.oauth2.server.resource.web.server; | ||||||
|  | 
 | ||||||
|  | import java.net.URI; | ||||||
|  | import java.time.Duration; | ||||||
|  | import java.time.Instant; | ||||||
|  | import java.util.Collections; | ||||||
|  | import java.util.Map; | ||||||
|  | 
 | ||||||
|  | import org.junit.Test; | ||||||
|  | 
 | ||||||
|  | import org.springframework.http.HttpHeaders; | ||||||
|  | import org.springframework.security.core.Authentication; | ||||||
|  | import org.springframework.security.core.context.ReactiveSecurityContextHolder; | ||||||
|  | import org.springframework.security.oauth2.core.OAuth2AccessToken; | ||||||
|  | import org.springframework.security.oauth2.server.resource.authentication.AbstractOAuth2TokenAuthenticationToken; | ||||||
|  | import org.springframework.security.oauth2.server.resource.web.MockExchangeFunction; | ||||||
|  | import org.springframework.web.reactive.function.client.ClientRequest; | ||||||
|  | 
 | ||||||
|  | import static org.assertj.core.api.Assertions.assertThat; | ||||||
|  | import static org.springframework.http.HttpMethod.GET; | ||||||
|  | import static org.springframework.security.oauth2.server.resource.web.ServletBearerExchangeFilterFunction.authentication; | ||||||
|  | 
 | ||||||
|  | /** | ||||||
|  |  * Tests for {@link ServerBearerExchangeFilterFunction} | ||||||
|  |  * | ||||||
|  |  * @author Josh Cummings | ||||||
|  |  */ | ||||||
|  | public class ServerBearerExchangeFilterFunctionTests { | ||||||
|  | 	private ServerBearerExchangeFilterFunction function = new ServerBearerExchangeFilterFunction(); | ||||||
|  | 
 | ||||||
|  | 	private MockExchangeFunction exchange = new MockExchangeFunction(); | ||||||
|  | 
 | ||||||
|  | 	private OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, | ||||||
|  | 			"token-0", | ||||||
|  | 			Instant.now(), | ||||||
|  | 			Instant.now().plus(Duration.ofDays(1))); | ||||||
|  | 	private Authentication authentication = new AbstractOAuth2TokenAuthenticationToken<OAuth2AccessToken>(accessToken) { | ||||||
|  | 		@Override | ||||||
|  | 		public Map<String, Object> getTokenAttributes() { | ||||||
|  | 			return Collections.emptyMap(); | ||||||
|  | 		} | ||||||
|  | 	}; | ||||||
|  | 
 | ||||||
|  | 	@Test | ||||||
|  | 	public void filterWhenUnauthenticatedThenAuthorizationHeaderNull() { | ||||||
|  | 		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) | ||||||
|  | 				.build(); | ||||||
|  | 
 | ||||||
|  | 		this.function.filter(request, this.exchange).block(); | ||||||
|  | 
 | ||||||
|  | 		assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)).isNull(); | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	@Test | ||||||
|  | 	public void filterWhenAuthenticatedThenAuthorizationHeaderNull() throws Exception { | ||||||
|  | 		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) | ||||||
|  | 				.build(); | ||||||
|  | 
 | ||||||
|  | 		this.function.filter(request, this.exchange) | ||||||
|  | 				.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(this.authentication)) | ||||||
|  | 				.block(); | ||||||
|  | 
 | ||||||
|  | 		assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)) | ||||||
|  | 				.isEqualTo("Bearer " + this.accessToken.getTokenValue()); | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	@Test | ||||||
|  | 	public void filterWhenAuthenticationAttributeThenAuthorizationHeader() { | ||||||
|  | 		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) | ||||||
|  | 				.attributes(authentication(this.authentication)) | ||||||
|  | 				.build(); | ||||||
|  | 
 | ||||||
|  | 		this.function.filter(request, this.exchange).block(); | ||||||
|  | 
 | ||||||
|  | 		assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)) | ||||||
|  | 				.isEqualTo("Bearer " + this.accessToken.getTokenValue()); | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	@Test | ||||||
|  | 	public void filterWhenExistingAuthorizationThenSingleAuthorizationHeader() { | ||||||
|  | 		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) | ||||||
|  | 				.header(HttpHeaders.AUTHORIZATION, "Existing") | ||||||
|  | 				.attributes(authentication(this.authentication)) | ||||||
|  | 				.build(); | ||||||
|  | 
 | ||||||
|  | 		this.function.filter(request, this.exchange).block(); | ||||||
|  | 
 | ||||||
|  | 		HttpHeaders headers = this.exchange.getRequest().headers(); | ||||||
|  | 		assertThat(headers.get(HttpHeaders.AUTHORIZATION)).containsOnly("Bearer " + this.accessToken.getTokenValue()); | ||||||
|  | 	} | ||||||
|  | } | ||||||
		Loading…
	
		Reference in New Issue