diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/config/CorsRegistration.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/config/CorsRegistration.java new file mode 100644 index 00000000000..eb0bf1534e3 --- /dev/null +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/config/CorsRegistration.java @@ -0,0 +1,101 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.config; + +import java.util.ArrayList; +import java.util.Arrays; + +import org.springframework.http.HttpMethod; +import org.springframework.web.bind.annotation.CrossOrigin; +import org.springframework.web.cors.CorsConfiguration; + +/** + * {@code CorsRegistration} assists with the creation of a + * {@link CorsConfiguration} instance mapped to a path pattern. + * + *

If no path pattern is specified, cross-origin request handling is + * mapped to {@code "/**"}. + * + *

By default, all origins, all headers, credentials and {@code GET}, + * {@code HEAD}, and {@code POST} methods are allowed, and the max age is + * set to 30 minutes. + * + * @author Sebastien Deleuze + * @author Sam Brannen + * @since 5.0 + * @see CorsConfiguration + * @see CorsRegistry + */ +public class CorsRegistration { + + private final String pathPattern; + + private final CorsConfiguration config; + + + public CorsRegistration(String pathPattern) { + this.pathPattern = pathPattern; + // Same implicit default values as the @CrossOrigin annotation + allows simple methods + this.config = new CorsConfiguration(); + this.config.setAllowedOrigins(Arrays.asList(CrossOrigin.DEFAULT_ORIGINS)); + this.config.setAllowedMethods(Arrays.asList(HttpMethod.GET.name(), + HttpMethod.HEAD.name(), HttpMethod.POST.name())); + this.config.setAllowedHeaders(Arrays.asList(CrossOrigin.DEFAULT_ALLOWED_HEADERS)); + this.config.setAllowCredentials(CrossOrigin.DEFAULT_ALLOW_CREDENTIALS); + this.config.setMaxAge(CrossOrigin.DEFAULT_MAX_AGE); + } + + + public CorsRegistration allowedOrigins(String... origins) { + this.config.setAllowedOrigins(new ArrayList<>(Arrays.asList(origins))); + return this; + } + + public CorsRegistration allowedMethods(String... methods) { + this.config.setAllowedMethods(new ArrayList<>(Arrays.asList(methods))); + return this; + } + + public CorsRegistration allowedHeaders(String... headers) { + this.config.setAllowedHeaders(new ArrayList<>(Arrays.asList(headers))); + return this; + } + + public CorsRegistration exposedHeaders(String... headers) { + this.config.setExposedHeaders(new ArrayList<>(Arrays.asList(headers))); + return this; + } + + public CorsRegistration maxAge(long maxAge) { + this.config.setMaxAge(maxAge); + return this; + } + + public CorsRegistration allowCredentials(boolean allowCredentials) { + this.config.setAllowCredentials(allowCredentials); + return this; + } + + protected String getPathPattern() { + return this.pathPattern; + } + + protected CorsConfiguration getCorsConfiguration() { + return this.config; + } + +} diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/config/CorsRegistry.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/config/CorsRegistry.java new file mode 100644 index 00000000000..f7f11e7a1dc --- /dev/null +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/config/CorsRegistry.java @@ -0,0 +1,61 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.config; + +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import org.springframework.web.cors.CorsConfiguration; + +/** + * {@code CorsRegistry} assists with the registration of {@link CorsConfiguration} + * mapped to a path pattern. + * + * @author Sebastien Deleuze + * @since 5.0 + */ +public class CorsRegistry { + + private final List registrations = new ArrayList<>(); + + + /** + * Enable cross origin request handling for the specified path pattern. + * + *

Exact path mapping URIs (such as {@code "/admin"}) are supported as + * well as Ant-style path patterns (such as {@code "/admin/**"}). + * + *

By default, all origins, all headers, credentials and {@code GET}, + * {@code HEAD}, and {@code POST} methods are allowed, and the max age + * is set to 30 minutes. + */ + public CorsRegistration addMapping(String pathPattern) { + CorsRegistration registration = new CorsRegistration(pathPattern); + this.registrations.add(registration); + return registration; + } + + protected Map getCorsConfigurations() { + Map configs = new LinkedHashMap<>(this.registrations.size()); + for (CorsRegistration registration : this.registrations) { + configs.put(registration.getPathPattern(), registration.getCorsConfiguration()); + } + return configs; + } +} diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/config/WebReactiveConfiguration.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/config/WebReactiveConfiguration.java index b73398fd00d..2a92ab463a5 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/config/WebReactiveConfiguration.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/config/WebReactiveConfiguration.java @@ -54,6 +54,7 @@ import org.springframework.http.codec.xml.Jaxb2XmlEncoder; import org.springframework.util.ClassUtils; import org.springframework.validation.Errors; import org.springframework.validation.Validator; +import org.springframework.web.cors.CorsConfiguration; import org.springframework.web.reactive.HandlerMapping; import org.springframework.web.reactive.accept.CompositeContentTypeResolver; import org.springframework.web.reactive.accept.RequestedContentTypeResolverBuilder; @@ -95,6 +96,8 @@ public class WebReactiveConfiguration implements ApplicationContextAware { private List> messageWriters; + private Map corsConfigurations; + private ApplicationContext applicationContext; @@ -113,6 +116,7 @@ public class WebReactiveConfiguration implements ApplicationContextAware { RequestMappingHandlerMapping mapping = createRequestMappingHandlerMapping(); mapping.setOrder(0); mapping.setContentTypeResolver(mvcContentTypeResolver()); + mapping.setCorsConfigurations(getCorsConfigurations()); PathMatchConfigurer configurer = getPathMatchConfigurer(); if (configurer.isUseSuffixPatternMatch() != null) { @@ -440,6 +444,22 @@ public class WebReactiveConfiguration implements ApplicationContextAware { protected void configureViewResolvers(ViewResolverRegistry registry) { } + protected final Map getCorsConfigurations() { + if (this.corsConfigurations == null) { + CorsRegistry registry = new CorsRegistry(); + addCorsMappings(registry); + this.corsConfigurations = registry.getCorsConfigurations(); + } + return this.corsConfigurations; + } + + /** + * Override this method to configure cross origin requests processing. + * @see CorsRegistry + */ + protected void addCorsMappings(CorsRegistry registry) { + } + private static final class EmptyHandlerMapping extends AbstractHandlerMapping { diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/handler/AbstractHandlerMapping.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/handler/AbstractHandlerMapping.java index 7dac4edc954..b009b23cc73 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/handler/AbstractHandlerMapping.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/handler/AbstractHandlerMapping.java @@ -15,12 +15,24 @@ */ package org.springframework.web.reactive.handler; +import java.util.Map; + +import reactor.core.publisher.Mono; + import org.springframework.context.support.ApplicationObjectSupport; import org.springframework.core.Ordered; import org.springframework.util.AntPathMatcher; import org.springframework.util.Assert; import org.springframework.util.PathMatcher; +import org.springframework.web.cors.CorsConfiguration; +import org.springframework.web.cors.reactive.CorsConfigurationSource; +import org.springframework.web.cors.reactive.CorsProcessor; +import org.springframework.web.cors.reactive.CorsUtils; +import org.springframework.web.cors.reactive.DefaultCorsProcessor; +import org.springframework.web.cors.reactive.UrlBasedCorsConfigurationSource; import org.springframework.web.reactive.HandlerMapping; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebHandler; import org.springframework.web.util.HttpRequestPathHelper; /** @@ -39,8 +51,9 @@ public abstract class AbstractHandlerMapping extends ApplicationObjectSupport private PathMatcher pathMatcher = new AntPathMatcher(); + protected CorsProcessor corsProcessor = new DefaultCorsProcessor(); - // TODO: CORS + protected final UrlBasedCorsConfigurationSource corsConfigSource = new UrlBasedCorsConfigurationSource(); /** * Specify the order value for this HandlerMapping bean. @@ -91,7 +104,7 @@ public abstract class AbstractHandlerMapping extends ApplicationObjectSupport public void setPathMatcher(PathMatcher pathMatcher) { Assert.notNull(pathMatcher, "PathMatcher must not be null"); this.pathMatcher = pathMatcher; - // this.corsConfigSource.setPathMatcher(pathMatcher); + this.corsConfigSource.setPathMatcher(pathMatcher); } /** @@ -102,4 +115,62 @@ public abstract class AbstractHandlerMapping extends ApplicationObjectSupport return this.pathMatcher; } + /** + * Configure a custom {@link CorsProcessor} to use to apply the matched + * {@link CorsConfiguration} for a request. By default {@link DefaultCorsProcessor} is used. + */ + public void setCorsProcessor(CorsProcessor corsProcessor) { + Assert.notNull(corsProcessor, "CorsProcessor must not be null"); + this.corsProcessor = corsProcessor; + } + + /** + * Return the configured {@link CorsProcessor}. + */ + public CorsProcessor getCorsProcessor() { + return this.corsProcessor; + } + + /** + * Set "global" CORS configuration based on URL patterns. By default the first + * matching URL pattern is combined with the CORS configuration for the + * handler, if any. + */ + public void setCorsConfigurations(Map corsConfigurations) { + this.corsConfigSource.setCorsConfigurations(corsConfigurations); + } + + /** + * Get the CORS configuration. + */ + public Map getCorsConfigurations() { + return this.corsConfigSource.getCorsConfigurations(); + } + + protected CorsConfiguration getCorsConfiguration(Object handler, ServerWebExchange exchange) { + if (handler != null && handler instanceof CorsConfigurationSource) { + return ((CorsConfigurationSource) handler).getCorsConfiguration(exchange); + } + return null; + } + + protected Object processCorsRequest(ServerWebExchange exchange, Object handler) { + if (CorsUtils.isCorsRequest(exchange.getRequest())) { + CorsConfiguration globalConfig = this.corsConfigSource.getCorsConfiguration(exchange); + CorsConfiguration handlerConfig = getCorsConfiguration(handler, exchange); + CorsConfiguration config = (globalConfig != null ? globalConfig.combine(handlerConfig) : handlerConfig); + if (!corsProcessor.processRequest(config, exchange) || CorsUtils.isPreFlightRequest(exchange.getRequest())) { + return new NoOpHandler(); + } + } + return handler; + } + + private class NoOpHandler implements WebHandler { + @Override + public Mono handle(ServerWebExchange exchange) { + return Mono.empty(); + } + } + } diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/handler/AbstractUrlHandlerMapping.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/handler/AbstractUrlHandlerMapping.java index 1cc674bb606..bb87cee7078 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/handler/AbstractUrlHandlerMapping.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/handler/AbstractUrlHandlerMapping.java @@ -101,6 +101,7 @@ public abstract class AbstractUrlHandlerMapping extends AbstractHandlerMapping { Object handler = null; try { handler = lookupHandler(lookupPath, exchange); + handler = processCorsRequest(exchange, handler); } catch (Exception ex) { return Mono.error(ex); diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/result/condition/ConsumesRequestCondition.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/result/condition/ConsumesRequestCondition.java index e69f12f8f49..37fafc379a6 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/result/condition/ConsumesRequestCondition.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/result/condition/ConsumesRequestCondition.java @@ -27,6 +27,7 @@ import java.util.Set; import org.springframework.http.InvalidMediaTypeException; import org.springframework.http.MediaType; import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.cors.reactive.CorsUtils; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.UnsupportedMediaTypeStatusException; @@ -43,7 +44,7 @@ import org.springframework.web.server.UnsupportedMediaTypeStatusException; */ public final class ConsumesRequestCondition extends AbstractRequestCondition { -// private final static ConsumesRequestCondition PRE_FLIGHT_MATCH = new ConsumesRequestCondition(); + private final static ConsumesRequestCondition PRE_FLIGHT_MATCH = new ConsumesRequestCondition(); private final List expressions; @@ -160,9 +161,9 @@ public final class ConsumesRequestCondition extends AbstractRequestCondition { -// private final static HeadersRequestCondition PRE_FLIGHT_MATCH = new HeadersRequestCondition(); + private final static HeadersRequestCondition PRE_FLIGHT_MATCH = new HeadersRequestCondition(); private final Set expressions; @@ -107,9 +108,9 @@ public final class HeadersRequestCondition extends AbstractRequestCondition { -// private final static ProducesRequestCondition PRE_FLIGHT_MATCH = new ProducesRequestCondition(); + private final static ProducesRequestCondition PRE_FLIGHT_MATCH = new ProducesRequestCondition(); private final List MEDIA_TYPE_ALL_LIST = @@ -182,9 +183,9 @@ public final class ProducesRequestCondition extends AbstractRequestCondition extends AbstractHandlerMap */ private static final String SCOPED_TARGET_NAME_PREFIX = "scopedTarget."; + private static final HandlerMethod PREFLIGHT_AMBIGUOUS_MATCH = + new HandlerMethod(new EmptyHandler(), ClassUtils.getMethod(EmptyHandler.class, "handle")); + + private static final CorsConfiguration ALLOW_CORS_CONFIG = new CorsConfiguration(); + + static { + ALLOW_CORS_CONFIG.addAllowedOrigin("*"); + ALLOW_CORS_CONFIG.addAllowedMethod("*"); + ALLOW_CORS_CONFIG.addAllowedHeader("*"); + ALLOW_CORS_CONFIG.setAllowCredentials(true); + } + private final MappingRegistry mappingRegistry = new MappingRegistry(); @@ -212,6 +230,13 @@ public abstract class AbstractHandlerMethodMapping extends AbstractHandlerMap return handlerMethod; } + /** + * Extract and return the CORS configuration for the mapping. + */ + protected CorsConfiguration initCorsConfiguration(Object handler, Method method, T mapping) { + return null; + } + /** * Invoked after all handler methods have been detected. * @param handlerMethods a read-only map with handler methods and mappings. @@ -249,7 +274,10 @@ public abstract class AbstractHandlerMethodMapping extends AbstractHandlerMap logger.debug("Did not find handler method for [" + lookupPath + "]"); } } - return (handlerMethod != null ? Mono.just(handlerMethod.createWithResolvedBean()) : Mono.empty()); + if (handlerMethod != null) { + handlerMethod = handlerMethod.createWithResolvedBean(); + } + return Mono.justOrEmpty(processCorsRequest(exchange, handlerMethod)); } finally { this.mappingRegistry.releaseReadLock(); @@ -287,6 +315,9 @@ public abstract class AbstractHandlerMethodMapping extends AbstractHandlerMap } Match bestMatch = matches.get(0); if (matches.size() > 1) { + if (CorsUtils.isPreFlightRequest(exchange.getRequest())) { + return PREFLIGHT_AMBIGUOUS_MATCH; + } Match secondBestMatch = matches.get(1); if (comparator.compare(bestMatch, secondBestMatch) == 0) { Method m1 = bestMatch.handlerMethod.getMethod(); @@ -335,6 +366,22 @@ public abstract class AbstractHandlerMethodMapping extends AbstractHandlerMap return null; } + @Override + protected CorsConfiguration getCorsConfiguration(Object handler, ServerWebExchange exchange) { + CorsConfiguration corsConfig = super.getCorsConfiguration(handler, exchange); + if (handler instanceof HandlerMethod) { + HandlerMethod handlerMethod = (HandlerMethod) handler; + if (handlerMethod.equals(PREFLIGHT_AMBIGUOUS_MATCH)) { + return AbstractHandlerMethodMapping.ALLOW_CORS_CONFIG; + } + else { + CorsConfiguration corsConfigFromMethod = this.mappingRegistry.getCorsConfiguration(handlerMethod); + corsConfig = (corsConfig != null ? corsConfig.combine(corsConfigFromMethod) : corsConfigFromMethod); + } + } + return corsConfig; + } + // Abstract template methods @@ -392,6 +439,9 @@ public abstract class AbstractHandlerMethodMapping extends AbstractHandlerMap private final MultiValueMap urlLookup = new LinkedMultiValueMap<>(); + private final Map corsLookup = + new ConcurrentHashMap<>(); + private final ReentrantReadWriteLock readWriteLock = new ReentrantReadWriteLock(); /** @@ -410,6 +460,14 @@ public abstract class AbstractHandlerMethodMapping extends AbstractHandlerMap return this.urlLookup.get(urlPath); } + /** + * Return CORS configuration. Thread-safe for concurrent use. + */ + public CorsConfiguration getCorsConfiguration(HandlerMethod handlerMethod) { + HandlerMethod original = handlerMethod.getResolvedFromHandlerMethod(); + return this.corsLookup.get(original != null ? original : handlerMethod); + } + /** * Acquire the read lock when using getMappings and getMappingsByUrl. */ @@ -440,6 +498,11 @@ public abstract class AbstractHandlerMethodMapping extends AbstractHandlerMap this.urlLookup.add(url, mapping); } + CorsConfiguration corsConfig = initCorsConfiguration(handler, method, mapping); + if (corsConfig != null) { + this.corsLookup.put(handlerMethod, corsConfig); + } + this.registry.put(mapping, new MappingRegistration<>(mapping, handlerMethod, directUrls)); } finally { @@ -486,6 +549,7 @@ public abstract class AbstractHandlerMethodMapping extends AbstractHandlerMap } } } + this.corsLookup.remove(definition.getHandlerMethod()); } finally { this.readWriteLock.writeLock().unlock(); @@ -561,4 +625,11 @@ public abstract class AbstractHandlerMethodMapping extends AbstractHandlerMap } } + private static class EmptyHandler { + + public void handle() { + throw new UnsupportedOperationException("not implemented"); + } + } + } diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestMappingHandlerMapping.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestMappingHandlerMapping.java index b510510cba2..e87cf478553 100644 --- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestMappingHandlerMapping.java +++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestMappingHandlerMapping.java @@ -18,14 +18,20 @@ package org.springframework.web.reactive.result.method.annotation; import java.lang.reflect.AnnotatedElement; import java.lang.reflect.Method; +import java.util.Arrays; import java.util.Set; import org.springframework.context.EmbeddedValueResolverAware; import org.springframework.core.annotation.AnnotatedElementUtils; import org.springframework.stereotype.Controller; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; import org.springframework.util.StringValueResolver; +import org.springframework.web.bind.annotation.CrossOrigin; import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RequestMethod; +import org.springframework.web.cors.CorsConfiguration; +import org.springframework.web.method.HandlerMethod; import org.springframework.web.reactive.accept.RequestedContentTypeResolverBuilder; import org.springframework.web.reactive.accept.RequestedContentTypeResolver; import org.springframework.web.reactive.result.condition.RequestCondition; @@ -273,4 +279,76 @@ public class RequestMappingHandlerMapping extends RequestMappingInfoHandlerMappi } } + @Override + protected CorsConfiguration initCorsConfiguration(Object handler, Method method, RequestMappingInfo mappingInfo) { + HandlerMethod handlerMethod = createHandlerMethod(handler, method); + CrossOrigin typeAnnotation = AnnotatedElementUtils.findMergedAnnotation(handlerMethod.getBeanType(), CrossOrigin.class); + CrossOrigin methodAnnotation = AnnotatedElementUtils.findMergedAnnotation(method, CrossOrigin.class); + + if (typeAnnotation == null && methodAnnotation == null) { + return null; + } + + CorsConfiguration config = new CorsConfiguration(); + updateCorsConfig(config, typeAnnotation); + updateCorsConfig(config, methodAnnotation); + + if (CollectionUtils.isEmpty(config.getAllowedOrigins())) { + config.setAllowedOrigins(Arrays.asList(CrossOrigin.DEFAULT_ORIGINS)); + } + if (CollectionUtils.isEmpty(config.getAllowedMethods())) { + for (RequestMethod allowedMethod : mappingInfo.getMethodsCondition().getMethods()) { + config.addAllowedMethod(allowedMethod.name()); + } + } + if (CollectionUtils.isEmpty(config.getAllowedHeaders())) { + config.setAllowedHeaders(Arrays.asList(CrossOrigin.DEFAULT_ALLOWED_HEADERS)); + } + if (config.getAllowCredentials() == null) { + config.setAllowCredentials(CrossOrigin.DEFAULT_ALLOW_CREDENTIALS); + } + if (config.getMaxAge() == null) { + config.setMaxAge(CrossOrigin.DEFAULT_MAX_AGE); + } + return config; + } + + private void updateCorsConfig(CorsConfiguration config, CrossOrigin annotation) { + if (annotation == null) { + return; + } + for (String origin : annotation.origins()) { + config.addAllowedOrigin(resolveCorsAnnotationValue(origin)); + } + for (RequestMethod method : annotation.methods()) { + config.addAllowedMethod(method.name()); + } + for (String header : annotation.allowedHeaders()) { + config.addAllowedHeader(resolveCorsAnnotationValue(header)); + } + for (String header : annotation.exposedHeaders()) { + config.addExposedHeader(resolveCorsAnnotationValue(header)); + } + + String allowCredentials = resolveCorsAnnotationValue(annotation.allowCredentials()); + if ("true".equalsIgnoreCase(allowCredentials)) { + config.setAllowCredentials(true); + } + else if ("false".equalsIgnoreCase(allowCredentials)) { + config.setAllowCredentials(false); + } + else if (!allowCredentials.isEmpty()) { + throw new IllegalStateException("@CrossOrigin's allowCredentials value must be \"true\", \"false\", " + + "or an empty string (\"\"): current value is [" + allowCredentials + "]"); + } + + if (annotation.maxAge() >= 0 && config.getMaxAge() == null) { + config.setMaxAge(annotation.maxAge()); + } + } + + private String resolveCorsAnnotationValue(String value) { + return (this.embeddedValueResolver != null ? this.embeddedValueResolver.resolveStringValue(value) : value); + } + } diff --git a/spring-web-reactive/src/test/java/org/springframework/web/reactive/config/CorsRegistryTests.java b/spring-web-reactive/src/test/java/org/springframework/web/reactive/config/CorsRegistryTests.java new file mode 100644 index 00000000000..c131c28b02f --- /dev/null +++ b/spring-web-reactive/src/test/java/org/springframework/web/reactive/config/CorsRegistryTests.java @@ -0,0 +1,71 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.config; + +import java.util.Arrays; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import org.junit.Before; +import org.junit.Test; + +import org.springframework.web.cors.CorsConfiguration; + +/** + * Test fixture with a {@link CorsRegistry}. + * + * @author Sebastien Deleuze + */ +public class CorsRegistryTests { + + private CorsRegistry registry; + + @Before + public void setUp() { + this.registry = new CorsRegistry(); + } + + @Test + public void noMapping() { + assertTrue(this.registry.getCorsConfigurations().isEmpty()); + } + + @Test + public void multipleMappings() { + this.registry.addMapping("/foo"); + this.registry.addMapping("/bar"); + assertEquals(2, this.registry.getCorsConfigurations().size()); + } + + @Test + public void customizedMapping() { + this.registry.addMapping("/foo").allowedOrigins("http://domain2.com", "http://domain2.com") + .allowedMethods("DELETE").allowCredentials(false).allowedHeaders("header1", "header2") + .exposedHeaders("header3", "header4").maxAge(3600); + Map configs = this.registry.getCorsConfigurations(); + assertEquals(1, configs.size()); + CorsConfiguration config = configs.get("/foo"); + assertEquals(Arrays.asList("http://domain2.com", "http://domain2.com"), config.getAllowedOrigins()); + assertEquals(Arrays.asList("DELETE"), config.getAllowedMethods()); + assertEquals(Arrays.asList("header1", "header2"), config.getAllowedHeaders()); + assertEquals(Arrays.asList("header3", "header4"), config.getExposedHeaders()); + assertEquals(false, config.getAllowCredentials()); + assertEquals(Long.valueOf(3600), config.getMaxAge()); + } + +} diff --git a/spring-web-reactive/src/test/java/org/springframework/web/reactive/handler/CorsAbstractUrlHandlerMappingTests.java b/spring-web-reactive/src/test/java/org/springframework/web/reactive/handler/CorsAbstractUrlHandlerMappingTests.java new file mode 100644 index 00000000000..3ed1bc6efc5 --- /dev/null +++ b/spring-web-reactive/src/test/java/org/springframework/web/reactive/handler/CorsAbstractUrlHandlerMappingTests.java @@ -0,0 +1,182 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.web.reactive.handler; + +import java.net.URISyntaxException; +import java.util.Collections; + +import static org.junit.Assert.*; +import static org.junit.Assert.assertSame; +import org.junit.Before; +import org.junit.Test; + +import org.springframework.context.annotation.AnnotationConfigApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; +import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse; +import org.springframework.web.cors.CorsConfiguration; +import org.springframework.web.cors.reactive.CorsConfigurationSource; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.adapter.DefaultServerWebExchange; +import org.springframework.web.server.session.MockWebSessionManager; +import org.springframework.web.server.session.WebSessionManager; + +/** + * Unit tests for CORS support at {@link AbstractUrlHandlerMapping} level. + * + * @author Sebastien Deleuze + * @author Rossen Stoyanchev + */ +public class CorsAbstractUrlHandlerMappingTests { + + private AnnotationConfigApplicationContext wac; + + private TestUrlHandlerMapping handlerMapping; + + private Object mainController; + + private CorsAwareHandler corsConfigurationSourceController; + + @Before + public void setup() { + wac = new AnnotationConfigApplicationContext(); + wac.register(WebConfig.class); + wac.refresh(); + + handlerMapping = (TestUrlHandlerMapping) wac.getBean("handlerMapping"); + mainController = wac.getBean("mainController"); + corsConfigurationSourceController = (CorsAwareHandler) wac.getBean("corsConfigurationSourceController"); + } + + @Test + public void actualRequestWithoutCorsConfigurationProvider() throws Exception { + ServerWebExchange exchange = createExchange(HttpMethod.GET, "/welcome.html", "http://domain2.com", "GET"); + Object actual = handlerMapping.getHandler(exchange).block(); + assertNotNull(actual); + assertSame(mainController, actual); + } + + @Test + public void preflightRequestWithoutCorsConfigurationProvider() throws Exception { + ServerWebExchange exchange = createExchange(HttpMethod.OPTIONS, "/welcome.html", "http://domain2.com", "GET"); + Object actual = handlerMapping.getHandler(exchange).block(); + assertNotNull(actual); + assertEquals("NoOpHandler", actual.getClass().getSimpleName()); + assertNull(exchange.getResponse().getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + } + + @Test + public void actualRequestWithCorsConfigurationProvider() throws Exception { + ServerWebExchange exchange = createExchange(HttpMethod.GET, "/cors.html", "http://domain2.com", "GET"); + Object actual = handlerMapping.getHandler(exchange).block(); + assertNotNull(actual); + assertSame(corsConfigurationSourceController, actual); + CorsConfiguration config = ((CorsConfigurationSource)actual).getCorsConfiguration(createExchange(HttpMethod.GET, "", "","")); + assertNotNull(config); + assertArrayEquals(config.getAllowedOrigins().toArray(), new String[]{"*"}); + assertEquals("*", exchange.getResponse().getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + } + + @Test + public void preflightRequestWithCorsConfigurationProvider() throws Exception { + ServerWebExchange exchange = createExchange(HttpMethod.OPTIONS, "/cors.html", "http://domain2.com", "GET"); + Object actual = handlerMapping.getHandler(exchange).block(); + assertNotNull(actual); + assertEquals("NoOpHandler", actual.getClass().getSimpleName()); + assertEquals("*", exchange.getResponse().getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + } + + @Test + public void actualRequestWithMappedCorsConfiguration() throws Exception { + CorsConfiguration mappedConfig = new CorsConfiguration(); + mappedConfig.addAllowedOrigin("*"); + this.handlerMapping.setCorsConfigurations(Collections.singletonMap("/welcome.html", mappedConfig)); + + ServerWebExchange exchange = createExchange(HttpMethod.GET, "/welcome.html", "http://domain2.com", "GET"); + Object actual = handlerMapping.getHandler(exchange).block(); + assertNotNull(actual); + assertSame(mainController, actual); + assertEquals("*", exchange.getResponse().getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + } + + @Test + public void preflightRequestWithMappedCorsConfiguration() throws Exception { + CorsConfiguration mappedConfig = new CorsConfiguration(); + mappedConfig.addAllowedOrigin("*"); + this.handlerMapping.setCorsConfigurations(Collections.singletonMap("/welcome.html", mappedConfig)); + + ServerWebExchange exchange = createExchange(HttpMethod.OPTIONS, "/welcome.html", "http://domain2.com", "GET"); + Object actual = handlerMapping.getHandler(exchange).block(); + assertNotNull(actual); + assertEquals("NoOpHandler", actual.getClass().getSimpleName()); + assertEquals("*", exchange.getResponse().getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + } + + + private ServerWebExchange createExchange(HttpMethod method, String path, String origin, + String accessControlRequestMethod) throws URISyntaxException { + + ServerHttpRequest request = new MockServerHttpRequest(method, "http://localhost" + path); + request.getHeaders().add(HttpHeaders.ORIGIN, origin); + request.getHeaders().add(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, accessControlRequestMethod); + WebSessionManager sessionManager = new MockWebSessionManager(); + return new DefaultServerWebExchange(request, new MockServerHttpResponse(), sessionManager); + } + + + @Configuration + static class WebConfig { + + @Bean @SuppressWarnings("unused") + public TestUrlHandlerMapping handlerMapping() { + TestUrlHandlerMapping hm = new TestUrlHandlerMapping(); + hm.setUseTrailingSlashMatch(true); + hm.registerHandler("/welcome.html", mainController()); + hm.registerHandler("/cors.html", corsConfigurationSourceController()); + return hm; + } + + @Bean + public Object mainController() { + return new Object(); + } + + @Bean + public CorsAwareHandler corsConfigurationSourceController() { + return new CorsAwareHandler(); + } + + } + + static class TestUrlHandlerMapping extends AbstractUrlHandlerMapping { + + } + + static class CorsAwareHandler implements CorsConfigurationSource { + + @Override + public CorsConfiguration getCorsConfiguration(ServerWebExchange exchange) { + CorsConfiguration config = new CorsConfiguration(); + config.addAllowedOrigin("*"); + return config; + } + } + +} diff --git a/spring-web-reactive/src/test/java/org/springframework/web/reactive/result/method/annotation/AbstractRequestMappingIntegrationTests.java b/spring-web-reactive/src/test/java/org/springframework/web/reactive/result/method/annotation/AbstractRequestMappingIntegrationTests.java index 1d79704359c..b2aa35a3bfd 100644 --- a/spring-web-reactive/src/test/java/org/springframework/web/reactive/result/method/annotation/AbstractRequestMappingIntegrationTests.java +++ b/spring-web-reactive/src/test/java/org/springframework/web/reactive/result/method/annotation/AbstractRequestMappingIntegrationTests.java @@ -27,6 +27,7 @@ import org.springframework.http.server.reactive.HttpHandler; import org.springframework.web.client.RestTemplate; import org.springframework.web.reactive.DispatcherHandler; import org.springframework.web.server.adapter.WebHttpHandlerBuilder; +import org.springframework.web.server.handler.ResponseStatusExceptionHandler; import static org.springframework.http.RequestEntity.get; @@ -46,6 +47,7 @@ public abstract class AbstractRequestMappingIntegrationTests extends AbstractHtt this.applicationContext = initApplicationContext(); return WebHttpHandlerBuilder .webHandler(new DispatcherHandler(this.applicationContext)) + .exceptionHandlers(new ResponseStatusExceptionHandler()) .build(); } diff --git a/spring-web-reactive/src/test/java/org/springframework/web/reactive/result/method/annotation/CorsConfigurationIntegrationTests.java b/spring-web-reactive/src/test/java/org/springframework/web/reactive/result/method/annotation/CorsConfigurationIntegrationTests.java new file mode 100644 index 00000000000..b461559e2a2 --- /dev/null +++ b/spring-web-reactive/src/test/java/org/springframework/web/reactive/result/method/annotation/CorsConfigurationIntegrationTests.java @@ -0,0 +1,183 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.result.method.annotation; + +import static org.junit.Assert.*; +import org.junit.Test; + +import org.springframework.context.ApplicationContext; +import org.springframework.context.annotation.AnnotationConfigApplicationContext; +import org.springframework.context.annotation.ComponentScan; +import org.springframework.context.annotation.Configuration; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.ResponseEntity; +import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.client.HttpClientErrorException; +import org.springframework.web.client.RestTemplate; +import org.springframework.web.reactive.config.CorsRegistry; +import org.springframework.web.reactive.config.WebReactiveConfiguration; + +/** + * @author Sebastien Deleuze + */ +public class CorsConfigurationIntegrationTests extends AbstractRequestMappingIntegrationTests { + + // JDK default HTTP client blacklist headers like Origin + private RestTemplate restTemplate = new RestTemplate(new HttpComponentsClientHttpRequestFactory()); + + @Override + protected ApplicationContext initApplicationContext() { + AnnotationConfigApplicationContext wac = new AnnotationConfigApplicationContext(); + wac.register(WebConfig.class); + wac.refresh(); + return wac; + } + + @Override + RestTemplate getRestTemplate() { + return this.restTemplate; + } + + @Test + public void actualRequestWithCorsEnabled() throws Exception { + HttpHeaders headers = new HttpHeaders(); + headers.add(HttpHeaders.ORIGIN, "http://localhost:9000"); + HttpEntity requestEntity = new HttpEntity(headers); + ResponseEntity entity = this.restTemplate.exchange(getUrl("/cors"), + HttpMethod.GET, requestEntity, String.class); + assertEquals(HttpStatus.OK, entity.getStatusCode()); + assertEquals("http://localhost:9000", entity.getHeaders().getAccessControlAllowOrigin()); + assertEquals("cors", entity.getBody()); + } + + @Test + public void actualRequestWithCorsRejected() throws Exception { + HttpHeaders headers = new HttpHeaders(); + headers.add(HttpHeaders.ORIGIN, "http://localhost:9000"); + HttpEntity requestEntity = new HttpEntity(headers); + try { + this.restTemplate.exchange(getUrl("/cors-restricted"), HttpMethod.GET, + requestEntity, String.class); + } + catch (HttpClientErrorException e) { + assertEquals(HttpStatus.FORBIDDEN, e.getStatusCode()); + return; + } + fail(); + } + + @Test + public void actualRequestWithoutCorsEnabled() throws Exception { + HttpHeaders headers = new HttpHeaders(); + headers.add(HttpHeaders.ORIGIN, "http://localhost:9000"); + HttpEntity requestEntity = new HttpEntity(headers); + ResponseEntity entity = this.restTemplate.exchange(getUrl("/welcome"), + HttpMethod.GET, requestEntity, String.class); + assertEquals(HttpStatus.OK, entity.getStatusCode()); + assertNull(entity.getHeaders().getAccessControlAllowOrigin()); + assertEquals("welcome", entity.getBody()); + } + + @Test + public void preflightRequestWithCorsEnabled() throws Exception { + HttpHeaders headers = new HttpHeaders(); + headers.add(HttpHeaders.ORIGIN, "http://localhost:9000"); + headers.add(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); + HttpEntity requestEntity = new HttpEntity(headers); + ResponseEntity entity = this.restTemplate.exchange(getUrl("/cors"), + HttpMethod.OPTIONS, requestEntity, String.class); + assertEquals(HttpStatus.OK, entity.getStatusCode()); + assertEquals("http://localhost:9000", entity.getHeaders().getAccessControlAllowOrigin()); + } + + @Test + public void preflightRequestWithCorsRejected() throws Exception { + HttpHeaders headers = new HttpHeaders(); + headers.add(HttpHeaders.ORIGIN, "http://localhost:9000"); + headers.add(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); + HttpEntity requestEntity = new HttpEntity(headers); + try { + this.restTemplate.exchange(getUrl("/cors-restricted"), HttpMethod.OPTIONS, + requestEntity, String.class); + } + catch (HttpClientErrorException e) { + assertEquals(HttpStatus.FORBIDDEN, e.getStatusCode()); + return; + } + fail(); + } + + @Test + public void preflightRequestWithoutCorsEnabled() throws Exception { + HttpHeaders headers = new HttpHeaders(); + headers.add(HttpHeaders.ORIGIN, "http://localhost:9000"); + headers.add(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); + HttpEntity requestEntity = new HttpEntity(headers); + try { + this.restTemplate.exchange(getUrl("/welcome"), HttpMethod.OPTIONS, + requestEntity, String.class); + } + catch (HttpClientErrorException e) { + assertEquals(HttpStatus.FORBIDDEN, e.getStatusCode()); + return; + } + fail(); + } + + private String getUrl(String path) { + return "http://localhost:" + this.port + path; + } + + + @Configuration + @ComponentScan(resourcePattern = "**/CorsConfigurationIntegrationTests*.class") + @SuppressWarnings({"unused", "WeakerAccess"}) + static class WebConfig extends WebReactiveConfiguration { + + @Override + protected void addCorsMappings(CorsRegistry registry) { + registry.addMapping("/cors-restricted").allowedOrigins("http://foo"); + registry.addMapping("/cors"); + } + } + + @RestController + static class TestController { + + @GetMapping("/welcome") + public String welcome() { + return "welcome"; + } + + @GetMapping("/cors") + public String cors() { + return "cors"; + } + + @GetMapping("/cors-restricted") + public String corsRestricted() { + return "corsRestricted"; + } + + } + +} diff --git a/spring-web-reactive/src/test/java/org/springframework/web/reactive/result/method/annotation/CrossOriginAnnotationIntegrationTests.java b/spring-web-reactive/src/test/java/org/springframework/web/reactive/result/method/annotation/CrossOriginAnnotationIntegrationTests.java new file mode 100644 index 00000000000..1e5a9fdc0f1 --- /dev/null +++ b/spring-web-reactive/src/test/java/org/springframework/web/reactive/result/method/annotation/CrossOriginAnnotationIntegrationTests.java @@ -0,0 +1,345 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.result.method.annotation; + +import java.util.Properties; + +import static org.junit.Assert.*; +import static org.junit.Assert.assertArrayEquals; +import org.junit.Test; + +import org.springframework.context.ApplicationContext; +import org.springframework.context.annotation.AnnotationConfigApplicationContext; +import org.springframework.context.annotation.ComponentScan; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.support.PropertySourcesPlaceholderConfigurer; +import org.springframework.core.env.PropertiesPropertySource; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.ResponseEntity; +import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; +import org.springframework.web.bind.annotation.CrossOrigin; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RequestMethod; +import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.client.RestTemplate; +import org.springframework.web.reactive.config.WebReactiveConfiguration; + +/** + * @author Sebastien Deleuze + */ +public class CrossOriginAnnotationIntegrationTests extends AbstractRequestMappingIntegrationTests { + + // JDK default HTTP client blacklist headers like Origin + private RestTemplate restTemplate = new RestTemplate(new HttpComponentsClientHttpRequestFactory()); + + + @Override + protected ApplicationContext initApplicationContext() { + AnnotationConfigApplicationContext wac = new AnnotationConfigApplicationContext(); + wac.register(WebConfig.class); + Properties props = new Properties(); + props.setProperty("myOrigin", "http://site1.com"); + wac.getEnvironment().getPropertySources().addFirst(new PropertiesPropertySource("ps", props)); + wac.register(PropertySourcesPlaceholderConfigurer.class); + wac.refresh(); + return wac; + } + + @Override + RestTemplate getRestTemplate() { + return this.restTemplate; + } + + @Test + public void actualGetRequestWithoutAnnotation() { + HttpHeaders headers = new HttpHeaders(); + headers.add(HttpHeaders.ORIGIN, "http://site1.com"); + HttpEntity requestEntity = new HttpEntity(headers); + ResponseEntity entity = this.restTemplate.exchange(getUrl("/no"), + HttpMethod.GET, requestEntity, String.class); + assertEquals(HttpStatus.OK, entity.getStatusCode()); + assertNull(entity.getHeaders().getAccessControlAllowOrigin()); + assertEquals("no", entity.getBody()); + } + + @Test + public void actualPostRequestWithoutAnnotation() { + HttpHeaders headers = new HttpHeaders(); + headers.add(HttpHeaders.ORIGIN, "http://site1.com"); + HttpEntity requestEntity = new HttpEntity(headers); + ResponseEntity entity = this.restTemplate.exchange(getUrl("/no"), + HttpMethod.POST, requestEntity, String.class); + assertEquals(HttpStatus.OK, entity.getStatusCode()); + assertNull(entity.getHeaders().getAccessControlAllowOrigin()); + assertEquals("no-post", entity.getBody()); + } + + @Test + public void actualRequestWithDefaultAnnotation() { + HttpHeaders headers = new HttpHeaders(); + headers.add(HttpHeaders.ORIGIN, "http://site1.com"); + HttpEntity requestEntity = new HttpEntity(headers); + ResponseEntity entity = this.restTemplate.exchange(getUrl("/default"), + HttpMethod.GET, requestEntity, String.class); + assertEquals(HttpStatus.OK, entity.getStatusCode()); + assertEquals("http://site1.com", entity.getHeaders().getAccessControlAllowOrigin()); + assertEquals(true, entity.getHeaders().getAccessControlAllowCredentials()); + assertEquals("default", entity.getBody()); + } + + @Test + public void preflightRequestWithDefaultAnnotation() { + HttpHeaders headers = new HttpHeaders(); + headers.add(HttpHeaders.ORIGIN, "http://site1.com"); + headers.add(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); + HttpEntity requestEntity = new HttpEntity(headers); + ResponseEntity entity = this.restTemplate.exchange(getUrl("/default"), + HttpMethod.OPTIONS, requestEntity, Void.class); + assertEquals(HttpStatus.OK, entity.getStatusCode()); + assertEquals("http://site1.com", entity.getHeaders().getAccessControlAllowOrigin()); + assertEquals(1800, entity.getHeaders().getAccessControlMaxAge()); + assertEquals(true, entity.getHeaders().getAccessControlAllowCredentials()); + } + + @Test + public void actualRequestWithDefaultAnnotationAndNoOrigin() { + HttpHeaders headers = new HttpHeaders(); + HttpEntity requestEntity = new HttpEntity(headers); + ResponseEntity entity = this.restTemplate.exchange(getUrl("/default"), + HttpMethod.GET, requestEntity, String.class); + assertEquals(HttpStatus.OK, entity.getStatusCode()); + assertNull(entity.getHeaders().getAccessControlAllowOrigin()); + assertEquals("default", entity.getBody()); + } + + @Test + public void actualRequestWithCustomizedAnnotation() { + HttpHeaders headers = new HttpHeaders(); + headers.add(HttpHeaders.ORIGIN, "http://site1.com"); + HttpEntity requestEntity = new HttpEntity(headers); + ResponseEntity entity = this.restTemplate.exchange(getUrl("/customized"), + HttpMethod.GET, requestEntity, String.class); + assertEquals(HttpStatus.OK, entity.getStatusCode()); + assertEquals("http://site1.com", entity.getHeaders().getAccessControlAllowOrigin()); + assertEquals(false, entity.getHeaders().getAccessControlAllowCredentials()); + assertEquals(-1, entity.getHeaders().getAccessControlMaxAge()); + assertEquals("customized", entity.getBody()); + } + + @Test + public void preflightRequestWithCustomizedAnnotation() { + HttpHeaders headers = new HttpHeaders(); + headers.add(HttpHeaders.ORIGIN, "http://site1.com"); + headers.add(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); + headers.add(HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS, "header1, header2"); + HttpEntity requestEntity = new HttpEntity(headers); + ResponseEntity entity = this.restTemplate.exchange(getUrl("/customized"), + HttpMethod.OPTIONS, requestEntity, String.class); + assertEquals(HttpStatus.OK, entity.getStatusCode()); + assertEquals("http://site1.com", entity.getHeaders().getAccessControlAllowOrigin()); + assertArrayEquals(new HttpMethod[] {HttpMethod.GET}, entity.getHeaders().getAccessControlAllowMethods().toArray()); + assertEquals(false, entity.getHeaders().getAccessControlAllowCredentials()); + assertArrayEquals(new String[] {"header1", "header2"}, entity.getHeaders().getAccessControlAllowHeaders().toArray()); + assertArrayEquals(new String[] {"header3", "header4"}, entity.getHeaders().getAccessControlExposeHeaders().toArray()); + assertEquals(123, entity.getHeaders().getAccessControlMaxAge()); + } + + @Test + public void customOriginDefinedViaValueAttribute() { + HttpHeaders headers = new HttpHeaders(); + headers.add(HttpHeaders.ORIGIN, "http://site1.com"); + HttpEntity requestEntity = new HttpEntity(headers); + ResponseEntity entity = this.restTemplate.exchange(getUrl("/origin-value-attribute"), + HttpMethod.GET, requestEntity, String.class); + assertEquals(HttpStatus.OK, entity.getStatusCode()); + assertEquals("http://site1.com", entity.getHeaders().getAccessControlAllowOrigin()); + assertEquals("value-attribute", entity.getBody()); + } + + @Test + public void customOriginDefinedViaPlaceholder() { + HttpHeaders headers = new HttpHeaders(); + headers.add(HttpHeaders.ORIGIN, "http://site1.com"); + HttpEntity requestEntity = new HttpEntity(headers); + ResponseEntity entity = this.restTemplate.exchange(getUrl("/origin-placeholder"), + HttpMethod.GET, requestEntity, String.class); + assertEquals(HttpStatus.OK, entity.getStatusCode()); + assertEquals("http://site1.com", entity.getHeaders().getAccessControlAllowOrigin()); + assertEquals("placeholder", entity.getBody()); + } + + @Test + public void classLevel() { + HttpHeaders headers = new HttpHeaders(); + headers.add(HttpHeaders.ORIGIN, "http://site1.com"); + HttpEntity requestEntity = new HttpEntity(headers); + + ResponseEntity entity = this.restTemplate.exchange(getUrl("/foo"), + HttpMethod.GET, requestEntity, String.class); + assertEquals(HttpStatus.OK, entity.getStatusCode()); + assertEquals("*", entity.getHeaders().getAccessControlAllowOrigin()); + assertEquals(false, entity.getHeaders().getAccessControlAllowCredentials()); + assertEquals("foo", entity.getBody()); + + entity = this.restTemplate.exchange(getUrl("/bar"), HttpMethod.GET, requestEntity, String.class); + assertEquals(HttpStatus.OK, entity.getStatusCode()); + assertEquals("*", entity.getHeaders().getAccessControlAllowOrigin()); + assertEquals(false, entity.getHeaders().getAccessControlAllowCredentials()); + assertEquals("bar", entity.getBody()); + + entity = this.restTemplate.exchange(getUrl("/baz"), HttpMethod.GET, requestEntity, String.class); + assertEquals(HttpStatus.OK, entity.getStatusCode()); + assertEquals("http://site1.com", entity.getHeaders().getAccessControlAllowOrigin()); + assertEquals(true, entity.getHeaders().getAccessControlAllowCredentials()); + assertEquals("baz", entity.getBody()); + } + + @Test + public void ambiguousHeaderPreflightRequest() throws Exception { + HttpHeaders headers = new HttpHeaders(); + headers.add(HttpHeaders.ORIGIN, "http://site1.com"); + headers.add(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); + headers.add(HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS, "header1"); + HttpEntity requestEntity = new HttpEntity(headers); + ResponseEntity entity = this.restTemplate.exchange(getUrl("/ambiguous-header"), + HttpMethod.OPTIONS, requestEntity, String.class); + assertEquals(HttpStatus.OK, entity.getStatusCode()); + assertEquals("http://site1.com", entity.getHeaders().getAccessControlAllowOrigin()); + assertArrayEquals(new HttpMethod[] {HttpMethod.GET}, entity.getHeaders().getAccessControlAllowMethods().toArray()); + assertEquals(true, entity.getHeaders().getAccessControlAllowCredentials()); + assertArrayEquals(new String[] {"header1"}, entity.getHeaders().getAccessControlAllowHeaders().toArray()); + } + + @Test + public void ambiguousProducesPreflightRequest() throws Exception { + HttpHeaders headers = new HttpHeaders(); + headers.add(HttpHeaders.ORIGIN, "http://site1.com"); + headers.add(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); + HttpEntity requestEntity = new HttpEntity(headers); + ResponseEntity entity = this.restTemplate.exchange(getUrl("/ambiguous-produces"), + HttpMethod.OPTIONS, requestEntity, String.class); + assertEquals(HttpStatus.OK, entity.getStatusCode()); + assertEquals("http://site1.com", entity.getHeaders().getAccessControlAllowOrigin()); + assertArrayEquals(new HttpMethod[] {HttpMethod.GET}, entity.getHeaders().getAccessControlAllowMethods().toArray()); + assertEquals(true, entity.getHeaders().getAccessControlAllowCredentials()); + } + + private String getUrl(String path) { + return "http://localhost:" + this.port + path; + } + + + @Configuration + @ComponentScan(resourcePattern = "**/CrossOriginAnnotationIntegrationTests*") + @SuppressWarnings({"unused", "WeakerAccess"}) + static class WebConfig extends WebReactiveConfiguration { + + } + + @RestController + private static class MethodLevelController { + + @RequestMapping(path = "/no", method = RequestMethod.GET) + public String noAnnotation() { + return "no"; + } + + @RequestMapping(path = "/no", method = RequestMethod.POST) + public String noAnnotationPost() { + return "no-post"; + } + + @CrossOrigin + @RequestMapping(path = "/default", method = RequestMethod.GET) + public String defaultAnnotation() { + return "default"; + } + + @CrossOrigin + @RequestMapping(path = "/default", method = RequestMethod.GET, params = "q") + public void defaultAnnotationWithParams() { + } + + @CrossOrigin + @RequestMapping(path = "/ambiguous-header", method = RequestMethod.GET, headers = "header1=a") + public void ambigousHeader1a() { + } + + @CrossOrigin + @RequestMapping(path = "/ambiguous-header", method = RequestMethod.GET, headers = "header1=b") + public void ambigousHeader1b() { + } + + @CrossOrigin + @RequestMapping(path = "/ambiguous-produces", method = RequestMethod.GET, produces = "application/xml") + public String ambigousProducesXml() { + return ""; + } + + @CrossOrigin + @RequestMapping(path = "/ambiguous-produces", method = RequestMethod.GET, produces = "application/json") + public String ambigousProducesJson() { + return "{}"; + } + + @CrossOrigin(origins = { "http://site1.com", "http://site2.com" }, allowedHeaders = { "header1", "header2" }, + exposedHeaders = { "header3", "header4" }, methods = RequestMethod.GET, maxAge = 123, allowCredentials = "false") + @RequestMapping(path = "/customized", method = { RequestMethod.GET, RequestMethod.POST }) + public String customized() { + return "customized"; + } + + @CrossOrigin("http://site1.com") + @RequestMapping("/origin-value-attribute") + public String customOriginDefinedViaValueAttribute() { + return "value-attribute"; + } + + @CrossOrigin("${myOrigin}") + @RequestMapping("/origin-placeholder") + public String customOriginDefinedViaPlaceholder() { + return "placeholder"; + } + } + + @RestController + @CrossOrigin(allowCredentials = "false") + private static class ClassLevelController { + + @RequestMapping(path = "/foo", method = RequestMethod.GET) + public String foo() { + return "foo"; + } + + @CrossOrigin + @RequestMapping(path = "/bar", method = RequestMethod.GET) + public String bar() { + return "bar"; + } + + @CrossOrigin(allowCredentials = "true") + @RequestMapping(path = "/baz", method = RequestMethod.GET) + public String baz() { + return "baz"; + } + + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsConfigurationSource.java b/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsConfigurationSource.java new file mode 100644 index 00000000000..c0fc6e8f92f --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsConfigurationSource.java @@ -0,0 +1,37 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.cors.reactive; + +import org.springframework.web.cors.CorsConfiguration; +import org.springframework.web.server.ServerWebExchange; + +/** + * Interface to be implemented by classes (usually HTTP request handlers) that + * provides a {@link CorsConfiguration} instance based on the provided reactive request. + * + * @author Sebastien Deleuze + * @since 5.0 + */ +public interface CorsConfigurationSource { + + /** + * Return a {@link CorsConfiguration} based on the incoming request. + * @return the associated {@link CorsConfiguration}, or {@code null} if none + */ + CorsConfiguration getCorsConfiguration(ServerWebExchange exchange); + +} diff --git a/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsProcessor.java b/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsProcessor.java new file mode 100644 index 00000000000..b516b77893e --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsProcessor.java @@ -0,0 +1,48 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.cors.reactive; + +import reactor.core.publisher.Mono; + +import org.springframework.web.cors.CorsConfiguration; +import org.springframework.web.server.ServerWebExchange; + +/** + * A strategy that takes a reactive request and a {@link CorsConfiguration} and updates + * the response. + * + *

This component is not concerned with how a {@code CorsConfiguration} is + * selected but rather takes follow-up actions such as applying CORS validation + * checks and either rejecting the response or adding CORS headers to the + * response. + * + * @author Sebastien Deleuze + * @author Rossen Stoyanchev + * @since 5.0 + * @see CORS W3C recommandation + */ +public interface CorsProcessor { + + /** + * Process a request given a {@code CorsConfiguration}. + * @param configuration the applicable CORS configuration (possibly {@code null}) + * @param exchange the current HTTP request / response + * @return a {@link Mono} emitting {@code false} if the request is rejected, {@code true} otherwise + */ + boolean processRequest(CorsConfiguration configuration, ServerWebExchange exchange); + +} diff --git a/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsUtils.java b/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsUtils.java new file mode 100644 index 00000000000..4431c40e0c6 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsUtils.java @@ -0,0 +1,86 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.cors.reactive; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.util.Assert; +import org.springframework.web.util.UriComponents; +import org.springframework.web.util.UriComponentsBuilder; + +; + +/** + * Utility class for CORS reactive request handling based on the + * CORS W3C recommendation. + * + * @author Sebastien Deleuze + * @since 5.0 + */ +public abstract class CorsUtils { + + /** + * Returns {@code true} if the request is a valid CORS one. + */ + public static boolean isCorsRequest(ServerHttpRequest request) { + return (request.getHeaders().get(HttpHeaders.ORIGIN) != null); + } + + /** + * Returns {@code true} if the request is a valid CORS pre-flight one. + */ + public static boolean isPreFlightRequest(ServerHttpRequest request) { + return (isCorsRequest(request) && HttpMethod.OPTIONS == request.getMethod() && + request.getHeaders().get(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD) != null); + } + + /** + * Check if the request is a same-origin one, based on {@code Origin}, {@code Host}, + * {@code Forwarded} and {@code X-Forwarded-Host} headers. + * @return {@code true} if the request is a same-origin one, {@code false} in case + * of cross-origin request. + */ + public static boolean isSameOrigin(ServerHttpRequest request) { + String origin = request.getHeaders().getOrigin(); + if (origin == null) { + return true; + } + UriComponentsBuilder urlBuilder = UriComponentsBuilder.fromHttpRequest(request); + UriComponents actualUrl = urlBuilder.build(); + String actualHost = actualUrl.getHost(); + int actualPort = getPort(actualUrl); + Assert.notNull(actualHost, "Actual request host must not be null"); + Assert.isTrue(actualPort != -1, "Actual request port must not be undefined"); + UriComponents originUrl = UriComponentsBuilder.fromOriginHeader(origin).build(); + return (actualHost.equals(originUrl.getHost()) && actualPort == getPort(originUrl)); + } + + private static int getPort(UriComponents uri) { + int port = uri.getPort(); + if (port == -1) { + if ("http".equals(uri.getScheme()) || "ws".equals(uri.getScheme())) { + port = 80; + } + else if ("https".equals(uri.getScheme()) || "wss".equals(uri.getScheme())) { + port = 443; + } + } + return port; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/cors/reactive/DefaultCorsProcessor.java b/spring-web/src/main/java/org/springframework/web/cors/reactive/DefaultCorsProcessor.java new file mode 100644 index 00000000000..f2b7c57858f --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/cors/reactive/DefaultCorsProcessor.java @@ -0,0 +1,187 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.cors.reactive; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.http.server.reactive.ServerHttpResponse; +import org.springframework.util.CollectionUtils; +import org.springframework.web.cors.CorsConfiguration; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.util.WebUtils; + +/** + * The default implementation of {@link CorsProcessor}, + * as defined by the CORS W3C recommendation. + * + *

Note that when input {@link CorsConfiguration} is {@code null}, this + * implementation does not reject simple or actual requests outright but simply + * avoid adding CORS headers to the response. CORS processing is also skipped + * if the response already contains CORS headers, or if the request is detected + * as a same-origin one. + * + * @author Sebastien Deleuze + * @author Rossen Stoyanchev + * @since 5.0 + */ +public class DefaultCorsProcessor implements CorsProcessor { + + private static final Log logger = LogFactory.getLog(DefaultCorsProcessor.class); + + + @Override + @SuppressWarnings("resource") + public boolean processRequest(CorsConfiguration config, ServerWebExchange exchange) { + + ServerHttpRequest request = exchange.getRequest(); + ServerHttpResponse response = exchange.getResponse(); + + if (!CorsUtils.isCorsRequest(request)) { + return true; + } + + if (responseHasCors(response)) { + logger.debug("Skip CORS processing: response already contains \"Access-Control-Allow-Origin\" header"); + return true; + } + + if (CorsUtils.isSameOrigin(request)) { + logger.debug("Skip CORS processing: request is from same origin"); + return true; + } + + boolean preFlightRequest = CorsUtils.isPreFlightRequest(request); + if (config == null) { + if (preFlightRequest) { + rejectRequest(response); + return false; + } + else { + return true; + } + } + + return handleInternal(exchange, config, preFlightRequest); + } + + private boolean responseHasCors(ServerHttpResponse response) { + return (response.getHeaders().getAccessControlAllowOrigin() != null); + } + + /** + * Invoked when one of the CORS checks failed. + */ + protected void rejectRequest(ServerHttpResponse response) { + response.setStatusCode(HttpStatus.FORBIDDEN); + logger.debug("Invalid CORS request"); + } + + /** + * Handle the given request. + */ + protected boolean handleInternal(ServerWebExchange exchange, + CorsConfiguration config, boolean preFlightRequest) { + + ServerHttpRequest request = exchange.getRequest(); + ServerHttpResponse response = exchange.getResponse(); + + String requestOrigin = request.getHeaders().getOrigin(); + String allowOrigin = checkOrigin(config, requestOrigin); + + HttpMethod requestMethod = getMethodToUse(request, preFlightRequest); + List allowMethods = checkMethods(config, requestMethod); + + List requestHeaders = getHeadersToUse(request, preFlightRequest); + List allowHeaders = checkHeaders(config, requestHeaders); + + if (allowOrigin == null || allowMethods == null || (preFlightRequest && allowHeaders == null)) { + rejectRequest(response); + return false; + } + + HttpHeaders responseHeaders = response.getHeaders(); + responseHeaders.setAccessControlAllowOrigin(allowOrigin); + responseHeaders.add(HttpHeaders.VARY, HttpHeaders.ORIGIN); + + if (preFlightRequest) { + responseHeaders.setAccessControlAllowMethods(allowMethods); + } + + if (preFlightRequest && !allowHeaders.isEmpty()) { + responseHeaders.setAccessControlAllowHeaders(allowHeaders); + } + + if (!CollectionUtils.isEmpty(config.getExposedHeaders())) { + responseHeaders.setAccessControlExposeHeaders(config.getExposedHeaders()); + } + + if (Boolean.TRUE.equals(config.getAllowCredentials())) { + responseHeaders.setAccessControlAllowCredentials(true); + } + + if (preFlightRequest && config.getMaxAge() != null) { + responseHeaders.setAccessControlMaxAge(config.getMaxAge()); + } + + return true; + } + + /** + * Check the origin and determine the origin for the response. The default + * implementation simply delegates to + * {@link CorsConfiguration#checkOrigin(String)}. + */ + protected String checkOrigin(CorsConfiguration config, String requestOrigin) { + return config.checkOrigin(requestOrigin); + } + + /** + * Check the HTTP method and determine the methods for the response of a + * pre-flight request. The default implementation simply delegates to + * {@link CorsConfiguration#checkOrigin(String)}. + */ + protected List checkMethods(CorsConfiguration config, HttpMethod requestMethod) { + return config.checkHttpMethod(requestMethod); + } + + private HttpMethod getMethodToUse(ServerHttpRequest request, boolean isPreFlight) { + return (isPreFlight ? request.getHeaders().getAccessControlRequestMethod() : request.getMethod()); + } + + /** + * Check the headers and determine the headers for the response of a + * pre-flight request. The default implementation simply delegates to + * {@link CorsConfiguration#checkOrigin(String)}. + */ + protected List checkHeaders(CorsConfiguration config, List requestHeaders) { + return config.checkHeaders(requestHeaders); + } + + private List getHeadersToUse(ServerHttpRequest request, boolean isPreFlight) { + HttpHeaders headers = request.getHeaders(); + return (isPreFlight ? headers.getAccessControlRequestHeaders() : new ArrayList<>(headers.keySet())); + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/cors/reactive/UrlBasedCorsConfigurationSource.java b/spring-web/src/main/java/org/springframework/web/cors/reactive/UrlBasedCorsConfigurationSource.java new file mode 100644 index 00000000000..c6b78d77e8e --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/cors/reactive/UrlBasedCorsConfigurationSource.java @@ -0,0 +1,115 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.cors.reactive; + +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.Map; + +import org.springframework.util.AntPathMatcher; +import org.springframework.util.Assert; +import org.springframework.util.PathMatcher; +import org.springframework.web.cors.CorsConfiguration; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.util.HttpRequestPathHelper; + +/** + * Provide a per reactive request {@link CorsConfiguration} instance based on a + * collection of {@link CorsConfiguration} mapped on path patterns. + * + *

Exact path mapping URIs (such as {@code "/admin"}) are supported + * as well as Ant-style path patterns (such as {@code "/admin/**"}). + * + * @author Sebastien Deleuze + * @since 5.0 + */ +public class UrlBasedCorsConfigurationSource implements CorsConfigurationSource { + + private final Map corsConfigurations = new LinkedHashMap<>(); + + private PathMatcher pathMatcher = new AntPathMatcher(); + + private HttpRequestPathHelper pathHelper = new HttpRequestPathHelper(); + + + /** + * Set the PathMatcher implementation to use for matching URL paths + * against registered URL patterns. Default is AntPathMatcher. + * @see AntPathMatcher + */ + public void setPathMatcher(PathMatcher pathMatcher) { + Assert.notNull(pathMatcher, "PathMatcher must not be null"); + this.pathMatcher = pathMatcher; + } + + /** + * Set if context path and request URI should be URL-decoded. Both are returned + * undecoded by the Servlet API, in contrast to the servlet path. + *

Uses either the request encoding or the default encoding according + * to the Servlet spec (ISO-8859-1). + * @see HttpRequestPathHelper#setUrlDecode + */ + public void setUrlDecode(boolean urlDecode) { + this.pathHelper.setUrlDecode(urlDecode); + } + + /** + * Set the UrlPathHelper to use for resolution of lookup paths. + *

Use this to override the default UrlPathHelper with a custom subclass. + */ + public void setHttpRequestPathHelper(HttpRequestPathHelper pathHelper) { + Assert.notNull(pathHelper, "HttpRequestPathHelper must not be null"); + this.pathHelper = pathHelper; + } + + /** + * Set CORS configuration based on URL patterns. + */ + public void setCorsConfigurations(Map corsConfigurations) { + this.corsConfigurations.clear(); + if (corsConfigurations != null) { + this.corsConfigurations.putAll(corsConfigurations); + } + } + + /** + * Get the CORS configuration. + */ + public Map getCorsConfigurations() { + return Collections.unmodifiableMap(this.corsConfigurations); + } + + /** + * Register a {@link CorsConfiguration} for the specified path pattern. + */ + public void registerCorsConfiguration(String path, CorsConfiguration config) { + this.corsConfigurations.put(path, config); + } + + + @Override + public CorsConfiguration getCorsConfiguration(ServerWebExchange exchange) { + String lookupPath = this.pathHelper.getLookupPathForRequest(exchange); + for (Map.Entry entry : this.corsConfigurations.entrySet()) { + if (this.pathMatcher.match(entry.getKey(), lookupPath)) { + return entry.getValue(); + } + } + return null; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/util/UriComponentsBuilder.java b/spring-web/src/main/java/org/springframework/web/util/UriComponentsBuilder.java index 98596580adf..09495d39777 100644 --- a/spring-web/src/main/java/org/springframework/web/util/UriComponentsBuilder.java +++ b/spring-web/src/main/java/org/springframework/web/util/UriComponentsBuilder.java @@ -710,8 +710,8 @@ public class UriComponentsBuilder implements Cloneable { } } - if ((this.scheme.equals("http") && "80".equals(this.port)) || - (this.scheme.equals("https") && "443".equals(this.port))) { + if ((this.scheme != null) && ((this.scheme.equals("http") && "80".equals(this.port)) || + (this.scheme.equals("https") && "443".equals(this.port)))) { this.port = null; } diff --git a/spring-web/src/test/java/org/springframework/web/cors/reactive/CorsUtilsTests.java b/spring-web/src/test/java/org/springframework/web/cors/reactive/CorsUtilsTests.java new file mode 100644 index 00000000000..a983350ed06 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/cors/reactive/CorsUtilsTests.java @@ -0,0 +1,73 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.cors.reactive; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import org.junit.Test; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; +import org.springframework.web.cors.reactive.CorsUtils; + +/** + * Test case for reactive {@link CorsUtils}. + * + * @author Sebastien Deleuze + */ +public class CorsUtilsTests { + + @Test + public void isCorsRequest() { + MockServerHttpRequest request = new MockServerHttpRequest(); + request.addHeader(HttpHeaders.ORIGIN, "http://domain.com"); + assertTrue(CorsUtils.isCorsRequest(request)); + } + + @Test + public void isNotCorsRequest() { + MockServerHttpRequest request = new MockServerHttpRequest(); + assertFalse(CorsUtils.isCorsRequest(request)); + } + + @Test + public void isPreFlightRequest() { + MockServerHttpRequest request = new MockServerHttpRequest(); + request.setHttpMethod(HttpMethod.OPTIONS); + request.addHeader(HttpHeaders.ORIGIN, "http://domain.com"); + request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); + assertTrue(CorsUtils.isPreFlightRequest(request)); + } + + @Test + public void isNotPreFlightRequest() { + MockServerHttpRequest request = new MockServerHttpRequest(); + assertFalse(CorsUtils.isPreFlightRequest(request)); + + request = new MockServerHttpRequest(); + request.setHttpMethod(HttpMethod.OPTIONS); + request.addHeader(HttpHeaders.ORIGIN, "http://domain.com"); + assertFalse(CorsUtils.isPreFlightRequest(request)); + + request = new MockServerHttpRequest(); + request.setHttpMethod(HttpMethod.OPTIONS); + request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); + assertFalse(CorsUtils.isPreFlightRequest(request)); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/cors/reactive/DefaultCorsProcessorTests.java b/spring-web/src/test/java/org/springframework/web/cors/reactive/DefaultCorsProcessorTests.java new file mode 100644 index 00000000000..a2cb70958d4 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/cors/reactive/DefaultCorsProcessorTests.java @@ -0,0 +1,351 @@ +/* + * Copyright 2002-2016 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.cors.reactive; + +import static org.junit.Assert.*; +import org.junit.Before; +import org.junit.Test; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; +import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse; +import org.springframework.web.cors.CorsConfiguration; +import org.springframework.web.cors.reactive.DefaultCorsProcessor; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.adapter.DefaultServerWebExchange; +import org.springframework.web.server.session.MockWebSessionManager; + +/** + * Test reactive {@link DefaultCorsProcessor} with simple or preflight CORS request. + * + * @author Sebastien Deleuze + * @author Rossen Stoyanchev + * @author Juergen Hoeller + */ +public class DefaultCorsProcessorTests { + + private MockServerHttpRequest request; + + private MockServerHttpResponse response; + + private ServerWebExchange exchange; + + private DefaultCorsProcessor processor; + + private CorsConfiguration conf; + + + @Before + public void setup() { + this.request = new MockServerHttpRequest(); + this.request.setUri("http://localhost/test.html"); + this.conf = new CorsConfiguration(); + this.response = new MockServerHttpResponse(); + this.response.setStatusCode(HttpStatus.OK); + this.processor = new DefaultCorsProcessor(); + this.exchange = new DefaultServerWebExchange(this.request, this.response, new MockWebSessionManager()); + } + + + @Test + public void actualRequestWithOriginHeader() throws Exception { + this.request.setHttpMethod(HttpMethod.GET); + this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com"); + + this.processor.processRequest(this.conf, this.exchange); + assertFalse(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertEquals(HttpStatus.FORBIDDEN, this.response.getStatusCode()); + } + + @Test + public void actualRequestWithOriginHeaderAndNullConfig() throws Exception { + this.request.setHttpMethod(HttpMethod.GET); + this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com"); + + this.processor.processRequest(null, this.exchange); + assertFalse(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertEquals(HttpStatus.OK, this.response.getStatusCode()); + } + + @Test + public void actualRequestWithOriginHeaderAndAllowedOrigin() throws Exception { + this.request.setHttpMethod(HttpMethod.GET); + this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com"); + this.conf.addAllowedOrigin("*"); + + this.processor.processRequest(this.conf, this.exchange); + assertTrue(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertEquals("*", this.response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertFalse(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_MAX_AGE)); + assertFalse(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS)); + assertEquals(HttpStatus.OK, this.response.getStatusCode()); + } + + @Test + public void actualRequestCredentials() throws Exception { + this.request.setHttpMethod(HttpMethod.GET); + this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com"); + this.conf.addAllowedOrigin("http://domain1.com"); + this.conf.addAllowedOrigin("http://domain2.com"); + this.conf.addAllowedOrigin("http://domain3.com"); + this.conf.setAllowCredentials(true); + + this.processor.processRequest(this.conf, this.exchange); + assertTrue(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertEquals("http://domain2.com", this.response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertTrue(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); + assertEquals("true", this.response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); + assertEquals(HttpStatus.OK, this.response.getStatusCode()); + } + + @Test + public void actualRequestCredentialsWithOriginWildcard() throws Exception { + this.request.setHttpMethod(HttpMethod.GET); + this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com"); + this.conf.addAllowedOrigin("*"); + this.conf.setAllowCredentials(true); + + this.processor.processRequest(this.conf, this.exchange); + assertTrue(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertEquals("http://domain2.com", this.response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertTrue(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); + assertEquals("true", this.response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); + assertEquals(HttpStatus.OK, this.response.getStatusCode()); + } + + @Test + public void actualRequestCaseInsensitiveOriginMatch() throws Exception { + this.request.setHttpMethod(HttpMethod.GET); + this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com"); + this.conf.addAllowedOrigin("http://DOMAIN2.com"); + + this.processor.processRequest(this.conf, this.exchange); + assertTrue(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertEquals(HttpStatus.OK, this.response.getStatusCode()); + } + + @Test + public void actualRequestExposedHeaders() throws Exception { + this.request.setHttpMethod(HttpMethod.GET); + this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com"); + this.conf.addExposedHeader("header1"); + this.conf.addExposedHeader("header2"); + this.conf.addAllowedOrigin("http://domain2.com"); + + this.processor.processRequest(this.conf, this.exchange); + assertTrue(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertEquals("http://domain2.com", this.response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertTrue(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS)); + assertTrue(this.response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS).contains("header1")); + assertTrue(this.response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS).contains("header2")); + assertEquals(HttpStatus.OK, this.response.getStatusCode()); + } + + @Test + public void preflightRequestAllOriginsAllowed() throws Exception { + this.request.setHttpMethod(HttpMethod.OPTIONS); + this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com"); + this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); + this.conf.addAllowedOrigin("*"); + + this.processor.processRequest(this.conf, this.exchange); + assertEquals(HttpStatus.OK, this.response.getStatusCode()); + } + + @Test + public void preflightRequestWrongAllowedMethod() throws Exception { + this.request.setHttpMethod(HttpMethod.OPTIONS); + this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com"); + this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "DELETE"); + this.conf.addAllowedOrigin("*"); + + this.processor.processRequest(this.conf, this.exchange); + assertEquals(HttpStatus.FORBIDDEN, this.response.getStatusCode()); + } + + @Test + public void preflightRequestMatchedAllowedMethod() throws Exception { + this.request.setHttpMethod(HttpMethod.OPTIONS); + this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com"); + this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); + this.conf.addAllowedOrigin("*"); + + this.processor.processRequest(this.conf, this.exchange); + assertEquals(HttpStatus.OK, this.response.getStatusCode()); + assertEquals("GET,HEAD", this.response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS)); + } + + @Test + public void preflightRequestTestWithOriginButWithoutOtherHeaders() throws Exception { + this.request.setHttpMethod(HttpMethod.OPTIONS); + this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com"); + + this.processor.processRequest(this.conf, this.exchange); + assertFalse(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertEquals(HttpStatus.FORBIDDEN, this.response.getStatusCode()); + } + + @Test + public void preflightRequestWithoutRequestMethod() throws Exception { + this.request.setHttpMethod(HttpMethod.OPTIONS); + this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com"); + this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS, "Header1"); + + this.processor.processRequest(this.conf, this.exchange); + assertFalse(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertEquals(HttpStatus.FORBIDDEN, this.response.getStatusCode()); + } + + @Test + public void preflightRequestWithRequestAndMethodHeaderButNoConfig() throws Exception { + this.request.setHttpMethod(HttpMethod.OPTIONS); + this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com"); + this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); + this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS, "Header1"); + + this.processor.processRequest(this.conf, this.exchange); + assertFalse(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertEquals(HttpStatus.FORBIDDEN, this.response.getStatusCode()); + } + + @Test + public void preflightRequestValidRequestAndConfig() throws Exception { + this.request.setHttpMethod(HttpMethod.OPTIONS); + this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com"); + this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); + this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS, "Header1"); + this.conf.addAllowedOrigin("*"); + this.conf.addAllowedMethod("GET"); + this.conf.addAllowedMethod("PUT"); + this.conf.addAllowedHeader("header1"); + this.conf.addAllowedHeader("header2"); + + this.processor.processRequest(this.conf, this.exchange); + assertTrue(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertEquals("*", this.response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertTrue(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS)); + assertEquals("GET,PUT", this.response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS)); + assertFalse(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_MAX_AGE)); + assertEquals(HttpStatus.OK, this.response.getStatusCode()); + } + + @Test + public void preflightRequestCredentials() throws Exception { + this.request.setHttpMethod(HttpMethod.OPTIONS); + this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com"); + this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); + this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS, "Header1"); + this.conf.addAllowedOrigin("http://domain1.com"); + this.conf.addAllowedOrigin("http://domain2.com"); + this.conf.addAllowedOrigin("http://domain3.com"); + this.conf.addAllowedHeader("Header1"); + this.conf.setAllowCredentials(true); + + this.processor.processRequest(this.conf, this.exchange); + assertTrue(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertEquals("http://domain2.com", this.response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertTrue(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); + assertEquals("true", this.response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS)); + assertEquals(HttpStatus.OK, this.response.getStatusCode()); + } + + @Test + public void preflightRequestCredentialsWithOriginWildcard() throws Exception { + this.request.setHttpMethod(HttpMethod.OPTIONS); + this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com"); + this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); + this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS, "Header1"); + this.conf.addAllowedOrigin("http://domain1.com"); + this.conf.addAllowedOrigin("*"); + this.conf.addAllowedOrigin("http://domain3.com"); + this.conf.addAllowedHeader("Header1"); + this.conf.setAllowCredentials(true); + + this.processor.processRequest(this.conf, this.exchange); + assertTrue(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertEquals("http://domain2.com", this.response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertEquals(HttpStatus.OK, this.response.getStatusCode()); + } + + @Test + public void preflightRequestAllowedHeaders() throws Exception { + this.request.setHttpMethod(HttpMethod.OPTIONS); + this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com"); + this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); + this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS, "Header1, Header2"); + this.conf.addAllowedHeader("Header1"); + this.conf.addAllowedHeader("Header2"); + this.conf.addAllowedHeader("Header3"); + this.conf.addAllowedOrigin("http://domain2.com"); + + this.processor.processRequest(this.conf, this.exchange); + assertTrue(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertTrue(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS)); + assertTrue(this.response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header1")); + assertTrue(this.response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header2")); + assertFalse(this.response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header3")); + assertEquals(HttpStatus.OK, this.response.getStatusCode()); + } + + @Test + public void preflightRequestAllowsAllHeaders() throws Exception { + this.request.setHttpMethod(HttpMethod.OPTIONS); + this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com"); + this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); + this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS, "Header1, Header2"); + this.conf.addAllowedHeader("*"); + this.conf.addAllowedOrigin("http://domain2.com"); + + this.processor.processRequest(this.conf, this.exchange); + assertTrue(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertTrue(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS)); + assertTrue(this.response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header1")); + assertTrue(this.response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header2")); + assertFalse(this.response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("*")); + assertEquals(HttpStatus.OK, this.response.getStatusCode()); + } + + @Test + public void preflightRequestWithEmptyHeaders() throws Exception { + this.request.setHttpMethod(HttpMethod.OPTIONS); + this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com"); + this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); + this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS, ""); + this.conf.addAllowedHeader("*"); + this.conf.addAllowedOrigin("http://domain2.com"); + + this.processor.processRequest(this.conf, this.exchange); + assertTrue(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertFalse(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS)); + assertEquals(HttpStatus.OK, this.response.getStatusCode()); + } + + @Test + public void preflightRequestWithNullConfig() throws Exception { + this.request.setHttpMethod(HttpMethod.OPTIONS); + this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com"); + this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); + this.conf.addAllowedOrigin("*"); + + this.processor.processRequest(null, this.exchange); + assertFalse(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertEquals(HttpStatus.FORBIDDEN, this.response.getStatusCode()); + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/cors/reactive/UrlBasedCorsConfigurationSourceTests.java b/spring-web/src/test/java/org/springframework/web/cors/reactive/UrlBasedCorsConfigurationSourceTests.java new file mode 100644 index 00000000000..f0be0bfaded --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/cors/reactive/UrlBasedCorsConfigurationSourceTests.java @@ -0,0 +1,68 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.cors.reactive; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import org.junit.Test; + +import org.springframework.http.HttpMethod; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; +import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse; +import org.springframework.web.cors.CorsConfiguration; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.adapter.DefaultServerWebExchange; +import org.springframework.web.server.session.MockWebSessionManager; + +/** + * Unit tests for reactive {@link UrlBasedCorsConfigurationSource}. + * @author Sebastien Deleuze + */ +public class UrlBasedCorsConfigurationSourceTests { + + private final UrlBasedCorsConfigurationSource configSource = new UrlBasedCorsConfigurationSource(); + + @Test + public void empty() { + ServerHttpRequest request = new MockServerHttpRequest(HttpMethod.GET, "/bar/test.html"); + ServerWebExchange exchange = new DefaultServerWebExchange(request, + new MockServerHttpResponse(), new MockWebSessionManager()); + assertNull(this.configSource.getCorsConfiguration(exchange)); + } + + @Test + public void registerAndMatch() { + CorsConfiguration config = new CorsConfiguration(); + this.configSource.registerCorsConfiguration("/bar/**", config); + assertNull(this.configSource.getCorsConfiguration( + new DefaultServerWebExchange( + new MockServerHttpRequest(HttpMethod.GET, "/foo/test.html"), + new MockServerHttpResponse(), + new MockWebSessionManager()))); + assertEquals(config, this.configSource.getCorsConfiguration(new DefaultServerWebExchange( + new MockServerHttpRequest(HttpMethod.GET, "/bar/test.html"), + new MockServerHttpResponse(), + new MockWebSessionManager()))); + } + + @Test(expected = UnsupportedOperationException.class) + public void unmodifiableConfigurationsMap() { + this.configSource.getCorsConfigurations().put("/**", new CorsConfiguration()); + } + +}