Change SockJS and Websocket default allowedOrigins to same origin

This commit adds support for a same origin check that compares
Origin header to Host header. It also changes the default setting
from all origins allowed to only same origin allowed.

Issues: SPR-12697, SPR-12685
This commit is contained in:
Sebastien Deleuze 2015-02-13 16:56:09 +01:00
parent 42af33034d
commit 6062e15572
20 changed files with 363 additions and 123 deletions

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2014 the original author or authors. * Copyright 2002-2015 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -19,6 +19,7 @@ package org.springframework.web.util;
import java.io.File; import java.io.File;
import java.io.FileNotFoundException; import java.io.FileNotFoundException;
import java.util.Enumeration; import java.util.Enumeration;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.StringTokenizer; import java.util.StringTokenizer;
import java.util.TreeMap; import java.util.TreeMap;
@ -32,6 +33,7 @@ import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession; import javax.servlet.http.HttpSession;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap; import org.springframework.util.MultiValueMap;
@ -43,6 +45,7 @@ import org.springframework.util.StringUtils;
* *
* @author Rod Johnson * @author Rod Johnson
* @author Juergen Hoeller * @author Juergen Hoeller
* @author Sebastien Deleuze
*/ */
public abstract class WebUtils { public abstract class WebUtils {
@ -765,4 +768,47 @@ public abstract class WebUtils {
} }
return result; return result;
} }
/**
* Check the given request origin against a list of allowed origins.
* A list containing "*" means that all origins are allowed.
* An empty list means only same origin is allowed.
*
* @return true if the request origin is valid, false otherwise
* @since 4.1.5
* @see <a href="https://tools.ietf.org/html/rfc6454">RFC 6454: The Web Origin Concept</a>
*/
public static boolean isValidOrigin(ServerHttpRequest request, List<String> allowedOrigins) {
Assert.notNull(request, "Request must not be null");
Assert.notNull(allowedOrigins, "Allowed origins must not be null");
String origin = request.getHeaders().getOrigin();
if (origin == null || allowedOrigins.contains("*")) {
return true;
}
else if (allowedOrigins.isEmpty()) {
UriComponents originComponents = UriComponentsBuilder.fromHttpUrl(origin).build();
UriComponents requestComponents = UriComponentsBuilder.fromHttpRequest(request).build();
int originPort = getPort(originComponents);
int requestPort = getPort(requestComponents);
return originComponents.getHost().equals(requestComponents.getHost()) && (originPort == requestPort);
}
else {
return allowedOrigins.contains(origin);
}
}
private static int getPort(UriComponents component) {
int port = component.getPort();
if (port == -1) {
if ("http".equals(component.getScheme())) {
port = 80;
}
else if ("https".equals(component.getScheme())) {
port = 443;
}
}
return port;
}
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2008 the original author or authors. * Copyright 2002-2015 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,12 +16,18 @@
package org.springframework.web.util; package org.springframework.web.util;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.HashMap; import java.util.HashMap;
import java.util.List;
import java.util.Map; import java.util.Map;
import org.junit.Test; import org.junit.Test;
import org.springframework.http.HttpHeaders;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.mock.web.test.MockHttpServletRequest;
import org.springframework.util.MultiValueMap; import org.springframework.util.MultiValueMap;
import static org.junit.Assert.*; import static org.junit.Assert.*;
@ -30,6 +36,7 @@ import static org.junit.Assert.*;
* @author Juergen Hoeller * @author Juergen Hoeller
* @author Arjen Poutsma * @author Arjen Poutsma
* @author Rossen Stoyanchev * @author Rossen Stoyanchev
* @author Sebastien Deleuze
*/ */
public class WebUtilsTests { public class WebUtilsTests {
@ -98,4 +105,57 @@ public class WebUtilsTests {
assertEquals(Arrays.asList("red", "blue", "green"), variables.get("colors")); assertEquals(Arrays.asList("red", "blue", "green"), variables.get("colors"));
} }
@Test
public void isValidOrigin() {
List<String> allowedOrigins = new ArrayList<>();
MockHttpServletRequest servletRequest = new MockHttpServletRequest();
ServerHttpRequest request = new ServletServerHttpRequest(servletRequest);
servletRequest.setServerName("mydomain1.com");
request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain1.com");
assertTrue(WebUtils.isValidOrigin(request, allowedOrigins));
servletRequest.setServerName("mydomain1.com");
request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain1.com:80");
assertTrue(WebUtils.isValidOrigin(request, allowedOrigins));
servletRequest.setServerName("mydomain1.com");
servletRequest.setServerPort(443);
request.getHeaders().set(HttpHeaders.ORIGIN, "https://mydomain1.com");
assertTrue(WebUtils.isValidOrigin(request, allowedOrigins));
servletRequest.setServerName("mydomain1.com");
servletRequest.setServerPort(443);
request.getHeaders().set(HttpHeaders.ORIGIN, "https://mydomain1.com:443");
assertTrue(WebUtils.isValidOrigin(request, allowedOrigins));
servletRequest.setServerName("mydomain1.com");
servletRequest.setServerPort(123);
request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain1.com:123");
assertTrue(WebUtils.isValidOrigin(request, allowedOrigins));
servletRequest.setServerName("mydomain1.com");
request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain2.com");
assertFalse(WebUtils.isValidOrigin(request, allowedOrigins));
servletRequest.setServerName("mydomain1.com");
request.getHeaders().set(HttpHeaders.ORIGIN, "https://mydomain1.com");
assertFalse(WebUtils.isValidOrigin(request, allowedOrigins));
allowedOrigins = Arrays.asList("*");
servletRequest.setServerName("mydomain1.com");
request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain2.com");
assertTrue(WebUtils.isValidOrigin(request, allowedOrigins));
allowedOrigins = Arrays.asList("http://mydomain1.com");
servletRequest.setServerName("mydomain2.com");
request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain1.com");
assertTrue(WebUtils.isValidOrigin(request, allowedOrigins));
allowedOrigins = Arrays.asList("http://mydomain1.com");
servletRequest.setServerName("mydomain2.com");
request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain3.com");
assertFalse(WebUtils.isValidOrigin(request, allowedOrigins));
}
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2014 the original author or authors. * Copyright 2002-2015 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -83,11 +83,7 @@ class HandlersBeanDefinitionParser implements BeanDefinitionParser {
ManagedList<? super Object> interceptors = WebSocketNamespaceUtils.parseBeanSubElements(interceptorsElement, context); ManagedList<? super Object> interceptors = WebSocketNamespaceUtils.parseBeanSubElements(interceptorsElement, context);
String allowedOriginsAttribute = element.getAttribute("allowed-origins"); String allowedOriginsAttribute = element.getAttribute("allowed-origins");
List<String> allowedOrigins = Arrays.asList(StringUtils.tokenizeToStringArray(allowedOriginsAttribute, ",")); List<String> allowedOrigins = Arrays.asList(StringUtils.tokenizeToStringArray(allowedOriginsAttribute, ","));
if (!allowedOrigins.isEmpty()) { interceptors.add(new OriginHandshakeInterceptor(allowedOrigins));
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(allowedOrigins);
interceptors.add(interceptor);
}
strategy = new WebSocketHandlerMappingStrategy(handshakeHandler, interceptors); strategy = new WebSocketHandlerMappingStrategy(handshakeHandler, interceptors);
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2014 the original author or authors. * Copyright 2002-2015 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -288,11 +288,7 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser {
ManagedList<? super Object> interceptors = WebSocketNamespaceUtils.parseBeanSubElements(interceptorsElement, context); ManagedList<? super Object> interceptors = WebSocketNamespaceUtils.parseBeanSubElements(interceptorsElement, context);
String allowedOriginsAttribute = element.getAttribute("allowed-origins"); String allowedOriginsAttribute = element.getAttribute("allowed-origins");
List<String> allowedOrigins = Arrays.asList(StringUtils.tokenizeToStringArray(allowedOriginsAttribute, ",")); List<String> allowedOrigins = Arrays.asList(StringUtils.tokenizeToStringArray(allowedOriginsAttribute, ","));
if (!allowedOrigins.isEmpty()) { interceptors.add(new OriginHandshakeInterceptor(allowedOrigins));
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(allowedOrigins);
interceptors.add(interceptor);
}
ConstructorArgumentValues cavs = new ConstructorArgumentValues(); ConstructorArgumentValues cavs = new ConstructorArgumentValues();
cavs.addIndexedArgumentValue(0, subProtoHandler); cavs.addIndexedArgumentValue(0, subProtoHandler);
if (handshakeHandler != null) { if (handshakeHandler != null) {

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2014 the original author or authors. * Copyright 2002-2015 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -105,12 +105,8 @@ class WebSocketNamespaceUtils {
ManagedList<? super Object> interceptors = WebSocketNamespaceUtils.parseBeanSubElements(interceptorsElement, context); ManagedList<? super Object> interceptors = WebSocketNamespaceUtils.parseBeanSubElements(interceptorsElement, context);
String allowedOriginsAttribute = element.getAttribute("allowed-origins"); String allowedOriginsAttribute = element.getAttribute("allowed-origins");
List<String> allowedOrigins = Arrays.asList(StringUtils.tokenizeToStringArray(allowedOriginsAttribute, ",")); List<String> allowedOrigins = Arrays.asList(StringUtils.tokenizeToStringArray(allowedOriginsAttribute, ","));
if (!allowedOrigins.isEmpty()) { sockJsServiceDef.getPropertyValues().add("allowedOrigins", allowedOrigins);
sockJsServiceDef.getPropertyValues().add("allowedOrigins", allowedOrigins); interceptors.add(new OriginHandshakeInterceptor(allowedOrigins));
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(allowedOrigins);
interceptors.add(interceptor);
}
sockJsServiceDef.getPropertyValues().add("handshakeInterceptors", interceptors); sockJsServiceDef.getPropertyValues().add("handshakeInterceptors", interceptors);
String attrValue = sockJsElement.getAttribute("name"); String attrValue = sockJsElement.getAttribute("name");

View File

@ -88,11 +88,10 @@ public abstract class AbstractWebSocketHandlerRegistration<M> implements WebSock
} }
@Override @Override
public WebSocketHandlerRegistration setAllowedOrigins(String... origins) { public WebSocketHandlerRegistration setAllowedOrigins(String... allowedOrigins) {
Assert.notEmpty(origins, "No allowed origin specified");
this.allowedOrigins.clear(); this.allowedOrigins.clear();
if (!ObjectUtils.isEmpty(origins)) { if (!ObjectUtils.isEmpty(allowedOrigins)) {
this.allowedOrigins.addAll(Arrays.asList(origins)); this.allowedOrigins.addAll(Arrays.asList(allowedOrigins));
} }
return this; return this;
} }
@ -117,11 +116,7 @@ public abstract class AbstractWebSocketHandlerRegistration<M> implements WebSock
protected HandshakeInterceptor[] getInterceptors() { protected HandshakeInterceptor[] getInterceptors() {
List<HandshakeInterceptor> interceptors = new ArrayList<HandshakeInterceptor>(); List<HandshakeInterceptor> interceptors = new ArrayList<HandshakeInterceptor>();
interceptors.addAll(this.interceptors); interceptors.addAll(this.interceptors);
if (!this.allowedOrigins.isEmpty()) { interceptors.add(new OriginHandshakeInterceptor(this.allowedOrigins));
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(this.allowedOrigins);
interceptors.add(interceptor);
}
return interceptors.toArray(new HandshakeInterceptor[interceptors.size()]); return interceptors.toArray(new HandshakeInterceptor[interceptors.size()]);
} }

View File

@ -206,6 +206,17 @@ public class SockJsServiceRegistration {
return this; return this;
} }
/**
* @since 4.1.2
*/
protected SockJsServiceRegistration setAllowedOrigins(String... allowedOrigins) {
this.allowedOrigins.clear();
if (!ObjectUtils.isEmpty(allowedOrigins)) {
this.allowedOrigins.addAll(Arrays.asList(allowedOrigins));
}
return this;
}
/** /**
* This option can be used to disable automatic addition of CORS headers for * This option can be used to disable automatic addition of CORS headers for
* SockJS requests. * SockJS requests.
@ -229,17 +240,6 @@ public class SockJsServiceRegistration {
return this; return this;
} }
/**
* @since 4.1.2
*/
protected SockJsServiceRegistration setAllowedOrigins(String... origins) {
this.allowedOrigins.clear();
if (!ObjectUtils.isEmpty(origins)) {
this.allowedOrigins.addAll(Arrays.asList(origins));
}
return this;
}
protected SockJsService getSockJsService() { protected SockJsService getSockJsService() {
TransportHandlingSockJsService service = createSockJsService(); TransportHandlingSockJsService service = createSockJsService();
service.setHandshakeInterceptors(this.interceptors); service.setHandshakeInterceptors(this.interceptors);
@ -264,12 +264,12 @@ public class SockJsServiceRegistration {
if (this.webSocketEnabled != null) { if (this.webSocketEnabled != null) {
service.setWebSocketEnabled(this.webSocketEnabled); service.setWebSocketEnabled(this.webSocketEnabled);
} }
if (this.allowedOrigins != null) {
service.setAllowedOrigins(this.allowedOrigins);
}
if (this.suppressCors != null) { if (this.suppressCors != null) {
service.setSuppressCors(this.suppressCors); service.setSuppressCors(this.suppressCors);
} }
if (!this.allowedOrigins.isEmpty()) {
service.setAllowedOrigins(this.allowedOrigins);
}
if (this.messageCodec != null) { if (this.messageCodec != null) {
service.setMessageCodec(this.messageCodec); service.setMessageCodec(this.messageCodec);
} }

View File

@ -52,8 +52,8 @@ public interface StompWebSocketEndpointRegistration {
* As a consequence, IE 6 to 9 are not supported when origins are restricted. * As a consequence, IE 6 to 9 are not supported when origins are restricted.
* *
* <p>Each provided allowed origin must start by "http://", "https://" or be "*" * <p>Each provided allowed origin must start by "http://", "https://" or be "*"
* (means that all origins are allowed). Empty allowed origin list is not supported. * (means that all origins are allowed). By default, only same origin requests are
* By default, all origins are allowed. * allowed (empty list).
* *
* @since 4.1.2 * @since 4.1.2
* @see <a href="https://tools.ietf.org/html/rfc6454">RFC 6454: The Web Origin Concept</a> * @see <a href="https://tools.ietf.org/html/rfc6454">RFC 6454: The Web Origin Concept</a>

View File

@ -85,10 +85,11 @@ public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketE
} }
@Override @Override
public StompWebSocketEndpointRegistration setAllowedOrigins(String... origins) { public StompWebSocketEndpointRegistration setAllowedOrigins(String... allowedOrigins) {
Assert.notEmpty(origins, "No allowed origin specified");
this.allowedOrigins.clear(); this.allowedOrigins.clear();
this.allowedOrigins.addAll(Arrays.asList(origins)); if (!ObjectUtils.isEmpty(allowedOrigins)) {
this.allowedOrigins.addAll(Arrays.asList(allowedOrigins));
}
return this; return this;
} }
@ -112,11 +113,7 @@ public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketE
protected HandshakeInterceptor[] getInterceptors() { protected HandshakeInterceptor[] getInterceptors() {
List<HandshakeInterceptor> interceptors = new ArrayList<HandshakeInterceptor>(); List<HandshakeInterceptor> interceptors = new ArrayList<HandshakeInterceptor>();
interceptors.addAll(this.interceptors); interceptors.addAll(this.interceptors);
if (!this.allowedOrigins.isEmpty()) { interceptors.add(new OriginHandshakeInterceptor(this.allowedOrigins));
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(this.allowedOrigins);
interceptors.add(interceptor);
}
return interceptors.toArray(new HandshakeInterceptor[interceptors.size()]); return interceptors.toArray(new HandshakeInterceptor[interceptors.size()]);
} }

View File

@ -54,8 +54,8 @@ public interface WebSocketHandlerRegistration {
* As a consequence, IE 6 to 9 are not supported when origins are restricted. * As a consequence, IE 6 to 9 are not supported when origins are restricted.
* *
* <p>Each provided allowed origin must start by "http://", "https://" or be "*" * <p>Each provided allowed origin must start by "http://", "https://" or be "*"
* (means that all origins are allowed). Empty allowed origin list is not supported. * (means that all origins are allowed). By default, only same origin requests are
* By default, all origins are allowed. * allowed (empty list).
* *
* @since 4.1.2 * @since 4.1.2
* @see <a href="https://tools.ietf.org/html/rfc6454">RFC 6454: The Web Origin Concept</a> * @see <a href="https://tools.ietf.org/html/rfc6454">RFC 6454: The Web Origin Concept</a>

View File

@ -31,6 +31,7 @@ import org.springframework.http.server.ServerHttpResponse;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor; import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.util.WebUtils;
/** /**
* An interceptor to check request {@code Origin} header value against a collection of * An interceptor to check request {@code Origin} header value against a collection of
@ -47,12 +48,22 @@ public class OriginHandshakeInterceptor implements HandshakeInterceptor {
/** /**
* Default constructor with no origin allowed. * Default constructor with only same origin requests allowed.
*/ */
public OriginHandshakeInterceptor() { public OriginHandshakeInterceptor() {
this.allowedOrigins = new ArrayList<String>(); this.allowedOrigins = new ArrayList<String>();
} }
/**
* Constructor using the specified allowed origin values.
*
* @see #setAllowedOrigins(Collection)
*/
public OriginHandshakeInterceptor(Collection<String> allowedOrigins) {
this();
setAllowedOrigins(allowedOrigins);
}
/** /**
* Configure allowed {@code Origin} header values. This check is mostly designed for * Configure allowed {@code Origin} header values. This check is mostly designed for
* browser clients. There is nothing preventing other types of client to modify the * browser clients. There is nothing preventing other types of client to modify the
@ -85,7 +96,7 @@ public class OriginHandshakeInterceptor implements HandshakeInterceptor {
@Override @Override
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception { WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
if (!isValidOrigin(request)) { if (!WebUtils.isValidOrigin(request, this.allowedOrigins)) {
response.setStatusCode(HttpStatus.FORBIDDEN); response.setStatusCode(HttpStatus.FORBIDDEN);
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("Handshake request rejected, Origin header value " logger.debug("Handshake request rejected, Origin header value "
@ -96,17 +107,6 @@ public class OriginHandshakeInterceptor implements HandshakeInterceptor {
return true; return true;
} }
protected boolean isValidOrigin(ServerHttpRequest request) {
String origin = request.getHeaders().getOrigin();
if (origin == null) {
return true;
}
if (this.allowedOrigins.contains("*")) {
return true;
}
return this.allowedOrigins.contains(origin);
}
@Override @Override
public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler wsHandler, Exception exception) { WebSocketHandler wsHandler, Exception exception) {

View File

@ -46,12 +46,16 @@ import org.springframework.util.StringUtils;
import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.sockjs.SockJsException; import org.springframework.web.socket.sockjs.SockJsException;
import org.springframework.web.socket.sockjs.SockJsService; import org.springframework.web.socket.sockjs.SockJsService;
import org.springframework.web.util.WebUtils;
/** /**
* An abstract base class for {@link SockJsService} implementations that provides SockJS * An abstract base class for {@link SockJsService} implementations that provides SockJS
* path resolution and handling of static SockJS requests (e.g. "/info", "/iframe.html", * path resolution and handling of static SockJS requests (e.g. "/info", "/iframe.html",
* etc). Sub-classes must handle session URLs (i.e. transport-specific requests). * etc). Sub-classes must handle session URLs (i.e. transport-specific requests).
* *
* By default, only same origin requests are allowed. Use {@link #setAllowedOrigins(List)}
* to specify a list of allowed origins (a list containing "*" will allow all origins).
*
* @author Rossen Stoyanchev * @author Rossen Stoyanchev
* @author Sebastien Deleuze * @author Sebastien Deleuze
* @since 4.0 * @since 4.0
@ -64,6 +68,8 @@ public abstract class AbstractSockJsService implements SockJsService {
private static final Random random = new Random(); private static final Random random = new Random();
private static final String XFRAME_OPTIONS_HEADER = "X-Frame-Options";
protected final Log logger = LogFactory.getLog(getClass()); protected final Log logger = LogFactory.getLog(getClass());
@ -85,7 +91,7 @@ public abstract class AbstractSockJsService implements SockJsService {
private boolean webSocketEnabled = true; private boolean webSocketEnabled = true;
private final List<String> allowedOrigins = new ArrayList<String>(Arrays.asList("*")); private final List<String> allowedOrigins = new ArrayList<String>();
private boolean suppressCors = false; private boolean suppressCors = false;
@ -275,15 +281,14 @@ public abstract class AbstractSockJsService implements SockJsService {
* As a consequence, IE 6 to 9 are not supported when origins are restricted. * As a consequence, IE 6 to 9 are not supported when origins are restricted.
* *
* <p>Each provided allowed origin must start by "http://", "https://" or be "*" * <p>Each provided allowed origin must start by "http://", "https://" or be "*"
* (means that all origins are allowed). Empty allowed origin list is not supported. * (means that all origins are allowed).
* By default, all origins are allowed.
* *
* @since 4.1.2 * @since 4.1.2
* @see <a href="https://tools.ietf.org/html/rfc6454">RFC 6454: The Web Origin Concept</a> * @see <a href="https://tools.ietf.org/html/rfc6454">RFC 6454: The Web Origin Concept</a>
* @see <a href="https://github.com/sockjs/sockjs-client#supported-transports-by-browser-html-served-from-http-or-https">SockJS supported transports by browser</a> * @see <a href="https://github.com/sockjs/sockjs-client#supported-transports-by-browser-html-served-from-http-or-https">SockJS supported transports by browser</a>
*/ */
public void setAllowedOrigins(List<String> allowedOrigins) { public void setAllowedOrigins(List<String> allowedOrigins) {
Assert.notEmpty(allowedOrigins, "Allowed origin List must not be empty"); Assert.notNull(allowedOrigins, "Allowed origin List must not be null");
for (String allowedOrigin : allowedOrigins) { for (String allowedOrigin : allowedOrigins) {
Assert.isTrue( Assert.isTrue(
allowedOrigin.equals("*") || allowedOrigin.startsWith("http://") || allowedOrigin.equals("*") || allowedOrigin.startsWith("http://") ||
@ -360,6 +365,9 @@ public abstract class AbstractSockJsService implements SockJsService {
response.setStatusCode(HttpStatus.NOT_FOUND); response.setStatusCode(HttpStatus.NOT_FOUND);
return; return;
} }
if (this.allowedOrigins.isEmpty()) {
response.getHeaders().add(XFRAME_OPTIONS_HEADER, "SAMEORIGIN");
}
logger.debug(requestInfo); logger.debug(requestInfo);
this.iframeHandler.handle(request, response); this.iframeHandler.handle(request, response);
} }
@ -438,13 +446,12 @@ public abstract class AbstractSockJsService implements SockJsService {
HttpHeaders requestHeaders = request.getHeaders(); HttpHeaders requestHeaders = request.getHeaders();
HttpHeaders responseHeaders = response.getHeaders(); HttpHeaders responseHeaders = response.getHeaders();
String origin = requestHeaders.getOrigin(); String origin = requestHeaders.getOrigin();
String host = requestHeaders.getFirst(HttpHeaders.HOST);
if (origin == null) { if (origin == null) {
return true; return true;
} }
if (!this.allowedOrigins.contains("*") && !this.allowedOrigins.contains(origin)) { if (!WebUtils.isValidOrigin(request, this.allowedOrigins)) {
logger.debug("Request rejected, Origin header value " + origin + " not allowed"); logger.debug("Request rejected, Origin header value " + origin + " not allowed");
response.setStatusCode(HttpStatus.FORBIDDEN); response.setStatusCode(HttpStatus.FORBIDDEN);
return false; return false;

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2014 the original author or authors. * Copyright 2002-2015 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -45,9 +45,9 @@ public enum TransportType {
XHR_STREAMING("xhr_streaming", HttpMethod.POST, "cors", "jsessionid", "no_cache"), XHR_STREAMING("xhr_streaming", HttpMethod.POST, "cors", "jsessionid", "no_cache"),
EVENT_SOURCE("eventsource", HttpMethod.GET, "jsessionid", "no_cache"), EVENT_SOURCE("eventsource", HttpMethod.GET, "origin", "jsessionid", "no_cache"),
HTML_FILE("htmlfile", HttpMethod.GET, "jsessionid", "no_cache"); HTML_FILE("htmlfile", HttpMethod.GET, "cors", "jsessionid", "no_cache");
private final String value; private final String value;

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2014 the original author or authors. * Copyright 2002-2015 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -106,7 +106,8 @@ public class HandlersBeanDefinitionParserTests {
HandshakeHandler handshakeHandler = handler.getHandshakeHandler(); HandshakeHandler handshakeHandler = handler.getHandshakeHandler();
assertNotNull(handshakeHandler); assertNotNull(handshakeHandler);
assertTrue(handshakeHandler instanceof DefaultHandshakeHandler); assertTrue(handshakeHandler instanceof DefaultHandshakeHandler);
assertTrue(handler.getHandshakeInterceptors().isEmpty()); assertFalse(handler.getHandshakeInterceptors().isEmpty());
assertTrue(handler.getHandshakeInterceptors().get(0) instanceof OriginHandshakeInterceptor);
} }
else { else {
assertThat(shm.getUrlMap().keySet(), contains("/test")); assertThat(shm.getUrlMap().keySet(), contains("/test"));
@ -116,7 +117,8 @@ public class HandlersBeanDefinitionParserTests {
HandshakeHandler handshakeHandler = handler.getHandshakeHandler(); HandshakeHandler handshakeHandler = handler.getHandshakeHandler();
assertNotNull(handshakeHandler); assertNotNull(handshakeHandler);
assertTrue(handshakeHandler instanceof DefaultHandshakeHandler); assertTrue(handshakeHandler instanceof DefaultHandshakeHandler);
assertTrue(handler.getHandshakeInterceptors().isEmpty()); assertFalse(handler.getHandshakeInterceptors().isEmpty());
assertTrue(handler.getHandshakeInterceptors().get(0) instanceof OriginHandshakeInterceptor);
} }
} }
} }
@ -196,7 +198,7 @@ public class HandlersBeanDefinitionParserTests {
assertEquals(TestHandshakeHandler.class, handler.getHandshakeHandler().getClass()); assertEquals(TestHandshakeHandler.class, handler.getHandshakeHandler().getClass());
List<HandshakeInterceptor> interceptors = defaultSockJsService.getHandshakeInterceptors(); List<HandshakeInterceptor> interceptors = defaultSockJsService.getHandshakeInterceptors();
assertThat(interceptors, contains(instanceOf(FooTestInterceptor.class), instanceOf(BarTestInterceptor.class))); assertThat(interceptors, contains(instanceOf(FooTestInterceptor.class), instanceOf(BarTestInterceptor.class), instanceOf(OriginHandshakeInterceptor.class)));
} }
@Test @Test

View File

@ -71,7 +71,7 @@ public class WebMvcStompWebSocketEndpointRegistrationTests {
Map.Entry<HttpRequestHandler, List<String>> entry = mappings.entrySet().iterator().next(); Map.Entry<HttpRequestHandler, List<String>> entry = mappings.entrySet().iterator().next();
assertNotNull(((WebSocketHttpRequestHandler) entry.getKey()).getWebSocketHandler()); assertNotNull(((WebSocketHttpRequestHandler) entry.getKey()).getWebSocketHandler());
assertTrue(((WebSocketHttpRequestHandler) entry.getKey()).getHandshakeInterceptors().isEmpty()); assertEquals(1, ((WebSocketHttpRequestHandler) entry.getKey()).getHandshakeInterceptors().size());
assertEquals(Arrays.asList("/foo"), entry.getValue()); assertEquals(Arrays.asList("/foo"), entry.getValue());
} }
@ -80,7 +80,7 @@ public class WebMvcStompWebSocketEndpointRegistrationTests {
WebMvcStompWebSocketEndpointRegistration registration = WebMvcStompWebSocketEndpointRegistration registration =
new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler); new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
registration.setAllowedOrigins("http://mydomain.com"); registration.setAllowedOrigins();
MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings(); MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
assertEquals(1, mappings.size()); assertEquals(1, mappings.size());
@ -90,10 +90,18 @@ public class WebMvcStompWebSocketEndpointRegistrationTests {
assertEquals(OriginHandshakeInterceptor.class, requestHandler.getHandshakeInterceptors().get(0).getClass()); assertEquals(OriginHandshakeInterceptor.class, requestHandler.getHandshakeInterceptors().get(0).getClass());
} }
@Test(expected = IllegalArgumentException.class) @Test
public void noAllowedOrigin() { public void sameOrigin() {
WebMvcStompWebSocketEndpointRegistration registration = new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler); WebMvcStompWebSocketEndpointRegistration registration = new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
registration.setAllowedOrigins(); registration.setAllowedOrigins();
MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
assertEquals(1, mappings.size());
WebSocketHttpRequestHandler requestHandler = (WebSocketHttpRequestHandler)mappings.entrySet().iterator().next().getKey();
assertNotNull(requestHandler.getWebSocketHandler());
assertEquals(1, requestHandler.getHandshakeInterceptors().size());
assertEquals(OriginHandshakeInterceptor.class, requestHandler.getHandshakeInterceptors().get(0).getClass());
} }
@Test @Test
@ -158,7 +166,9 @@ public class WebMvcStompWebSocketEndpointRegistrationTests {
WebSocketHttpRequestHandler requestHandler = (WebSocketHttpRequestHandler) entry.getKey(); WebSocketHttpRequestHandler requestHandler = (WebSocketHttpRequestHandler) entry.getKey();
assertNotNull(requestHandler.getWebSocketHandler()); assertNotNull(requestHandler.getWebSocketHandler());
assertSame(handshakeHandler, requestHandler.getHandshakeHandler()); assertSame(handshakeHandler, requestHandler.getHandshakeHandler());
assertEquals(Arrays.asList(interceptor), requestHandler.getHandshakeInterceptors()); assertEquals(2, requestHandler.getHandshakeInterceptors().size());
assertEquals(interceptor, requestHandler.getHandshakeInterceptors().get(0));
assertEquals(OriginHandshakeInterceptor.class, requestHandler.getHandshakeInterceptors().get(1).getClass());
} }
@Test @Test
@ -210,7 +220,9 @@ public class WebMvcStompWebSocketEndpointRegistrationTests {
Map<TransportType, TransportHandler> handlers = sockJsService.getTransportHandlers(); Map<TransportType, TransportHandler> handlers = sockJsService.getTransportHandlers();
WebSocketTransportHandler transportHandler = (WebSocketTransportHandler) handlers.get(TransportType.WEBSOCKET); WebSocketTransportHandler transportHandler = (WebSocketTransportHandler) handlers.get(TransportType.WEBSOCKET);
assertSame(handshakeHandler, transportHandler.getHandshakeHandler()); assertSame(handshakeHandler, transportHandler.getHandshakeHandler());
assertEquals(Arrays.asList(interceptor), sockJsService.getHandshakeInterceptors()); assertEquals(2, sockJsService.getHandshakeInterceptors().size());
assertEquals(interceptor, sockJsService.getHandshakeInterceptors().get(0));
assertEquals(OriginHandshakeInterceptor.class, sockJsService.getHandshakeInterceptors().get(1).getClass());
} }
@Test @Test

View File

@ -69,12 +69,14 @@ public class WebSocketHandlerRegistrationTests {
Mapping m1 = mappings.get(0); Mapping m1 = mappings.get(0);
assertEquals(handler, m1.webSocketHandler); assertEquals(handler, m1.webSocketHandler);
assertEquals("/foo", m1.path); assertEquals("/foo", m1.path);
assertEquals(0, m1.interceptors.length); assertEquals(1, m1.interceptors.length);
assertEquals(OriginHandshakeInterceptor.class, m1.interceptors[0].getClass());
Mapping m2 = mappings.get(1); Mapping m2 = mappings.get(1);
assertEquals(handler, m2.webSocketHandler); assertEquals(handler, m2.webSocketHandler);
assertEquals("/bar", m2.path); assertEquals("/bar", m2.path);
assertEquals(0, m2.interceptors.length); assertEquals(1, m2.interceptors.length);
assertEquals(OriginHandshakeInterceptor.class, m2.interceptors[0].getClass());
} }
@Test @Test
@ -90,12 +92,27 @@ public class WebSocketHandlerRegistrationTests {
Mapping mapping = mappings.get(0); Mapping mapping = mappings.get(0);
assertEquals(handler, mapping.webSocketHandler); assertEquals(handler, mapping.webSocketHandler);
assertEquals("/foo", mapping.path); assertEquals("/foo", mapping.path);
assertArrayEquals(new HandshakeInterceptor[] {interceptor}, mapping.interceptors); assertEquals(2, mapping.interceptors.length);
assertEquals(interceptor, mapping.interceptors[0]);
assertEquals(OriginHandshakeInterceptor.class, mapping.interceptors[1].getClass());
} }
@Test(expected = IllegalArgumentException.class) @Test
public void noAllowedOrigin() { public void emptyAllowedOrigin() {
this.registration.addHandler(Mockito.mock(WebSocketHandler.class), "/foo").setAllowedOrigins(); WebSocketHandler handler = new TextWebSocketHandler();
HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor();
this.registration.addHandler(handler, "/foo").addInterceptors(interceptor).setAllowedOrigins();
List<Mapping> mappings = this.registration.getMappings();
assertEquals(1, mappings.size());
Mapping mapping = mappings.get(0);
assertEquals(handler, mapping.webSocketHandler);
assertEquals("/foo", mapping.path);
assertEquals(2, mapping.interceptors.length);
assertEquals(interceptor, mapping.interceptors[0]);
assertEquals(OriginHandshakeInterceptor.class, mapping.interceptors[1].getClass());
} }
@Test @Test

View File

@ -39,20 +39,22 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
@Test(expected = IllegalArgumentException.class) @Test(expected = IllegalArgumentException.class)
public void nullAllowedOriginList() { public void nullAllowedOriginList() {
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(); new OriginHandshakeInterceptor(null);
interceptor.setAllowedOrigins(null);
} }
@Test(expected = IllegalArgumentException.class) @Test(expected = IllegalArgumentException.class)
public void invalidAllowedOrigin() { public void invalidAllowedOrigin() {
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(); new OriginHandshakeInterceptor(Arrays.asList("domain.com"));
interceptor.setAllowedOrigins(Arrays.asList("domain.com")); }
@Test
public void emtpyAllowedOriginList() {
new OriginHandshakeInterceptor(Arrays.asList());
} }
@Test @Test
public void validAllowedOrigins() { public void validAllowedOrigins() {
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(); new OriginHandshakeInterceptor(Arrays.asList("http://domain.com", "https://domain.com", "*"));
interceptor.setAllowedOrigins(Arrays.asList("http://domain.com", "https://domain.com", "*"));
} }
@Test @Test
@ -60,8 +62,7 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
Map<String, Object> attributes = new HashMap<String, Object>(); Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
setOrigin("http://mydomain1.com"); setOrigin("http://mydomain1.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(); OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain1.com"));
interceptor.setAllowedOrigins(Arrays.asList("http://mydomain1.com"));
assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes)); assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value()); assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
} }
@ -71,8 +72,7 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
Map<String, Object> attributes = new HashMap<String, Object>(); Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
setOrigin("http://mydomain1.com"); setOrigin("http://mydomain1.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(); OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain2.com"));
interceptor.setAllowedOrigins(Arrays.asList("http://mydomain2.com"));
assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes)); assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value()); assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
} }
@ -82,8 +82,7 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
Map<String, Object> attributes = new HashMap<String, Object>(); Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
setOrigin("http://mydomain2.com"); setOrigin("http://mydomain2.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(); OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com"));
interceptor.setAllowedOrigins(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com"));
assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes)); assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value()); assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
} }
@ -93,8 +92,7 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
Map<String, Object> attributes = new HashMap<String, Object>(); Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
setOrigin("http://mydomain4.com"); setOrigin("http://mydomain4.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(); OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com"));
interceptor.setAllowedOrigins(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com"));
assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes)); assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value()); assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
} }
@ -123,4 +121,26 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value()); assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
} }
@Test
public void sameOriginMatch() throws Exception {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
setOrigin("http://mydomain2.com");
this.servletRequest.setServerName("mydomain2.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList());
assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
}
@Test
public void sameOriginNoMatch() throws Exception {
Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
setOrigin("http://mydomain3.com");
this.servletRequest.setServerName("mydomain2.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList());
assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
}
} }

