Add CORS support

This commit introduces support for CORS in Spring Framework.

Cross-origin resource sharing (CORS) is a mechanism that allows
many resources (e.g. fonts, JavaScript, etc.) on a web page to
be requested from another domain outside the domain from which
the resource originated. It is defined by the CORS W3C
recommandation (http://www.w3.org/TR/cors/).

A new annotation @CrossOrigin allows to enable CORS support
on Controller type or method level. By default all origins
("*") are allowed.

@RestController
public class SampleController {

	@CrossOrigin
	@RequestMapping("/foo")
	public String foo() {
		// ...
	}
}

Various @CrossOrigin attributes allow to customize the CORS configuration.

@RestController
public class SampleController {

	@CrossOrigin(origin = { "http://site1.com", "http://site2.com" },
				 allowedHeaders = { "header1", "header2" },
				 exposedHeaders = { "header1", "header2" },
				 method = RequestMethod.DELETE,
				 maxAge = 123, allowCredentials = "true")
	@RequestMapping(value = "/foo", method = { RequestMethod.GET, RequestMethod.POST} )
	public String foo() {
		// ...
	}
}

A CorsConfigurationSource interface can be implemented by HTTP request
handlers that want to support CORS by providing a CorsConfiguration
that will be detected at AbstractHandlerMapping level. See for
example ResourceHttpRequestHandler that implements this interface.

Global CORS configuration should be supported through ControllerAdvice
(with type level @CrossOrigin annotated class or class implementing
CorsConfigurationSource), or with XML namespace and JavaConfig
configuration, but this is not implemented yet.

Issue: SPR-9278
This commit is contained in:
Sebastien Deleuze 2015-04-02 15:46:30 +02:00
parent 35f40ae654
commit b0e1e66b7f
24 changed files with 1942 additions and 196 deletions

View File

@ -0,0 +1,82 @@
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.bind.annotation;
import java.lang.annotation.Documented;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* Marks the annotated method as permitting cross origin requests.
* By default, all origins and headers are permitted.
*
* @since 4.2
* @author Russell Allen
* @author Sebastien Deleuze
*/
@Target({ElementType.METHOD, ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface CrossOrigin {
/**
* List of allowed origins. {@code "*"} means that all origins are allowed. These values
* are placed in the {@code Access-Control-Allow-Origin } header of both the pre-flight
* and actual responses. Default value is <b>"*"</b>.
*/
String[] origin() default {"*"};
/**
* Indicates which request headers can be used during the actual request. {@code "*"} means
* that all headers asked by the client are allowed. This property controls the value of
* pre-flight response's {@code Access-Control-Allow-Headers} header. Default value is
* <b>"*"</b>.
*/
String[] allowedHeaders() default {"*"};
/**
* List of response headers that the user-agent will allow the client to access. This property
* controls the value of actual response's {@code Access-Control-Expose-Headers} header.
*/
String[] exposedHeaders() default {};
/**
* The HTTP request methods to allow: GET, POST, HEAD, OPTIONS, PUT, PATCH, DELETE, TRACE.
* Methods specified here overrides {@code RequestMapping} ones.
*/
RequestMethod[] method() default {};
/**
* Set to {@code "true"} if the the browser should include any cookies associated to the domain
* of the request being annotated, or "false" if it should not. Empty string "" means undefined.
* If true, the pre-flight response will include the header
* {@code Access-Control-Allow-Credentials=true}. Default value is <b>"true"</b>.
*/
String allowCredentials() default "true";
/**
* Controls the cache duration for pre-flight responses. Setting this to a reasonable
* value can reduce the number of pre-flight request/response interaction required by
* the browser. This property controls the value of the {@code Access-Control-Max-Age header}
* in the pre-flight response. Value set to -1 means undefined. Default value is
* <b>1800</b> seconds, or 30 minutes.
*/
long maxAge() default 1800;
}

View File

@ -0,0 +1,229 @@
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.cors;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
/**
* Represents the CORS configuration that stores various properties used to check if a
* CORS request is allowed and to generate CORS response headers.
*
* @author Sebastien Deleuze
* @author Rossen Stoyanchev
* @since 4.2
* @see <a href="http://www.w3.org/TR/cors/">CORS W3C recommandation</a>
*/
public class CorsConfiguration {
private List<String> allowedOrigins;
private List<String> allowedMethods;
private List<String> allowedHeaders;
private List<String> exposedHeaders;
private Boolean allowCredentials;
private Long maxAge;
public CorsConfiguration() {
}
public CorsConfiguration(CorsConfiguration config) {
if (config.allowedOrigins != null) {
this.allowedOrigins = new ArrayList<String>(config.allowedOrigins);
}
if (config.allowCredentials != null) {
this.allowCredentials = new Boolean(config.allowCredentials);
}
if (config.exposedHeaders != null) {
this.exposedHeaders = new ArrayList<String>(config.exposedHeaders);
}
if (config.allowedMethods != null) {
this.allowedMethods = new ArrayList<String>(config.allowedMethods);
}
if (config.allowedHeaders != null) {
this.allowedHeaders = new ArrayList<String>(config.allowedHeaders);
}
if (config.maxAge != null) {
this.maxAge = new Long(config.maxAge);
}
}
public CorsConfiguration combine(CorsConfiguration other) {
CorsConfiguration config = new CorsConfiguration(this);
if (other.getAllowedOrigins() != null) {
config.setAllowedOrigins(other.getAllowedOrigins());
}
if (other.getAllowedMethods() != null) {
config.setAllowedMethods(other.getAllowedMethods());
}
if (other.getAllowedHeaders() != null) {
config.setAllowedHeaders(other.getAllowedHeaders());
}
if (other.getExposedHeaders() != null) {
config.setExposedHeaders(other.getExposedHeaders());
}
if (other.getMaxAge() != null) {
config.setMaxAge(other.getMaxAge());
}
if (other.isAllowCredentials() != null) {
config.setAllowCredentials(other.isAllowCredentials());
}
return config;
}
/**
* @see #setAllowedOrigins(java.util.List)
*/
public List<String> getAllowedOrigins() {
if (this.allowedOrigins != null) {
return this.allowedOrigins.contains("*") ? Arrays.asList("*") : Collections.unmodifiableList(this.allowedOrigins);
}
return null;
}
/**
* Set allowed origins that will define Access-Control-Allow-Origin response
* header values (mandatory). For example "http://domain1.com", "http://domain2.com" ...
* "*" means that all domains are allowed.
*/
public void setAllowedOrigins(List<String> allowedOrigins) {
this.allowedOrigins = allowedOrigins;
}
/**
* @see #setAllowedOrigins(java.util.List)
*/
public void addAllowedOrigin(String allowedOrigin) {
if (this.allowedOrigins == null) {
this.allowedOrigins = new ArrayList<String>();
}
this.allowedOrigins.add(allowedOrigin);
}
/**
* @see #setAllowedMethods(java.util.List)
*/
public List<String> getAllowedMethods() {
return this.allowedMethods == null ? null : Collections.unmodifiableList(this.allowedMethods);
}
/**
* Set allow methods that will define Access-Control-Allow-Methods response header
* values. For example "GET", "POST", "PUT" ... "*" means that all methods requested
* by the client are allowed. If not set, allowed method is set to "GET".
*
*/
public void setAllowedMethods(List<String> allowedMethods) {
this.allowedMethods = allowedMethods;
}
/**
* @see #setAllowedMethods(java.util.List)
*/
public void addAllowedMethod(String allowedMethod) {
if (this.allowedMethods == null) {
this.allowedMethods = new ArrayList<String>();
}
this.allowedMethods.add(allowedMethod);
}
/**
* @see #setAllowedHeaders(java.util.List)
*/
public List<String> getAllowedHeaders() {
return this.allowedHeaders == null ? null : Collections.unmodifiableList(this.allowedHeaders);
}
/**
* Set a list of request headers that will define Access-Control-Allow-Methods response
* header values. If a header field name is one of the following, it is not required
* to be listed: Cache-Control, Content-Language, Expires, Last-Modified, Pragma.
* "*" means that all headers asked by the client will be allowed.
*/
public void setAllowedHeaders(List<String> allowedHeaders) {
this.allowedHeaders = allowedHeaders;
}
/**
* @see #setAllowedHeaders(java.util.List)
*/
public void addAllowedHeader(String allowedHeader) {
if (this.allowedHeaders == null) {
this.allowedHeaders = new ArrayList<String>();
}
this.allowedHeaders.add(allowedHeader);
}
/**
* @see #setExposedHeaders(java.util.List)
*/
public List<String> getExposedHeaders() {
return this.exposedHeaders == null ? null : Collections.unmodifiableList(this.exposedHeaders);
}
/**
* Set a list of response headers other than simple headers that the resource might use
* and can be exposed. Simple response headers are: Cache-Control, Content-Language,
* Content-Type, Expires, Last-Modified, Pragma.
*/
public void setExposedHeaders(List<String> exposedHeaders) {
this.exposedHeaders = exposedHeaders;
}
/**
* @see #setExposedHeaders(java.util.List)
*/
public void addExposedHeader(String exposedHeader) {
if (this.exposedHeaders == null) {
this.exposedHeaders = new ArrayList<String>();
}
this.exposedHeaders.add(exposedHeader);
}
/**
* @see #setAllowCredentials(Boolean)
*/
public Boolean isAllowCredentials() {
return this.allowCredentials;
}
/**
* Indicates whether the resource supports user credentials.
* Set the value of Access-Control-Allow-Credentials response header.
*/
public void setAllowCredentials(Boolean allowCredentials) {
this.allowCredentials = allowCredentials;
}
/**
* @see #setMaxAge(Long)
*/
public Long getMaxAge() {
return maxAge;
}
/**
* Indicates how long (seconds) the results of a preflight request can be cached
* in a preflight result cache.
*/
public void setMaxAge(Long maxAge) {
this.maxAge = maxAge;
}
}

View File

@ -0,0 +1,35 @@
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.cors;
import javax.servlet.http.HttpServletRequest;
/**
* Interface to be implemented by classes (usually HTTP request handlers) that provides
* a {@link CorsConfiguration} instance based on the provided request.
*
* @author Sebastien Deleuze
* @since 4.2
*/
public interface CorsConfigurationSource {
/**
* Return a {@link CorsConfiguration} based on the incoming request.
*/
CorsConfiguration getCorsConfiguration(HttpServletRequest request);
}

View File

@ -0,0 +1,50 @@
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.cors;
import java.io.IOException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
/**
* Interface to be implemented by classes that process CORS preflight and actual requests.
*
* @author Sebastien Deleuze
* @since 4.2
* @see <a href="http://www.w3.org/TR/cors/">CORS W3C recommandation</a>
*/
public interface CorsProcessor {
/**
* Process a pre-flight CORS request based on the provided {@link CorsConfiguration}.
* If the request is not a valid CORS pre-flight request or if it does not comply with
* the configuration, it should be rejected.
* If the request is valid and comply with the configuration, this method adds the related
* CORS headers to the response.
*/
boolean processPreFlightRequest(CorsConfiguration conf, HttpServletRequest request, HttpServletResponse response) throws IOException;
/**
* Process a simple or actual CORS request based on the provided {@link CorsConfiguration}.
* If the request is not a valid CORS simple or actual request or if it does not comply
* with the configuration, it should be rejected.
* If the request is valid and comply with the configuration, this method adds the related
* CORS headers to the response.
*/
boolean processActualRequest(CorsConfiguration conf, HttpServletRequest request, HttpServletResponse response) throws IOException;
}

View File

@ -0,0 +1,110 @@
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.cors;
import java.util.Arrays;
import java.util.List;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.util.CollectionUtils;
/**
* Utility class for CORS request handling based on the
* <a href="http://www.w3.org/TR/cors/">CORS W3C recommandation</a>.
*
* @author Sebastien Deleuze
* @since 4.2
*/
public class CorsUtils {
/**
* The CORS {@code Access-Control-Request-Headers} request header field name.
* @see <a href="http://www.w3.org/TR/cors/">CORS W3C recommandation</a>
*/
public static final String ACCESS_CONTROL_REQUEST_HEADERS = "Access-Control-Request-Headers";
/**
* The CORS {@code Access-Control-Request-Method} request header field name.
* @see <a href="http://www.w3.org/TR/cors/">CORS W3C recommandation</a>
*/
public static final String ACCESS_CONTROL_REQUEST_METHOD = "Access-Control-Request-Method";
/**
* The CORS {@code Access-Control-Allow-Origin} response header field name.
* @see <a href="http://www.w3.org/TR/cors/">CORS W3C recommandation</a>
*/
public static final String ACCESS_CONTROL_ALLOW_ORIGIN = "Access-Control-Allow-Origin";
/**
* The CORS {@code Access-Control-Allow-Headers} response header field name.
* @see <a href="http://www.w3.org/TR/cors/">CORS W3C recommandation</a>
*/
public static final String ACCESS_CONTROL_ALLOW_HEADERS = "Access-Control-Allow-Headers";
/**
* The CORS {@code Access-Control-Allow-Methods} response header field name.
* @see <a href="http://www.w3.org/TR/cors/">CORS W3C recommandation</a>
*/
public static final String ACCESS_CONTROL_ALLOW_METHODS = "Access-Control-Allow-Methods";
/**
* The CORS {@code Access-Control-Max-Age} response header field name.
* @see <a href="http://www.w3.org/TR/cors/">CORS W3C recommandation</a>
*/
public static final String ACCESS_CONTROL_MAX_AGE = "Access-Control-Max-Age";
/**
* The CORS {@code Access-Control-Allow-Credentials} response header field name.
* @see <a href="http://www.w3.org/TR/cors/">CORS W3C recommandation</a>
*/
public static final String ACCESS_CONTROL_ALLOW_CREDENTIALS = "Access-Control-Allow-Credentials";
/**
* The CORS {@code Access-Control-Expose-Headers} response header field name.
* @see <a href="http://www.w3.org/TR/cors/">CORS W3C recommandation</a>
*/
public static final String ACCESS_CONTROL_EXPOSE_HEADERS = "Access-Control-Expose-Headers";
/**
* Returns {@code true} if the request is a valid CORS one.
*/
public static boolean isCorsRequest(HttpServletRequest request) {
return request.getHeader(HttpHeaders.ORIGIN) != null;
}
/**
* Returns {@code true} if the request is a valid CORS pre-flight one.
*/
public static boolean isPreFlightRequest(HttpServletRequest request) {
if (!isCorsRequest(request)) {
return false;
}
return request.getMethod().equals(HttpMethod.OPTIONS.name());
}
/**
* Returns {@code true} if the response already contains CORS headers.
*/
public static boolean isCorsResponse(HttpServletResponse response) {
boolean hasCorsResponseHeaders = false;
try {
// Perhaps a CORS Filter has already added this?
hasCorsResponseHeaders = !CollectionUtils.isEmpty(response.getHeaders(CorsUtils.ACCESS_CONTROL_ALLOW_ORIGIN));
}
catch (NullPointerException npe) {
// See SPR-11919 and https://issues.jboss.org/browse/WFLY-3474
}
return hasCorsResponseHeaders;
}
}

View File

@ -0,0 +1,217 @@
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.cors;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.util.StringUtils;
/**
* Default implementation of {@link CorsProcessor}, as defined by the
* <a href="http://www.w3.org/TR/cors/">CORS W3C recommandation</a>.
*
* @author Sebastien Deleuze
* @since 4.2
*/
public class DefaultCorsProcessor implements CorsProcessor {
protected final Log logger = LogFactory.getLog(getClass());
@Override
public boolean processPreFlightRequest(CorsConfiguration config, HttpServletRequest request, HttpServletResponse response) throws IOException {
if (!CorsUtils.isPreFlightRequest(request)) {
rejectCorsRequest(response);
return false;
}
if (check(request, response, config)) {
setOriginHeader(request, response, config.getAllowedOrigins(), config.isAllowCredentials());
setAllowCredentialsHeader(response, config.isAllowCredentials());
setAllowMethodsHeader(request, response, config.getAllowedMethods());
setAllowHeadersHeader(request, response, config.getAllowedHeaders());
setMaxAgeHeader(response, config.getMaxAge());
}
return true;
}
@Override
public boolean processActualRequest(CorsConfiguration config, HttpServletRequest request, HttpServletResponse response) throws IOException {
if (CorsUtils.isPreFlightRequest(request) || !CorsUtils.isCorsRequest(request)) {
rejectCorsRequest(response);
return false;
}
if (check(request, response, config)) {
setOriginHeader(request, response, config.getAllowedOrigins(), config.isAllowCredentials());
setAllowCredentialsHeader(response, config.isAllowCredentials());
setExposeHeadersHeader(response, config.getExposedHeaders());
}
return true;
}
private void rejectCorsRequest(HttpServletResponse response) throws IOException {
response.sendError(HttpServletResponse.SC_FORBIDDEN, "Invalid CORS request");
}
private boolean check(HttpServletRequest request, HttpServletResponse response, CorsConfiguration config) throws IOException {
if (CorsUtils.isCorsResponse(response)) {
logger.debug("Skip adding CORS headers, response already contains \"Access-Control-Allow-Origin\"");
return false;
}
if (!(checkOrigin(request, config.getAllowedOrigins()) &&
checkRequestMethod(request, config.getAllowedMethods()) &&
checkRequestHeaders(request, config.getAllowedHeaders()))) {
rejectCorsRequest(response);
return false;
}
return true;
}
private boolean checkOrigin(HttpServletRequest request, List<String> allowedOrigins) {
String origin = request.getHeader(HttpHeaders.ORIGIN);
if ((origin == null) || (allowedOrigins == null)) {
return false;
}
if (allowedOrigins.contains("*")) {
return true;
}
for (String allowedOrigin : allowedOrigins) {
if (origin.equalsIgnoreCase(allowedOrigin)) {
return true;
}
}
return false;
}
private boolean checkRequestMethod(HttpServletRequest request, List<String> allowedMethods) {
String requestMethod = CorsUtils.isPreFlightRequest(request) ?
request.getHeader(CorsUtils.ACCESS_CONTROL_REQUEST_METHOD) : request.getMethod();
if (requestMethod == null) {
return false;
}
if (allowedMethods == null) {
allowedMethods = Arrays.asList(HttpMethod.GET.name());
}
if (allowedMethods.contains("*")) {
return true;
}
for (String allowedMethod : allowedMethods) {
if (requestMethod.equalsIgnoreCase(allowedMethod)) {
return true;
}
}
return false;
}
private boolean checkRequestHeaders(HttpServletRequest request, List<String> allowedHeaders) {
String[] requestHeaders = CorsUtils.isPreFlightRequest(request) ?
StringUtils.commaDelimitedListToStringArray(request.getHeader(CorsUtils.ACCESS_CONTROL_REQUEST_HEADERS)) :
Collections.list(request.getHeaderNames()).toArray(new String [0]);
if ((allowedHeaders != null) && allowedHeaders.contains("*")) {
return true;
}
for (String requestHeader : requestHeaders) {
if (!HttpHeaders.ORIGIN.equals(requestHeader)) {
requestHeader = requestHeader.trim();
boolean found = false;
if (allowedHeaders != null) {
for (String header : allowedHeaders) {
if (requestHeader.equalsIgnoreCase(header)) {
found = true;
break;
}
}
}
if (!found) {
return false;
}
}
}
return true;
}
private void setOriginHeader(HttpServletRequest request, HttpServletResponse response, List<String> allowedOrigins, Boolean allowCredentials) {
String origin = request.getHeader(HttpHeaders.ORIGIN);
if (allowedOrigins.contains("*") && (allowCredentials == null || !allowCredentials)) {
response.addHeader(CorsUtils.ACCESS_CONTROL_ALLOW_ORIGIN, "*");
return;
}
response.addHeader(CorsUtils.ACCESS_CONTROL_ALLOW_ORIGIN, origin);
response.addHeader(HttpHeaders.VARY, HttpHeaders.ORIGIN);
}
private void setAllowCredentialsHeader(HttpServletResponse response, Boolean allowCredentials) {
if ((allowCredentials != null) && allowCredentials) {
response.addHeader(CorsUtils.ACCESS_CONTROL_ALLOW_CREDENTIALS, "true");
}
}
private void setAllowMethodsHeader(HttpServletRequest request, HttpServletResponse response, List<String> allowedMethods) {
if (allowedMethods == null) {
allowedMethods = Arrays.asList(HttpMethod.GET.name());
}
if (allowedMethods.contains("*")) {
response.addHeader(CorsUtils.ACCESS_CONTROL_ALLOW_METHODS, request.getHeader(CorsUtils.ACCESS_CONTROL_REQUEST_METHOD));
}
else {
response.addHeader(CorsUtils.ACCESS_CONTROL_ALLOW_METHODS, StringUtils.collectionToCommaDelimitedString(allowedMethods));
}
}
private void setAllowHeadersHeader(HttpServletRequest request, HttpServletResponse response, List<String> allowedHeaders) {
if ((allowedHeaders != null) && !allowedHeaders.isEmpty()) {
String[] requestHeaders = StringUtils.commaDelimitedListToStringArray(request.getHeader(CorsUtils.ACCESS_CONTROL_REQUEST_HEADERS));
boolean matchAll = allowedHeaders.contains("*");
List<String> matchingHeaders = new ArrayList<String>();
for (String requestHeader : requestHeaders) {
for (String header : allowedHeaders) {
requestHeader = requestHeader.trim();
if (matchAll || requestHeader.equalsIgnoreCase(header)) {
matchingHeaders.add(requestHeader);
break;
}
}
}
if (!matchingHeaders.isEmpty()) {
response.addHeader(CorsUtils.ACCESS_CONTROL_ALLOW_HEADERS, StringUtils.collectionToCommaDelimitedString(matchingHeaders));
}
}
}
private void setExposeHeadersHeader(HttpServletResponse response, List<String> exposedHeaders) {
if ((exposedHeaders != null) && !exposedHeaders.isEmpty()) {
response.addHeader(CorsUtils.ACCESS_CONTROL_EXPOSE_HEADERS, StringUtils.collectionToCommaDelimitedString(exposedHeaders));
}
}
private void setMaxAgeHeader(HttpServletResponse response, Long maxAge) {
if (maxAge != null) {
response.addHeader(CorsUtils.ACCESS_CONTROL_MAX_AGE, maxAge.toString());
}
}
}

View File

@ -0,0 +1,82 @@
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.cors;
import static org.junit.Assert.*;
import org.junit.Test;
import org.springframework.http.HttpHeaders;
import org.springframework.mock.web.test.MockHttpServletRequest;
import org.springframework.mock.web.test.MockHttpServletResponse;
/**
* Test case for {@link CorsUtils}.
*
* @author Sebastien Deleuze
*/
public class CorsUtilsTests {
@Test
public void isCorsRequest() {
MockHttpServletRequest request = new MockHttpServletRequest();
request.addHeader(HttpHeaders.ORIGIN, "http://domain.com");
assertTrue(CorsUtils.isCorsRequest(request));
}
@Test
public void isNotCorsRequest() {
MockHttpServletRequest request = new MockHttpServletRequest();
assertFalse(CorsUtils.isCorsRequest(request));
}
@Test
public void isPreFlightRequest() {
MockHttpServletRequest request = new MockHttpServletRequest();
request.setMethod("OPTIONS");
request.addHeader(HttpHeaders.ORIGIN, "http://domain.com");
request.addHeader(CorsUtils.ACCESS_CONTROL_REQUEST_METHOD, "GET");
assertTrue(CorsUtils.isPreFlightRequest(request));
}
@Test
public void isNotPreFlightRequest() {
MockHttpServletRequest request = new MockHttpServletRequest();
assertFalse(CorsUtils.isPreFlightRequest(request));
request = new MockHttpServletRequest();
request.addHeader(HttpHeaders.ORIGIN, "http://domain.com");
assertFalse(CorsUtils.isPreFlightRequest(request));
request = new MockHttpServletRequest();
request.addHeader(CorsUtils.ACCESS_CONTROL_REQUEST_METHOD, "GET");
assertFalse(CorsUtils.isPreFlightRequest(request));
}
@Test
public void isCorsResponse() {
MockHttpServletResponse response = new MockHttpServletResponse();
response.addHeader(CorsUtils.ACCESS_CONTROL_ALLOW_ORIGIN, "*");
assertTrue(CorsUtils.isCorsResponse(response));
}
@Test
public void isNotCorsResponse() {
MockHttpServletResponse response = new MockHttpServletResponse();
assertFalse(CorsUtils.isCorsResponse(response));
}
}

View File

@ -0,0 +1,302 @@
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.cors;
import org.junit.Before;
import org.junit.Test;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.mock.web.test.MockHttpServletRequest;
import org.springframework.mock.web.test.MockHttpServletResponse;
import javax.servlet.http.HttpServletResponse;
import static org.junit.Assert.*;
/**
* Test {@link DefaultCorsProcessor} with simple or preflight CORS request.
*
* @author Sebastien Deleuze
*/
public class DefaultCorsProcessorTests {
private MockHttpServletRequest request;
private MockHttpServletResponse response;
private DefaultCorsProcessor processor;
private CorsConfiguration conf;
@Before
public void setup() {
this.request = new MockHttpServletRequest();
this.request.setRequestURI("/test.html");
this.request.setRemoteHost("domain1.com");
this.conf = new CorsConfiguration();
this.response = new MockHttpServletResponse();
this.response.setStatus(HttpServletResponse.SC_OK);
this.processor = new DefaultCorsProcessor();
}
@Test
public void actualRequestWithoutOriginHeader() throws Exception {
this.request.setMethod(HttpMethod.GET.name());
this.processor.processActualRequest(this.conf, request, response);
assertFalse(response.containsHeader(CorsUtils.ACCESS_CONTROL_ALLOW_ORIGIN));
assertEquals(HttpServletResponse.SC_FORBIDDEN, response.getStatus());
}
@Test
public void actualRequestWithOriginHeader() throws Exception {
this.request.setMethod(HttpMethod.GET.name());
this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html");
this.processor.processActualRequest(this.conf, request, response);
assertFalse(response.containsHeader(CorsUtils.ACCESS_CONTROL_ALLOW_ORIGIN));
assertEquals(HttpServletResponse.SC_FORBIDDEN, response.getStatus());
}
@Test
public void actualRequestwithOriginHeaderAndAllowedOrigin() throws Exception {
this.request.setMethod(HttpMethod.GET.name());
this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html");
this.conf.addAllowedOrigin("*");
this.processor.processActualRequest(this.conf, request, response);
assertTrue(response.containsHeader(CorsUtils.ACCESS_CONTROL_ALLOW_ORIGIN));
assertEquals("*", response.getHeader(CorsUtils.ACCESS_CONTROL_ALLOW_ORIGIN));
assertFalse(response.containsHeader(CorsUtils.ACCESS_CONTROL_MAX_AGE));
assertFalse(response.containsHeader(CorsUtils.ACCESS_CONTROL_EXPOSE_HEADERS));
assertEquals(HttpServletResponse.SC_OK, response.getStatus());
}
@Test
public void actualRequestCrendentials() throws Exception {
this.request.setMethod(HttpMethod.GET.name());
this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html");
this.conf.addAllowedOrigin("http://domain2.com/home.html");
this.conf.addAllowedOrigin("http://domain2.com/test.html");
this.conf.addAllowedOrigin("http://domain2.com/logout.html");
this.conf.setAllowCredentials(true);
this.processor.processActualRequest(this.conf, request, response);
assertTrue(response.containsHeader(CorsUtils.ACCESS_CONTROL_ALLOW_ORIGIN));
assertEquals("http://domain2.com/test.html", response.getHeader(CorsUtils.ACCESS_CONTROL_ALLOW_ORIGIN));
assertTrue(response.containsHeader(CorsUtils.ACCESS_CONTROL_ALLOW_CREDENTIALS));
assertEquals("true", response.getHeader(CorsUtils.ACCESS_CONTROL_ALLOW_CREDENTIALS));
assertEquals(HttpServletResponse.SC_OK, response.getStatus());
}
@Test
public void actualRequestCredentialsWithOriginWildcard() throws Exception {
this.request.setMethod(HttpMethod.GET.name());
this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html");
this.conf.addAllowedOrigin("*");
this.conf.setAllowCredentials(true);
this.processor.processActualRequest(this.conf, request, response);
assertTrue(response.containsHeader(CorsUtils.ACCESS_CONTROL_ALLOW_ORIGIN));
assertEquals("http://domain2.com/test.html", response.getHeader(CorsUtils.ACCESS_CONTROL_ALLOW_ORIGIN));
assertTrue(response.containsHeader(CorsUtils.ACCESS_CONTROL_ALLOW_CREDENTIALS));
assertEquals("true", response.getHeader(CorsUtils.ACCESS_CONTROL_ALLOW_CREDENTIALS));
assertEquals(HttpServletResponse.SC_OK, response.getStatus());
}
@Test
public void actualRequestCaseInsensitiveOriginMatch() throws Exception {
this.request.setMethod(HttpMethod.GET.name());
this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html");
this.conf.addAllowedOrigin("http://domain2.com/TEST.html");
this.processor.processActualRequest(this.conf, request, response);
assertTrue(response.containsHeader(CorsUtils.ACCESS_CONTROL_ALLOW_ORIGIN));
assertEquals(HttpServletResponse.SC_OK, response.getStatus());
}
@Test
public void actualRequestExposedHeaders() throws Exception {
this.request.setMethod(HttpMethod.GET.name());
this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html");
this.conf.addExposedHeader("header1");
this.conf.addExposedHeader("header2");
this.conf.addAllowedOrigin("http://domain2.com/test.html");
this.processor.processActualRequest(this.conf, request, response);
assertTrue(response.containsHeader(CorsUtils.ACCESS_CONTROL_ALLOW_ORIGIN));
assertEquals("http://domain2.com/test.html", response.getHeader(CorsUtils.ACCESS_CONTROL_ALLOW_ORIGIN));
assertTrue(response.containsHeader(CorsUtils.ACCESS_CONTROL_EXPOSE_HEADERS));
assertTrue(response.getHeader(CorsUtils.ACCESS_CONTROL_EXPOSE_HEADERS).contains("header1"));
assertTrue(response.getHeader(CorsUtils.ACCESS_CONTROL_EXPOSE_HEADERS).contains("header2"));
assertEquals(HttpServletResponse.SC_OK, response.getStatus());
}
@Test
public void preflightRequestAllOriginsAllowed() throws Exception {
this.request.setMethod(HttpMethod.OPTIONS.name());
this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html");
this.request.addHeader(CorsUtils.ACCESS_CONTROL_REQUEST_METHOD, "GET");
this.conf.addAllowedOrigin("*");
this.processor.processPreFlightRequest(this.conf, request, response);
assertEquals(HttpServletResponse.SC_OK, response.getStatus());
}
@Test
public void preflightRequestWrongAllowedMethod() throws Exception {
this.request.setMethod(HttpMethod.OPTIONS.name());
this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html");
this.request.addHeader(CorsUtils.ACCESS_CONTROL_REQUEST_METHOD, "DELETE");
this.conf.addAllowedOrigin("*");
this.processor.processPreFlightRequest(this.conf, request, response);
assertEquals(HttpServletResponse.SC_FORBIDDEN, response.getStatus());
}
@Test
public void preflightRequestMatchedAllowedMethod() throws Exception {
this.request.setMethod(HttpMethod.OPTIONS.name());
this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html");
this.request.addHeader(CorsUtils.ACCESS_CONTROL_REQUEST_METHOD, "GET");
this.conf.addAllowedOrigin("*");
this.processor.processPreFlightRequest(this.conf, request, response);
assertEquals(HttpServletResponse.SC_OK, response.getStatus());
assertEquals("GET", response.getHeader(CorsUtils.ACCESS_CONTROL_ALLOW_METHODS));
}
@Test
public void preflightRequestWithoutOriginHeader() throws Exception {
this.request.setMethod(HttpMethod.OPTIONS.name());
this.request.addHeader(CorsUtils.ACCESS_CONTROL_REQUEST_METHOD, "GET");
this.processor.processPreFlightRequest(this.conf, request, response);
assertFalse(response.containsHeader(CorsUtils.ACCESS_CONTROL_ALLOW_ORIGIN));
assertEquals(HttpServletResponse.SC_FORBIDDEN, response.getStatus());
}
@Test
public void preflightRequestTestWithOriginButWithoutOtherHeaders() throws Exception {
this.request.setMethod(HttpMethod.OPTIONS.name());
this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html");
this.processor.processPreFlightRequest(this.conf, request, response);
assertFalse(response.containsHeader(CorsUtils.ACCESS_CONTROL_ALLOW_ORIGIN));
assertEquals(HttpServletResponse.SC_FORBIDDEN, response.getStatus());
}
@Test
public void preflightRequestWithoutRequestMethod() throws Exception {
this.request.setMethod(HttpMethod.OPTIONS.name());
this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html");
this.request.addHeader(CorsUtils.ACCESS_CONTROL_REQUEST_HEADERS, "Header1");
this.processor.processPreFlightRequest(this.conf, request, response);
assertFalse(response.containsHeader(CorsUtils.ACCESS_CONTROL_ALLOW_ORIGIN));
assertEquals(HttpServletResponse.SC_FORBIDDEN, response.getStatus());
}
@Test
public void preflightRequestWithRequestAndMethodHeaderButNoConfig() throws Exception {
this.request.setMethod(HttpMethod.OPTIONS.name());
this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html");
this.request.addHeader(CorsUtils.ACCESS_CONTROL_REQUEST_HEADERS, "Header1");
this.request.addHeader(CorsUtils.ACCESS_CONTROL_REQUEST_METHOD, "GET");
this.processor.processPreFlightRequest(this.conf, request, response);
assertFalse(response.containsHeader(CorsUtils.ACCESS_CONTROL_ALLOW_ORIGIN));
assertEquals(HttpServletResponse.SC_FORBIDDEN, response.getStatus());
}
@Test
public void preflightRequestValidRequestAndConfig() throws Exception {
this.request.setMethod(HttpMethod.OPTIONS.name());
this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html");
this.request.addHeader(CorsUtils.ACCESS_CONTROL_REQUEST_HEADERS, "Header1");
this.request.addHeader(CorsUtils.ACCESS_CONTROL_REQUEST_METHOD, "GET");
this.conf.addAllowedOrigin("*");
this.conf.addAllowedMethod("GET");
this.conf.addAllowedMethod("PUT");
this.conf.addAllowedHeader("header1");
this.conf.addAllowedHeader("header2");
this.processor.processPreFlightRequest(this.conf, request, response);
assertTrue(response.containsHeader(CorsUtils.ACCESS_CONTROL_ALLOW_ORIGIN));
assertEquals("*", response.getHeader(CorsUtils.ACCESS_CONTROL_ALLOW_ORIGIN));
assertTrue(response.containsHeader(CorsUtils.ACCESS_CONTROL_ALLOW_METHODS));
assertEquals("GET,PUT", response.getHeader(CorsUtils.ACCESS_CONTROL_ALLOW_METHODS));
assertFalse(response.containsHeader(CorsUtils.ACCESS_CONTROL_MAX_AGE));
assertEquals(HttpServletResponse.SC_OK, response.getStatus());
}
@Test
public void preflightRequestCrendentials() throws Exception {
this.request.setMethod(HttpMethod.OPTIONS.name());
this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html");
this.request.addHeader(CorsUtils.ACCESS_CONTROL_REQUEST_HEADERS, "Header1");
this.request.addHeader(CorsUtils.ACCESS_CONTROL_REQUEST_METHOD, "GET");
this.conf.addAllowedOrigin("http://domain2.com/home.html");
this.conf.addAllowedOrigin("http://domain2.com/test.html");
this.conf.addAllowedOrigin("http://domain2.com/logout.html");
this.conf.addAllowedHeader("Header1");
this.conf.setAllowCredentials(true);
this.processor.processPreFlightRequest(this.conf, request, response);
assertTrue(response.containsHeader(CorsUtils.ACCESS_CONTROL_ALLOW_ORIGIN));
assertEquals("http://domain2.com/test.html", response.getHeader(CorsUtils.ACCESS_CONTROL_ALLOW_ORIGIN));
assertTrue(response.containsHeader(CorsUtils.ACCESS_CONTROL_ALLOW_CREDENTIALS));
assertEquals("true", response.getHeader(CorsUtils.ACCESS_CONTROL_ALLOW_CREDENTIALS));
assertEquals(HttpServletResponse.SC_OK, response.getStatus());
}
@Test
public void preflightRequestCrendentialsWithOriginWildcard() throws Exception {
this.request.setMethod(HttpMethod.OPTIONS.name());
this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html");
this.request.addHeader(CorsUtils.ACCESS_CONTROL_REQUEST_HEADERS, "Header1");
this.request.addHeader(CorsUtils.ACCESS_CONTROL_REQUEST_METHOD, "GET");
this.conf.addAllowedOrigin("http://domain2.com/home.html");
this.conf.addAllowedOrigin("*");
this.conf.addAllowedOrigin("http://domain2.com/logout.html");
this.conf.addAllowedHeader("Header1");
this.conf.setAllowCredentials(true);
this.processor.processPreFlightRequest(this.conf, request, response);
assertTrue(response.containsHeader(CorsUtils.ACCESS_CONTROL_ALLOW_ORIGIN));
assertEquals("http://domain2.com/test.html", response.getHeader(CorsUtils.ACCESS_CONTROL_ALLOW_ORIGIN));
assertEquals(HttpServletResponse.SC_OK, response.getStatus());
}
@Test
public void preflightRequestAllowedHeaders() throws Exception {
this.request.setMethod(HttpMethod.OPTIONS.name());
this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html");
this.request.addHeader(CorsUtils.ACCESS_CONTROL_REQUEST_HEADERS, "Header1, Header2");
this.request.addHeader(CorsUtils.ACCESS_CONTROL_REQUEST_METHOD, "GET");
this.conf.addAllowedHeader("Header1");
this.conf.addAllowedHeader("Header2");
this.conf.addAllowedHeader("Header3");
this.conf.addAllowedOrigin("http://domain2.com/test.html");
this.processor.processPreFlightRequest(this.conf, request, response);
assertTrue(response.containsHeader(CorsUtils.ACCESS_CONTROL_ALLOW_ORIGIN));
assertTrue(response.containsHeader(CorsUtils.ACCESS_CONTROL_ALLOW_HEADERS));
assertTrue(response.getHeader(CorsUtils.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header1"));
assertTrue(response.getHeader(CorsUtils.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header2"));
assertFalse(response.getHeader(CorsUtils.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header3"));
assertEquals(HttpServletResponse.SC_OK, response.getStatus());
}
@Test
public void preflightRequestAllowsAllHeaders() throws Exception {
this.request.setMethod(HttpMethod.OPTIONS.name());
this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html");
this.request.addHeader(CorsUtils.ACCESS_CONTROL_REQUEST_HEADERS, "Header1, Header2");
this.request.addHeader(CorsUtils.ACCESS_CONTROL_REQUEST_METHOD, "GET");
this.conf.addAllowedHeader("*");
this.conf.addAllowedOrigin("http://domain2.com/test.html");
this.processor.processPreFlightRequest(this.conf, request, response);
assertTrue(response.containsHeader(CorsUtils.ACCESS_CONTROL_ALLOW_ORIGIN));
assertTrue(response.containsHeader(CorsUtils.ACCESS_CONTROL_ALLOW_HEADERS));
assertTrue(response.getHeader(CorsUtils.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header1"));
assertTrue(response.getHeader(CorsUtils.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header2"));
assertFalse(response.getHeader(CorsUtils.ACCESS_CONTROL_ALLOW_HEADERS).contains("*"));
assertEquals(HttpServletResponse.SC_OK, response.getStatus());
}
}

View File

@ -903,7 +903,7 @@ public abstract class FrameworkServlet extends HttpServletBean implements Applic
protected void doOptions(HttpServletRequest request, HttpServletResponse response)
throws ServletException, IOException {
if (this.dispatchOptionsRequest) {
if (this.dispatchOptionsRequest || request.getHeader("Origin") != null) {
processRequest(request, response);
if (response.containsHeader("Allow")) {
// Proper OPTIONS response coming from a handler - we're done.

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2014 the original author or authors.
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,14 +16,20 @@
package org.springframework.web.servlet.handler;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.BeanFactoryUtils;
import org.springframework.core.Ordered;
import org.springframework.web.HttpRequestHandler;
import org.springframework.web.cors.CorsProcessor;
import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.cors.CorsConfigurationSource;
import org.springframework.util.AntPathMatcher;
import org.springframework.util.Assert;
import org.springframework.util.PathMatcher;
@ -32,6 +38,8 @@ import org.springframework.web.context.support.WebApplicationObjectSupport;
import org.springframework.web.servlet.HandlerExecutionChain;
import org.springframework.web.servlet.HandlerInterceptor;
import org.springframework.web.servlet.HandlerMapping;
import org.springframework.web.cors.DefaultCorsProcessor;
import org.springframework.web.cors.CorsUtils;
import org.springframework.web.util.UrlPathHelper;
/**
@ -71,6 +79,8 @@ public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport
private final List<MappedInterceptor> mappedInterceptors = new ArrayList<MappedInterceptor>();
private CorsProcessor corsProcessor = new DefaultCorsProcessor();
/**
* Specify the order value for this HandlerMapping bean.
@ -184,6 +194,13 @@ public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport
this.interceptors.addAll(Arrays.asList(interceptors));
}
/**
* @since 4.2
*/
public void setCorsProcessor(CorsProcessor corsProcessor) {
Assert.notNull(corsProcessor, "CorsProcessor must not be null");
this.corsProcessor = corsProcessor;
}
/**
* Initializes the interceptors.
@ -308,16 +325,29 @@ public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport
String handlerName = (String) handler;
handler = getApplicationContext().getBean(handlerName);
}
return getHandlerExecutionChain(handler, request);
HandlerExecutionChain executionChain = getHandlerExecutionChain(handler, request);
if (CorsUtils.isCorsRequest(request)) {
CorsConfiguration config = getCorsConfiguration(handler, request);
executionChain = getCorsHandlerExecutionChain(request, executionChain, config);
}
return executionChain;
}
/**
* Look up a handler for the given request, returning {@code null} if no
* specific one is found. This method is called by {@link #getHandler};
* a {@code null} return value will lead to the default handler, if one is set.
*
* <p>On CORS pre-flight requests this method should return a match not for
* the pre-flight request but for the expected actual request based on the URL
* path, the HTTP methods from the "Access-Control-Request-Method" header, and
* the headers from the "Access-Control-Request-Headers" header thus allowing
* the CORS configuration to be obtained via {@link #getCorsConfiguration},
*
* <p>Note: This method may also return a pre-built {@link HandlerExecutionChain},
* combining a handler object with dynamically determined interceptors.
* Statically specified interceptors will get merged into such an existing chain.
*
* @param request current HTTP request
* @return the corresponding handler instance, or {@code null} if none found
* @throws Exception if there is an internal error
@ -358,4 +388,72 @@ public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport
return chain;
}
/**
* Retrieve the CORS configuration for the given handler.
*/
protected CorsConfiguration getCorsConfiguration(Object handler, HttpServletRequest request) {
handler = (handler instanceof HandlerExecutionChain) ? ((HandlerExecutionChain) handler).getHandler() : handler;
if (handler != null && handler instanceof CorsConfigurationSource) {
return ((CorsConfigurationSource) handler).getCorsConfiguration(request);
}
return null;
}
/**
* Update the HandlerExecutionChain for CORS-related handling.
*
* <p>For pre-flight requests, the default implementation replaces the selected
* handler with a simple HttpRequestHandler that invokes the configured
* {@link #setCorsProcessor}.
*
* <p>For actual requests, the default implementation inserts a
* HandlerInterceptor that makes CORS-related checks and adds CORS headers.
*/
protected HandlerExecutionChain getCorsHandlerExecutionChain(HttpServletRequest request,
HandlerExecutionChain chain, CorsConfiguration config) {
if (config != null) {
if (CorsUtils.isPreFlightRequest(request)) {
HandlerInterceptor[] interceptors = chain.getInterceptors();
chain = new HandlerExecutionChain(new PreFlightHandler(config), interceptors);
}
else {
chain.addInterceptor(new CorsInterceptor(config));
}
}
return chain;
}
private class PreFlightHandler implements HttpRequestHandler {
private final CorsConfiguration config;
public PreFlightHandler(CorsConfiguration config) {
this.config = config;
}
@Override
public void handleRequest(HttpServletRequest request, HttpServletResponse response) throws IOException {
corsProcessor.processPreFlightRequest(this.config, request, response);
}
}
private class CorsInterceptor extends HandlerInterceptorAdapter {
private final CorsConfiguration config;
public CorsInterceptor(CorsConfiguration config) {
this.config = config;
}
@Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
return corsProcessor.processActualRequest(this.config, request, response);
}
}
}

View File

@ -35,6 +35,8 @@ import org.springframework.util.ClassUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.ReflectionUtils.MethodFilter;
import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.cors.CorsUtils;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.method.HandlerMethodSelector;
import org.springframework.web.servlet.HandlerMapping;
@ -67,6 +69,9 @@ public abstract class AbstractHandlerMethodMapping<T> extends AbstractHandlerMap
*/
private static final String SCOPED_TARGET_NAME_PREFIX = "scopedTarget.";
private static final HandlerMethod PREFLIGHT_MULTI_MATCH_HANDLER_METHOD =
new HandlerMethod(new EmptyHandler(), ClassUtils.getMethod(EmptyHandler.class, "handle"));
private boolean detectHandlerMethodsInAncestorContexts = false;
@ -78,6 +83,8 @@ public abstract class AbstractHandlerMethodMapping<T> extends AbstractHandlerMap
private final MultiValueMap<String, HandlerMethod> nameMap = new LinkedMultiValueMap<String, HandlerMethod>();
private final Map<Method, CorsConfiguration> corsConfigurations = new LinkedHashMap<Method, CorsConfiguration>();
/**
* Whether to detect handler methods in beans in ancestor ApplicationContexts.
@ -106,6 +113,20 @@ public abstract class AbstractHandlerMethodMapping<T> extends AbstractHandlerMap
return Collections.unmodifiableMap(this.handlerMethods);
}
protected Map<Method, CorsConfiguration> getCorsConfigurations() {
return corsConfigurations;
}
@Override
protected CorsConfiguration getCorsConfiguration(Object handler, HttpServletRequest request) {
CorsConfiguration config = super.getCorsConfiguration(handler, request);
if (config == null && handler instanceof HandlerMethod) {
HandlerMethod handlerMethod = (HandlerMethod)handler;
config = this.getCorsConfigurations().get(handlerMethod.getMethod());
}
return config;
}
/**
* Return the handler methods mapped to the mapping with the given name.
* @param mappingName the mapping name
@ -144,9 +165,19 @@ public abstract class AbstractHandlerMethodMapping<T> extends AbstractHandlerMap
detectHandlerMethods(beanName);
}
}
registerMultiMatchCorsConfiguration();
handlerMethodsInitialized(getHandlerMethods());
}
private void registerMultiMatchCorsConfiguration() {
CorsConfiguration config = new CorsConfiguration();
config.addAllowedOrigin("*");
config.addAllowedMethod("*");
config.addAllowedHeader("*");
config.setAllowCredentials(true);
this.corsConfigurations.put(PREFLIGHT_MULTI_MATCH_HANDLER_METHOD.getMethod(), config);
}
/**
* Whether the given type is a handler with handler methods.
* @param beanType the type of the bean being checked
@ -228,6 +259,15 @@ public abstract class AbstractHandlerMethodMapping<T> extends AbstractHandlerMap
String name = this.namingStrategy.getName(newHandlerMethod, mapping);
updateNameMap(name, newHandlerMethod);
}
CorsConfiguration config = initCorsConfiguration(handler, method, mapping);
if (config != null) {
this.corsConfigurations.put(method, config);
}
}
protected CorsConfiguration initCorsConfiguration(Object handler, Method method, T mappingInfo) {
return null;
}
private void updateNameMap(String name, HandlerMethod newHandlerMethod) {
@ -333,6 +373,9 @@ public abstract class AbstractHandlerMethodMapping<T> extends AbstractHandlerMap
}
Match bestMatch = matches.get(0);
if (matches.size() > 1) {
if (CorsUtils.isPreFlightRequest(request)) {
return PREFLIGHT_MULTI_MATCH_HANDLER_METHOD;
}
Match secondBestMatch = matches.get(1);
if (comparator.compare(bestMatch, secondBestMatch) == 0) {
Method m1 = bestMatch.handlerMethod.getMethod();
@ -436,4 +479,13 @@ public abstract class AbstractHandlerMethodMapping<T> extends AbstractHandlerMap
}
}
private static class EmptyHandler {
public void handle() {
throw new UnsupportedOperationException("not implemented");
}
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2014 the original author or authors.
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -19,6 +19,8 @@ package org.springframework.web.servlet.mvc.method;
import javax.servlet.http.HttpServletRequest;
import org.springframework.util.StringUtils;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.cors.CorsUtils;
import org.springframework.web.servlet.mvc.condition.ConsumesRequestCondition;
import org.springframework.web.servlet.mvc.condition.HeadersRequestCondition;
import org.springframework.web.servlet.mvc.condition.ParamsRequestCondition;
@ -208,7 +210,15 @@ public final class RequestMappingInfo implements RequestCondition<RequestMapping
ProducesRequestCondition produces = this.producesCondition.getMatchingCondition(request);
if (methods == null || params == null || headers == null || consumes == null || produces == null) {
return null;
if (CorsUtils.isPreFlightRequest(request)) {
methods = getAccessControlRequestMethodCondition(request);
if (methods == null || params == null) {
return null;
}
}
else {
return null;
}
}
PatternsRequestCondition patterns = this.patternsCondition.getMatchingCondition(request);
@ -225,6 +235,21 @@ public final class RequestMappingInfo implements RequestCondition<RequestMapping
methods, params, headers, consumes, produces, custom.getCondition());
}
/**
* Return a matching RequestMethodsRequestCondition based on the expected
* HTTP method specified in a CORS pre-flight request.
*/
private RequestMethodsRequestCondition getAccessControlRequestMethodCondition(HttpServletRequest request) {
String expectedMethod = request.getHeader(CorsUtils.ACCESS_CONTROL_REQUEST_METHOD);
if (StringUtils.hasText(expectedMethod)) {
for (RequestMethod method : getMethodsCondition().getMethods()) {
if (expectedMethod.equalsIgnoreCase(method.name())) {
return new RequestMethodsRequestCondition(method);
}
}
}
return null;
}
/**
* Compares "this" info (i.e. the current instance) with another info in the context of a request.

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2014 the original author or authors.
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -24,13 +24,19 @@ import org.springframework.context.EmbeddedValueResolverAware;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.stereotype.Controller;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringValueResolver;
import org.springframework.web.accept.ContentNegotiationManager;
import org.springframework.web.bind.annotation.CrossOrigin;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.servlet.mvc.condition.AbstractRequestCondition;
import org.springframework.web.servlet.mvc.condition.CompositeRequestCondition;
import org.springframework.web.servlet.mvc.condition.ConsumesRequestCondition;
import org.springframework.web.servlet.mvc.condition.HeadersRequestCondition;
import org.springframework.web.servlet.mvc.condition.NameValueExpression;
import org.springframework.web.servlet.mvc.condition.ParamsRequestCondition;
import org.springframework.web.servlet.mvc.condition.PatternsRequestCondition;
import org.springframework.web.servlet.mvc.condition.ProducesRequestCondition;
@ -262,4 +268,61 @@ public class RequestMappingHandlerMapping extends RequestMappingInfoHandlerMappi
}
}
@Override
protected CorsConfiguration initCorsConfiguration(Object handler, Method method, RequestMappingInfo mappingInfo) {
HandlerMethod handlerMethod = createHandlerMethod(handler, method);
CorsConfiguration config = new CorsConfiguration();
CrossOrigin typeAnnotation = AnnotationUtils.findAnnotation(handlerMethod.getBeanType(), CrossOrigin.class);
applyAnnotation(config, typeAnnotation);
CrossOrigin methodAnnotation = AnnotationUtils.findAnnotation(method, CrossOrigin.class);
applyAnnotation(config, methodAnnotation);
if (CollectionUtils.isEmpty(config.getAllowedMethods())) {
for (RequestMethod allowedMethod : mappingInfo.getMethodsCondition().getMethods()) {
config.addAllowedMethod(allowedMethod.name());
}
}
if (CollectionUtils.isEmpty(config.getAllowedHeaders())) {
for (NameValueExpression<String> headerExpression : mappingInfo.getHeadersCondition().getExpressions()) {
if (!headerExpression.isNegated()) {
config.addAllowedHeader(headerExpression.getName());
}
}
}
return config;
}
private void applyAnnotation(CorsConfiguration config, CrossOrigin annotation) {
if (annotation == null) {
return;
}
for (String origin : annotation.origin()) {
config.addAllowedOrigin(origin);
}
for (RequestMethod method : annotation.method()) {
config.addAllowedMethod(method.name());
}
for (String header : annotation.allowedHeaders()) {
config.addAllowedHeader(header);
}
for (String header : annotation.exposedHeaders()) {
config.addExposedHeader(header);
}
if (annotation.allowCredentials().equalsIgnoreCase("true")) {
config.setAllowCredentials(true);
}
else if (annotation.allowCredentials().equalsIgnoreCase("false")) {
config.setAllowCredentials(false);
}
else if (!annotation.allowCredentials().isEmpty()) {
throw new IllegalStateException("AllowCredentials value must be \"true\", \"false\" or \"\" (empty string), current value is " + annotation.allowCredentials());
}
if (annotation.maxAge() != -1 && config.getMaxAge() == null) {
config.setMaxAge(annotation.maxAge());
}
}
}

View File

@ -38,7 +38,9 @@ import org.springframework.core.io.Resource;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpRange;
import org.springframework.http.MediaType;
import org.springframework.web.cors.CorsConfiguration;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.web.cors.CorsConfigurationSource;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.CollectionUtils;
@ -88,7 +90,7 @@ import org.springframework.web.servlet.support.WebContentGenerator;
* @author Arjen Poutsma
* @since 3.0.4
*/
public class ResourceHttpRequestHandler extends WebContentGenerator implements HttpRequestHandler, InitializingBean {
public class ResourceHttpRequestHandler extends WebContentGenerator implements HttpRequestHandler, InitializingBean, CorsConfigurationSource {
private static final String CONTENT_ENCODING = "Content-Encoding";
@ -104,6 +106,8 @@ public class ResourceHttpRequestHandler extends WebContentGenerator implements H
private final List<ResourceTransformer> resourceTransformers = new ArrayList<ResourceTransformer>(4);
private CorsConfiguration corsConfiguration;
public ResourceHttpRequestHandler() {
super(METHOD_GET, METHOD_HEAD);
@ -162,6 +166,9 @@ public class ResourceHttpRequestHandler extends WebContentGenerator implements H
return this.resourceTransformers;
}
public void setCorsConfiguration(CorsConfiguration corsConfiguration) {
this.corsConfiguration = corsConfiguration;
}
@Override
public void afterPropertiesSet() throws Exception {
@ -172,6 +179,11 @@ public class ResourceHttpRequestHandler extends WebContentGenerator implements H
initAllowedLocations();
}
@Override
public CorsConfiguration getCorsConfiguration(HttpServletRequest request) {
return corsConfiguration;
}
/**
* Look for a {@link org.springframework.web.servlet.resource.PathResourceResolver}
* among the {@link #getResourceResolvers() resource resolvers} and configure

View File

@ -0,0 +1,168 @@
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.servlet.handler;
import java.io.IOException;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import static org.junit.Assert.*;
import org.junit.Before;
import org.junit.Test;
import org.springframework.beans.DirectFieldAccessor;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.cors.CorsConfigurationSource;
import org.springframework.mock.web.test.MockHttpServletRequest;
import org.springframework.mock.web.test.MockHttpServletResponse;
import org.springframework.web.HttpRequestHandler;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.context.support.StaticWebApplicationContext;
import org.springframework.web.cors.CorsUtils;
import org.springframework.web.servlet.HandlerExecutionChain;
import org.springframework.web.servlet.HandlerInterceptor;
import org.springframework.web.servlet.support.WebContentGenerator;
/**
* @author Sebastien Deleuze
*/
public class CorsAbstractHandlerMappingTests {
private MockHttpServletRequest request;
private MockHttpServletResponse response;
private AbstractHandlerMapping handlerMapping;
private StaticWebApplicationContext context;
@Before
public void setup() {
this.context = new StaticWebApplicationContext();
this.handlerMapping = new TestHandlerMapping();
this.handlerMapping.setApplicationContext(this.context);
this.request = new MockHttpServletRequest();
this.request.setRemoteHost("domain1.com");
this.response = new MockHttpServletResponse();
}
@Test
public void actualRequestWithoutCorsConfigurationProvider() throws Exception {
this.request.setMethod(RequestMethod.GET.name());
this.request.setRequestURI("/notcors");
this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html");
this.request.addHeader(CorsUtils.ACCESS_CONTROL_REQUEST_METHOD, "GET");
HandlerExecutionChain chain = handlerMapping.getHandler(this.request);
assertTrue(chain.getHandler() instanceof SimpleHandler);
}
@Test
public void preflightRequestWithoutCorsConfigurationProvider() throws Exception {
this.request.setMethod(RequestMethod.OPTIONS.name());
this.request.setRequestURI("/notcors");
this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html");
this.request.addHeader(CorsUtils.ACCESS_CONTROL_REQUEST_METHOD, "GET");
HandlerExecutionChain chain = handlerMapping.getHandler(this.request);
assertTrue(chain.getHandler() instanceof SimpleHandler);
}
@Test
public void actualRequestWithCorsConfigurationProvider() throws Exception {
this.request.setMethod(RequestMethod.GET.name());
this.request.setRequestURI("/cors");
this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html");
this.request.addHeader(CorsUtils.ACCESS_CONTROL_REQUEST_METHOD, "GET");
HandlerExecutionChain chain = handlerMapping.getHandler(this.request);
assertTrue(chain.getHandler() instanceof CorsAwareHandler);
CorsConfiguration config = getCorsConfiguration(chain, false);
assertNotNull(config);
assertArrayEquals(config.getAllowedOrigins().toArray(), new String[]{"*"});
}
@Test
public void preflightRequestWithCorsConfigurationProvider() throws Exception {
this.request.setMethod(RequestMethod.OPTIONS.name());
this.request.setRequestURI("/cors");
this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com/test.html");
this.request.addHeader(CorsUtils.ACCESS_CONTROL_REQUEST_METHOD, "GET");
HandlerExecutionChain chain = handlerMapping.getHandler(this.request);
assertNotNull(chain.getHandler());
assertTrue(chain.getHandler().getClass().getSimpleName().equals("PreFlightHandler"));
CorsConfiguration config = getCorsConfiguration(chain, true);
assertNotNull(config);
assertArrayEquals(config.getAllowedOrigins().toArray(), new String[]{"*"});
}
private CorsConfiguration getCorsConfiguration(HandlerExecutionChain chain, boolean isPreFlightRequest) {
if (isPreFlightRequest) {
Object handler = chain.getHandler();
assertTrue(handler.getClass().getSimpleName().equals("PreFlightHandler"));
DirectFieldAccessor accessor = new DirectFieldAccessor(handler);
return (CorsConfiguration)accessor.getPropertyValue("config");
}
else {
HandlerInterceptor[] interceptors = chain.getInterceptors();
if (interceptors != null) {
for (HandlerInterceptor interceptor : interceptors) {
if (interceptor.getClass().getSimpleName().equals("CorsInterceptor")) {
DirectFieldAccessor accessor = new DirectFieldAccessor(interceptor);
return (CorsConfiguration) accessor.getPropertyValue("config");
}
}
}
}
return null;
}
public class TestHandlerMapping extends AbstractHandlerMapping {
@Override
protected Object getHandlerInternal(HttpServletRequest request) throws Exception {
if (request.getRequestURI().equals("/cors")) {
return new CorsAwareHandler();
}
return new SimpleHandler();
}
}
public class SimpleHandler extends WebContentGenerator implements HttpRequestHandler {
public SimpleHandler() {
super(METHOD_GET);
}
@Override
public void handleRequest(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
response.setStatus(HttpStatus.OK.value());
}
}
public class CorsAwareHandler extends SimpleHandler implements CorsConfigurationSource {
@Override
public CorsConfiguration getCorsConfiguration(HttpServletRequest request) {
CorsConfiguration config = new CorsConfiguration();
config.addAllowedOrigin("*");
return config;
}
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2014 the original author or authors.
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -22,8 +22,10 @@ import java.util.List;
import org.junit.Test;
import org.springframework.http.HttpHeaders;
import org.springframework.mock.web.test.MockHttpServletRequest;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.cors.CorsUtils;
import org.springframework.web.servlet.mvc.condition.ConsumesRequestCondition;
import org.springframework.web.servlet.mvc.condition.HeadersRequestCondition;
import org.springframework.web.servlet.mvc.condition.ParamsRequestCondition;
@ -315,4 +317,23 @@ public class RequestMappingInfoTests {
assertNotEquals(info1.hashCode(), info2.hashCode());
}
@Test
public void preFlightRequest() {
MockHttpServletRequest request = new MockHttpServletRequest("OPTIONS", "/foo");
request.addHeader(HttpHeaders.ORIGIN, "http://domain.com");
request.addHeader(CorsUtils.ACCESS_CONTROL_REQUEST_METHOD, "POST");
RequestMappingInfo info = new RequestMappingInfo(
new PatternsRequestCondition("/foo"), new RequestMethodsRequestCondition(RequestMethod.POST), null,
null, null, null, null);
RequestMappingInfo match = info.getMatchingCondition(request);
assertNotNull(match);
info = new RequestMappingInfo(
new PatternsRequestCondition("/foo"), new RequestMethodsRequestCondition(RequestMethod.OPTIONS), null,
null, null, null, null);
match = info.getMatchingCondition(request);
assertNotNull(match);
}
}

View File

@ -0,0 +1,284 @@
/*
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.servlet.mvc.method.annotation;
import java.lang.reflect.Method;
import static org.junit.Assert.*;
import org.junit.Before;
import org.junit.Test;
import org.springframework.beans.DirectFieldAccessor;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.http.HttpHeaders;
import org.springframework.web.context.support.StaticWebApplicationContext;
import org.springframework.web.cors.CorsConfiguration;
import org.springframework.mock.web.test.MockHttpServletRequest;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.CrossOrigin;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.cors.CorsUtils;
import org.springframework.web.servlet.HandlerExecutionChain;
import org.springframework.web.servlet.HandlerInterceptor;
import org.springframework.web.servlet.mvc.condition.ConsumesRequestCondition;
import org.springframework.web.servlet.mvc.condition.HeadersRequestCondition;
import org.springframework.web.servlet.mvc.condition.ParamsRequestCondition;
import org.springframework.web.servlet.mvc.condition.PatternsRequestCondition;
import org.springframework.web.servlet.mvc.condition.ProducesRequestCondition;
import org.springframework.web.servlet.mvc.condition.RequestMethodsRequestCondition;
import org.springframework.web.servlet.mvc.method.RequestMappingInfo;
/**
* Test fixture for {@link CrossOrigin @CrossOrigin} annotated methods.
*
* @author Sebastien Deleuze
*/
@SuppressWarnings("unchecked")
public class CrossOriginTests {
private TestRequestMappingInfoHandlerMapping handlerMapping;
private MockHttpServletRequest request;
@Before
public void setUp() {
this.handlerMapping = new TestRequestMappingInfoHandlerMapping();
this.handlerMapping.setRemoveSemicolonContent(false);
this.handlerMapping.setApplicationContext(new StaticWebApplicationContext());
this.handlerMapping.afterPropertiesSet();
this.request = new MockHttpServletRequest();
this.request.setMethod("GET");
this.request.addHeader(HttpHeaders.ORIGIN, "http://domain.com/");
}
@Test
public void noAnnotation() throws Exception {
this.handlerMapping.registerHandler(new MethodLevelController());
MockHttpServletRequest request = new MockHttpServletRequest("GET", "/no");
HandlerExecutionChain chain = this.handlerMapping.getHandler(request);
CorsConfiguration config = getCorsConfiguration(chain, false);
assertNull(config);
}
@Test
public void defaultAnnotation() throws Exception {
this.handlerMapping.registerHandler(new MethodLevelController());
this.request.setRequestURI("/default");
HandlerExecutionChain chain = this.handlerMapping.getHandler(request);
CorsConfiguration config = getCorsConfiguration(chain, false);
assertNotNull(config);
assertArrayEquals(new String[]{"GET"}, config.getAllowedMethods().toArray());
assertArrayEquals(new String[]{"*"}, config.getAllowedOrigins().toArray());
assertTrue(config.isAllowCredentials());
assertArrayEquals(new String[]{"*"}, config.getAllowedHeaders().toArray());
assertNull(config.getExposedHeaders());
assertEquals(new Long(1800), config.getMaxAge());
}
@Test
public void customized() throws Exception {
this.handlerMapping.registerHandler(new MethodLevelController());
this.request.setRequestURI("/customized");
HandlerExecutionChain chain = this.handlerMapping.getHandler(request);
CorsConfiguration config = getCorsConfiguration(chain, false);
assertNotNull(config);
assertArrayEquals(new String[]{"DELETE"}, config.getAllowedMethods().toArray());
assertArrayEquals(new String[]{"http://site1.com", "http://site2.com"}, config.getAllowedOrigins().toArray());
assertArrayEquals(new String[]{"header1", "header2"}, config.getAllowedHeaders().toArray());
assertArrayEquals(new String[]{"header3", "header4"}, config.getExposedHeaders().toArray());
assertEquals(new Long(123), config.getMaxAge());
assertEquals(false, config.isAllowCredentials());
}
@Test
public void classLevel() throws Exception {
this.handlerMapping.registerHandler(new ClassLevelController());
this.request.setRequestURI("/foo");
HandlerExecutionChain chain = this.handlerMapping.getHandler(request);
CorsConfiguration config = getCorsConfiguration(chain, false);
assertNotNull(config);
assertArrayEquals(new String[]{"GET"}, config.getAllowedMethods().toArray());
assertArrayEquals(new String[]{"*"}, config.getAllowedOrigins().toArray());
}
@Test
public void preFlightRequest() throws Exception {
this.handlerMapping.registerHandler(new MethodLevelController());
this.request.setMethod("OPTIONS");
this.request.addHeader(CorsUtils.ACCESS_CONTROL_REQUEST_METHOD, "GET");
this.request.setRequestURI("/default");
HandlerExecutionChain chain = this.handlerMapping.getHandler(request);
CorsConfiguration config = getCorsConfiguration(chain, true);
assertNotNull(config);
assertArrayEquals(new String[]{"GET"}, config.getAllowedMethods().toArray());
assertArrayEquals(new String[]{"*"}, config.getAllowedOrigins().toArray());
assertTrue(config.isAllowCredentials());
assertArrayEquals(new String[]{"*"}, config.getAllowedHeaders().toArray());
assertNull(config.getExposedHeaders());
assertEquals(new Long(1800), config.getMaxAge());
}
@Test
public void ambiguousHeaderPreFlightRequest() throws Exception {
this.handlerMapping.registerHandler(new MethodLevelController());
this.request.setMethod("OPTIONS");
this.request.addHeader(CorsUtils.ACCESS_CONTROL_REQUEST_METHOD, "GET");
this.request.addHeader(CorsUtils.ACCESS_CONTROL_REQUEST_HEADERS, "header1");
this.request.setRequestURI("/ambiguous-header");
HandlerExecutionChain chain = this.handlerMapping.getHandler(request);
CorsConfiguration config = getCorsConfiguration(chain, true);
assertNotNull(config);
assertArrayEquals(new String[]{"*"}, config.getAllowedMethods().toArray());
assertArrayEquals(new String[]{"*"}, config.getAllowedOrigins().toArray());
assertArrayEquals(new String[]{"*"}, config.getAllowedHeaders().toArray());
assertTrue(config.isAllowCredentials());
assertNull(config.getExposedHeaders());
assertNull(config.getMaxAge());
}
@Test
public void ambiguousProducesPreFlightRequest() throws Exception {
this.handlerMapping.registerHandler(new MethodLevelController());
this.request.setMethod("OPTIONS");
this.request.addHeader(CorsUtils.ACCESS_CONTROL_REQUEST_METHOD, "GET");
this.request.setRequestURI("/ambiguous-produces");
HandlerExecutionChain chain = this.handlerMapping.getHandler(request);
CorsConfiguration config = getCorsConfiguration(chain, true);
assertNotNull(config);
assertArrayEquals(new String[]{"*"}, config.getAllowedMethods().toArray());
assertArrayEquals(new String[]{"*"}, config.getAllowedOrigins().toArray());
assertArrayEquals(new String[]{"*"}, config.getAllowedHeaders().toArray());
assertTrue(config.isAllowCredentials());
assertNull(config.getExposedHeaders());
assertNull(config.getMaxAge());
}
@Test
public void preFlightRequestWithoutRequestMethodHeader() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest("OPTIONS", "/default");
request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com");
assertNull(this.handlerMapping.getHandler(request));
}
private CorsConfiguration getCorsConfiguration(HandlerExecutionChain chain, boolean isPreFlightRequest) {
if (isPreFlightRequest) {
Object handler = chain.getHandler();
assertTrue(handler.getClass().getSimpleName().equals("PreFlightHandler"));
DirectFieldAccessor accessor = new DirectFieldAccessor(handler);
return (CorsConfiguration)accessor.getPropertyValue("config");
}
else {
HandlerInterceptor[] interceptors = chain.getInterceptors();
if (interceptors != null) {
for (HandlerInterceptor interceptor : interceptors) {
if (interceptor.getClass().getSimpleName().equals("CorsInterceptor")) {
DirectFieldAccessor accessor = new DirectFieldAccessor(interceptor);
return (CorsConfiguration) accessor.getPropertyValue("config");
}
}
}
}
return null;
}
@Controller
private static class MethodLevelController {
@RequestMapping(value = "/no", method = RequestMethod.GET)
public void noAnnotation() {
}
@CrossOrigin
@RequestMapping(value = "/default", method = RequestMethod.GET)
public void defaultAnnotation() {
}
@CrossOrigin
@RequestMapping(value = "/default", method = RequestMethod.GET, params = "q")
public void defaultAnnotationWithParams() {
}
@CrossOrigin
@RequestMapping(value = "/ambiguous-header", method = RequestMethod.GET, headers = "header1=a")
public void ambigousHeader1a() {
}
@CrossOrigin
@RequestMapping(value = "/ambiguous-header", method = RequestMethod.GET, headers = "header1=b")
public void ambigousHeader1b() {
}
@CrossOrigin
@RequestMapping(value = "/ambiguous-produces", method = RequestMethod.GET, produces = "application/xml")
public String ambigousProducesXml() {
return "<a></a>";
}
@CrossOrigin
@RequestMapping(value = "/ambiguous-produces", method = RequestMethod.GET, produces = "application/json")
public String ambigousProducesJson() {
return "{}";
}
@CrossOrigin(origin = { "http://site1.com", "http://site2.com" }, allowedHeaders = { "header1", "header2" },
exposedHeaders = { "header3", "header4" }, method = RequestMethod.DELETE, maxAge = 123, allowCredentials = "false")
@RequestMapping(value = "/customized", method = { RequestMethod.GET, RequestMethod.POST } )
public void customized() {
}
}
@Controller
@CrossOrigin
private static class ClassLevelController {
@RequestMapping(value = "/foo", method = RequestMethod.GET)
public void foo() {
}
}
private static class TestRequestMappingInfoHandlerMapping extends RequestMappingHandlerMapping {
public void registerHandler(Object handler) {
super.detectHandlerMethods(handler);
}
@Override
protected boolean isHandler(Class<?> beanType) {
return AnnotationUtils.findAnnotation(beanType, Controller.class) != null;
}
@Override
protected RequestMappingInfo getMappingForMethod(Method method, Class<?> handlerType) {
RequestMapping annotation = AnnotationUtils.findAnnotation(method, RequestMapping.class);
if (annotation != null) {
return new RequestMappingInfo(
new PatternsRequestCondition(annotation.value(), getUrlPathHelper(), getPathMatcher(), true, true),
new RequestMethodsRequestCondition(annotation.method()),
new ParamsRequestCondition(annotation.params()),
new HeadersRequestCondition(annotation.headers()),
new ConsumesRequestCondition(annotation.consumes(), annotation.headers()),
new ProducesRequestCondition(annotation.produces(), annotation.headers()), null);
}
else {
return null;
}
}
}
}

View File

@ -27,10 +27,11 @@ import java.util.List;
import java.util.Random;
import java.util.concurrent.TimeUnit;
import javax.servlet.http.HttpServletRequest;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.InvalidMediaTypeException;
@ -43,6 +44,9 @@ import org.springframework.util.CollectionUtils;
import org.springframework.util.DigestUtils;
import org.springframework.util.ObjectUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.cors.CorsConfigurationSource;
import org.springframework.web.cors.CorsUtils;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.sockjs.SockJsException;
import org.springframework.web.socket.sockjs.SockJsService;
@ -60,7 +64,7 @@ import org.springframework.web.util.WebUtils;
* @author Sebastien Deleuze
* @since 4.0
*/
public abstract class AbstractSockJsService implements SockJsService {
public abstract class AbstractSockJsService implements SockJsService, CorsConfigurationSource {
private static final Charset UTF8_CHARSET = Charset.forName("UTF-8");
@ -447,16 +451,8 @@ public abstract class AbstractSockJsService implements SockJsService {
protected abstract void handleTransportRequest(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler webSocketHandler, String sessionId, String transport) throws SockJsException;
/**
* Check the {@code Origin} header value and eventually call {@link #addCorsHeaders(ServerHttpRequest, ServerHttpResponse, HttpMethod...)}.
* If the request origin is not allowed, the request is rejected.
* @return false if the request is rejected, else true
* @since 4.1.2
*/
protected boolean checkAndAddCorsHeaders(ServerHttpRequest request, ServerHttpResponse response, HttpMethod... httpMethods) {
HttpHeaders requestHeaders = request.getHeaders();
HttpHeaders responseHeaders = response.getHeaders();
String origin = requestHeaders.getOrigin();
protected boolean checkOrigin(ServerHttpRequest request, ServerHttpResponse response, HttpMethod... httpMethods) throws IOException {
String origin = request.getHeaders().getOrigin();
if (origin == null) {
return true;
@ -468,46 +464,26 @@ public abstract class AbstractSockJsService implements SockJsService {
return false;
}
boolean hasCorsResponseHeaders = false;
try {
// Perhaps a CORS Filter has already added this?
hasCorsResponseHeaders = !CollectionUtils.isEmpty(responseHeaders.get("Access-Control-Allow-Origin"));
}
catch (NullPointerException npe) {
// See SPR-11919 and https://issues.jboss.org/browse/WFLY-3474
}
if (!this.suppressCors && !hasCorsResponseHeaders) {
addCorsHeaders(request, response, httpMethods);
}
return true;
}
protected void addCorsHeaders(ServerHttpRequest request, ServerHttpResponse response, HttpMethod... httpMethods) {
HttpHeaders requestHeaders = request.getHeaders();
HttpHeaders responseHeaders = response.getHeaders();
responseHeaders.add("Access-Control-Allow-Origin", requestHeaders.getFirst("Origin"));
responseHeaders.add("Access-Control-Allow-Credentials", "true");
List<String> accessControllerHeaders = requestHeaders.get("Access-Control-Request-Headers");
if (accessControllerHeaders != null) {
for (String header : accessControllerHeaders) {
responseHeaders.add("Access-Control-Allow-Headers", header);
}
@Override
public CorsConfiguration getCorsConfiguration(HttpServletRequest request) {
if (!this.suppressCors && CorsUtils.isCorsRequest(request)) {
CorsConfiguration config = new CorsConfiguration();
config.addAllowedOrigin("*");
config.addAllowedMethod("*");
config.setAllowCredentials(true);
config.setMaxAge(ONE_YEAR);
config.addAllowedHeader("*");
return config;
}
if (!ObjectUtils.isEmpty(httpMethods)) {
responseHeaders.add("Access-Control-Allow-Methods", StringUtils.arrayToDelimitedString(httpMethods, ", "));
responseHeaders.add("Access-Control-Max-Age", String.valueOf(ONE_YEAR));
}
responseHeaders.add(HttpHeaders.VARY, HttpHeaders.ORIGIN);
return null;
}
protected void addCacheHeaders(ServerHttpResponse response) {
response.getHeaders().setCacheControl("public, max-age=" + ONE_YEAR);
response.getHeaders().setExpires(new Date().getTime() + ONE_YEAR * 1000);
response.getHeaders().add(HttpHeaders.VARY, HttpHeaders.ORIGIN);
}
protected void addNoCacheHeaders(ServerHttpResponse response) {
@ -536,15 +512,15 @@ public abstract class AbstractSockJsService implements SockJsService {
public void handle(ServerHttpRequest request, ServerHttpResponse response) throws IOException {
if (HttpMethod.GET.equals(request.getMethod())) {
addNoCacheHeaders(response);
if (checkAndAddCorsHeaders(request, response)) {
if (checkOrigin(request, response)) {
response.getHeaders().setContentType(new MediaType("application", "json", UTF8_CHARSET));
String content = String.format(INFO_CONTENT, random.nextInt(), isSessionCookieNeeded(), isWebSocketEnabled());
response.getBody().write(content.getBytes());
}
}
else if (HttpMethod.OPTIONS.equals(request.getMethod())) {
if (checkAndAddCorsHeaders(request, response, HttpMethod.OPTIONS,
HttpMethod.GET)) {
if (checkOrigin(request, response)) {
addCacheHeaders(response);
response.setStatusCode(HttpStatus.NO_CONTENT);
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2014 the original author or authors.
* Copyright 2002-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -27,6 +27,8 @@ import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.http.server.ServletServerHttpResponse;
import org.springframework.util.Assert;
import org.springframework.web.HttpRequestHandler;
import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.cors.CorsConfigurationSource;
import org.springframework.web.servlet.HandlerMapping;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.handler.ExceptionWebSocketHandlerDecorator;
@ -39,9 +41,10 @@ import org.springframework.web.socket.sockjs.SockJsService;
* in a Servlet container.
*
* @author Rossen Stoyanchev
* @author Sebastien Deleuze
* @since 4.0
*/
public class SockJsHttpRequestHandler implements HttpRequestHandler {
public class SockJsHttpRequestHandler implements HttpRequestHandler, CorsConfigurationSource {
// No logging: HTTP transports too verbose and we don't know enough to log anything of value
@ -100,4 +103,12 @@ public class SockJsHttpRequestHandler implements HttpRequestHandler {
return ((path.length() > 0) && (path.charAt(0) != '/')) ? "/" + path : path;
}
@Override
public CorsConfiguration getCorsConfiguration(HttpServletRequest request) {
if (sockJsService instanceof CorsConfigurationSource) {
return ((CorsConfigurationSource)sockJsService).getCorsConfiguration(request);
}
return null;
}
}

View File

@ -56,6 +56,7 @@ import org.springframework.web.socket.sockjs.support.AbstractSockJsService;
*
* @author Rossen Stoyanchev
* @author Juergen Hoeller
* @author Sebastien Deleuze
* @since 4.0
*/
public class TransportHandlingSockJsService extends AbstractSockJsService implements SockJsServiceConfig {
@ -208,27 +209,27 @@ public class TransportHandlingSockJsService extends AbstractSockJsService implem
return;
}
HttpMethod supportedMethod = transportType.getHttpMethod();
if (!supportedMethod.equals(request.getMethod())) {
if (HttpMethod.OPTIONS.equals(request.getMethod()) && transportType.supportsCors()) {
if (checkAndAddCorsHeaders(request, response, HttpMethod.OPTIONS, supportedMethod)) {
response.setStatusCode(HttpStatus.NO_CONTENT);
addCacheHeaders(response);
}
}
else if (transportType.supportsCors()) {
sendMethodNotAllowed(response, supportedMethod, HttpMethod.OPTIONS);
}
else {
sendMethodNotAllowed(response, supportedMethod);
}
return;
}
HandshakeInterceptorChain chain = new HandshakeInterceptorChain(this.interceptors, handler);
SockJsException failure = null;
HandshakeInterceptorChain chain = new HandshakeInterceptorChain(this.interceptors, handler);
try {
HttpMethod supportedMethod = transportType.getHttpMethod();
if (!supportedMethod.equals(request.getMethod())) {
if (HttpMethod.OPTIONS.equals(request.getMethod()) && transportType.supportsCors()) {
if (checkOrigin(request, response, HttpMethod.OPTIONS, supportedMethod)) {
response.setStatusCode(HttpStatus.NO_CONTENT);
addCacheHeaders(response);
}
}
else if (transportType.supportsCors()) {
sendMethodNotAllowed(response, supportedMethod, HttpMethod.OPTIONS);
}
else {
sendMethodNotAllowed(response, supportedMethod);
}
return;
}
SockJsSession session = this.sessions.get(sessionId);
if (session == null) {
if (transportHandler instanceof SockJsSessionFactory) {
@ -264,7 +265,7 @@ public class TransportHandlingSockJsService extends AbstractSockJsService implem
}
if (transportType.supportsCors()) {
if (!checkAndAddCorsHeaders(request, response)) {
if (!checkOrigin(request, response)) {
return;
}
}

View File

@ -18,7 +18,6 @@ package org.springframework.web.socket;
import org.junit.Before;
import org.springframework.http.HttpHeaders;
import org.springframework.http.server.ServerHttpAsyncRequestControl;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
@ -55,10 +54,6 @@ public abstract class AbstractHttpRequestTests {
this.servletRequest.setRequestURI(requestUri);
}
protected void setOrigin(String origin) {
this.request.getHeaders().add(HttpHeaders.ORIGIN, origin);
}
protected void resetRequestAndResponse() {
resetRequest();
resetResponse();

View File

@ -26,6 +26,7 @@ import static org.junit.Assert.*;
import org.junit.Test;
import org.mockito.Mockito;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.web.socket.AbstractHttpRequestTests;
import org.springframework.web.socket.WebSocketHandler;
@ -61,7 +62,7 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
public void originValueMatch() throws Exception {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
setOrigin("http://mydomain1.com");
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain1.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain1.com"));
assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
@ -71,7 +72,7 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
public void originValueNoMatch() throws Exception {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
setOrigin("http://mydomain1.com");
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain1.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain2.com"));
assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
@ -81,7 +82,7 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
public void originListMatch() throws Exception {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
setOrigin("http://mydomain2.com");
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain2.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com"));
assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
@ -91,7 +92,7 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
public void originListNoMatch() throws Exception {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
setOrigin("http://mydomain4.com");
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain4.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com"));
assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
@ -101,7 +102,7 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
public void originNoMatchWithNullHostileCollection() throws Exception {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
setOrigin("http://mydomain4.com");
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain4.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
Set<String> allowedOrigins = new ConcurrentSkipListSet<String>();
allowedOrigins.add("http://mydomain1.com");
@ -114,7 +115,7 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
public void originMatchAll() throws Exception {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
setOrigin("http://mydomain1.com");
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain1.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(Arrays.asList("*"));
assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes));
@ -125,7 +126,7 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
public void sameOriginMatch() throws Exception {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
setOrigin("http://mydomain2.com");
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain2.com");
this.servletRequest.setServerName("mydomain2.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList());
assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes));
@ -136,7 +137,7 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
public void sameOriginNoMatch() throws Exception {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
setOrigin("http://mydomain3.com");
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain3.com");
this.servletRequest.setServerName("mydomain2.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList());
assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes));

View File

@ -26,12 +26,14 @@ import static org.junit.Assert.assertEquals;
import org.junit.Before;
import org.junit.Test;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.http.server.ServletServerHttpResponse;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
import org.springframework.web.cors.CorsUtils;
import org.springframework.web.socket.AbstractHttpRequestTests;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.sockjs.SockJsException;
@ -84,10 +86,10 @@ public class SockJsServiceTests extends AbstractHttpRequestTests {
resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.OK);
assertEquals("application/json;charset=UTF-8", this.servletResponse.getContentType());
assertEquals("no-store, no-cache, must-revalidate, max-age=0", this.servletResponse.getHeader("Cache-Control"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertNull(this.servletResponse.getHeader("Vary"));
assertEquals("no-store, no-cache, must-revalidate, max-age=0", this.servletResponse.getHeader(HttpHeaders.CACHE_CONTROL));
assertNull(this.servletResponse.getHeader(CorsUtils.ACCESS_CONTROL_ALLOW_ORIGIN));
assertNull(this.servletResponse.getHeader(CorsUtils.ACCESS_CONTROL_ALLOW_CREDENTIALS));
assertNull(this.servletResponse.getHeader(HttpHeaders.VARY));
String body = this.servletResponse.getContentAsString();
assertEquals("{\"entropy\"", body.substring(0, body.indexOf(':')));
@ -104,56 +106,42 @@ public class SockJsServiceTests extends AbstractHttpRequestTests {
this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com"));
resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.OK);
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertNull(this.servletResponse.getHeader("Vary"));
assertNull(this.servletResponse.getHeader(CorsUtils.ACCESS_CONTROL_ALLOW_ORIGIN));
assertNull(this.servletResponse.getHeader(CorsUtils.ACCESS_CONTROL_ALLOW_CREDENTIALS));
assertNull(this.servletResponse.getHeader(HttpHeaders.VARY));
}
@Test // SPR-12226 and SPR-12660
public void handleInfoGetWithOrigin() throws Exception {
this.servletRequest.setServerName("mydomain2.com");
setOrigin("http://mydomain2.com");
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain2.com");
resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.OK);
assertEquals("application/json;charset=UTF-8", this.servletResponse.getContentType());
assertEquals("no-store, no-cache, must-revalidate, max-age=0", this.servletResponse.getHeader("Cache-Control"));
assertEquals("http://mydomain2.com", this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertEquals("true", this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertEquals("Origin", this.servletResponse.getHeader("Vary"));
assertEquals("no-store, no-cache, must-revalidate, max-age=0", this.servletResponse.getHeader(HttpHeaders.CACHE_CONTROL));
String body = this.servletResponse.getContentAsString();
assertEquals("{\"entropy\"", body.substring(0, body.indexOf(':')));
assertEquals(",\"origins\":[\"*:*\"],\"cookie_needed\":true,\"websocket\":true}",
body.substring(body.indexOf(',')));
assertEquals(",\"origins\":[\"*:*\"],\"cookie_needed\":true,\"websocket\":true}", body.substring(body.indexOf(',')));
this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com"));
resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.FORBIDDEN);
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertNull(this.servletResponse.getHeader("Vary"));
this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com"));
resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.OK);
assertEquals("http://mydomain2.com", this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertEquals("true", this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertEquals("Origin", this.servletResponse.getHeader("Vary"));
this.service.setAllowedOrigins(Arrays.asList("*"));
resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.OK);
assertEquals("http://mydomain2.com", this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertEquals("true", this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertEquals("Origin", this.servletResponse.getHeader("Vary"));
}
@Test // SPR-11443
public void handleInfoGetCorsFilter() throws Exception {
// Simulate scenario where Filter would have already set CORS headers
this.servletResponse.setHeader("Access-Control-Allow-Origin", "foobar:123");
this.servletResponse.setHeader(CorsUtils.ACCESS_CONTROL_ALLOW_ORIGIN, "foobar:123");
handleRequest("GET", "/echo/info", HttpStatus.OK);
assertEquals("foobar:123", this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertEquals("foobar:123", this.servletResponse.getHeader(CorsUtils.ACCESS_CONTROL_ALLOW_ORIGIN));
}
@Test // SPR-11919
@ -161,7 +149,7 @@ public class SockJsServiceTests extends AbstractHttpRequestTests {
public void handleInfoGetWildflyNPE() throws Exception {
HttpServletResponse mockResponse = mock(HttpServletResponse.class);
ServletOutputStream ous = mock(ServletOutputStream.class);
given(mockResponse.getHeaders("Access-Control-Allow-Origin")).willThrow(NullPointerException.class);
given(mockResponse.getHeaders(CorsUtils.ACCESS_CONTROL_ALLOW_ORIGIN)).willThrow(NullPointerException.class);
given(mockResponse.getOutputStream()).willReturn(ous);
this.response = new ServletServerHttpResponse(mockResponse);
@ -172,107 +160,53 @@ public class SockJsServiceTests extends AbstractHttpRequestTests {
@Test // SPR-12660
public void handleInfoOptions() throws Exception {
this.servletRequest.addHeader("Access-Control-Request-Headers", "Last-Modified");
this.servletRequest.addHeader(CorsUtils.ACCESS_CONTROL_REQUEST_HEADERS, "Last-Modified");
resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.NO_CONTENT);
this.response.flush();
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Headers"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Methods"));
assertNull(this.servletResponse.getHeader("Access-Control-Max-Age"));
assertEquals("Origin", this.servletResponse.getHeader("Vary"));
assertNull(this.service.getCorsConfiguration(this.servletRequest));
this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com"));
resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.NO_CONTENT);
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Headers"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Methods"));
assertNull(this.servletResponse.getHeader("Access-Control-Max-Age"));
assertEquals("Origin", this.servletResponse.getHeader("Vary"));
assertNull(this.service.getCorsConfiguration(this.servletRequest));
}
@Test // SPR-12226 and SPR-12660
public void handleInfoOptionsWithOrigin() throws Exception {
this.servletRequest.setServerName("mydomain2.com");
setOrigin("http://mydomain2.com");
this.request.getHeaders().add("Access-Control-Request-Headers", "Last-Modified");
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain2.com");
this.servletRequest.addHeader(CorsUtils.ACCESS_CONTROL_REQUEST_METHOD, "GET");
this.servletRequest.addHeader(CorsUtils.ACCESS_CONTROL_REQUEST_HEADERS, "Last-Modified");
resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.NO_CONTENT);
this.response.flush();
assertEquals("http://mydomain2.com", this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertEquals("true", this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertEquals("Last-Modified", this.servletResponse.getHeader("Access-Control-Allow-Headers"));
assertEquals("OPTIONS, GET", this.servletResponse.getHeader("Access-Control-Allow-Methods"));
assertEquals("31536000", this.servletResponse.getHeader("Access-Control-Max-Age"));
assertEquals("Origin", this.servletResponse.getHeader("Vary"));
assertNotNull(this.service.getCorsConfiguration(this.servletRequest));
this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com"));
resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.FORBIDDEN);
this.response.flush();
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Headers"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Methods"));
assertNull(this.servletResponse.getHeader("Access-Control-Max-Age"));
assertNull(this.servletResponse.getHeader("Vary"));
this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com"));
resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.NO_CONTENT);
this.response.flush();
assertEquals("http://mydomain2.com", this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertEquals("true", this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertEquals("Last-Modified", this.servletResponse.getHeader("Access-Control-Allow-Headers"));
assertEquals("OPTIONS, GET", this.servletResponse.getHeader("Access-Control-Allow-Methods"));
assertEquals("31536000", this.servletResponse.getHeader("Access-Control-Max-Age"));
assertEquals("Origin", this.servletResponse.getHeader("Vary"));
assertNotNull(this.service.getCorsConfiguration(this.servletRequest));
this.service.setAllowedOrigins(Arrays.asList("*"));
resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.NO_CONTENT);
this.response.flush();
assertEquals("http://mydomain2.com", this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertEquals("true", this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertEquals("Last-Modified", this.servletResponse.getHeader("Access-Control-Allow-Headers"));
assertEquals("OPTIONS, GET", this.servletResponse.getHeader("Access-Control-Allow-Methods"));
assertEquals("31536000", this.servletResponse.getHeader("Access-Control-Max-Age"));
assertEquals("Origin", this.servletResponse.getHeader("Vary"));
assertNotNull(this.service.getCorsConfiguration(this.servletRequest));
}
@Test // SPR-12283
public void handleInfoOptionsWithOriginAndCorsHeadersDisabled() throws Exception {
setOrigin("http://mydomain2.com");
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain2.com");
this.service.setAllowedOrigins(Arrays.asList("*"));
this.service.setSuppressCors(true);
this.servletRequest.addHeader("Access-Control-Request-Headers", "Last-Modified");
this.servletRequest.addHeader(CorsUtils.ACCESS_CONTROL_REQUEST_HEADERS, "Last-Modified");
resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.NO_CONTENT);
this.response.flush();
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Headers"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Methods"));
assertNull(this.servletResponse.getHeader("Access-Control-Max-Age"));
assertEquals("Origin", this.servletResponse.getHeader("Vary"));
assertNull(this.service.getCorsConfiguration(this.servletRequest));
this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com"));
resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.FORBIDDEN);
this.response.flush();
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Headers"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Methods"));
assertNull(this.servletResponse.getHeader("Access-Control-Max-Age"));
assertNull(this.servletResponse.getHeader("Vary"));
assertNull(this.service.getCorsConfiguration(this.servletRequest));
this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com"));
resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.NO_CONTENT);
this.response.flush();
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Headers"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Methods"));
assertNull(this.servletResponse.getHeader("Access-Control-Max-Age"));
assertEquals("Origin", this.servletResponse.getHeader("Vary"));
assertNull(this.service.getCorsConfiguration(this.servletRequest));
}
@Test

View File

@ -26,7 +26,9 @@ import org.junit.Test;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.springframework.http.HttpHeaders;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.web.cors.CorsUtils;
import org.springframework.web.socket.AbstractHttpRequestTests;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.handler.TestPrincipal;
@ -163,8 +165,8 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
verify(taskScheduler).scheduleAtFixedRate(any(Runnable.class), eq(service.getDisconnectDelay()));
assertEquals("no-store, no-cache, must-revalidate, max-age=0", this.response.getHeaders().getCacheControl());
assertNull(this.response.getHeaders().getFirst("Access-Control-Allow-Origin"));
assertNull(this.response.getHeaders().getFirst("Access-Control-Allow-Credentials"));
assertNull(this.servletResponse.getHeader(CorsUtils.ACCESS_CONTROL_ALLOW_ORIGIN));
assertNull(this.servletResponse.getHeader(CorsUtils.ACCESS_CONTROL_ALLOW_CREDENTIALS));
}
@Test // SPR-12226
@ -172,12 +174,10 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
String sockJsPath = sessionUrlPrefix + "xhr";
setRequest("POST", sockJsPrefix + sockJsPath);
this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com", "http://mydomain2.com"));
setOrigin("http://mydomain1.com");
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain1.com");
this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertEquals(200, this.servletResponse.getStatus());
assertEquals("http://mydomain1.com", this.response.getHeaders().getFirst("Access-Control-Allow-Origin"));
assertEquals("true", this.response.getHeaders().getFirst("Access-Control-Allow-Credentials"));
}
@Test // SPR-12226
@ -185,12 +185,10 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
String sockJsPath = sessionUrlPrefix + "xhr";
setRequest("POST", sockJsPrefix + sockJsPath);
this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com", "http://mydomain2.com"));
setOrigin("http://mydomain3.com");
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain3.com");
this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertEquals(403, this.servletResponse.getStatus());
assertNull(this.response.getHeaders().getFirst("Access-Control-Allow-Origin"));
assertNull(this.response.getHeaders().getFirst("Access-Control-Allow-Credentials"));
}
@Test
@ -200,9 +198,9 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertEquals(204, this.servletResponse.getStatus());
assertNull(this.response.getHeaders().getFirst("Access-Control-Allow-Origin"));
assertNull(this.response.getHeaders().getFirst("Access-Control-Allow-Credentials"));
assertNull(this.response.getHeaders().getFirst("Access-Control-Allow-Methods"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertNull(this.servletResponse.getHeader("Access-Control-Allow-Methods"));
}
@Test
@ -294,13 +292,13 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain1.com"));
wsService.setHandshakeInterceptors(Arrays.asList(interceptor));
setRequest("GET", sockJsPrefix + sockJsPath);
setOrigin("http://mydomain1.com");
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain1.com");
wsService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertNotEquals(403, this.servletResponse.getStatus());
resetRequestAndResponse();
setRequest("GET", sockJsPrefix + sockJsPath);
setOrigin("http://mydomain2.com");
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain2.com");
wsService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertEquals(403, this.servletResponse.getStatus());
}