Efficient and consistent setAllowedOrigins collection type

Issue: SPR-13761
This commit is contained in:
Juergen Hoeller 2015-12-04 16:21:53 +01:00
parent cd4ce8727e
commit 3d1ae9c604
7 changed files with 99 additions and 106 deletions

View File

@ -16,11 +16,11 @@
package org.springframework.web.socket.server.support; package org.springframework.web.socket.server.support;
import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.LinkedHashSet;
import java.util.Map; import java.util.Map;
import java.util.Set;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
@ -34,8 +34,8 @@ import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.util.WebUtils; import org.springframework.web.util.WebUtils;
/** /**
* An interceptor to check request {@code Origin} header value against a collection of * An interceptor to check request {@code Origin} header value against a
* allowed origins. * collection of allowed origins.
* *
* @author Sebastien Deleuze * @author Sebastien Deleuze
* @since 4.1.2 * @since 4.1.2
@ -44,60 +44,57 @@ public class OriginHandshakeInterceptor implements HandshakeInterceptor {
protected Log logger = LogFactory.getLog(getClass()); protected Log logger = LogFactory.getLog(getClass());
private final List<String> allowedOrigins; private final Set<String> allowedOrigins = new LinkedHashSet<String>();
/** /**
* Default constructor with only same origin requests allowed. * Default constructor with only same origin requests allowed.
*/ */
public OriginHandshakeInterceptor() { public OriginHandshakeInterceptor() {
this.allowedOrigins = new ArrayList<String>();
} }
/** /**
* Constructor using the specified allowed origin values. * Constructor using the specified allowed origin values.
*
* @see #setAllowedOrigins(Collection) * @see #setAllowedOrigins(Collection)
*/ */
public OriginHandshakeInterceptor(Collection<String> allowedOrigins) { public OriginHandshakeInterceptor(Collection<String> allowedOrigins) {
this();
setAllowedOrigins(allowedOrigins); setAllowedOrigins(allowedOrigins);
} }
/** /**
* Configure allowed {@code Origin} header values. This check is mostly * Configure allowed {@code Origin} header values. This check is mostly
* designed for browsers. There is nothing preventing other types of client * designed for browsers. There is nothing preventing other types of client
* to modify the {@code Origin} header value. * to modify the {@code Origin} header value.
*
* <p>Each provided allowed origin must have a scheme, and optionally a port * <p>Each provided allowed origin must have a scheme, and optionally a port
* (e.g. "http://example.org", "http://example.org:9090"). An allowed origin * (e.g. "http://example.org", "http://example.org:9090"). An allowed origin
* string may also be "*" in which case all origins are allowed. * string may also be "*" in which case all origins are allowed.
*
* @see <a href="https://tools.ietf.org/html/rfc6454">RFC 6454: The Web Origin Concept</a> * @see <a href="https://tools.ietf.org/html/rfc6454">RFC 6454: The Web Origin Concept</a>
*/ */
public void setAllowedOrigins(Collection<String> allowedOrigins) { public void setAllowedOrigins(Collection<String> allowedOrigins) {
Assert.notNull(allowedOrigins, "Allowed origin Collection must not be null"); Assert.notNull(allowedOrigins, "Allowed origins Collection must not be null");
this.allowedOrigins.clear(); this.allowedOrigins.clear();
this.allowedOrigins.addAll(allowedOrigins); this.allowedOrigins.addAll(allowedOrigins);
} }
/** /**
* @see #setAllowedOrigins(Collection)
* @since 4.1.5 * @since 4.1.5
* @see #setAllowedOrigins
*/ */
public Collection<String> getAllowedOrigins() { public Collection<String> getAllowedOrigins() {
return Collections.unmodifiableList(this.allowedOrigins); return Collections.unmodifiableSet(this.allowedOrigins);
} }
@Override @Override
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception { WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
if (!WebUtils.isSameOrigin(request) && !WebUtils.isValidOrigin(request, this.allowedOrigins)) { if (!WebUtils.isSameOrigin(request) && !WebUtils.isValidOrigin(request, this.allowedOrigins)) {
response.setStatusCode(HttpStatus.FORBIDDEN); response.setStatusCode(HttpStatus.FORBIDDEN);
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("Handshake request rejected, Origin header value " logger.debug("Handshake request rejected, Origin header value " +
+ request.getHeaders().getOrigin() + " not allowed"); request.getHeaders().getOrigin() + " not allowed");
} }
return false; return false;
} }