View File

@ -110,6 +110,7 @@ public class SockJsServiceTests extends AbstractHttpRequestTests {
@Test // SPR-12226 and SPR-12660 @Test // SPR-12226 and SPR-12660
public void handleInfoGetWithOrigin() throws Exception { public void handleInfoGetWithOrigin() throws Exception {
this.servletRequest.setServerName("mydomain2.com");
setOrigin("http://mydomain2.com"); setOrigin("http://mydomain2.com");
resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.OK); resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.OK);
@ -135,6 +136,12 @@ public class SockJsServiceTests extends AbstractHttpRequestTests {
assertEquals("http://mydomain2.com", this.servletResponse.getHeader("Access-Control-Allow-Origin")); assertEquals("http://mydomain2.com", this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertEquals("true", this.servletResponse.getHeader("Access-Control-Allow-Credentials")); assertEquals("true", this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertEquals("Origin", this.servletResponse.getHeader("Vary")); assertEquals("Origin", this.servletResponse.getHeader("Vary"));
this.service.setAllowedOrigins(Arrays.asList("*"));
resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.OK);
assertEquals("http://mydomain2.com", this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertEquals("true", this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertEquals("Origin", this.servletResponse.getHeader("Vary"));
} }
@Test // SPR-11443 @Test // SPR-11443
@ -186,6 +193,7 @@ public class SockJsServiceTests extends AbstractHttpRequestTests {
@Test // SPR-12226 and SPR-12660 @Test // SPR-12226 and SPR-12660
public void handleInfoOptionsWithOrigin() throws Exception { public void handleInfoOptionsWithOrigin() throws Exception {
this.servletRequest.setServerName("mydomain2.com");
setOrigin("http://mydomain2.com"); setOrigin("http://mydomain2.com");
this.request.getHeaders().add("Access-Control-Request-Headers", "Last-Modified"); this.request.getHeaders().add("Access-Control-Request-Headers", "Last-Modified");
resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.NO_CONTENT); resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.NO_CONTENT);
@ -216,6 +224,16 @@ public class SockJsServiceTests extends AbstractHttpRequestTests {
assertEquals("OPTIONS, GET", this.servletResponse.getHeader("Access-Control-Allow-Methods")); assertEquals("OPTIONS, GET", this.servletResponse.getHeader("Access-Control-Allow-Methods"));
assertEquals("31536000", this.servletResponse.getHeader("Access-Control-Max-Age")); assertEquals("31536000", this.servletResponse.getHeader("Access-Control-Max-Age"));
assertEquals("Origin", this.servletResponse.getHeader("Vary")); assertEquals("Origin", this.servletResponse.getHeader("Vary"));
this.service.setAllowedOrigins(Arrays.asList("*"));
resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.NO_CONTENT);
this.response.flush();
assertEquals("http://mydomain2.com", this.servletResponse.getHeader("Access-Control-Allow-Origin"));
assertEquals("true", this.servletResponse.getHeader("Access-Control-Allow-Credentials"));
assertEquals("Last-Modified", this.servletResponse.getHeader("Access-Control-Allow-Headers"));
assertEquals("OPTIONS, GET", this.servletResponse.getHeader("Access-Control-Allow-Methods"));
assertEquals("31536000", this.servletResponse.getHeader("Access-Control-Max-Age"));
assertEquals("Origin", this.servletResponse.getHeader("Vary"));
} }
@Test // SPR-12283 @Test // SPR-12283

View File

@ -122,19 +122,15 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
assertSame(xhrHandler, handlers.get(xhrHandler.getTransportType())); assertSame(xhrHandler, handlers.get(xhrHandler.getTransportType()));
} }
@Test
public void defaultAllowedOrigin() {
assertThat(this.service.getAllowedOrigins(), Matchers.contains("*"));
}
@Test(expected = IllegalArgumentException.class) @Test(expected = IllegalArgumentException.class)
public void nullAllowedOriginList() { public void nullAllowedOriginList() {
this.service.setAllowedOrigins(null); this.service.setAllowedOrigins(null);
} }
@Test(expected = IllegalArgumentException.class) @Test
public void emptyAllowedOriginList() { public void emptyAllowedOriginList() {
this.service.setAllowedOrigins(Arrays.asList()); this.service.setAllowedOrigins(Arrays.asList());
assertThat(this.service.getAllowedOrigins(), Matchers.empty());
} }
@Test(expected = IllegalArgumentException.class) @Test(expected = IllegalArgumentException.class)
@ -271,13 +267,19 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
String sockJsPath = sessionUrlPrefix+ "jsonp"; String sockJsPath = sessionUrlPrefix+ "jsonp";
setRequest("GET", sockJsPrefix + sockJsPath); setRequest("GET", sockJsPrefix + sockJsPath);
jsonpService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler); jsonpService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertNotEquals(404, this.servletResponse.getStatus()); assertEquals(404, this.servletResponse.getStatus());
resetRequestAndResponse(); resetRequestAndResponse();
jsonpService.setAllowedOrigins(Arrays.asList("http://mydomain1.com")); jsonpService.setAllowedOrigins(Arrays.asList("http://mydomain1.com"));
setRequest("GET", sockJsPrefix + sockJsPath); setRequest("GET", sockJsPrefix + sockJsPath);
jsonpService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler); jsonpService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertEquals(404, this.servletResponse.getStatus()); assertEquals(404, this.servletResponse.getStatus());
resetRequestAndResponse();
jsonpService.setAllowedOrigins(Arrays.asList("*"));
setRequest("GET", sockJsPrefix + sockJsPath);
jsonpService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertNotEquals(404, this.servletResponse.getStatus());
} }
@Test @Test
@ -289,8 +291,7 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
assertNotEquals(403, this.servletResponse.getStatus()); assertNotEquals(403, this.servletResponse.getStatus());
resetRequestAndResponse(); resetRequestAndResponse();
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(); OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain1.com"));
interceptor.setAllowedOrigins(Arrays.asList("http://mydomain1.com"));
wsService.setHandshakeInterceptors(Arrays.asList(interceptor)); wsService.setHandshakeInterceptors(Arrays.asList(interceptor));
setRequest("GET", sockJsPrefix + sockJsPath); setRequest("GET", sockJsPrefix + sockJsPath);
setOrigin("http://mydomain1.com"); setOrigin("http://mydomain1.com");
@ -310,13 +311,21 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
setRequest("GET", sockJsPrefix + sockJsPath); setRequest("GET", sockJsPrefix + sockJsPath);
this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler); this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertNotEquals(404, this.servletResponse.getStatus()); assertNotEquals(404, this.servletResponse.getStatus());
assertNull(this.servletResponse.getHeader("X-Frame-Options")); assertEquals("SAMEORIGIN", this.servletResponse.getHeader("X-Frame-Options"));
resetRequestAndResponse(); resetRequestAndResponse();
setRequest("GET", sockJsPrefix + sockJsPath); setRequest("GET", sockJsPrefix + sockJsPath);
this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com")); this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com"));
this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler); this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertEquals(404, this.servletResponse.getStatus()); assertEquals(404, this.servletResponse.getStatus());
assertNull(this.servletResponse.getHeader("X-Frame-Options"));
resetRequestAndResponse();
setRequest("GET", sockJsPrefix + sockJsPath);
this.service.setAllowedOrigins(Arrays.asList("*"));
this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertNotEquals(404, this.servletResponse.getStatus());
assertNull(this.servletResponse.getHeader("X-Frame-Options"));
} }

