diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptor.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptor.java index 0008222f14..0401a80c3b 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptor.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptor.java @@ -16,11 +16,11 @@ package org.springframework.web.socket.server.support; -import java.util.ArrayList; import java.util.Collection; import java.util.Collections; -import java.util.List; +import java.util.LinkedHashSet; import java.util.Map; +import java.util.Set; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -34,8 +34,8 @@ import org.springframework.web.socket.server.HandshakeInterceptor; import org.springframework.web.util.WebUtils; /** - * An interceptor to check request {@code Origin} header value against a collection of - * allowed origins. + * An interceptor to check request {@code Origin} header value against a + * collection of allowed origins. * * @author Sebastien Deleuze * @since 4.1.2 @@ -44,60 +44,57 @@ public class OriginHandshakeInterceptor implements HandshakeInterceptor { protected Log logger = LogFactory.getLog(getClass()); - private final List allowedOrigins; + private final Set allowedOrigins = new LinkedHashSet(); /** * Default constructor with only same origin requests allowed. */ public OriginHandshakeInterceptor() { - this.allowedOrigins = new ArrayList(); } /** * Constructor using the specified allowed origin values. - * * @see #setAllowedOrigins(Collection) */ public OriginHandshakeInterceptor(Collection allowedOrigins) { - this(); setAllowedOrigins(allowedOrigins); } + /** * Configure allowed {@code Origin} header values. This check is mostly * designed for browsers. There is nothing preventing other types of client * to modify the {@code Origin} header value. - * *

Each provided allowed origin must have a scheme, and optionally a port * (e.g. "http://example.org", "http://example.org:9090"). An allowed origin * string may also be "*" in which case all origins are allowed. - * * @see RFC 6454: The Web Origin Concept */ public void setAllowedOrigins(Collection allowedOrigins) { - Assert.notNull(allowedOrigins, "Allowed origin Collection must not be null"); + Assert.notNull(allowedOrigins, "Allowed origins Collection must not be null"); this.allowedOrigins.clear(); this.allowedOrigins.addAll(allowedOrigins); } /** - * @see #setAllowedOrigins(Collection) * @since 4.1.5 + * @see #setAllowedOrigins */ public Collection getAllowedOrigins() { - return Collections.unmodifiableList(this.allowedOrigins); + return Collections.unmodifiableSet(this.allowedOrigins); } @Override public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map attributes) throws Exception { + if (!WebUtils.isSameOrigin(request) && !WebUtils.isValidOrigin(request, this.allowedOrigins)) { response.setStatusCode(HttpStatus.FORBIDDEN); if (logger.isDebugEnabled()) { - logger.debug("Handshake request rejected, Origin header value " - + request.getHeaders().getOrigin() + " not allowed"); + logger.debug("Handshake request rejected, Origin header value " + + request.getHeaders().getOrigin() + " not allowed"); } return false; } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java index a282facaf0..9c00716740 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java @@ -18,13 +18,15 @@ package org.springframework.web.socket.sockjs.support; import java.io.IOException; import java.nio.charset.Charset; -import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.Collections; import java.util.Date; import java.util.HashSet; +import java.util.LinkedHashSet; import java.util.List; import java.util.Random; +import java.util.Set; import java.util.concurrent.TimeUnit; import javax.servlet.http.HttpServletRequest; @@ -56,7 +58,7 @@ import org.springframework.web.util.WebUtils; * path resolution and handling of static SockJS requests (e.g. "/info", "/iframe.html", * etc). Sub-classes must handle session URLs (i.e. transport-specific requests). * - * By default, only same origin requests are allowed. Use {@link #setAllowedOrigins(List)} + * By default, only same origin requests are allowed. Use {@link #setAllowedOrigins} * to specify a list of allowed origins (a list containing "*" will allow all origins). * * @author Rossen Stoyanchev @@ -94,10 +96,10 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig private boolean webSocketEnabled = true; - private final List allowedOrigins = new ArrayList(); - private boolean suppressCors = false; + protected final Set allowedOrigins = new LinkedHashSet(); + public AbstractSockJsService(TaskScheduler scheduler) { Assert.notNull(scheduler, "TaskScheduler must not be null"); @@ -274,35 +276,6 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig return this.webSocketEnabled; } - /** - * Configure allowed {@code Origin} header values. This check is mostly - * designed for browsers. There is nothing preventing other types of client - * to modify the {@code Origin} header value. - *

When SockJS is enabled and origins are restricted, transport types - * that do not allow to check request origin (JSONP and Iframe based - * transports) are disabled. As a consequence, IE 6 to 9 are not supported - * when origins are restricted. - *

Each provided allowed origin must have a scheme, and optionally a port - * (e.g. "http://example.org", "http://example.org:9090"). An allowed origin - * string may also be "*" in which case all origins are allowed. - * @since 4.1.2 - * @see RFC 6454: The Web Origin Concept - * @see SockJS supported transports by browser - */ - public void setAllowedOrigins(List allowedOrigins) { - Assert.notNull(allowedOrigins, "Allowed origin List must not be null"); - this.allowedOrigins.clear(); - this.allowedOrigins.addAll(allowedOrigins); - } - - /** - * @since 4.1.2 - * @see #setAllowedOrigins(List) - */ - public List getAllowedOrigins() { - return Collections.unmodifiableList(this.allowedOrigins); - } - /** * This option can be used to disable automatic addition of CORS headers for * SockJS requests. @@ -321,6 +294,35 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig return this.suppressCors; } + /** + * Configure allowed {@code Origin} header values. This check is mostly + * designed for browsers. There is nothing preventing other types of client + * to modify the {@code Origin} header value. + *

When SockJS is enabled and origins are restricted, transport types + * that do not allow to check request origin (JSONP and Iframe based + * transports) are disabled. As a consequence, IE 6 to 9 are not supported + * when origins are restricted. + *

Each provided allowed origin must have a scheme, and optionally a port + * (e.g. "http://example.org", "http://example.org:9090"). An allowed origin + * string may also be "*" in which case all origins are allowed. + * @since 4.1.2 + * @see RFC 6454: The Web Origin Concept + * @see SockJS supported transports by browser + */ + public void setAllowedOrigins(Collection allowedOrigins) { + Assert.notNull(allowedOrigins, "Allowed origins Collection must not be null"); + this.allowedOrigins.clear(); + this.allowedOrigins.addAll(allowedOrigins); + } + + /** + * @since 4.1.2 + * @see #setAllowedOrigins + */ + public Collection getAllowedOrigins() { + return Collections.unmodifiableSet(this.allowedOrigins); + } + /** * This method determines the SockJS path and handles SockJS static URLs. @@ -465,24 +467,11 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig String path = request.getURI().getPath(); int index = path.lastIndexOf('/') + 1; String filename = path.substring(index); - return filename.indexOf(';') == -1; + return (filename.indexOf(';') == -1); } - /** - * Handle request for raw WebSocket communication, i.e. without any SockJS message framing. - */ - protected abstract void handleRawWebSocketRequest(ServerHttpRequest request, - ServerHttpResponse response, WebSocketHandler webSocketHandler) throws IOException; - - /** - * Handle a SockJS session URL (i.e. transport-specific request). - */ - protected abstract void handleTransportRequest(ServerHttpRequest request, ServerHttpResponse response, - WebSocketHandler webSocketHandler, String sessionId, String transport) throws SockJsException; - - - protected boolean checkOrigin(ServerHttpRequest request, ServerHttpResponse response, - HttpMethod... httpMethods) throws IOException { + protected boolean checkOrigin(ServerHttpRequest request, ServerHttpResponse response, HttpMethod... httpMethods) + throws IOException { if (WebUtils.isSameOrigin(request)) { return true; @@ -529,6 +518,19 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig } + /** + * Handle request for raw WebSocket communication, i.e. without any SockJS message framing. + */ + protected abstract void handleRawWebSocketRequest(ServerHttpRequest request, + ServerHttpResponse response, WebSocketHandler webSocketHandler) throws IOException; + + /** + * Handle a SockJS session URL (i.e. transport-specific request). + */ + protected abstract void handleTransportRequest(ServerHttpRequest request, ServerHttpResponse response, + WebSocketHandler webSocketHandler, String sessionId, String transport) throws SockJsException; + + private interface SockJsRequestHandler { void handle(ServerHttpRequest request, ServerHttpResponse response) throws IOException; @@ -546,8 +548,8 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig addNoCacheHeaders(response); if (checkOrigin(request, response)) { response.getHeaders().setContentType(new MediaType("application", "json", UTF8_CHARSET)); - String content = String.format(INFO_CONTENT, random.nextInt(), - isSessionCookieNeeded(), isWebSocketEnabled()); + String content = String.format( + INFO_CONTENT, random.nextInt(), isSessionCookieNeeded(), isWebSocketEnabled()); response.getBody().write(content.getBytes()); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java index 6854bf5db4..811c0861b6 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java @@ -326,7 +326,7 @@ public class TransportHandlingSockJsService extends AbstractSockJsService implem return false; } - if (!getAllowedOrigins().contains("*")) { + if (!this.allowedOrigins.contains("*")) { TransportType transportType = TransportType.fromValue(transport); if (transportType == null || !transportType.supportsOrigin()) { if (logger.isWarnEnabled()) { diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/config/HandlersBeanDefinitionParserTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/config/HandlersBeanDefinitionParserTests.java index 5c7c949104..462307b1bd 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/config/HandlersBeanDefinitionParserTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/config/HandlersBeanDefinitionParserTests.java @@ -16,18 +16,13 @@ package org.springframework.web.socket.config; -import static org.hamcrest.Matchers.*; -import static org.junit.Assert.*; - import java.io.IOException; import java.io.InputStream; -import java.util.Arrays; import java.util.Date; import java.util.List; import java.util.Map; import java.util.concurrent.ScheduledFuture; -import org.junit.Before; import org.junit.Test; import org.springframework.beans.factory.xml.XmlBeanDefinitionReader; @@ -67,6 +62,9 @@ import org.springframework.web.socket.sockjs.transport.handler.XhrPollingTranspo import org.springframework.web.socket.sockjs.transport.handler.XhrReceivingTransportHandler; import org.springframework.web.socket.sockjs.transport.handler.XhrStreamingTransportHandler; +import static org.hamcrest.Matchers.*; +import static org.junit.Assert.*; + /** * Test fixture for HandlersBeanDefinitionParser. * See test configuration files websocket-config-handlers-*.xml. @@ -76,13 +74,7 @@ import org.springframework.web.socket.sockjs.transport.handler.XhrStreamingTrans */ public class HandlersBeanDefinitionParserTests { - private GenericWebApplicationContext appContext; - - - @Before - public void setup() { - this.appContext = new GenericWebApplicationContext(); - } + private GenericWebApplicationContext appContext = new GenericWebApplicationContext(); @Test @@ -234,10 +226,12 @@ public class HandlersBeanDefinitionParserTests { List interceptors = transportService.getHandshakeInterceptors(); assertThat(interceptors, contains(instanceOf(OriginHandshakeInterceptor.class))); - assertEquals(Arrays.asList("http://mydomain1.com", "http://mydomain2.com"), transportService.getAllowedOrigins()); assertTrue(transportService.shouldSuppressCors()); + assertTrue(transportService.getAllowedOrigins().contains("http://mydomain1.com")); + assertTrue(transportService.getAllowedOrigins().contains("http://mydomain2.com")); } + private void loadBeanDefinitions(String fileName) { XmlBeanDefinitionReader reader = new XmlBeanDefinitionReader(this.appContext); ClassPathResource resource = new ClassPathResource(fileName, HandlersBeanDefinitionParserTests.class); @@ -278,9 +272,11 @@ class TestWebSocketHandler implements WebSocketHandler { } } + class FooWebSocketHandler extends TestWebSocketHandler { } + class TestHandshakeHandler implements HandshakeHandler { @Override @@ -291,9 +287,11 @@ class TestHandshakeHandler implements HandshakeHandler { } } + class TestChannelInterceptor extends ChannelInterceptorAdapter { } + class FooTestInterceptor implements HandshakeInterceptor { @Override @@ -309,9 +307,11 @@ class FooTestInterceptor implements HandshakeInterceptor { } } + class BarTestInterceptor extends FooTestInterceptor { } + @SuppressWarnings({ "unchecked", "rawtypes" }) class TestTaskScheduler implements TaskScheduler { @@ -344,9 +344,9 @@ class TestTaskScheduler implements TaskScheduler { public ScheduledFuture scheduleWithFixedDelay(Runnable task, long delay) { return null; } - } + class TestMessageCodec implements SockJsMessageCodec { @Override @@ -363,4 +363,4 @@ class TestMessageCodec implements SockJsMessageCodec { public String[] decodeInputStream(InputStream content) throws IOException { return new String[0]; } -} \ No newline at end of file +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java index ca0076312c..a7a10465f2 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java @@ -86,16 +86,8 @@ import org.springframework.web.socket.sockjs.transport.TransportType; import org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService; import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler; -import static org.hamcrest.Matchers.contains; -import static org.hamcrest.Matchers.instanceOf; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertSame; -import static org.junit.Assert.assertThat; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; +import static org.hamcrest.Matchers.*; +import static org.junit.Assert.*; /** * Test fixture for MessageBrokerBeanDefinitionParser. @@ -192,7 +184,8 @@ public class MessageBrokerBeanDefinitionParserTests { interceptors = defaultSockJsService.getHandshakeInterceptors(); assertThat(interceptors, contains(instanceOf(FooTestInterceptor.class), instanceOf(BarTestInterceptor.class), instanceOf(OriginHandshakeInterceptor.class))); - assertEquals(Arrays.asList("http://mydomain3.com", "http://mydomain4.com"), defaultSockJsService.getAllowedOrigins()); + assertTrue(defaultSockJsService.getAllowedOrigins().contains("http://mydomain3.com")); + assertTrue(defaultSockJsService.getAllowedOrigins().contains("http://mydomain4.com")); SimpUserRegistry userRegistry = this.appContext.getBean(SimpUserRegistry.class); assertNotNull(userRegistry); @@ -478,9 +471,9 @@ public class MessageBrokerBeanDefinitionParserTests { return (handler instanceof WebSocketHandlerDecorator) ? ((WebSocketHandlerDecorator) handler).getLastHandler() : handler; } - } + class CustomArgumentResolver implements HandlerMethodArgumentResolver { @Override @@ -494,6 +487,7 @@ class CustomArgumentResolver implements HandlerMethodArgumentResolver { } } + class CustomReturnValueHandler implements HandlerMethodReturnValueHandler { @Override @@ -507,6 +501,7 @@ class CustomReturnValueHandler implements HandlerMethodReturnValueHandler { } } + class TestWebSocketHandlerDecoratorFactory implements WebSocketHandlerDecoratorFactory { @Override @@ -515,6 +510,7 @@ class TestWebSocketHandlerDecoratorFactory implements WebSocketHandlerDecoratorF } } + class TestWebSocketHandlerDecorator extends WebSocketHandlerDecorator { public TestWebSocketHandlerDecorator(WebSocketHandler delegate) { @@ -528,6 +524,6 @@ class TestWebSocketHandlerDecorator extends WebSocketHandlerDecorator { } } -class TestStompErrorHandler extends StompSubProtocolErrorHandler { -} \ No newline at end of file +class TestStompErrorHandler extends StompSubProtocolErrorHandler { +} diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistrationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistrationTests.java index d323e2b843..435e515850 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistrationTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistrationTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,9 +29,9 @@ import org.springframework.scheduling.TaskScheduler; import org.springframework.util.MultiValueMap; import org.springframework.web.HttpRequestHandler; import org.springframework.web.socket.messaging.SubProtocolWebSocketHandler; -import org.springframework.web.socket.server.support.OriginHandshakeInterceptor; import org.springframework.web.socket.server.support.DefaultHandshakeHandler; import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor; +import org.springframework.web.socket.server.support.OriginHandshakeInterceptor; import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler; import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler; import org.springframework.web.socket.sockjs.transport.TransportHandler; @@ -117,7 +117,7 @@ public class WebMvcStompWebSocketEndpointRegistrationTests { SockJsHttpRequestHandler requestHandler = (SockJsHttpRequestHandler)mappings.entrySet().iterator().next().getKey(); assertNotNull(requestHandler.getSockJsService()); DefaultSockJsService sockJsService = (DefaultSockJsService)requestHandler.getSockJsService(); - assertEquals(Arrays.asList(origin), sockJsService.getAllowedOrigins()); + assertTrue(sockJsService.getAllowedOrigins().contains(origin)); assertFalse(sockJsService.shouldSuppressCors()); registration = @@ -128,7 +128,7 @@ public class WebMvcStompWebSocketEndpointRegistrationTests { requestHandler = (SockJsHttpRequestHandler)mappings.entrySet().iterator().next().getKey(); assertNotNull(requestHandler.getSockJsService()); sockJsService = (DefaultSockJsService)requestHandler.getSockJsService(); - assertEquals(Arrays.asList(origin), sockJsService.getAllowedOrigins()); + assertTrue(sockJsService.getAllowedOrigins().contains(origin)); assertFalse(sockJsService.shouldSuppressCors()); } @@ -255,7 +255,7 @@ public class WebMvcStompWebSocketEndpointRegistrationTests { assertEquals(interceptor, sockJsService.getHandshakeInterceptors().get(0)); assertEquals(OriginHandshakeInterceptor.class, sockJsService.getHandshakeInterceptors().get(1).getClass()); - assertEquals(Arrays.asList(origin), sockJsService.getAllowedOrigins()); + assertTrue(sockJsService.getAllowedOrigins().contains(origin)); } } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketHandlerRegistrationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketHandlerRegistrationTests.java index 7591558e10..3a18fb7d13 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketHandlerRegistrationTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketHandlerRegistrationTests.java @@ -17,7 +17,6 @@ package org.springframework.web.socket.config.annotation; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import org.junit.Before; @@ -29,9 +28,9 @@ import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.handler.TextWebSocketHandler; import org.springframework.web.socket.server.HandshakeHandler; import org.springframework.web.socket.server.HandshakeInterceptor; -import org.springframework.web.socket.server.support.OriginHandshakeInterceptor; import org.springframework.web.socket.server.support.DefaultHandshakeHandler; import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor; +import org.springframework.web.socket.server.support.OriginHandshakeInterceptor; import org.springframework.web.socket.sockjs.SockJsService; import org.springframework.web.socket.sockjs.transport.TransportType; import org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService; @@ -148,8 +147,7 @@ public class WebSocketHandlerRegistrationTests { assertEquals(handler, mapping.webSocketHandler); assertEquals("/foo/**", mapping.path); assertNotNull(mapping.sockJsService); - assertEquals(Arrays.asList("http://mydomain1.com"), - mapping.sockJsService.getAllowedOrigins()); + assertTrue(mapping.sockJsService.getAllowedOrigins().contains("http://mydomain1.com")); List interceptors = mapping.sockJsService.getHandshakeInterceptors(); assertEquals(interceptor, interceptors.get(0)); assertEquals(OriginHandshakeInterceptor.class, interceptors.get(1).getClass()); @@ -218,6 +216,7 @@ public class WebSocketHandlerRegistrationTests { } } + private static class Mapping { private final WebSocketHandler webSocketHandler; @@ -230,7 +229,6 @@ public class WebSocketHandlerRegistrationTests { private final DefaultSockJsService sockJsService; - public Mapping(WebSocketHandler handler, String path, SockJsService sockJsService) { this.webSocketHandler = handler; this.path = path;