View File

@ -18,13 +18,15 @@ package org.springframework.web.socket.sockjs.support;
import java.io.IOException; import java.io.IOException;
import java.nio.charset.Charset; import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.Date; import java.util.Date;
import java.util.HashSet; import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List; import java.util.List;
import java.util.Random; import java.util.Random;
import java.util.Set;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
@ -56,7 +58,7 @@ import org.springframework.web.util.WebUtils;
* path resolution and handling of static SockJS requests (e.g. "/info", "/iframe.html", * path resolution and handling of static SockJS requests (e.g. "/info", "/iframe.html",
* etc). Sub-classes must handle session URLs (i.e. transport-specific requests). * etc). Sub-classes must handle session URLs (i.e. transport-specific requests).
* *
* By default, only same origin requests are allowed. Use {@link #setAllowedOrigins(List)} * By default, only same origin requests are allowed. Use {@link #setAllowedOrigins}
* to specify a list of allowed origins (a list containing "*" will allow all origins). * to specify a list of allowed origins (a list containing "*" will allow all origins).
* *
* @author Rossen Stoyanchev * @author Rossen Stoyanchev
@ -94,10 +96,10 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig
private boolean webSocketEnabled = true; private boolean webSocketEnabled = true;
private final List<String> allowedOrigins = new ArrayList<String>();
private boolean suppressCors = false; private boolean suppressCors = false;
protected final Set<String> allowedOrigins = new LinkedHashSet<String>();
public AbstractSockJsService(TaskScheduler scheduler) { public AbstractSockJsService(TaskScheduler scheduler) {
Assert.notNull(scheduler, "TaskScheduler must not be null"); Assert.notNull(scheduler, "TaskScheduler must not be null");
@ -274,35 +276,6 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig
return this.webSocketEnabled; return this.webSocketEnabled;
} }
/**
* Configure allowed {@code Origin} header values. This check is mostly
* designed for browsers. There is nothing preventing other types of client
* to modify the {@code Origin} header value.
* <p>When SockJS is enabled and origins are restricted, transport types
* that do not allow to check request origin (JSONP and Iframe based
* transports) are disabled. As a consequence, IE 6 to 9 are not supported
* when origins are restricted.
* <p>Each provided allowed origin must have a scheme, and optionally a port
* (e.g. "http://example.org", "http://example.org:9090"). An allowed origin
* string may also be "*" in which case all origins are allowed.
* @since 4.1.2
* @see <a href="https://tools.ietf.org/html/rfc6454">RFC 6454: The Web Origin Concept</a>
* @see <a href="https://github.com/sockjs/sockjs-client#supported-transports-by-browser-html-served-from-http-or-https">SockJS supported transports by browser</a>
*/
public void setAllowedOrigins(List<String> allowedOrigins) {
Assert.notNull(allowedOrigins, "Allowed origin List must not be null");
this.allowedOrigins.clear();
this.allowedOrigins.addAll(allowedOrigins);
}
/**
* @since 4.1.2
* @see #setAllowedOrigins(List)
*/
public List<String> getAllowedOrigins() {
return Collections.unmodifiableList(this.allowedOrigins);
}
/** /**
* This option can be used to disable automatic addition of CORS headers for * This option can be used to disable automatic addition of CORS headers for
* SockJS requests. * SockJS requests.
@ -321,6 +294,35 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig
return this.suppressCors; return this.suppressCors;
} }
/**
* Configure allowed {@code Origin} header values. This check is mostly
* designed for browsers. There is nothing preventing other types of client
* to modify the {@code Origin} header value.
* <p>When SockJS is enabled and origins are restricted, transport types
* that do not allow to check request origin (JSONP and Iframe based
* transports) are disabled. As a consequence, IE 6 to 9 are not supported
* when origins are restricted.
* <p>Each provided allowed origin must have a scheme, and optionally a port
* (e.g. "http://example.org", "http://example.org:9090"). An allowed origin
* string may also be "*" in which case all origins are allowed.
* @since 4.1.2
* @see <a href="https://tools.ietf.org/html/rfc6454">RFC 6454: The Web Origin Concept</a>
* @see <a href="https://github.com/sockjs/sockjs-client#supported-transports-by-browser-html-served-from-http-or-https">SockJS supported transports by browser</a>
*/
public void setAllowedOrigins(Collection<String> allowedOrigins) {
Assert.notNull(allowedOrigins, "Allowed origins Collection must not be null");
this.allowedOrigins.clear();
this.allowedOrigins.addAll(allowedOrigins);
}
/**
* @since 4.1.2
* @see #setAllowedOrigins
*/
public Collection<String> getAllowedOrigins() {
return Collections.unmodifiableSet(this.allowedOrigins);
}
/** /**
* This method determines the SockJS path and handles SockJS static URLs. * This method determines the SockJS path and handles SockJS static URLs.
@ -465,24 +467,11 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig
String path = request.getURI().getPath(); String path = request.getURI().getPath();
int index = path.lastIndexOf('/') + 1; int index = path.lastIndexOf('/') + 1;
String filename = path.substring(index); String filename = path.substring(index);
return filename.indexOf(';') == -1; return (filename.indexOf(';') == -1);
} }
/** protected boolean checkOrigin(ServerHttpRequest request, ServerHttpResponse response, HttpMethod... httpMethods)
* Handle request for raw WebSocket communication, i.e. without any SockJS message framing. throws IOException {
*/
protected abstract void handleRawWebSocketRequest(ServerHttpRequest request,
ServerHttpResponse response, WebSocketHandler webSocketHandler) throws IOException;
/**
* Handle a SockJS session URL (i.e. transport-specific request).
*/
protected abstract void handleTransportRequest(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler webSocketHandler, String sessionId, String transport) throws SockJsException;
protected boolean checkOrigin(ServerHttpRequest request, ServerHttpResponse response,
HttpMethod... httpMethods) throws IOException {
if (WebUtils.isSameOrigin(request)) { if (WebUtils.isSameOrigin(request)) {
return true; return true;
@ -529,6 +518,19 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig
} }
/**
* Handle request for raw WebSocket communication, i.e. without any SockJS message framing.
*/
protected abstract void handleRawWebSocketRequest(ServerHttpRequest request,
ServerHttpResponse response, WebSocketHandler webSocketHandler) throws IOException;
/**
* Handle a SockJS session URL (i.e. transport-specific request).
*/
protected abstract void handleTransportRequest(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler webSocketHandler, String sessionId, String transport) throws SockJsException;
private interface SockJsRequestHandler { private interface SockJsRequestHandler {
void handle(ServerHttpRequest request, ServerHttpResponse response) throws IOException; void handle(ServerHttpRequest request, ServerHttpResponse response) throws IOException;
@ -546,8 +548,8 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig
addNoCacheHeaders(response); addNoCacheHeaders(response);
if (checkOrigin(request, response)) { if (checkOrigin(request, response)) {
response.getHeaders().setContentType(new MediaType("application", "json", UTF8_CHARSET)); response.getHeaders().setContentType(new MediaType("application", "json", UTF8_CHARSET));
String content = String.format(INFO_CONTENT, random.nextInt(), String content = String.format(
isSessionCookieNeeded(), isWebSocketEnabled()); INFO_CONTENT, random.nextInt(), isSessionCookieNeeded(), isWebSocketEnabled());
response.getBody().write(content.getBytes()); response.getBody().write(content.getBytes());
} }

View File

@ -326,7 +326,7 @@ public class TransportHandlingSockJsService extends AbstractSockJsService implem
return false; return false;
} }
if (!getAllowedOrigins().contains("*")) { if (!this.allowedOrigins.contains("*")) {
TransportType transportType = TransportType.fromValue(transport); TransportType transportType = TransportType.fromValue(transport);
if (transportType == null || !transportType.supportsOrigin()) { if (transportType == null || !transportType.supportsOrigin()) {
if (logger.isWarnEnabled()) { if (logger.isWarnEnabled()) {

View File

@ -16,18 +16,13 @@
package org.springframework.web.socket.config; package org.springframework.web.socket.config;
import static org.hamcrest.Matchers.*;
import static org.junit.Assert.*;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.util.Arrays;
import java.util.Date; import java.util.Date;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.ScheduledFuture; import java.util.concurrent.ScheduledFuture;
import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.springframework.beans.factory.xml.XmlBeanDefinitionReader; import org.springframework.beans.factory.xml.XmlBeanDefinitionReader;
@ -67,6 +62,9 @@ import org.springframework.web.socket.sockjs.transport.handler.XhrPollingTranspo
import org.springframework.web.socket.sockjs.transport.handler.XhrReceivingTransportHandler; import org.springframework.web.socket.sockjs.transport.handler.XhrReceivingTransportHandler;
import org.springframework.web.socket.sockjs.transport.handler.XhrStreamingTransportHandler; import org.springframework.web.socket.sockjs.transport.handler.XhrStreamingTransportHandler;
import static org.hamcrest.Matchers.*;
import static org.junit.Assert.*;
/** /**
* Test fixture for HandlersBeanDefinitionParser. * Test fixture for HandlersBeanDefinitionParser.
* See test configuration files websocket-config-handlers-*.xml. * See test configuration files websocket-config-handlers-*.xml.
@ -76,13 +74,7 @@ import org.springframework.web.socket.sockjs.transport.handler.XhrStreamingTrans
*/ */
public class HandlersBeanDefinitionParserTests { public class HandlersBeanDefinitionParserTests {
private GenericWebApplicationContext appContext; private GenericWebApplicationContext appContext = new GenericWebApplicationContext();
@Before
public void setup() {
this.appContext = new GenericWebApplicationContext();
}
@Test @Test
@ -234,10 +226,12 @@ public class HandlersBeanDefinitionParserTests {
List<HandshakeInterceptor> interceptors = transportService.getHandshakeInterceptors(); List<HandshakeInterceptor> interceptors = transportService.getHandshakeInterceptors();
assertThat(interceptors, contains(instanceOf(OriginHandshakeInterceptor.class))); assertThat(interceptors, contains(instanceOf(OriginHandshakeInterceptor.class)));
assertEquals(Arrays.asList("http://mydomain1.com", "http://mydomain2.com"), transportService.getAllowedOrigins());
assertTrue(transportService.shouldSuppressCors()); assertTrue(transportService.shouldSuppressCors());
assertTrue(transportService.getAllowedOrigins().contains("http://mydomain1.com"));
assertTrue(transportService.getAllowedOrigins().contains("http://mydomain2.com"));
} }
private void loadBeanDefinitions(String fileName) { private void loadBeanDefinitions(String fileName) {
XmlBeanDefinitionReader reader = new XmlBeanDefinitionReader(this.appContext); XmlBeanDefinitionReader reader = new XmlBeanDefinitionReader(this.appContext);
ClassPathResource resource = new ClassPathResource(fileName, HandlersBeanDefinitionParserTests.class); ClassPathResource resource = new ClassPathResource(fileName, HandlersBeanDefinitionParserTests.class);
@ -278,9 +272,11 @@ class TestWebSocketHandler implements WebSocketHandler {
} }
} }
class FooWebSocketHandler extends TestWebSocketHandler { class FooWebSocketHandler extends TestWebSocketHandler {
} }
class TestHandshakeHandler implements HandshakeHandler { class TestHandshakeHandler implements HandshakeHandler {
@Override @Override
@ -291,9 +287,11 @@ class TestHandshakeHandler implements HandshakeHandler {
} }
} }
class TestChannelInterceptor extends ChannelInterceptorAdapter { class TestChannelInterceptor extends ChannelInterceptorAdapter {
} }
class FooTestInterceptor implements HandshakeInterceptor { class FooTestInterceptor implements HandshakeInterceptor {
@Override @Override
@ -309,9 +307,11 @@ class FooTestInterceptor implements HandshakeInterceptor {
} }
} }
class BarTestInterceptor extends FooTestInterceptor { class BarTestInterceptor extends FooTestInterceptor {
} }
@SuppressWarnings({ "unchecked", "rawtypes" }) @SuppressWarnings({ "unchecked", "rawtypes" })
class TestTaskScheduler implements TaskScheduler { class TestTaskScheduler implements TaskScheduler {
@ -344,9 +344,9 @@ class TestTaskScheduler implements TaskScheduler {
public ScheduledFuture scheduleWithFixedDelay(Runnable task, long delay) { public ScheduledFuture scheduleWithFixedDelay(Runnable task, long delay) {
return null; return null;
} }
} }
class TestMessageCodec implements SockJsMessageCodec { class TestMessageCodec implements SockJsMessageCodec {
@Override @Override

View File

@ -86,16 +86,8 @@ import org.springframework.web.socket.sockjs.transport.TransportType;
import org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService; import org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService;
import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler; import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler;
import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.*;
import static org.hamcrest.Matchers.instanceOf; import static org.junit.Assert.*;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
/** /**
* Test fixture for MessageBrokerBeanDefinitionParser. * Test fixture for MessageBrokerBeanDefinitionParser.
@ -192,7 +184,8 @@ public class MessageBrokerBeanDefinitionParserTests {
interceptors = defaultSockJsService.getHandshakeInterceptors(); interceptors = defaultSockJsService.getHandshakeInterceptors();
assertThat(interceptors, contains(instanceOf(FooTestInterceptor.class), assertThat(interceptors, contains(instanceOf(FooTestInterceptor.class),
instanceOf(BarTestInterceptor.class), instanceOf(OriginHandshakeInterceptor.class))); instanceOf(BarTestInterceptor.class), instanceOf(OriginHandshakeInterceptor.class)));
assertEquals(Arrays.asList("http://mydomain3.com", "http://mydomain4.com"), defaultSockJsService.getAllowedOrigins()); assertTrue(defaultSockJsService.getAllowedOrigins().contains("http://mydomain3.com"));
assertTrue(defaultSockJsService.getAllowedOrigins().contains("http://mydomain4.com"));
SimpUserRegistry userRegistry = this.appContext.getBean(SimpUserRegistry.class); SimpUserRegistry userRegistry = this.appContext.getBean(SimpUserRegistry.class);
assertNotNull(userRegistry); assertNotNull(userRegistry);
@ -478,9 +471,9 @@ public class MessageBrokerBeanDefinitionParserTests {
return (handler instanceof WebSocketHandlerDecorator) ? return (handler instanceof WebSocketHandlerDecorator) ?
((WebSocketHandlerDecorator) handler).getLastHandler() : handler; ((WebSocketHandlerDecorator) handler).getLastHandler() : handler;
} }
} }
class CustomArgumentResolver implements HandlerMethodArgumentResolver { class CustomArgumentResolver implements HandlerMethodArgumentResolver {
@Override @Override
@ -494,6 +487,7 @@ class CustomArgumentResolver implements HandlerMethodArgumentResolver {
} }
} }
class CustomReturnValueHandler implements HandlerMethodReturnValueHandler { class CustomReturnValueHandler implements HandlerMethodReturnValueHandler {
@Override @Override
@ -507,6 +501,7 @@ class CustomReturnValueHandler implements HandlerMethodReturnValueHandler {
} }
} }
class TestWebSocketHandlerDecoratorFactory implements WebSocketHandlerDecoratorFactory { class TestWebSocketHandlerDecoratorFactory implements WebSocketHandlerDecoratorFactory {
@Override @Override
@ -515,6 +510,7 @@ class TestWebSocketHandlerDecoratorFactory implements WebSocketHandlerDecoratorF
} }
} }
class TestWebSocketHandlerDecorator extends WebSocketHandlerDecorator { class TestWebSocketHandlerDecorator extends WebSocketHandlerDecorator {
public TestWebSocketHandlerDecorator(WebSocketHandler delegate) { public TestWebSocketHandlerDecorator(WebSocketHandler delegate) {
@ -528,6 +524,6 @@ class TestWebSocketHandlerDecorator extends WebSocketHandlerDecorator {
} }
} }
class TestStompErrorHandler extends StompSubProtocolErrorHandler {
class TestStompErrorHandler extends StompSubProtocolErrorHandler {
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2014 the original author or authors. * Copyright 2002-2015 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -29,9 +29,9 @@ import org.springframework.scheduling.TaskScheduler;
import org.springframework.util.MultiValueMap; import org.springframework.util.MultiValueMap;
import org.springframework.web.HttpRequestHandler; import org.springframework.web.HttpRequestHandler;
import org.springframework.web.socket.messaging.SubProtocolWebSocketHandler; import org.springframework.web.socket.messaging.SubProtocolWebSocketHandler;
import org.springframework.web.socket.server.support.OriginHandshakeInterceptor;
import org.springframework.web.socket.server.support.DefaultHandshakeHandler; import org.springframework.web.socket.server.support.DefaultHandshakeHandler;
import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor; import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor;
import org.springframework.web.socket.server.support.OriginHandshakeInterceptor;
import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler; import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler;
import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler; import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler;
import org.springframework.web.socket.sockjs.transport.TransportHandler; import org.springframework.web.socket.sockjs.transport.TransportHandler;
@ -117,7 +117,7 @@ public class WebMvcStompWebSocketEndpointRegistrationTests {
SockJsHttpRequestHandler requestHandler = (SockJsHttpRequestHandler)mappings.entrySet().iterator().next().getKey(); SockJsHttpRequestHandler requestHandler = (SockJsHttpRequestHandler)mappings.entrySet().iterator().next().getKey();
assertNotNull(requestHandler.getSockJsService()); assertNotNull(requestHandler.getSockJsService());
DefaultSockJsService sockJsService = (DefaultSockJsService)requestHandler.getSockJsService(); DefaultSockJsService sockJsService = (DefaultSockJsService)requestHandler.getSockJsService();
assertEquals(Arrays.asList(origin), sockJsService.getAllowedOrigins()); assertTrue(sockJsService.getAllowedOrigins().contains(origin));
assertFalse(sockJsService.shouldSuppressCors()); assertFalse(sockJsService.shouldSuppressCors());
registration = registration =
@ -128,7 +128,7 @@ public class WebMvcStompWebSocketEndpointRegistrationTests {
requestHandler = (SockJsHttpRequestHandler)mappings.entrySet().iterator().next().getKey(); requestHandler = (SockJsHttpRequestHandler)mappings.entrySet().iterator().next().getKey();
assertNotNull(requestHandler.getSockJsService()); assertNotNull(requestHandler.getSockJsService());
sockJsService = (DefaultSockJsService)requestHandler.getSockJsService(); sockJsService = (DefaultSockJsService)requestHandler.getSockJsService();
assertEquals(Arrays.asList(origin), sockJsService.getAllowedOrigins()); assertTrue(sockJsService.getAllowedOrigins().contains(origin));
assertFalse(sockJsService.shouldSuppressCors()); assertFalse(sockJsService.shouldSuppressCors());
} }
@ -255,7 +255,7 @@ public class WebMvcStompWebSocketEndpointRegistrationTests {
assertEquals(interceptor, sockJsService.getHandshakeInterceptors().get(0)); assertEquals(interceptor, sockJsService.getHandshakeInterceptors().get(0));
assertEquals(OriginHandshakeInterceptor.class, assertEquals(OriginHandshakeInterceptor.class,
sockJsService.getHandshakeInterceptors().get(1).getClass()); sockJsService.getHandshakeInterceptors().get(1).getClass());
assertEquals(Arrays.asList(origin), sockJsService.getAllowedOrigins()); assertTrue(sockJsService.getAllowedOrigins().contains(origin));
} }
} }

View File

@ -17,7 +17,6 @@
package org.springframework.web.socket.config.annotation; package org.springframework.web.socket.config.annotation;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.List; import java.util.List;
import org.junit.Before; import org.junit.Before;
@ -29,9 +28,9 @@ import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.handler.TextWebSocketHandler; import org.springframework.web.socket.handler.TextWebSocketHandler;
import org.springframework.web.socket.server.HandshakeHandler; import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.HandshakeInterceptor; import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.socket.server.support.OriginHandshakeInterceptor;
import org.springframework.web.socket.server.support.DefaultHandshakeHandler; import org.springframework.web.socket.server.support.DefaultHandshakeHandler;
import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor; import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor;
import org.springframework.web.socket.server.support.OriginHandshakeInterceptor;
import org.springframework.web.socket.sockjs.SockJsService; import org.springframework.web.socket.sockjs.SockJsService;
import org.springframework.web.socket.sockjs.transport.TransportType; import org.springframework.web.socket.sockjs.transport.TransportType;
import org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService; import org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService;
@ -148,8 +147,7 @@ public class WebSocketHandlerRegistrationTests {
assertEquals(handler, mapping.webSocketHandler); assertEquals(handler, mapping.webSocketHandler);
assertEquals("/foo/**", mapping.path); assertEquals("/foo/**", mapping.path);
assertNotNull(mapping.sockJsService); assertNotNull(mapping.sockJsService);
assertEquals(Arrays.asList("http://mydomain1.com"), assertTrue(mapping.sockJsService.getAllowedOrigins().contains("http://mydomain1.com"));
mapping.sockJsService.getAllowedOrigins());
List<HandshakeInterceptor> interceptors = mapping.sockJsService.getHandshakeInterceptors(); List<HandshakeInterceptor> interceptors = mapping.sockJsService.getHandshakeInterceptors();
assertEquals(interceptor, interceptors.get(0)); assertEquals(interceptor, interceptors.get(0));
assertEquals(OriginHandshakeInterceptor.class, interceptors.get(1).getClass()); assertEquals(OriginHandshakeInterceptor.class, interceptors.get(1).getClass());
@ -218,6 +216,7 @@ public class WebSocketHandlerRegistrationTests {
} }
} }
private static class Mapping { private static class Mapping {
private final WebSocketHandler webSocketHandler; private final WebSocketHandler webSocketHandler;
@ -230,7 +229,6 @@ public class WebSocketHandlerRegistrationTests {
private final DefaultSockJsService sockJsService; private final DefaultSockJsService sockJsService;
public Mapping(WebSocketHandler handler, String path, SockJsService sockJsService) { public Mapping(WebSocketHandler handler, String path, SockJsService sockJsService) {
this.webSocketHandler = handler; this.webSocketHandler = handler;
this.path = path; this.path = path;