View File

@ -39483,7 +39483,76 @@ or WebSocket XML namespace:
</beans> </beans>
---- ----
[[websocket-server-allowed-origins]]
==== Configuring allowed origins
As of Spring Framework 4.1.5, Websocket and SockJS default behavior is to accept only same
origin requests. It is also possible to allow all or a specified list of origins.
This check is mostly designed for browser clients. There is nothing preventing other types
of client to modify the `Origin` header value (see
https://tools.ietf.org/html/rfc6454[RFC 6454: The Web Origin Concept] for more details).
The 3 possible behaviors are:
* Allow only same origin requests (default): in this mode, when SockJS is enabled, the
Iframe HTTP response header `X-Frame-Options` is set to `SAMEORIGIN`, and JSONP
transport is disabled since it does not allow to check the origin of a request.
As a consequence, IE6 and IE7 are not supported when this mode is enabled.
* Allow a specified list of origins: each provided allowed origin must start by `http://`
or `https://`. In this mode, when SockJS is enabled, both IFrame and JSONP based
transports are disabled. As a consequence, IE6 up to IE9 are not supported when this
mode is enabled.
* Allow all origins: to enable this mode, you should provide `*` as allowed origin. In this
mode, all transports are available.
Websocket and SockJS allowed origins can be configured as shown bellow:
[source,java,indent=0]
[subs="verbatim,quotes"]
----
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
@Configuration
@EnableWebSocket
public class WebSocketConfig implements WebSocketConfigurer {
@Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
registry.addHandler(myHandler(), "/myHandler").setAllowedOrigins("http://mydomain.com");
}
@Bean
public WebSocketHandler myHandler() {
return new MyHandler();
}
}
----
XML configuration equivalent:
[source,xml,indent=0]
[subs="verbatim,quotes,attributes"]
----
<beans xmlns="http://www.springframework.org/schema/beans"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xmlns:websocket="http://www.springframework.org/schema/websocket"
xsi:schemaLocation="
http://www.springframework.org/schema/beans
http://www.springframework.org/schema/beans/spring-beans.xsd
http://www.springframework.org/schema/websocket
http://www.springframework.org/schema/websocket/spring-websocket.xsd">
<websocket:handlers allowed-origins="http://mydomain.com">
<websocket:mapping path="/myHandler" handler="myHandler" />
</websocket:handlers>
<bean id="myHandler" class="org.springframework.samples.MyHandler"/>
</beans>
----
[[websocket-fallback]] [[websocket-fallback]]
@ -39750,11 +39819,11 @@ log category to TRACE.
[[websocket-fallback-cors]] [[websocket-fallback-cors]]
==== CORS Headers for SockJS ==== CORS Headers for SockJS
The SockJS protocol uses CORS for cross-domain support in the XHR streaming and If you allow cross-origin requests (see <<websocket-server-allowed-origins>>), the SockJS protocol
polling transports. Therefore CORS headers are added automatically unless the uses CORS for cross-domain support in the XHR streaming and polling transports. Therefore
presence of CORS headers in the response is detected. So if an application is CORS headers are added automatically unless the presence of CORS headers in the response
already configured to provide CORS support, e.g. through a Servlet Filter, is detected. So if an application is already configured to provide CORS support, e.g.
Spring's SockJsService will skip this part. through a Servlet Filter, Spring's SockJsService will skip this part.
It is also possible to disable the addition of these CORS headers thanks to the It is also possible to disable the addition of these CORS headers thanks to the
`suppressCors` property in Spring's SockJsService. `suppressCors` property in Spring's SockJsService.