Allow "ws" and "wss" for isValidCorsOrigin checks

Issue: SPR-12956
This commit is contained in:
Rossen Stoyanchev 2015-05-03 10:23:13 +02:00
parent 222f6998e4
commit 68ecb92d1f
7 changed files with 117 additions and 148 deletions

View File

@ -317,6 +317,29 @@ public class UriComponentsBuilder implements Cloneable {
} }
/**
* Create an instance by parsing the "origin" header of an HTTP request.
*/
public static UriComponentsBuilder fromOriginHeader(String origin) {
UriComponentsBuilder builder = UriComponentsBuilder.newInstance();
if (StringUtils.hasText(origin)) {
int schemaIdx = origin.indexOf("://");
String schema = (schemaIdx != -1 ? origin.substring(0, schemaIdx) : "http");
builder.scheme(schema);
String hostString = (schemaIdx != -1 ? origin.substring(schemaIdx + 3) : origin);
if (hostString.contains(":")) {
String[] hostAndPort = StringUtils.split(hostString, ":");
builder.host(hostAndPort[0]);
builder.port(Integer.parseInt(hostAndPort[1]));
}
else {
builder.host(hostString);
}
}
return builder;
}
// build methods // build methods
/** /**

View File

@ -23,6 +23,7 @@ import java.util.Enumeration;
import java.util.Map; import java.util.Map;
import java.util.StringTokenizer; import java.util.StringTokenizer;
import java.util.TreeMap; import java.util.TreeMap;
import javax.servlet.ServletContext; import javax.servlet.ServletContext;
import javax.servlet.ServletRequest; import javax.servlet.ServletRequest;
import javax.servlet.ServletRequestWrapper; import javax.servlet.ServletRequestWrapper;
@ -38,6 +39,7 @@ import org.apache.commons.logging.LogFactory;
import org.springframework.http.HttpRequest; import org.springframework.http.HttpRequest;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap; import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
@ -790,21 +792,10 @@ public abstract class WebUtils {
if (origin == null || allowedOrigins.contains("*")) { if (origin == null || allowedOrigins.contains("*")) {
return true; return true;
} }
else if (allowedOrigins.isEmpty()) { else if (CollectionUtils.isEmpty(allowedOrigins)) {
UriComponents originComponents; UriComponents actualUrl = UriComponentsBuilder.fromHttpRequest(request).build();
try { UriComponents originUrl = UriComponentsBuilder.fromOriginHeader(origin).build();
originComponents = UriComponentsBuilder.fromHttpUrl(origin).build(); return (actualUrl.getHost().equals(originUrl.getHost()) && getPort(actualUrl) == getPort(originUrl));
}
catch (IllegalArgumentException ex) {
if (logger.isWarnEnabled()) {
logger.warn("Failed to parse Origin header value [" + origin + "]");
}
return false;
}
UriComponents requestComponents = UriComponentsBuilder.fromHttpRequest(request).build();
int originPort = getPort(originComponents);
int requestPort = getPort(requestComponents);
return (originComponents.getHost().equals(requestComponents.getHost()) && originPort == requestPort);
} }
else { else {
return allowedOrigins.contains(origin); return allowedOrigins.contains(origin);
@ -814,10 +805,10 @@ public abstract class WebUtils {
private static int getPort(UriComponents component) { private static int getPort(UriComponents component) {
int port = component.getPort(); int port = component.getPort();
if (port == -1) { if (port == -1) {
if ("http".equals(component.getScheme())) { if ("http".equals(component.getScheme()) || "ws".equals(component.getScheme())) {
port = 80; port = 80;
} }
else if ("https".equals(component.getScheme())) { else if ("https".equals(component.getScheme()) || "wss".equals(component.getScheme())) {
port = 443; port = 443;
} }
} }

View File

@ -16,8 +16,8 @@
package org.springframework.web.util; package org.springframework.web.util;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -106,60 +106,45 @@ public class WebUtilsTests {
} }
@Test @Test
public void isValidOrigin() { public void isValidOriginSuccess() {
List<String> allowedOrigins = new ArrayList<>();
List<String> allowed = Collections.emptyList();
assertTrue(checkOrigin("mydomain1.com", -1, "http://mydomain1.com", allowed));
assertTrue(checkOrigin("mydomain1.com", -1, "http://mydomain1.com:80", allowed));
assertTrue(checkOrigin("mydomain1.com", 443, "https://mydomain1.com", allowed));
assertTrue(checkOrigin("mydomain1.com", 443, "https://mydomain1.com:443", allowed));
assertTrue(checkOrigin("mydomain1.com", 123, "http://mydomain1.com:123", allowed));
assertTrue(checkOrigin("mydomain1.com", -1, "ws://mydomain1.com", allowed));
assertTrue(checkOrigin("mydomain1.com", 443, "wss://mydomain1.com", allowed));
allowed = Collections.singletonList("*");
assertTrue(checkOrigin("mydomain1.com", -1, "http://mydomain2.com", allowed));
allowed = Collections.singletonList("http://mydomain1.com");
assertTrue(checkOrigin("mydomain2.com", -1, "http://mydomain1.com", allowed));
}
@Test
public void isValidOriginFailure() {
List<String> allowed = Collections.emptyList();
assertFalse(checkOrigin("mydomain1.com", -1, "http://mydomain2.com", allowed));
assertFalse(checkOrigin("mydomain1.com", -1, "https://mydomain1.com", allowed));
assertFalse(checkOrigin("mydomain1.com", -1, "invalid-origin", allowed));
allowed = Collections.singletonList("http://mydomain1.com");
assertFalse(checkOrigin("mydomain2.com", -1, "http://mydomain3.com", allowed));
}
private boolean checkOrigin(String serverName, int port, String originHeader, List<String> allowed) {
MockHttpServletRequest servletRequest = new MockHttpServletRequest(); MockHttpServletRequest servletRequest = new MockHttpServletRequest();
ServerHttpRequest request = new ServletServerHttpRequest(servletRequest); ServerHttpRequest request = new ServletServerHttpRequest(servletRequest);
servletRequest.setServerName(serverName);
servletRequest.setServerName("mydomain1.com"); if (port != -1) {
request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain1.com"); servletRequest.setServerPort(port);
assertTrue(WebUtils.isValidOrigin(request, allowedOrigins)); }
request.getHeaders().set(HttpHeaders.ORIGIN, originHeader);
servletRequest.setServerName("mydomain1.com"); return WebUtils.isValidOrigin(request, allowed);
request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain1.com:80");
assertTrue(WebUtils.isValidOrigin(request, allowedOrigins));
servletRequest.setServerName("mydomain1.com");
servletRequest.setServerPort(443);
request.getHeaders().set(HttpHeaders.ORIGIN, "https://mydomain1.com");
assertTrue(WebUtils.isValidOrigin(request, allowedOrigins));
servletRequest.setServerName("mydomain1.com");
servletRequest.setServerPort(443);
request.getHeaders().set(HttpHeaders.ORIGIN, "https://mydomain1.com:443");
assertTrue(WebUtils.isValidOrigin(request, allowedOrigins));
servletRequest.setServerName("mydomain1.com");
servletRequest.setServerPort(123);
request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain1.com:123");
assertTrue(WebUtils.isValidOrigin(request, allowedOrigins));
servletRequest.setServerName("mydomain1.com");
request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain2.com");
assertFalse(WebUtils.isValidOrigin(request, allowedOrigins));
servletRequest.setServerName("mydomain1.com");
request.getHeaders().set(HttpHeaders.ORIGIN, "https://mydomain1.com");
assertFalse(WebUtils.isValidOrigin(request, allowedOrigins));
servletRequest.setServerName("invalid-origin");
request.getHeaders().set(HttpHeaders.ORIGIN, "invalid-origin");
assertFalse(WebUtils.isValidOrigin(request, allowedOrigins));
allowedOrigins = Arrays.asList("*");
servletRequest.setServerName("mydomain1.com");
request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain2.com");
assertTrue(WebUtils.isValidOrigin(request, allowedOrigins));
allowedOrigins = Arrays.asList("http://mydomain1.com");
servletRequest.setServerName("mydomain2.com");
request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain1.com");
assertTrue(WebUtils.isValidOrigin(request, allowedOrigins));
allowedOrigins = Arrays.asList("http://mydomain1.com");
servletRequest.setServerName("mydomain2.com");
request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain3.com");
assertFalse(WebUtils.isValidOrigin(request, allowedOrigins));
} }
} }

View File

@ -65,22 +65,18 @@ public class OriginHandshakeInterceptor implements HandshakeInterceptor {
} }
/** /**
* Configure allowed {@code Origin} header values. This check is mostly designed for * Configure allowed {@code Origin} header values. This check is mostly
* browser clients. There is nothing preventing other types of client to modify the * designed for browsers. There is nothing preventing other types of client
* {@code Origin} header value. * to modify the {@code Origin} header value.
* *
* <p>Each provided allowed origin must start by "http://", "https://" or be "*" * <p>Each provided allowed origin must have a scheme, and optionally a port
* (means that all origins are allowed). * (e.g. "http://example.org", "http://example.org:9090"). An allowed origin
* string may also be "*" in which case all origins are allowed.
* *
* @see <a href="https://tools.ietf.org/html/rfc6454">RFC 6454: The Web Origin Concept</a> * @see <a href="https://tools.ietf.org/html/rfc6454">RFC 6454: The Web Origin Concept</a>
*/ */
public void setAllowedOrigins(Collection<String> allowedOrigins) { public void setAllowedOrigins(Collection<String> allowedOrigins) {
Assert.notNull(allowedOrigins, "Allowed origin Collection must not be null"); Assert.notNull(allowedOrigins, "Allowed origin Collection must not be null");
for (String allowedOrigin : allowedOrigins) {
Assert.isTrue(allowedOrigin.equals("*") || allowedOrigin.startsWith("http://") ||
allowedOrigin.startsWith("https://"), "Invalid allowed origin provided: \"" +
allowedOrigin + "\". It must start with \"http://\", \"https://\" or be \"*\"");
}
this.allowedOrigins.clear(); this.allowedOrigins.clear();
this.allowedOrigins.addAll(allowedOrigins); this.allowedOrigins.addAll(allowedOrigins);
} }
@ -93,6 +89,7 @@ public class OriginHandshakeInterceptor implements HandshakeInterceptor {
return Collections.unmodifiableList(this.allowedOrigins); return Collections.unmodifiableList(this.allowedOrigins);
} }
@Override @Override
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception { WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {

View File

@ -276,16 +276,18 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig
} }
/** /**
* Configure allowed {@code Origin} header values. This check is mostly designed for * Configure allowed {@code Origin} header values. This check is mostly
* browser clients. There is nothing preventing other types of client to modify the * designed for browsers. There is nothing preventing other types of client
* {@code Origin} header value. * to modify the {@code Origin} header value.
* *
* <p>When SockJS is enabled and origins are restricted, transport types that do not * <p>When SockJS is enabled and origins are restricted, transport types
* allow to check request origin (JSONP and Iframe based transports) are disabled. * that do not allow to check request origin (JSONP and Iframe based
* As a consequence, IE 6 to 9 are not supported when origins are restricted. * transports) are disabled. As a consequence, IE 6 to 9 are not supported
* when origins are restricted.
* *
* <p>Each provided allowed origin must start by "http://", "https://" or be "*" * <p>Each provided allowed origin must have a scheme, and optionally a port
* (means that all origins are allowed). * (e.g. "http://example.org", "http://example.org:9090"). An allowed origin
* string may also be "*" in which case all origins are allowed.
* *
* @since 4.1.2 * @since 4.1.2
* @see <a href="https://tools.ietf.org/html/rfc6454">RFC 6454: The Web Origin Concept</a> * @see <a href="https://tools.ietf.org/html/rfc6454">RFC 6454: The Web Origin Concept</a>
@ -293,14 +295,6 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig
*/ */
public void setAllowedOrigins(List<String> allowedOrigins) { public void setAllowedOrigins(List<String> allowedOrigins) {
Assert.notNull(allowedOrigins, "Allowed origin List must not be null"); Assert.notNull(allowedOrigins, "Allowed origin List must not be null");
for (String allowedOrigin : allowedOrigins) {
Assert.isTrue(
allowedOrigin.equals("*") || allowedOrigin.startsWith("http://") ||
allowedOrigin.startsWith("https://"),
"Invalid allowed origin provided: \"" +
allowedOrigin +
"\". It must start with \"http://\", \"https://\" or be \"*\"");
}
this.allowedOrigins.clear(); this.allowedOrigins.clear();
this.allowedOrigins.addAll(allowedOrigins); this.allowedOrigins.addAll(allowedOrigins);
} }
@ -451,7 +445,9 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig
protected abstract void handleTransportRequest(ServerHttpRequest request, ServerHttpResponse response, protected abstract void handleTransportRequest(ServerHttpRequest request, ServerHttpResponse response,
WebSocketHandler webSocketHandler, String sessionId, String transport) throws SockJsException; WebSocketHandler webSocketHandler, String sessionId, String transport) throws SockJsException;
protected boolean checkOrigin(ServerHttpRequest request, ServerHttpResponse response, HttpMethod... httpMethods) throws IOException { protected boolean checkOrigin(ServerHttpRequest request, ServerHttpResponse response,
HttpMethod... httpMethods) throws IOException {
String origin = request.getHeaders().getOrigin(); String origin = request.getHeaders().getOrigin();
if (origin == null) { if (origin == null) {
@ -514,7 +510,8 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig
addNoCacheHeaders(response); addNoCacheHeaders(response);
if (checkOrigin(request, response)) { if (checkOrigin(request, response)) {
response.getHeaders().setContentType(new MediaType("application", "json", UTF8_CHARSET)); response.getHeaders().setContentType(new MediaType("application", "json", UTF8_CHARSET));
String content = String.format(INFO_CONTENT, random.nextInt(), isSessionCookieNeeded(), isWebSocketEnabled()); String content = String.format(INFO_CONTENT, random.nextInt(),
isSessionCookieNeeded(), isWebSocketEnabled());
response.getBody().write(content.getBytes()); response.getBody().write(content.getBytes());
} }

View File

@ -17,7 +17,9 @@
package org.springframework.web.socket.server.support; package org.springframework.web.socket.server.support;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.concurrent.ConcurrentSkipListSet; import java.util.concurrent.ConcurrentSkipListSet;
@ -39,31 +41,17 @@ import org.springframework.web.socket.WebSocketHandler;
public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests { public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
@Test(expected = IllegalArgumentException.class) @Test(expected = IllegalArgumentException.class)
public void nullAllowedOriginList() { public void invalidInput() {
new OriginHandshakeInterceptor(null); new OriginHandshakeInterceptor(null);
} }
@Test(expected = IllegalArgumentException.class)
public void invalidAllowedOrigin() {
new OriginHandshakeInterceptor(Arrays.asList("domain.com"));
}
@Test
public void emtpyAllowedOriginList() {
new OriginHandshakeInterceptor(Arrays.asList());
}
@Test
public void validAllowedOrigins() {
new OriginHandshakeInterceptor(Arrays.asList("http://domain.com", "https://domain.com", "*"));
}
@Test @Test
public void originValueMatch() throws Exception { public void originValueMatch() throws Exception {
Map<String, Object> attributes = new HashMap<String, Object>(); Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain1.com"); this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain1.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain1.com")); List<String> allowed = Collections.singletonList("http://mydomain1.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(allowed);
assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes)); assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value()); assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
} }
@ -73,7 +61,8 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
Map<String, Object> attributes = new HashMap<String, Object>(); Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain1.com"); this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain1.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain2.com")); List<String> allowed = Collections.singletonList("http://mydomain2.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(allowed);
assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes)); assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value()); assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
} }
@ -83,7 +72,8 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
Map<String, Object> attributes = new HashMap<String, Object>(); Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain2.com"); this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain2.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com")); List<String> allowed = Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(allowed);
assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes)); assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value()); assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
} }
@ -93,7 +83,8 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
Map<String, Object> attributes = new HashMap<String, Object>(); Map<String, Object> attributes = new HashMap<String, Object>();
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain4.com"); this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain4.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com")); List<String> allowed = Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(allowed);
assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes)); assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value()); assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
} }
@ -117,7 +108,7 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain1.com"); this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain1.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(); OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
interceptor.setAllowedOrigins(Arrays.asList("*")); interceptor.setAllowedOrigins(Collections.singletonList("*"));
assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes)); assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value()); assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
} }
@ -128,7 +119,7 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain2.com"); this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain2.com");
this.servletRequest.setServerName("mydomain2.com"); this.servletRequest.setServerName("mydomain2.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList()); OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Collections.emptyList());
assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes)); assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value()); assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
} }
@ -139,7 +130,7 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain3.com"); this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain3.com");
this.servletRequest.setServerName("mydomain2.com"); this.servletRequest.setServerName("mydomain2.com");
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList()); OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Collections.emptyList());
assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes)); assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes));
assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value()); assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
} }

