diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/HandlersBeanDefinitionParser.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/HandlersBeanDefinitionParser.java index dd8dbd9e3c7..7a1bdd44c28 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/HandlersBeanDefinitionParser.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/HandlersBeanDefinitionParser.java @@ -34,6 +34,7 @@ import org.springframework.beans.factory.xml.ParserContext; import org.springframework.util.StringUtils; import org.springframework.util.xml.DomUtils; import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping; +import org.springframework.web.socket.server.support.OriginHandshakeInterceptor; import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler; import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler; @@ -79,7 +80,14 @@ class HandlersBeanDefinitionParser implements BeanDefinitionParser { else { RuntimeBeanReference handshakeHandler = WebSocketNamespaceUtils.registerHandshakeHandler(element, context, source); Element interceptorsElement = DomUtils.getChildElementByTagName(element, "handshake-interceptors"); - ManagedList interceptors = WebSocketNamespaceUtils.parseBeanSubElements(interceptorsElement, context); + ManagedList interceptors = WebSocketNamespaceUtils.parseBeanSubElements(interceptorsElement, context); + String allowedOriginsAttribute = element.getAttribute("allowed-origins"); + List allowedOrigins = Arrays.asList(StringUtils.tokenizeToStringArray(allowedOriginsAttribute, ",")); + if(!allowedOrigins.isEmpty()) { + OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(); + interceptor.setAllowedOrigins(allowedOrigins); + interceptors.add(interceptor); + } strategy = new WebSocketHandlerMappingStrategy(handshakeHandler, interceptors); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java index 0c06e0be850..e28bd073a81 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java @@ -62,6 +62,7 @@ import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.handler.WebSocketHandlerDecoratorFactory; import org.springframework.web.socket.messaging.StompSubProtocolHandler; import org.springframework.web.socket.messaging.SubProtocolWebSocketHandler; +import org.springframework.web.socket.server.support.OriginHandshakeInterceptor; import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler; import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler; @@ -282,7 +283,14 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser { else { RuntimeBeanReference handshakeHandler = WebSocketNamespaceUtils.registerHandshakeHandler(element, context, source); Element interceptorsElement = DomUtils.getChildElementByTagName(element, "handshake-interceptors"); - ManagedList interceptors = WebSocketNamespaceUtils.parseBeanSubElements(interceptorsElement, context); + ManagedList interceptors = WebSocketNamespaceUtils.parseBeanSubElements(interceptorsElement, context); + String allowedOriginsAttribute = element.getAttribute("allowed-origins"); + List allowedOrigins = Arrays.asList(StringUtils.tokenizeToStringArray(allowedOriginsAttribute, ",")); + if(!allowedOrigins.isEmpty()) { + OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(); + interceptor.setAllowedOrigins(allowedOrigins); + interceptors.add(interceptor); + } ConstructorArgumentValues cavs = new ConstructorArgumentValues(); cavs.addIndexedArgumentValue(0, subProtoHandler); if (handshakeHandler != null) { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/WebSocketNamespaceUtils.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/WebSocketNamespaceUtils.java index a4761a6969f..36891644856 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/WebSocketNamespaceUtils.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/WebSocketNamespaceUtils.java @@ -16,6 +16,9 @@ package org.springframework.web.socket.config; +import java.util.Arrays; +import java.util.List; + import org.w3c.dom.Element; import org.springframework.beans.factory.config.BeanDefinition; @@ -25,8 +28,10 @@ import org.springframework.beans.factory.support.ManagedList; import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.beans.factory.xml.ParserContext; import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; +import org.springframework.util.StringUtils; import org.springframework.util.xml.DomUtils; import org.springframework.web.socket.server.support.DefaultHandshakeHandler; +import org.springframework.web.socket.server.support.OriginHandshakeInterceptor; import org.springframework.web.socket.sockjs.transport.TransportHandlingSockJsService; import org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService; import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler; @@ -97,7 +102,15 @@ class WebSocketNamespaceUtils { } Element interceptorsElement = DomUtils.getChildElementByTagName(element, "handshake-interceptors"); - ManagedList interceptors = WebSocketNamespaceUtils.parseBeanSubElements(interceptorsElement, context); + ManagedList interceptors = WebSocketNamespaceUtils.parseBeanSubElements(interceptorsElement, context); + String allowedOriginsAttribute = element.getAttribute("allowed-origins"); + List allowedOrigins = Arrays.asList(StringUtils.tokenizeToStringArray(allowedOriginsAttribute, ",")); + if(!allowedOrigins.isEmpty()) { + sockJsServiceDef.getPropertyValues().add("allowedOrigins", allowedOrigins); + OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(); + interceptor.setAllowedOrigins(allowedOrigins); + interceptors.add(interceptor); + } sockJsServiceDef.getPropertyValues().add("handshakeInterceptors", interceptors); String attrValue = sockJsElement.getAttribute("name"); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/AbstractWebSocketHandlerRegistration.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/AbstractWebSocketHandlerRegistration.java index 89b244ee162..1bdbd9c44d1 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/AbstractWebSocketHandlerRegistration.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/AbstractWebSocketHandlerRegistration.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2013 the original author or authors. + * Copyright 2002-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,15 +16,19 @@ package org.springframework.web.socket.config.annotation; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; import org.springframework.scheduling.TaskScheduler; import org.springframework.util.Assert; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; +import org.springframework.util.ObjectUtils; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.server.HandshakeHandler; import org.springframework.web.socket.server.HandshakeInterceptor; +import org.springframework.web.socket.server.support.OriginHandshakeInterceptor; import org.springframework.web.socket.server.support.DefaultHandshakeHandler; import org.springframework.web.socket.sockjs.SockJsService; import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler; @@ -34,6 +38,7 @@ import org.springframework.web.socket.sockjs.transport.handler.WebSocketTranspor * options but allows sub-classes to put together the actual HTTP request mappings. * * @author Rossen Stoyanchev + * @author Sebastien Deleuze * @since 4.0 */ public abstract class AbstractWebSocketHandlerRegistration implements WebSocketHandlerRegistration { @@ -44,7 +49,9 @@ public abstract class AbstractWebSocketHandlerRegistration implements WebSock private HandshakeHandler handshakeHandler; - private HandshakeInterceptor[] interceptors; + private final List interceptors = new ArrayList(); + + private final List allowedOrigins = new ArrayList(); private SockJsServiceRegistration sockJsServiceRegistration; @@ -74,27 +81,49 @@ public abstract class AbstractWebSocketHandlerRegistration implements WebSock @Override public WebSocketHandlerRegistration addInterceptors(HandshakeInterceptor... interceptors) { - this.interceptors = interceptors; + if (!ObjectUtils.isEmpty(interceptors)) { + this.interceptors.addAll(Arrays.asList(interceptors)); + } return this; } - protected HandshakeInterceptor[] getInterceptors() { - return this.interceptors; + @Override + public WebSocketHandlerRegistration setAllowedOrigins(String... origins) { + this.allowedOrigins.clear(); + if (!ObjectUtils.isEmpty(origins)) { + this.allowedOrigins.addAll(Arrays.asList(origins)); + } + return this; } @Override public SockJsServiceRegistration withSockJS() { this.sockJsServiceRegistration = new SockJsServiceRegistration(this.sockJsTaskScheduler); - if (this.interceptors != null) { - this.sockJsServiceRegistration.setInterceptors(this.interceptors); + HandshakeInterceptor[] interceptors = getInterceptors(); + if (interceptors.length > 0) { + this.sockJsServiceRegistration.setInterceptors(interceptors); } if (this.handshakeHandler != null) { WebSocketTransportHandler transportHandler = new WebSocketTransportHandler(this.handshakeHandler); this.sockJsServiceRegistration.setTransportHandlerOverrides(transportHandler); } + if (!this.allowedOrigins.isEmpty()) { + this.sockJsServiceRegistration.setAllowedOrigins(this.allowedOrigins.toArray(new String[this.allowedOrigins.size()])); + } return this.sockJsServiceRegistration; } + protected HandshakeInterceptor[] getInterceptors() { + List interceptors = new ArrayList(); + interceptors.addAll(this.interceptors); + if(!this.allowedOrigins.isEmpty()) { + OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(); + interceptor.setAllowedOrigins(this.allowedOrigins); + interceptors.add(interceptor); + } + return interceptors.toArray(new HandshakeInterceptor[interceptors.size()]); + } + protected final M getMappings() { M mappings = createMappings(); if (this.sockJsServiceRegistration != null) { @@ -108,9 +137,10 @@ public abstract class AbstractWebSocketHandlerRegistration implements WebSock } else { HandshakeHandler handshakeHandler = getOrCreateHandshakeHandler(); + HandshakeInterceptor[] interceptors = getInterceptors(); for (WebSocketHandler wsHandler : this.handlerMap.keySet()) { for (String path : this.handlerMap.get(wsHandler)) { - addWebSocketHandlerMapping(mappings, wsHandler, handshakeHandler, this.interceptors, path); + addWebSocketHandlerMapping(mappings, wsHandler, handshakeHandler, interceptors, path); } } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/SockJsServiceRegistration.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/SockJsServiceRegistration.java index 17b5597afe8..0b7b028c686 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/SockJsServiceRegistration.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/SockJsServiceRegistration.java @@ -62,6 +62,8 @@ public class SockJsServiceRegistration { private final List interceptors = new ArrayList(); + private final List allowedOrigins = new ArrayList(); + private SockJsMessageCodec messageCodec; @@ -195,6 +197,7 @@ public class SockJsServiceRegistration { } public SockJsServiceRegistration setInterceptors(HandshakeInterceptor... interceptors) { + this.interceptors.clear(); if (!ObjectUtils.isEmpty(interceptors)) { this.interceptors.addAll(Arrays.asList(interceptors)); } @@ -213,6 +216,17 @@ public class SockJsServiceRegistration { return this; } + /** + * @since 4.1.2 + */ + protected SockJsServiceRegistration setAllowedOrigins(String... origins) { + this.allowedOrigins.clear(); + if (!ObjectUtils.isEmpty(origins)) { + this.allowedOrigins.addAll(Arrays.asList(origins)); + } + return this; + } + protected SockJsService getSockJsService() { TransportHandlingSockJsService service = createSockJsService(); service.setHandshakeInterceptors(this.interceptors); @@ -237,6 +251,9 @@ public class SockJsServiceRegistration { if (this.webSocketEnabled != null) { service.setWebSocketEnabled(this.webSocketEnabled); } + if (!this.allowedOrigins.isEmpty()) { + service.setAllowedOrigins(this.allowedOrigins); + } if (this.messageCodec != null) { service.setMessageCodec(this.messageCodec); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/StompWebSocketEndpointRegistration.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/StompWebSocketEndpointRegistration.java index 4260cca4171..c0bea5565e7 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/StompWebSocketEndpointRegistration.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/StompWebSocketEndpointRegistration.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2013 the original author or authors. + * Copyright 2002-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -42,4 +42,19 @@ public interface StompWebSocketEndpointRegistration { */ StompWebSocketEndpointRegistration addInterceptors(HandshakeInterceptor... interceptors); + /** + * Configure allowed {@code Origin} header values. This check is mostly designed for browser + * clients. There is noting preventing other types of client to modify the Origin header value. + * + *

When SockJS is enabled and allowed origins are restricted, transport types that do not + * use {@code Origin} headers for cross origin requests (jsonp-polling, iframe-xhr-polling, + * iframe-eventsource and iframe-htmlfile) are disabled. As a consequence, IE6/IE7 won't be + * supported anymore and IE8/IE9 will only be supported without cookies. + * + *

By default, all origins are allowed. + * @since 4.1.2 + * @see SockJS supported transports by browser + */ + StompWebSocketEndpointRegistration setAllowedOrigins(String... origins); + } \ No newline at end of file diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistration.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistration.java index 6d1f42f078d..46975710ca6 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistration.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistration.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2013 the original author or authors. + * Copyright 2002-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,15 +22,19 @@ import org.springframework.scheduling.TaskScheduler; import org.springframework.util.Assert; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; +import org.springframework.util.ObjectUtils; import org.springframework.web.HttpRequestHandler; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.server.HandshakeHandler; import org.springframework.web.socket.server.HandshakeInterceptor; +import org.springframework.web.socket.server.support.OriginHandshakeInterceptor; import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler; import org.springframework.web.socket.sockjs.SockJsService; import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler; import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler; +import java.util.ArrayList; +import java.util.List; /** * An abstract base class class for configuring STOMP over WebSocket/SockJS endpoints. * @@ -47,7 +51,9 @@ public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketE private HandshakeHandler handshakeHandler; - private HandshakeInterceptor[] interceptors; + private final List interceptors = new ArrayList(); + + private final List allowedOrigins = new ArrayList(); private StompSockJsServiceRegistration registration; @@ -72,27 +78,49 @@ public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketE @Override public StompWebSocketEndpointRegistration addInterceptors(HandshakeInterceptor... interceptors) { - this.interceptors = interceptors; + if (!ObjectUtils.isEmpty(interceptors)) { + this.interceptors.addAll(Arrays.asList(interceptors)); + } return this; } - protected HandshakeInterceptor[] getInterceptors() { - return this.interceptors; + @Override + public StompWebSocketEndpointRegistration setAllowedOrigins(String... origins) { + this.allowedOrigins.clear(); + if (!ObjectUtils.isEmpty(origins)) { + this.allowedOrigins.addAll(Arrays.asList(origins)); + } + return this; } @Override public SockJsServiceRegistration withSockJS() { this.registration = new StompSockJsServiceRegistration(this.sockJsTaskScheduler); - if (this.interceptors != null) { - this.registration.setInterceptors(this.interceptors); + HandshakeInterceptor[] interceptors = getInterceptors(); + if (interceptors.length > 0) { + this.registration.setInterceptors(interceptors); } if (this.handshakeHandler != null) { WebSocketTransportHandler transportHandler = new WebSocketTransportHandler(this.handshakeHandler); this.registration.setTransportHandlerOverrides(transportHandler); } + if (!this.allowedOrigins.isEmpty()) { + this.registration.setAllowedOrigins(this.allowedOrigins.toArray(new String[this.allowedOrigins.size()])); + } return this.registration; } + protected HandshakeInterceptor[] getInterceptors() { + List interceptors = new ArrayList(); + interceptors.addAll(this.interceptors); + if(!this.allowedOrigins.isEmpty()) { + OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(); + interceptor.setAllowedOrigins(this.allowedOrigins); + interceptors.add(interceptor); + } + return interceptors.toArray(new HandshakeInterceptor[interceptors.size()]); + } + public final MultiValueMap getMappings() { MultiValueMap mappings = new LinkedMultiValueMap(); if (this.registration != null) { @@ -112,8 +140,9 @@ public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketE else { handler = new WebSocketHttpRequestHandler(this.webSocketHandler); } - if (this.interceptors != null) { - handler.setHandshakeInterceptors(Arrays.asList(this.interceptors)); + HandshakeInterceptor[] interceptors = getInterceptors(); + if (interceptors.length > 0) { + handler.setHandshakeInterceptors(Arrays.asList(interceptors)); } mappings.add(handler, path); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketHandlerRegistration.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketHandlerRegistration.java index b960c0f2a0d..4deb3f7c104 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketHandlerRegistration.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketHandlerRegistration.java @@ -44,9 +44,27 @@ public interface WebSocketHandlerRegistration { */ WebSocketHandlerRegistration addInterceptors(HandshakeInterceptor... interceptors); + /** + * Configure allowed {@code Origin} header values. This check is mostly designed for browser + * clients. There is noting preventing other types of client to modify the Origin header value. + * + *

When SockJS is enabled and allowed origins are restricted, transport types that do not + * use {@code Origin} headers for cross origin requests (jsonp-polling, iframe-xhr-polling, + * iframe-eventsource and iframe-htmlfile) are disabled. As a consequence, IE6/IE7 won't be + * supported anymore and IE8/IE9 will only be supported without cookies. + * + *

By default, all origins are allowed. + * + * @since 4.1.2 + * @see SockJS supported transports by browser + */ + WebSocketHandlerRegistration setAllowedOrigins(String... origins); + /** * Enable SockJS fallback options. */ SockJsServiceRegistration withSockJS(); + + } \ No newline at end of file diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/DefaultHandshakeHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/DefaultHandshakeHandler.java index 8a022dec9f0..442436887d0 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/DefaultHandshakeHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/DefaultHandshakeHandler.java @@ -260,6 +260,11 @@ public class DefaultHandshakeHandler implements HandshakeHandler { Arrays.asList(StringUtils.arrayToCommaDelimitedString(getSupportedVersions()))); } + /** + * Return whether the request {@code Origin} header value is valid or not. + * By default, all origins as considered as valid. Consider using an + * {@link OriginHandshakeInterceptor} for filtering origins if needed. + */ protected boolean isValidOrigin(ServerHttpRequest request) { return true; } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptor.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptor.java new file mode 100644 index 00000000000..505195ca333 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptor.java @@ -0,0 +1,88 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.server.support; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.http.HttpStatus; +import org.springframework.http.server.ServerHttpRequest; +import org.springframework.http.server.ServerHttpResponse; +import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.server.HandshakeInterceptor; + +/** + * An interceptor to check request {@code Origin} header value against a collection of + * allowed origins. + * + * @author Sebastien Deleuze + * @since 4.1.2 + */ +public class OriginHandshakeInterceptor implements HandshakeInterceptor { + + protected Log logger = LogFactory.getLog(getClass()); + + private final List allowedOrigins; + + + /** + * Default constructor with no origin allowed. + */ + public OriginHandshakeInterceptor() { + this.allowedOrigins = new ArrayList(); + } + + /** + * Use this property to define a collection of allowed origins. + */ + public void setAllowedOrigins(Collection allowedOrigins) { + this.allowedOrigins.clear(); + if (allowedOrigins != null) { + this.allowedOrigins.addAll(allowedOrigins); + } + } + + @Override + public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, + WebSocketHandler wsHandler, Map attributes) throws Exception { + if(!isValidOrigin(request)) { + response.setStatusCode(HttpStatus.FORBIDDEN); + if (logger.isDebugEnabled()) { + logger.debug("Handshake request rejected, Origin header value " + + request.getHeaders().getOrigin() + " not allowed"); + } + return false; + } + return true; + } + + protected boolean isValidOrigin(ServerHttpRequest request) { + return this.allowedOrigins.contains(request.getHeaders().getOrigin()); + } + + @Override + public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, + WebSocketHandler wsHandler, Exception exception) { + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java index 97e0fa83244..6b006bdf1d4 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java @@ -18,7 +18,9 @@ package org.springframework.web.socket.sockjs.support; import java.io.IOException; import java.nio.charset.Charset; +import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.Date; import java.util.HashSet; import java.util.List; @@ -44,6 +46,7 @@ import org.springframework.util.StringUtils; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.sockjs.SockJsException; import org.springframework.web.socket.sockjs.SockJsService; +import org.springframework.web.socket.sockjs.transport.TransportType; /** * An abstract base class for {@link SockJsService} implementations that provides SockJS @@ -51,6 +54,7 @@ import org.springframework.web.socket.sockjs.SockJsService; * etc). Sub-classes must handle session URLs (i.e. transport-specific requests). * * @author Rossen Stoyanchev + * @author Sebastien Deleuze * @since 4.0 */ public abstract class AbstractSockJsService implements SockJsService { @@ -82,6 +86,8 @@ public abstract class AbstractSockJsService implements SockJsService { private boolean webSocketEnabled = true; + private final List allowedOrigins = new ArrayList(Arrays.asList("*")); + public AbstractSockJsService(TaskScheduler scheduler) { Assert.notNull(scheduler, "TaskScheduler must not be null"); @@ -258,6 +264,34 @@ public abstract class AbstractSockJsService implements SockJsService { return this.webSocketEnabled; } + /** + * Configure allowed {@code Origin} header values. This check is mostly designed for browser + * clients. There is noting preventing other types of client to modify the Origin header value. + * + *

When SockJS is enabled and allowed origins are restricted, transport types that do not + * use {@code Origin} headers for cross origin requests (jsonp-polling, iframe-xhr-polling, + * iframe-eventsource and iframe-htmlfile) are disabled. As a consequence, IE6/IE7 won't be + * supported anymore and IE8/IE9 will only be supported without cookies. + * + *

By default, all origins are allowed. + * + * @since 4.1.2 + * @see SockJS supported transports by browser + */ + public void setAllowedOrigins(List allowedOrigins) { + this.allowedOrigins.clear(); + if (allowedOrigins != null) { + this.allowedOrigins.addAll(allowedOrigins); + } + } + + /** + * @since 4.1.2 + * @see #setAllowedOrigins(List) + */ + public List getAllowedOrigins() { + return Collections.unmodifiableList(allowedOrigins); + } /** * This method determines the SockJS path and handles SockJS static URLs. @@ -325,6 +359,12 @@ public abstract class AbstractSockJsService implements SockJsService { response.setStatusCode(HttpStatus.NOT_FOUND); return; } + else if(!this.allowedOrigins.contains("*") && !TransportType.fromValue(transport).supportsOrigin()) { + logger.debug("Origin check has been enabled, but this transport does not support it, ignoring " + + requestInfo); + response.setStatusCode(HttpStatus.NOT_FOUND); + return; + } handleTransportRequest(request, response, wsHandler, sessionId, transport); } response.close(); @@ -360,23 +400,43 @@ public abstract class AbstractSockJsService implements SockJsService { protected abstract void handleTransportRequest(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler webSocketHandler, String sessionId, String transport) throws SockJsException; - - protected void addCorsHeaders(ServerHttpRequest request, ServerHttpResponse response, HttpMethod... httpMethods) { + /** + * Check the {@code Origin} header value and eventually call {@link #addCorsHeaders(ServerHttpRequest, ServerHttpResponse, HttpMethod...)}. + * If the request origin is not allowed, the request is rejected. + * @return false if the request is rejected, else true + * @since 4.1.2 + */ + protected boolean checkAndAddCorsHeaders(ServerHttpRequest request, ServerHttpResponse response, HttpMethod... httpMethods) { HttpHeaders requestHeaders = request.getHeaders(); HttpHeaders responseHeaders = response.getHeaders(); + String origin = requestHeaders.getOrigin(); + + if(!this.allowedOrigins.contains("*") && (origin == null || !this.allowedOrigins.contains(origin))) { + logger.debug("Request rejected, Origin header value " + origin + " not allowed"); + response.setStatusCode(HttpStatus.FORBIDDEN); + return false; + } + + boolean hasCorsResponseHeaders = false; try { // Perhaps a CORS Filter has already added this? - if (!CollectionUtils.isEmpty(responseHeaders.get("Access-Control-Allow-Origin"))) { - return; - } + hasCorsResponseHeaders = !CollectionUtils.isEmpty(responseHeaders.get("Access-Control-Allow-Origin")); } catch (NullPointerException npe) { // See SPR-11919 and https://issues.jboss.org/browse/WFLY-3474 } - String origin = requestHeaders.getFirst("origin"); - origin = (origin == null || origin.equals("null") ? "*" : origin); - responseHeaders.add("Access-Control-Allow-Origin", origin); + if(origin != null && !hasCorsResponseHeaders) { + addCorsHeaders(request, response, httpMethods); + } + return true; + } + + protected void addCorsHeaders(ServerHttpRequest request, ServerHttpResponse response, HttpMethod... httpMethods) { + HttpHeaders requestHeaders = request.getHeaders(); + HttpHeaders responseHeaders = response.getHeaders(); + + responseHeaders.add("Access-Control-Allow-Origin", requestHeaders.getFirst("Origin")); responseHeaders.add("Access-Control-Allow-Credentials", "true"); List accessControllerHeaders = requestHeaders.get("Access-Control-Request-Headers"); @@ -424,16 +484,19 @@ public abstract class AbstractSockJsService implements SockJsService { @Override public void handle(ServerHttpRequest request, ServerHttpResponse response) throws IOException { if (HttpMethod.GET.equals(request.getMethod())) { - response.getHeaders().setContentType(new MediaType("application", "json", UTF8_CHARSET)); - addCorsHeaders(request, response); addNoCacheHeaders(response); - String content = String.format(INFO_CONTENT, random.nextInt(), isSessionCookieNeeded(), isWebSocketEnabled()); - response.getBody().write(content.getBytes()); + if(checkAndAddCorsHeaders(request, response)) { + response.getHeaders().setContentType(new MediaType("application", "json", UTF8_CHARSET)); + String content = String.format(INFO_CONTENT, random.nextInt(), isSessionCookieNeeded(), isWebSocketEnabled()); + response.getBody().write(content.getBytes()); + } } else if (HttpMethod.OPTIONS.equals(request.getMethod())) { - response.setStatusCode(HttpStatus.NO_CONTENT); - addCorsHeaders(request, response, HttpMethod.OPTIONS, HttpMethod.GET); - addCacheHeaders(response); + if(checkAndAddCorsHeaders(request, response, HttpMethod.OPTIONS, + HttpMethod.GET)) { + addCacheHeaders(response); + response.setStatusCode(HttpStatus.NO_CONTENT); + } } else { sendMethodNotAllowed(response, HttpMethod.OPTIONS, HttpMethod.GET); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java index 66d504e73d8..279cbc86a77 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java @@ -207,9 +207,10 @@ public class TransportHandlingSockJsService extends AbstractSockJsService implem HttpMethod supportedMethod = transportType.getHttpMethod(); if (!supportedMethod.equals(request.getMethod())) { if (HttpMethod.OPTIONS.equals(request.getMethod()) && transportType.supportsCors()) { - response.setStatusCode(HttpStatus.NO_CONTENT); - addCorsHeaders(request, response, HttpMethod.OPTIONS, supportedMethod); - addCacheHeaders(response); + if(checkAndAddCorsHeaders(request, response, HttpMethod.OPTIONS, supportedMethod)) { + response.setStatusCode(HttpStatus.NO_CONTENT); + addCacheHeaders(response); + } } else if (transportType.supportsCors()) { sendMethodNotAllowed(response, supportedMethod, HttpMethod.OPTIONS); @@ -250,7 +251,9 @@ public class TransportHandlingSockJsService extends AbstractSockJsService implem } if (transportType.supportsCors()) { - addCorsHeaders(request, response); + if(!checkAndAddCorsHeaders(request, response)) { + return; + } } transportHandler.handleRequest(request, response, handler, session); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportType.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportType.java index adc30b3e9c6..04e7f30f6bb 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportType.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportType.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2013 the original author or authors. + * Copyright 2002-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,11 +28,12 @@ import org.springframework.http.HttpMethod; * SockJS transport types. * * @author Rossen Stoyanchev + * @author Sebastien Deleuze * @since 4.0 */ public enum TransportType { - WEBSOCKET("websocket", HttpMethod.GET), + WEBSOCKET("websocket", HttpMethod.GET, "origin"), XHR("xhr", HttpMethod.POST, "cors", "jsessionid", "no_cache"), @@ -91,6 +92,10 @@ public enum TransportType { return this.headerHints.contains("cors"); } + public boolean supportsOrigin() { + return this.headerHints.contains("cors") || this.headerHints.contains("origin"); + } + public boolean sendsSessionCookie() { return this.headerHints.contains("jsessionid"); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/HtmlFileTransportHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/HtmlFileTransportHandler.java index 5d9382117c8..28fae126eeb 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/HtmlFileTransportHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/handler/HtmlFileTransportHandler.java @@ -37,7 +37,7 @@ import org.springframework.web.socket.sockjs.transport.session.StreamingSockJsSe import org.springframework.web.util.JavaScriptUtils; /** - * An HTTP {@link TransportHandler} that uses a famous browsder document.domain technique: + * An HTTP {@link TransportHandler} that uses a famous browser document.domain technique: * * http://stackoverflow.com/questions/1481251/what-does-document-domain-document-domain-do * diff --git a/spring-websocket/src/main/resources/org/springframework/web/socket/config/spring-websocket-4.1.xsd b/spring-websocket/src/main/resources/org/springframework/web/socket/config/spring-websocket-4.1.xsd index 25cc3a15b1b..e58631e5996 100644 --- a/spring-websocket/src/main/resources/org/springframework/web/socket/config/spring-websocket-4.1.xsd +++ b/spring-websocket/src/main/resources/org/springframework/web/socket/config/spring-websocket-4.1.xsd @@ -474,6 +474,24 @@ ]]> + + + + + @@ -641,6 +659,24 @@ ]]> + + + + + diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/AbstractHttpRequestTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/AbstractHttpRequestTests.java index 94b6b36d922..82aa0dbf902 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/AbstractHttpRequestTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/AbstractHttpRequestTests.java @@ -54,6 +54,10 @@ public abstract class AbstractHttpRequestTests { this.servletRequest.setRequestURI(requestUri); } + protected void setOrigin(String origin) { + this.servletRequest.addHeader("Origin", origin); + } + protected void resetRequestAndResponse() { resetRequest(); resetResponse(); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/config/HandlersBeanDefinitionParserTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/config/HandlersBeanDefinitionParserTests.java index 966850fd5ec..3f79367f83f 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/config/HandlersBeanDefinitionParserTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/config/HandlersBeanDefinitionParserTests.java @@ -18,11 +18,13 @@ package org.springframework.web.socket.config; import java.io.IOException; import java.io.InputStream; +import java.util.Arrays; import java.util.Date; import java.util.List; import java.util.Map; import java.util.concurrent.ScheduledFuture; +import static org.junit.Assert.assertEquals; import org.junit.Before; import org.junit.Test; @@ -45,6 +47,7 @@ import org.springframework.web.socket.handler.WebSocketHandlerDecorator; import org.springframework.web.socket.server.HandshakeHandler; import org.springframework.web.socket.server.HandshakeInterceptor; import org.springframework.web.socket.server.support.DefaultHandshakeHandler; +import org.springframework.web.socket.server.support.OriginHandshakeInterceptor; import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler; import org.springframework.web.socket.sockjs.SockJsService; import org.springframework.web.socket.sockjs.frame.SockJsMessageCodec; @@ -103,6 +106,7 @@ public class HandlersBeanDefinitionParserTests { HandshakeHandler handshakeHandler = handler.getHandshakeHandler(); assertNotNull(handshakeHandler); assertTrue(handshakeHandler instanceof DefaultHandshakeHandler); + assertTrue(handler.getHandshakeInterceptors().isEmpty()); } else { assertThat(shm.getUrlMap().keySet(), contains("/test")); @@ -112,6 +116,7 @@ public class HandlersBeanDefinitionParserTests { HandshakeHandler handshakeHandler = handler.getHandshakeHandler(); assertNotNull(handshakeHandler); assertTrue(handshakeHandler instanceof DefaultHandshakeHandler); + assertTrue(handler.getHandshakeInterceptors().isEmpty()); } } } @@ -135,7 +140,8 @@ public class HandlersBeanDefinitionParserTests { assertNotNull(handshakeHandler); assertTrue(handshakeHandler instanceof TestHandshakeHandler); List interceptors = handler.getHandshakeInterceptors(); - assertThat(interceptors, contains(instanceOf(FooTestInterceptor.class), instanceOf(BarTestInterceptor.class))); + assertThat(interceptors, contains(instanceOf(FooTestInterceptor.class), + instanceOf(BarTestInterceptor.class), instanceOf(OriginHandshakeInterceptor.class))); handler = (WebSocketHttpRequestHandler) urlHandlerMapping.getUrlMap().get("/test"); assertNotNull(handler); @@ -144,8 +150,8 @@ public class HandlersBeanDefinitionParserTests { assertNotNull(handshakeHandler); assertTrue(handshakeHandler instanceof TestHandshakeHandler); interceptors = handler.getHandshakeInterceptors(); - assertThat(interceptors, contains(instanceOf(FooTestInterceptor.class), instanceOf(BarTestInterceptor.class))); - + assertThat(interceptors, contains(instanceOf(FooTestInterceptor.class), + instanceOf(BarTestInterceptor.class), instanceOf(OriginHandshakeInterceptor.class))); } @Test @@ -222,6 +228,10 @@ public class HandlersBeanDefinitionParserTests { assertEquals(1024, transportService.getHttpMessageCacheSize()); assertEquals(20, transportService.getHeartbeatTime()); assertEquals(TestMessageCodec.class, transportService.getMessageCodec().getClass()); + + List interceptors = transportService.getHandshakeInterceptors(); + assertThat(interceptors, contains(instanceOf(OriginHandshakeInterceptor.class))); + assertEquals(Arrays.asList("http://mydomain1.com", "http://mydomain2.com"), transportService.getAllowedOrigins()); } private void loadBeanDefinitions(String fileName) { diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java index d8fa79e8d08..ab9fcb6a763 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java @@ -68,6 +68,7 @@ import org.springframework.web.socket.messaging.StompSubProtocolHandler; import org.springframework.web.socket.messaging.SubProtocolWebSocketHandler; import org.springframework.web.socket.server.HandshakeHandler; import org.springframework.web.socket.server.HandshakeInterceptor; +import org.springframework.web.socket.server.support.OriginHandshakeInterceptor; import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler; import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler; import org.springframework.web.socket.sockjs.transport.TransportType; @@ -115,7 +116,8 @@ public class MessageBrokerBeanDefinitionParserTests { assertNotNull(handshakeHandler); assertTrue(handshakeHandler instanceof TestHandshakeHandler); List interceptors = wsHttpRequestHandler.getHandshakeInterceptors(); - assertThat(interceptors, contains(instanceOf(FooTestInterceptor.class), instanceOf(BarTestInterceptor.class))); + assertThat(interceptors, contains(instanceOf(FooTestInterceptor.class), + instanceOf(BarTestInterceptor.class), instanceOf(OriginHandshakeInterceptor.class))); WebSocketSession session = new TestWebSocketSession("id"); wsHttpRequestHandler.getWebSocketHandler().afterConnectionEstablished(session); @@ -158,7 +160,9 @@ public class MessageBrokerBeanDefinitionParserTests { assertTrue(scheduler.getScheduledThreadPoolExecutor().getRemoveOnCancelPolicy()); interceptors = defaultSockJsService.getHandshakeInterceptors(); - assertThat(interceptors, contains(instanceOf(FooTestInterceptor.class), instanceOf(BarTestInterceptor.class))); + assertThat(interceptors, contains(instanceOf(FooTestInterceptor.class), + instanceOf(BarTestInterceptor.class), instanceOf(OriginHandshakeInterceptor.class))); + assertEquals(Arrays.asList("http://mydomain3.com", "http://mydomain4.com"), defaultSockJsService.getAllowedOrigins()); UserSessionRegistry userSessionRegistry = this.appContext.getBean(UserSessionRegistry.class); assertNotNull(userSessionRegistry); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistrationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistrationTests.java index 90905e47721..048050150a9 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistrationTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistrationTests.java @@ -29,6 +29,7 @@ import org.springframework.scheduling.TaskScheduler; import org.springframework.util.MultiValueMap; import org.springframework.web.HttpRequestHandler; import org.springframework.web.socket.messaging.SubProtocolWebSocketHandler; +import org.springframework.web.socket.server.support.OriginHandshakeInterceptor; import org.springframework.web.socket.server.support.DefaultHandshakeHandler; import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor; import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler; @@ -70,19 +71,60 @@ public class WebMvcStompWebSocketEndpointRegistrationTests { Map.Entry> entry = mappings.entrySet().iterator().next(); assertNotNull(((WebSocketHttpRequestHandler) entry.getKey()).getWebSocketHandler()); + assertTrue(((WebSocketHttpRequestHandler) entry.getKey()).getHandshakeInterceptors().isEmpty()); assertEquals(Arrays.asList("/foo"), entry.getValue()); } @Test - public void handshakeHandlerAndInterceptors() { + public void allowedOrigins() { + WebMvcStompWebSocketEndpointRegistration registration = + new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler); + + registration.setAllowedOrigins("http://mydomain.com"); + + MultiValueMap mappings = registration.getMappings(); + assertEquals(1, mappings.size()); + WebSocketHttpRequestHandler requestHandler = (WebSocketHttpRequestHandler)mappings.entrySet().iterator().next().getKey(); + assertNotNull(requestHandler.getWebSocketHandler()); + assertEquals(1, requestHandler.getHandshakeInterceptors().size()); + assertEquals(OriginHandshakeInterceptor.class, requestHandler.getHandshakeInterceptors().get(0).getClass()); + } + + @Test + public void allowedOriginsWithSockJsService() { + WebMvcStompWebSocketEndpointRegistration registration = + new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler); + + String origin = "http://mydomain.com"; + registration.setAllowedOrigins(origin).withSockJS(); + + MultiValueMap mappings = registration.getMappings(); + assertEquals(1, mappings.size()); + SockJsHttpRequestHandler requestHandler = (SockJsHttpRequestHandler)mappings.entrySet().iterator().next().getKey(); + assertNotNull(requestHandler.getSockJsService()); + DefaultSockJsService sockJsService = (DefaultSockJsService)requestHandler.getSockJsService(); + assertEquals(Arrays.asList(origin), sockJsService.getAllowedOrigins()); + + registration = + new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler); + registration.withSockJS().setAllowedOrigins(origin); + mappings = registration.getMappings(); + assertEquals(1, mappings.size()); + requestHandler = (SockJsHttpRequestHandler)mappings.entrySet().iterator().next().getKey(); + assertNotNull(requestHandler.getSockJsService()); + sockJsService = (DefaultSockJsService)requestHandler.getSockJsService(); + assertEquals(Arrays.asList(origin), sockJsService.getAllowedOrigins()); + } + + @Test + public void handshakeHandlerAndInterceptor() { WebMvcStompWebSocketEndpointRegistration registration = new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler); DefaultHandshakeHandler handshakeHandler = new DefaultHandshakeHandler(); HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor(); - registration.setHandshakeHandler(handshakeHandler); - registration.addInterceptors(interceptor); + registration.setHandshakeHandler(handshakeHandler).addInterceptors(interceptor); MultiValueMap mappings = registration.getMappings(); assertEquals(1, mappings.size()); @@ -97,16 +139,38 @@ public class WebMvcStompWebSocketEndpointRegistrationTests { } @Test - public void handshakeHandlerAndInterceptorsWithSockJsService() { + public void handshakeHandlerAndInterceptorWithAllowedOrigins() { + WebMvcStompWebSocketEndpointRegistration registration = + new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler); + + DefaultHandshakeHandler handshakeHandler = new DefaultHandshakeHandler(); + HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor(); + String origin = "http://mydomain.com"; + registration.setHandshakeHandler(handshakeHandler).addInterceptors(interceptor).setAllowedOrigins(origin); + + MultiValueMap mappings = registration.getMappings(); + assertEquals(1, mappings.size()); + + Map.Entry> entry = mappings.entrySet().iterator().next(); + assertEquals(Arrays.asList("/foo"), entry.getValue()); + + WebSocketHttpRequestHandler requestHandler = (WebSocketHttpRequestHandler) entry.getKey(); + assertNotNull(requestHandler.getWebSocketHandler()); + assertSame(handshakeHandler, requestHandler.getHandshakeHandler()); + assertEquals(2, requestHandler.getHandshakeInterceptors().size()); + assertEquals(interceptor, requestHandler.getHandshakeInterceptors().get(0)); + assertEquals(OriginHandshakeInterceptor.class, requestHandler.getHandshakeInterceptors().get(1).getClass()); + } + + @Test + public void handshakeHandlerInterceptorWithSockJsService() { WebMvcStompWebSocketEndpointRegistration registration = new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler); DefaultHandshakeHandler handshakeHandler = new DefaultHandshakeHandler(); HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor(); - registration.setHandshakeHandler(handshakeHandler); - registration.addInterceptors(interceptor); - registration.withSockJS(); + registration.setHandshakeHandler(handshakeHandler).addInterceptors(interceptor).withSockJS(); MultiValueMap mappings = registration.getMappings(); assertEquals(1, mappings.size()); @@ -126,4 +190,37 @@ public class WebMvcStompWebSocketEndpointRegistrationTests { assertEquals(Arrays.asList(interceptor), sockJsService.getHandshakeInterceptors()); } + @Test + public void handshakeHandlerInterceptorWithSockJsServiceAndAllowedOrigins() { + WebMvcStompWebSocketEndpointRegistration registration = + new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler); + + DefaultHandshakeHandler handshakeHandler = new DefaultHandshakeHandler(); + HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor(); + String origin = "http://mydomain.com"; + + registration.setHandshakeHandler(handshakeHandler).addInterceptors(interceptor).setAllowedOrigins(origin).withSockJS(); + + MultiValueMap mappings = registration.getMappings(); + assertEquals(1, mappings.size()); + + Map.Entry> entry = mappings.entrySet().iterator().next(); + assertEquals(Arrays.asList("/foo/**"), entry.getValue()); + + SockJsHttpRequestHandler requestHandler = (SockJsHttpRequestHandler) entry.getKey(); + assertNotNull(requestHandler.getWebSocketHandler()); + + DefaultSockJsService sockJsService = (DefaultSockJsService) requestHandler.getSockJsService(); + assertNotNull(sockJsService); + + Map handlers = sockJsService.getTransportHandlers(); + WebSocketTransportHandler transportHandler = (WebSocketTransportHandler) handlers.get(TransportType.WEBSOCKET); + assertSame(handshakeHandler, transportHandler.getHandshakeHandler()); + assertEquals(2, sockJsService.getHandshakeInterceptors().size()); + assertEquals(interceptor, sockJsService.getHandshakeInterceptors().get(0)); + assertEquals(OriginHandshakeInterceptor.class, + sockJsService.getHandshakeInterceptors().get(1).getClass()); + assertEquals(Arrays.asList(origin), sockJsService.getAllowedOrigins()); + } + } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketHandlerRegistrationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketHandlerRegistrationTests.java index b05382582e7..9b2d175c15d 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketHandlerRegistrationTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketHandlerRegistrationTests.java @@ -29,6 +29,7 @@ import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.handler.TextWebSocketHandler; import org.springframework.web.socket.server.HandshakeHandler; import org.springframework.web.socket.server.HandshakeInterceptor; +import org.springframework.web.socket.server.support.OriginHandshakeInterceptor; import org.springframework.web.socket.server.support.DefaultHandshakeHandler; import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor; import org.springframework.web.socket.sockjs.SockJsService; @@ -68,10 +69,12 @@ public class WebSocketHandlerRegistrationTests { Mapping m1 = mappings.get(0); assertEquals(handler, m1.webSocketHandler); assertEquals("/foo", m1.path); + assertEquals(0, m1.interceptors.length); Mapping m2 = mappings.get(1); assertEquals(handler, m2.webSocketHandler); assertEquals("/bar", m2.path); + assertEquals(0, m2.interceptors.length); } @Test @@ -90,12 +93,31 @@ public class WebSocketHandlerRegistrationTests { assertArrayEquals(new HandshakeInterceptor[] {interceptor}, mapping.interceptors); } + @Test + public void interceptorsWithAllowedOrigins() { + WebSocketHandler handler = new TextWebSocketHandler(); + HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor(); + + this.registration.addHandler(handler, "/foo").addInterceptors(interceptor).setAllowedOrigins("http://mydomain1.com"); + + List mappings = this.registration.getMappings(); + assertEquals(1, mappings.size()); + + Mapping mapping = mappings.get(0); + assertEquals(handler, mapping.webSocketHandler); + assertEquals("/foo", mapping.path); + assertEquals(2, mapping.interceptors.length); + assertEquals(interceptor, mapping.interceptors[0]); + assertEquals(OriginHandshakeInterceptor.class, mapping.interceptors[1].getClass()); + } + @Test public void interceptorsPassedToSockJsRegistration() { WebSocketHandler handler = new TextWebSocketHandler(); HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor(); - this.registration.addHandler(handler, "/foo").addInterceptors(interceptor).withSockJS(); + this.registration.addHandler(handler, "/foo").addInterceptors(interceptor) + .setAllowedOrigins("http://mydomain1.com").withSockJS(); List mappings = this.registration.getMappings(); assertEquals(1, mappings.size()); @@ -104,7 +126,11 @@ public class WebSocketHandlerRegistrationTests { assertEquals(handler, mapping.webSocketHandler); assertEquals("/foo/**", mapping.path); assertNotNull(mapping.sockJsService); - assertEquals(Arrays.asList(interceptor), mapping.sockJsService.getHandshakeInterceptors()); + assertEquals(Arrays.asList("http://mydomain1.com"), + mapping.sockJsService.getAllowedOrigins()); + List interceptors = mapping.sockJsService.getHandshakeInterceptors(); + assertEquals(interceptor, interceptors.get(0)); + assertEquals(OriginHandshakeInterceptor.class, interceptors.get(1).getClass()); } @Test diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/server/support/AllowedOriginsInterceptorTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/server/support/AllowedOriginsInterceptorTests.java new file mode 100644 index 00000000000..da44c13668a --- /dev/null +++ b/spring-websocket/src/test/java/org/springframework/web/socket/server/support/AllowedOriginsInterceptorTests.java @@ -0,0 +1,106 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.server.support; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentSkipListSet; + +import static org.junit.Assert.*; +import org.junit.Test; +import org.mockito.Mockito; + +import org.springframework.http.HttpStatus; +import org.springframework.web.socket.AbstractHttpRequestTests; +import org.springframework.web.socket.WebSocketHandler; + +/** + * Test fixture for {@link OriginHandshakeInterceptor}. + * + * @author Sebastien Deleuze + */ +public class AllowedOriginsInterceptorTests extends AbstractHttpRequestTests { + + @Test + public void originValueMatch() throws Exception { + Map attributes = new HashMap(); + WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); + setOrigin("http://mydomain1.com"); + OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(); + interceptor.setAllowedOrigins(Arrays.asList("http://mydomain1.com")); + assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes)); + assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value()); + } + + @Test + public void originValueNoMatch() throws Exception { + Map attributes = new HashMap(); + WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); + setOrigin("http://mydomain1.com"); + OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(); + interceptor.setAllowedOrigins(Arrays.asList("http://mydomain2.com")); + assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes)); + assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value()); + } + + @Test + public void originListMatch() throws Exception { + Map attributes = new HashMap(); + WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); + setOrigin("http://mydomain2.com"); + OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(); + interceptor.setAllowedOrigins(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com")); + assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes)); + assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value()); + } + + @Test + public void originListNoMatch() throws Exception { + Map attributes = new HashMap(); + WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); + setOrigin("http://mydomain4.com"); + OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(); + interceptor.setAllowedOrigins(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com")); + assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes)); + assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value()); + } + + @Test + public void noOriginNoMatchWithNullHostileCollection() throws Exception { + Map attributes = new HashMap(); + WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); + Set allowedOrigins = new ConcurrentSkipListSet(); + allowedOrigins.add("http://mydomain1.com"); + OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(); + interceptor.setAllowedOrigins(allowedOrigins); + assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes)); + assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value()); + } + + @Test + public void noOriginNoMatch() throws Exception { + Map attributes = new HashMap(); + WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); + OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(); + assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes)); + assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value()); + } + +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/support/SockJsServiceTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/support/SockJsServiceTests.java index 0e99147aa12..ebad39d5aea 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/support/SockJsServiceTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/support/SockJsServiceTests.java @@ -17,9 +17,12 @@ package org.springframework.web.socket.sockjs.support; import java.io.IOException; +import java.util.Arrays; +import java.util.List; import javax.servlet.ServletOutputStream; import javax.servlet.http.HttpServletResponse; +import static org.junit.Assert.assertEquals; import org.junit.Before; import org.junit.Test; @@ -40,9 +43,12 @@ import static org.mockito.BDDMockito.*; * Test fixture for {@link AbstractSockJsService}. * * @author Rossen Stoyanchev + * @author Sebastien Deleuze */ public class SockJsServiceTests extends AbstractHttpRequestTests { + private static final List origins = Arrays.asList("http://mydomain1.com", "http://mydomain2.com"); + private TestSockJsService service; private WebSocketHandler handler; @@ -80,10 +86,10 @@ public class SockJsServiceTests extends AbstractHttpRequestTests { resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.OK); assertEquals("application/json;charset=UTF-8", this.servletResponse.getContentType()); - assertEquals("*", this.servletResponse.getHeader("Access-Control-Allow-Origin")); - assertEquals("true", this.servletResponse.getHeader("Access-Control-Allow-Credentials")); assertEquals("no-store, no-cache, must-revalidate, max-age=0", this.servletResponse.getHeader("Cache-Control")); - assertEquals("Origin", this.servletResponse.getHeader("Vary")); + assertNull(this.servletResponse.getHeader("Access-Control-Allow-Origin")); + assertNull(this.servletResponse.getHeader("Access-Control-Allow-Credentials")); + assertNull(this.servletResponse.getHeader("Vary")); String body = this.servletResponse.getContentAsString(); assertEquals("{\"entropy\"", body.substring(0, body.indexOf(':'))); @@ -97,6 +103,47 @@ public class SockJsServiceTests extends AbstractHttpRequestTests { body = this.servletResponse.getContentAsString(); assertEquals(",\"origins\":[\"*:*\"],\"cookie_needed\":false,\"websocket\":false}", body.substring(body.indexOf(','))); + + this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com")); + resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.FORBIDDEN); + assertNull(this.servletResponse.getHeader("Access-Control-Allow-Origin")); + assertNull(this.servletResponse.getHeader("Access-Control-Allow-Credentials")); + assertNull(this.servletResponse.getHeader("Vary")); + } + + @Test // SPR-12226 + public void handleInfoGetWithOrigin() throws Exception { + setOrigin("http://mydomain2.com"); + resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.OK); + + assertEquals("application/json;charset=UTF-8", this.servletResponse.getContentType()); + assertEquals("no-store, no-cache, must-revalidate, max-age=0", this.servletResponse.getHeader("Cache-Control")); + assertEquals("http://mydomain2.com", this.servletResponse.getHeader("Access-Control-Allow-Origin")); + assertEquals("true", this.servletResponse.getHeader("Access-Control-Allow-Credentials")); + assertEquals("Origin", this.servletResponse.getHeader("Vary")); + + String body = this.servletResponse.getContentAsString(); + assertEquals("{\"entropy\"", body.substring(0, body.indexOf(':'))); + assertEquals(",\"origins\":[\"*:*\"],\"cookie_needed\":true,\"websocket\":true}", + body.substring(body.indexOf(','))); + + this.service.setAllowedOrigins(null); + resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.FORBIDDEN); + assertNull(this.servletResponse.getHeader("Access-Control-Allow-Origin")); + assertNull(this.servletResponse.getHeader("Access-Control-Allow-Credentials")); + assertNull(this.servletResponse.getHeader("Vary")); + + this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com")); + resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.FORBIDDEN); + assertNull(this.servletResponse.getHeader("Access-Control-Allow-Origin")); + assertNull(this.servletResponse.getHeader("Access-Control-Allow-Credentials")); + assertNull(this.servletResponse.getHeader("Vary")); + + this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com")); + resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.OK); + assertEquals("http://mydomain2.com", this.servletResponse.getHeader("Access-Control-Allow-Origin")); + assertEquals("true", this.servletResponse.getHeader("Access-Control-Allow-Credentials")); + assertEquals("Origin", this.servletResponse.getHeader("Vary")); } @Test // SPR-11443 @@ -129,7 +176,60 @@ public class SockJsServiceTests extends AbstractHttpRequestTests { resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.NO_CONTENT); this.response.flush(); - assertEquals("*", this.servletResponse.getHeader("Access-Control-Allow-Origin")); + assertNull(this.servletResponse.getHeader("Access-Control-Allow-Origin")); + assertNull(this.servletResponse.getHeader("Access-Control-Allow-Credentials")); + assertNull(this.servletResponse.getHeader("Access-Control-Allow-Headers")); + assertNull(this.servletResponse.getHeader("Access-Control-Allow-Methods")); + assertNull(this.servletResponse.getHeader("Access-Control-Max-Age")); + assertEquals("Origin", this.servletResponse.getHeader("Vary")); + + this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com")); + resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.FORBIDDEN); + assertNull(this.servletResponse.getHeader("Access-Control-Allow-Origin")); + assertNull(this.servletResponse.getHeader("Access-Control-Allow-Credentials")); + assertNull(this.servletResponse.getHeader("Access-Control-Allow-Headers")); + assertNull(this.servletResponse.getHeader("Access-Control-Allow-Methods")); + assertNull(this.servletResponse.getHeader("Access-Control-Max-Age")); + assertNull(this.servletResponse.getHeader("Vary")); + } + + @Test // SPR-12226 + public void handleInfoOptionsWithOrigin() throws Exception { + setOrigin("http://mydomain2.com"); + this.servletRequest.addHeader("Access-Control-Request-Headers", "Last-Modified"); + resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.NO_CONTENT); + this.response.flush(); + assertEquals("http://mydomain2.com", this.servletResponse.getHeader("Access-Control-Allow-Origin")); + assertEquals("true", this.servletResponse.getHeader("Access-Control-Allow-Credentials")); + assertEquals("Last-Modified", this.servletResponse.getHeader("Access-Control-Allow-Headers")); + assertEquals("OPTIONS, GET", this.servletResponse.getHeader("Access-Control-Allow-Methods")); + assertEquals("31536000", this.servletResponse.getHeader("Access-Control-Max-Age")); + assertEquals("Origin", this.servletResponse.getHeader("Vary")); + + this.service.setAllowedOrigins(null); + resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.FORBIDDEN); + this.response.flush(); + assertNull(this.servletResponse.getHeader("Access-Control-Allow-Origin")); + assertNull(this.servletResponse.getHeader("Access-Control-Allow-Credentials")); + assertNull(this.servletResponse.getHeader("Access-Control-Allow-Headers")); + assertNull(this.servletResponse.getHeader("Access-Control-Allow-Methods")); + assertNull(this.servletResponse.getHeader("Access-Control-Max-Age")); + assertNull(this.servletResponse.getHeader("Vary")); + + this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com")); + resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.FORBIDDEN); + this.response.flush(); + assertNull(this.servletResponse.getHeader("Access-Control-Allow-Origin")); + assertNull(this.servletResponse.getHeader("Access-Control-Allow-Credentials")); + assertNull(this.servletResponse.getHeader("Access-Control-Allow-Headers")); + assertNull(this.servletResponse.getHeader("Access-Control-Allow-Methods")); + assertNull(this.servletResponse.getHeader("Access-Control-Max-Age")); + assertNull(this.servletResponse.getHeader("Vary")); + + this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com")); + resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.NO_CONTENT); + this.response.flush(); + assertEquals("http://mydomain2.com", this.servletResponse.getHeader("Access-Control-Allow-Origin")); assertEquals("true", this.servletResponse.getHeader("Access-Control-Allow-Credentials")); assertEquals("Last-Modified", this.servletResponse.getHeader("Access-Control-Allow-Headers")); assertEquals("OPTIONS, GET", this.servletResponse.getHeader("Access-Control-Allow-Methods")); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsServiceTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsServiceTests.java index 1e1251805a1..8805b925eed 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsServiceTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsServiceTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2013 the original author or authors. + * Copyright 2002-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,9 @@ package org.springframework.web.socket.sockjs.transport.handler; +import java.util.Arrays; import java.util.Collections; +import java.util.List; import java.util.Map; import org.junit.Before; @@ -27,6 +29,8 @@ import org.mockito.MockitoAnnotations; import org.springframework.scheduling.TaskScheduler; import org.springframework.web.socket.AbstractHttpRequestTests; import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.server.HandshakeHandler; +import org.springframework.web.socket.server.support.OriginHandshakeInterceptor; import org.springframework.web.socket.sockjs.transport.SockJsSessionFactory; import org.springframework.web.socket.sockjs.transport.TransportHandler; import org.springframework.web.socket.sockjs.transport.TransportHandlingSockJsService; @@ -41,6 +45,7 @@ import static org.mockito.BDDMockito.*; * Test fixture for {@link org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService}. * * @author Rossen Stoyanchev + * @author Sebastien Deleuze */ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests { @@ -50,11 +55,19 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests { private static final String sessionUrlPrefix = "/server1/" + sessionId + "/"; + private static final List origins = Arrays.asList("http://mydomain1.com", "http://mydomain2.com"); + @Mock private SessionCreatingTransportHandler xhrHandler; @Mock private TransportHandler xhrSendHandler; + @Mock private SessionCreatingTransportHandler jsonpHandler; + + @Mock private TransportHandler jsonpSendHandler; + + @Mock private HandshakeTransportHandler wsTransportHandler; + @Mock private WebSocketHandler wsHandler; @Mock private TaskScheduler taskScheduler; @@ -75,6 +88,10 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests { given(this.xhrHandler.getTransportType()).willReturn(TransportType.XHR); given(this.xhrHandler.createSession(sessionId, this.wsHandler, attributes)).willReturn(this.session); given(this.xhrSendHandler.getTransportType()).willReturn(TransportType.XHR_SEND); + given(this.jsonpHandler.getTransportType()).willReturn(TransportType.JSONP); + given(this.jsonpHandler.createSession(sessionId, this.wsHandler, attributes)).willReturn(this.session); + given(this.jsonpSendHandler.getTransportType()).willReturn(TransportType.JSONP_SEND); + given(this.wsTransportHandler.getTransportType()).willReturn(TransportType.WEBSOCKET); this.service = new TransportHandlingSockJsService(this.taskScheduler, this.xhrHandler, this.xhrSendHandler); } @@ -126,10 +143,47 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests { verify(taskScheduler).scheduleAtFixedRate(any(Runnable.class), eq(service.getDisconnectDelay())); assertEquals("no-store, no-cache, must-revalidate, max-age=0", this.response.getHeaders().getCacheControl()); - assertEquals("*", this.response.getHeaders().getFirst("Access-Control-Allow-Origin")); + assertNull(this.response.getHeaders().getFirst("Access-Control-Allow-Origin")); + assertNull(this.response.getHeaders().getFirst("Access-Control-Allow-Credentials")); + } + + @Test // SPR-12226 + public void handleTransportRequestXhrAllowNullOrigin() throws Exception { + String sockJsPath = sessionUrlPrefix + "xhr"; + setRequest("POST", sockJsPrefix + sockJsPath); + this.service.setAllowedOrigins(null); + this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler); + + assertNull(this.response.getHeaders().getFirst("Access-Control-Allow-Origin")); + assertNull(this.response.getHeaders().getFirst("Access-Control-Allow-Credentials")); + } + + @Test // SPR-12226 + public void handleTransportRequestXhrAllowedOriginsMatch() throws Exception { + String sockJsPath = sessionUrlPrefix + "xhr"; + setRequest("POST", sockJsPrefix + sockJsPath); + setOrigin(origins.get(0)); + this.service.setAllowedOrigins(origins); + this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler); + + assertEquals(200, this.servletResponse.getStatus()); + assertEquals(origins.get(0), this.response.getHeaders().getFirst("Access-Control-Allow-Origin")); assertEquals("true", this.response.getHeaders().getFirst("Access-Control-Allow-Credentials")); } + @Test // SPR-12226 + public void handleTransportRequestXhrAllowedOriginsNoMatch() throws Exception { + String sockJsPath = sessionUrlPrefix + "xhr"; + setRequest("POST", sockJsPrefix + sockJsPath); + setOrigin("http://mydomain3.com"); + this.service.setAllowedOrigins(origins); + this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler); + + assertEquals(403, this.servletResponse.getStatus()); + assertNull(this.response.getHeaders().getFirst("Access-Control-Allow-Origin")); + assertNull(this.response.getHeaders().getFirst("Access-Control-Allow-Credentials")); + } + @Test public void handleTransportRequestXhrOptions() throws Exception { String sockJsPath = sessionUrlPrefix + "xhr"; @@ -137,9 +191,22 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests { this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler); assertEquals(204, this.servletResponse.getStatus()); - assertEquals("*", this.response.getHeaders().getFirst("Access-Control-Allow-Origin")); - assertEquals("true", this.response.getHeaders().getFirst("Access-Control-Allow-Credentials")); - assertEquals("OPTIONS, POST", this.response.getHeaders().getFirst("Access-Control-Allow-Methods")); + assertNull(this.response.getHeaders().getFirst("Access-Control-Allow-Origin")); + assertNull(this.response.getHeaders().getFirst("Access-Control-Allow-Credentials")); + assertNull(this.response.getHeaders().getFirst("Access-Control-Allow-Methods")); + } + + @Test // SPR-12226 + public void handleTransportRequestXhrOptionsAllowNullOrigin() throws Exception { + String sockJsPath = sessionUrlPrefix + "xhr"; + setRequest("OPTIONS", sockJsPrefix + sockJsPath); + this.service.setAllowedOrigins(null); + this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler); + + assertEquals(403, this.servletResponse.getStatus()); + assertNull(this.response.getHeaders().getFirst("Access-Control-Allow-Origin")); + assertNull(this.response.getHeaders().getFirst("Access-Control-Allow-Credentials")); + assertNull(this.response.getHeaders().getFirst("Access-Control-Allow-Methods")); } @Test @@ -176,8 +243,56 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests { verify(this.xhrSendHandler).handleRequest(this.request, this.response, this.wsHandler, this.session); } + @Test + public void handleTransportRequestJsonp() throws Exception { + TransportHandlingSockJsService jsonpService = new TransportHandlingSockJsService(this.taskScheduler, this.jsonpHandler, this.jsonpSendHandler); + String sockJsPath = sessionUrlPrefix+ "jsonp"; + setRequest("GET", sockJsPrefix + sockJsPath); + jsonpService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler); + assertNotEquals(404, this.servletResponse.getStatus()); + + resetRequestAndResponse(); + jsonpService.setAllowedOrigins(Arrays.asList("http://mydomain1.com")); + setRequest("GET", sockJsPrefix + sockJsPath); + jsonpService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler); + assertEquals(404, this.servletResponse.getStatus()); + + resetRequestAndResponse(); + jsonpService.setAllowedOrigins(null); + setRequest("GET", sockJsPrefix + sockJsPath); + jsonpService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler); + assertEquals(404, this.servletResponse.getStatus()); + } + + @Test + public void handleTransportRequestWebsocket() throws Exception { + TransportHandlingSockJsService wsService = new TransportHandlingSockJsService(this.taskScheduler, this.wsTransportHandler); + String sockJsPath = "/websocket"; + setRequest("GET", sockJsPrefix + sockJsPath); + wsService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler); + assertNotEquals(403, this.servletResponse.getStatus()); + + resetRequestAndResponse(); + OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(); + interceptor.setAllowedOrigins(Arrays.asList("http://mydomain1.com")); + wsService.setHandshakeInterceptors(Arrays.asList(interceptor)); + setRequest("GET", sockJsPrefix + sockJsPath); + setOrigin("http://mydomain1.com"); + wsService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler); + assertNotEquals(403, this.servletResponse.getStatus()); + + resetRequestAndResponse(); + setRequest("GET", sockJsPrefix + sockJsPath); + setOrigin("http://mydomain2.com"); + wsService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler); + assertEquals(403, this.servletResponse.getStatus()); + } + interface SessionCreatingTransportHandler extends TransportHandler, SockJsSessionFactory { } + interface HandshakeTransportHandler extends TransportHandler, HandshakeHandler { + } + } diff --git a/spring-websocket/src/test/resources/org/springframework/web/socket/config/websocket-config-broker-simple.xml b/spring-websocket/src/test/resources/org/springframework/web/socket/config/websocket-config-broker-simple.xml index 174c9503e07..4405fe1864d 100644 --- a/spring-websocket/src/test/resources/org/springframework/web/socket/config/websocket-config-broker-simple.xml +++ b/spring-websocket/src/test/resources/org/springframework/web/socket/config/websocket-config-broker-simple.xml @@ -15,7 +15,7 @@ - + @@ -23,7 +23,7 @@ - + diff --git a/spring-websocket/src/test/resources/org/springframework/web/socket/config/websocket-config-handlers-attributes.xml b/spring-websocket/src/test/resources/org/springframework/web/socket/config/websocket-config-handlers-attributes.xml index 9cc65a3426a..a386e8e7dfc 100644 --- a/spring-websocket/src/test/resources/org/springframework/web/socket/config/websocket-config-handlers-attributes.xml +++ b/spring-websocket/src/test/resources/org/springframework/web/socket/config/websocket-config-handlers-attributes.xml @@ -5,7 +5,7 @@ http://www.springframework.org/schema/beans http://www.springframework.org/schema/beans/spring-beans.xsd http://www.springframework.org/schema/websocket http://www.springframework.org/schema/websocket/spring-websocket.xsd"> - + diff --git a/spring-websocket/src/test/resources/org/springframework/web/socket/config/websocket-config-handlers-sockjs-attributes.xml b/spring-websocket/src/test/resources/org/springframework/web/socket/config/websocket-config-handlers-sockjs-attributes.xml index e937d82c3e4..e43513a5797 100644 --- a/spring-websocket/src/test/resources/org/springframework/web/socket/config/websocket-config-handlers-sockjs-attributes.xml +++ b/spring-websocket/src/test/resources/org/springframework/web/socket/config/websocket-config-handlers-sockjs-attributes.xml @@ -5,7 +5,7 @@ http://www.springframework.org/schema/beans http://www.springframework.org/schema/beans/spring-beans.xsd http://www.springframework.org/schema/websocket http://www.springframework.org/schema/websocket/spring-websocket.xsd"> - +