View File

@ -16,11 +16,14 @@
package org.springframework.web.socket.sockjs.transport.handler; package org.springframework.web.socket.sockjs.transport.handler;
import static org.junit.Assert.*;
import static org.mockito.BDDMockito.*;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List;
import java.util.Map; import java.util.Map;
import org.hamcrest.Matchers;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.mockito.Mock; import org.mockito.Mock;
@ -41,9 +44,6 @@ import org.springframework.web.socket.sockjs.transport.TransportType;
import org.springframework.web.socket.sockjs.transport.session.StubSockJsServiceConfig; import org.springframework.web.socket.sockjs.transport.session.StubSockJsServiceConfig;
import org.springframework.web.socket.sockjs.transport.session.TestSockJsSession; import org.springframework.web.socket.sockjs.transport.session.TestSockJsSession;
import static org.junit.Assert.*;
import static org.mockito.BDDMockito.*;
/** /**
* Test fixture for {@link org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService}. * Test fixture for {@link org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService}.
* *
@ -125,26 +125,10 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
} }
@Test(expected = IllegalArgumentException.class) @Test(expected = IllegalArgumentException.class)
public void nullAllowedOriginList() { public void invalidAllowedOrigins() {
this.service.setAllowedOrigins(null); this.service.setAllowedOrigins(null);
} }
@Test
public void emptyAllowedOriginList() {
this.service.setAllowedOrigins(Arrays.asList());
assertThat(this.service.getAllowedOrigins(), Matchers.empty());
}
@Test(expected = IllegalArgumentException.class)
public void invalidAllowedOrigin() {
this.service.setAllowedOrigins(Arrays.asList("domain.com"));
}
@Test
public void validAllowedOrigins() {
this.service.setAllowedOrigins(Arrays.asList("http://domain.com", "https://domain.com", "*"));
}
@Test @Test
public void customizedTransportHandlerList() { public void customizedTransportHandlerList() {
TransportHandlingSockJsService service = new TransportHandlingSockJsService( TransportHandlingSockJsService service = new TransportHandlingSockJsService(
@ -268,13 +252,13 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
assertEquals(404, this.servletResponse.getStatus()); assertEquals(404, this.servletResponse.getStatus());
resetRequestAndResponse(); resetRequestAndResponse();
jsonpService.setAllowedOrigins(Arrays.asList("http://mydomain1.com")); jsonpService.setAllowedOrigins(Collections.singletonList("http://mydomain1.com"));
setRequest("GET", sockJsPrefix + sockJsPath); setRequest("GET", sockJsPrefix + sockJsPath);
jsonpService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler); jsonpService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertEquals(404, this.servletResponse.getStatus()); assertEquals(404, this.servletResponse.getStatus());
resetRequestAndResponse(); resetRequestAndResponse();
jsonpService.setAllowedOrigins(Arrays.asList("*")); jsonpService.setAllowedOrigins(Collections.singletonList("*"));
setRequest("GET", sockJsPrefix + sockJsPath); setRequest("GET", sockJsPrefix + sockJsPath);
jsonpService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler); jsonpService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertNotEquals(404, this.servletResponse.getStatus()); assertNotEquals(404, this.servletResponse.getStatus());
@ -289,8 +273,9 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
assertNotEquals(403, this.servletResponse.getStatus()); assertNotEquals(403, this.servletResponse.getStatus());
resetRequestAndResponse(); resetRequestAndResponse();
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain1.com")); List<String> allowed = Collections.singletonList("http://mydomain1.com");
wsService.setHandshakeInterceptors(Arrays.asList(interceptor)); OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(allowed);
wsService.setHandshakeInterceptors(Collections.singletonList(interceptor));
setRequest("GET", sockJsPrefix + sockJsPath); setRequest("GET", sockJsPrefix + sockJsPath);
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain1.com"); this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain1.com");
wsService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler); wsService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
@ -313,14 +298,14 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
resetRequestAndResponse(); resetRequestAndResponse();
setRequest("GET", sockJsPrefix + sockJsPath); setRequest("GET", sockJsPrefix + sockJsPath);
this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com")); this.service.setAllowedOrigins(Collections.singletonList("http://mydomain1.com"));
this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler); this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertEquals(404, this.servletResponse.getStatus()); assertEquals(404, this.servletResponse.getStatus());
assertNull(this.servletResponse.getHeader("X-Frame-Options")); assertNull(this.servletResponse.getHeader("X-Frame-Options"));
resetRequestAndResponse(); resetRequestAndResponse();
setRequest("GET", sockJsPrefix + sockJsPath); setRequest("GET", sockJsPrefix + sockJsPath);
this.service.setAllowedOrigins(Arrays.asList("*")); this.service.setAllowedOrigins(Collections.singletonList("*"));
this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler); this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
assertNotEquals(404, this.servletResponse.getStatus()); assertNotEquals(404, this.servletResponse.getStatus());
assertNull(this.servletResponse.getHeader("X-Frame-Options")); assertNull(this.servletResponse.getHeader("X-Frame-Options"));