Allow "ws" and "wss" for isValidCorsOrigin checks
Issue: SPR-12956
This commit is contained in:
parent
222f6998e4
commit
68ecb92d1f
|
@ -317,6 +317,29 @@ public class UriComponentsBuilder implements Cloneable {
|
|||
}
|
||||
|
||||
|
||||
/**
|
||||
* Create an instance by parsing the "origin" header of an HTTP request.
|
||||
*/
|
||||
public static UriComponentsBuilder fromOriginHeader(String origin) {
|
||||
UriComponentsBuilder builder = UriComponentsBuilder.newInstance();
|
||||
if (StringUtils.hasText(origin)) {
|
||||
int schemaIdx = origin.indexOf("://");
|
||||
String schema = (schemaIdx != -1 ? origin.substring(0, schemaIdx) : "http");
|
||||
builder.scheme(schema);
|
||||
String hostString = (schemaIdx != -1 ? origin.substring(schemaIdx + 3) : origin);
|
||||
if (hostString.contains(":")) {
|
||||
String[] hostAndPort = StringUtils.split(hostString, ":");
|
||||
builder.host(hostAndPort[0]);
|
||||
builder.port(Integer.parseInt(hostAndPort[1]));
|
||||
}
|
||||
else {
|
||||
builder.host(hostString);
|
||||
}
|
||||
}
|
||||
return builder;
|
||||
}
|
||||
|
||||
|
||||
// build methods
|
||||
|
||||
/**
|
||||
|
|
|
@ -23,6 +23,7 @@ import java.util.Enumeration;
|
|||
import java.util.Map;
|
||||
import java.util.StringTokenizer;
|
||||
import java.util.TreeMap;
|
||||
|
||||
import javax.servlet.ServletContext;
|
||||
import javax.servlet.ServletRequest;
|
||||
import javax.servlet.ServletRequestWrapper;
|
||||
|
@ -38,6 +39,7 @@ import org.apache.commons.logging.LogFactory;
|
|||
|
||||
import org.springframework.http.HttpRequest;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import org.springframework.util.LinkedMultiValueMap;
|
||||
import org.springframework.util.MultiValueMap;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
@ -790,21 +792,10 @@ public abstract class WebUtils {
|
|||
if (origin == null || allowedOrigins.contains("*")) {
|
||||
return true;
|
||||
}
|
||||
else if (allowedOrigins.isEmpty()) {
|
||||
UriComponents originComponents;
|
||||
try {
|
||||
originComponents = UriComponentsBuilder.fromHttpUrl(origin).build();
|
||||
}
|
||||
catch (IllegalArgumentException ex) {
|
||||
if (logger.isWarnEnabled()) {
|
||||
logger.warn("Failed to parse Origin header value [" + origin + "]");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
UriComponents requestComponents = UriComponentsBuilder.fromHttpRequest(request).build();
|
||||
int originPort = getPort(originComponents);
|
||||
int requestPort = getPort(requestComponents);
|
||||
return (originComponents.getHost().equals(requestComponents.getHost()) && originPort == requestPort);
|
||||
else if (CollectionUtils.isEmpty(allowedOrigins)) {
|
||||
UriComponents actualUrl = UriComponentsBuilder.fromHttpRequest(request).build();
|
||||
UriComponents originUrl = UriComponentsBuilder.fromOriginHeader(origin).build();
|
||||
return (actualUrl.getHost().equals(originUrl.getHost()) && getPort(actualUrl) == getPort(originUrl));
|
||||
}
|
||||
else {
|
||||
return allowedOrigins.contains(origin);
|
||||
|
@ -814,10 +805,10 @@ public abstract class WebUtils {
|
|||
private static int getPort(UriComponents component) {
|
||||
int port = component.getPort();
|
||||
if (port == -1) {
|
||||
if ("http".equals(component.getScheme())) {
|
||||
if ("http".equals(component.getScheme()) || "ws".equals(component.getScheme())) {
|
||||
port = 80;
|
||||
}
|
||||
else if ("https".equals(component.getScheme())) {
|
||||
else if ("https".equals(component.getScheme()) || "wss".equals(component.getScheme())) {
|
||||
port = 443;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,8 +16,8 @@
|
|||
|
||||
package org.springframework.web.util;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
@ -106,60 +106,45 @@ public class WebUtilsTests {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void isValidOrigin() {
|
||||
List<String> allowedOrigins = new ArrayList<>();
|
||||
public void isValidOriginSuccess() {
|
||||
|
||||
List<String> allowed = Collections.emptyList();
|
||||
assertTrue(checkOrigin("mydomain1.com", -1, "http://mydomain1.com", allowed));
|
||||
assertTrue(checkOrigin("mydomain1.com", -1, "http://mydomain1.com:80", allowed));
|
||||
assertTrue(checkOrigin("mydomain1.com", 443, "https://mydomain1.com", allowed));
|
||||
assertTrue(checkOrigin("mydomain1.com", 443, "https://mydomain1.com:443", allowed));
|
||||
assertTrue(checkOrigin("mydomain1.com", 123, "http://mydomain1.com:123", allowed));
|
||||
assertTrue(checkOrigin("mydomain1.com", -1, "ws://mydomain1.com", allowed));
|
||||
assertTrue(checkOrigin("mydomain1.com", 443, "wss://mydomain1.com", allowed));
|
||||
|
||||
allowed = Collections.singletonList("*");
|
||||
assertTrue(checkOrigin("mydomain1.com", -1, "http://mydomain2.com", allowed));
|
||||
|
||||
allowed = Collections.singletonList("http://mydomain1.com");
|
||||
assertTrue(checkOrigin("mydomain2.com", -1, "http://mydomain1.com", allowed));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void isValidOriginFailure() {
|
||||
|
||||
List<String> allowed = Collections.emptyList();
|
||||
assertFalse(checkOrigin("mydomain1.com", -1, "http://mydomain2.com", allowed));
|
||||
assertFalse(checkOrigin("mydomain1.com", -1, "https://mydomain1.com", allowed));
|
||||
assertFalse(checkOrigin("mydomain1.com", -1, "invalid-origin", allowed));
|
||||
|
||||
allowed = Collections.singletonList("http://mydomain1.com");
|
||||
assertFalse(checkOrigin("mydomain2.com", -1, "http://mydomain3.com", allowed));
|
||||
}
|
||||
|
||||
private boolean checkOrigin(String serverName, int port, String originHeader, List<String> allowed) {
|
||||
MockHttpServletRequest servletRequest = new MockHttpServletRequest();
|
||||
ServerHttpRequest request = new ServletServerHttpRequest(servletRequest);
|
||||
|
||||
servletRequest.setServerName("mydomain1.com");
|
||||
request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain1.com");
|
||||
assertTrue(WebUtils.isValidOrigin(request, allowedOrigins));
|
||||
|
||||
servletRequest.setServerName("mydomain1.com");
|
||||
request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain1.com:80");
|
||||
assertTrue(WebUtils.isValidOrigin(request, allowedOrigins));
|
||||
|
||||
servletRequest.setServerName("mydomain1.com");
|
||||
servletRequest.setServerPort(443);
|
||||
request.getHeaders().set(HttpHeaders.ORIGIN, "https://mydomain1.com");
|
||||
assertTrue(WebUtils.isValidOrigin(request, allowedOrigins));
|
||||
|
||||
servletRequest.setServerName("mydomain1.com");
|
||||
servletRequest.setServerPort(443);
|
||||
request.getHeaders().set(HttpHeaders.ORIGIN, "https://mydomain1.com:443");
|
||||
assertTrue(WebUtils.isValidOrigin(request, allowedOrigins));
|
||||
|
||||
servletRequest.setServerName("mydomain1.com");
|
||||
servletRequest.setServerPort(123);
|
||||
request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain1.com:123");
|
||||
assertTrue(WebUtils.isValidOrigin(request, allowedOrigins));
|
||||
|
||||
servletRequest.setServerName("mydomain1.com");
|
||||
request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain2.com");
|
||||
assertFalse(WebUtils.isValidOrigin(request, allowedOrigins));
|
||||
|
||||
servletRequest.setServerName("mydomain1.com");
|
||||
request.getHeaders().set(HttpHeaders.ORIGIN, "https://mydomain1.com");
|
||||
assertFalse(WebUtils.isValidOrigin(request, allowedOrigins));
|
||||
|
||||
servletRequest.setServerName("invalid-origin");
|
||||
request.getHeaders().set(HttpHeaders.ORIGIN, "invalid-origin");
|
||||
assertFalse(WebUtils.isValidOrigin(request, allowedOrigins));
|
||||
|
||||
allowedOrigins = Arrays.asList("*");
|
||||
servletRequest.setServerName("mydomain1.com");
|
||||
request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain2.com");
|
||||
assertTrue(WebUtils.isValidOrigin(request, allowedOrigins));
|
||||
|
||||
allowedOrigins = Arrays.asList("http://mydomain1.com");
|
||||
servletRequest.setServerName("mydomain2.com");
|
||||
request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain1.com");
|
||||
assertTrue(WebUtils.isValidOrigin(request, allowedOrigins));
|
||||
|
||||
allowedOrigins = Arrays.asList("http://mydomain1.com");
|
||||
servletRequest.setServerName("mydomain2.com");
|
||||
request.getHeaders().set(HttpHeaders.ORIGIN, "http://mydomain3.com");
|
||||
assertFalse(WebUtils.isValidOrigin(request, allowedOrigins));
|
||||
servletRequest.setServerName(serverName);
|
||||
if (port != -1) {
|
||||
servletRequest.setServerPort(port);
|
||||
}
|
||||
request.getHeaders().set(HttpHeaders.ORIGIN, originHeader);
|
||||
return WebUtils.isValidOrigin(request, allowed);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -65,22 +65,18 @@ public class OriginHandshakeInterceptor implements HandshakeInterceptor {
|
|||
}
|
||||
|
||||
/**
|
||||
* Configure allowed {@code Origin} header values. This check is mostly designed for
|
||||
* browser clients. There is nothing preventing other types of client to modify the
|
||||
* {@code Origin} header value.
|
||||
* Configure allowed {@code Origin} header values. This check is mostly
|
||||
* designed for browsers. There is nothing preventing other types of client
|
||||
* to modify the {@code Origin} header value.
|
||||
*
|
||||
* <p>Each provided allowed origin must start by "http://", "https://" or be "*"
|
||||
* (means that all origins are allowed).
|
||||
* <p>Each provided allowed origin must have a scheme, and optionally a port
|
||||
* (e.g. "http://example.org", "http://example.org:9090"). An allowed origin
|
||||
* string may also be "*" in which case all origins are allowed.
|
||||
*
|
||||
* @see <a href="https://tools.ietf.org/html/rfc6454">RFC 6454: The Web Origin Concept</a>
|
||||
*/
|
||||
public void setAllowedOrigins(Collection<String> allowedOrigins) {
|
||||
Assert.notNull(allowedOrigins, "Allowed origin Collection must not be null");
|
||||
for (String allowedOrigin : allowedOrigins) {
|
||||
Assert.isTrue(allowedOrigin.equals("*") || allowedOrigin.startsWith("http://") ||
|
||||
allowedOrigin.startsWith("https://"), "Invalid allowed origin provided: \"" +
|
||||
allowedOrigin + "\". It must start with \"http://\", \"https://\" or be \"*\"");
|
||||
}
|
||||
this.allowedOrigins.clear();
|
||||
this.allowedOrigins.addAll(allowedOrigins);
|
||||
}
|
||||
|
@ -93,6 +89,7 @@ public class OriginHandshakeInterceptor implements HandshakeInterceptor {
|
|||
return Collections.unmodifiableList(this.allowedOrigins);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response,
|
||||
WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
|
||||
|
|
|
@ -276,16 +276,18 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig
|
|||
}
|
||||
|
||||
/**
|
||||
* Configure allowed {@code Origin} header values. This check is mostly designed for
|
||||
* browser clients. There is nothing preventing other types of client to modify the
|
||||
* {@code Origin} header value.
|
||||
* Configure allowed {@code Origin} header values. This check is mostly
|
||||
* designed for browsers. There is nothing preventing other types of client
|
||||
* to modify the {@code Origin} header value.
|
||||
*
|
||||
* <p>When SockJS is enabled and origins are restricted, transport types that do not
|
||||
* allow to check request origin (JSONP and Iframe based transports) are disabled.
|
||||
* As a consequence, IE 6 to 9 are not supported when origins are restricted.
|
||||
* <p>When SockJS is enabled and origins are restricted, transport types
|
||||
* that do not allow to check request origin (JSONP and Iframe based
|
||||
* transports) are disabled. As a consequence, IE 6 to 9 are not supported
|
||||
* when origins are restricted.
|
||||
*
|
||||
* <p>Each provided allowed origin must start by "http://", "https://" or be "*"
|
||||
* (means that all origins are allowed).
|
||||
* <p>Each provided allowed origin must have a scheme, and optionally a port
|
||||
* (e.g. "http://example.org", "http://example.org:9090"). An allowed origin
|
||||
* string may also be "*" in which case all origins are allowed.
|
||||
*
|
||||
* @since 4.1.2
|
||||
* @see <a href="https://tools.ietf.org/html/rfc6454">RFC 6454: The Web Origin Concept</a>
|
||||
|
@ -293,14 +295,6 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig
|
|||
*/
|
||||
public void setAllowedOrigins(List<String> allowedOrigins) {
|
||||
Assert.notNull(allowedOrigins, "Allowed origin List must not be null");
|
||||
for (String allowedOrigin : allowedOrigins) {
|
||||
Assert.isTrue(
|
||||
allowedOrigin.equals("*") || allowedOrigin.startsWith("http://") ||
|
||||
allowedOrigin.startsWith("https://"),
|
||||
"Invalid allowed origin provided: \"" +
|
||||
allowedOrigin +
|
||||
"\". It must start with \"http://\", \"https://\" or be \"*\"");
|
||||
}
|
||||
this.allowedOrigins.clear();
|
||||
this.allowedOrigins.addAll(allowedOrigins);
|
||||
}
|
||||
|
@ -451,7 +445,9 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig
|
|||
protected abstract void handleTransportRequest(ServerHttpRequest request, ServerHttpResponse response,
|
||||
WebSocketHandler webSocketHandler, String sessionId, String transport) throws SockJsException;
|
||||
|
||||
protected boolean checkOrigin(ServerHttpRequest request, ServerHttpResponse response, HttpMethod... httpMethods) throws IOException {
|
||||
protected boolean checkOrigin(ServerHttpRequest request, ServerHttpResponse response,
|
||||
HttpMethod... httpMethods) throws IOException {
|
||||
|
||||
String origin = request.getHeaders().getOrigin();
|
||||
|
||||
if (origin == null) {
|
||||
|
@ -514,7 +510,8 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig
|
|||
addNoCacheHeaders(response);
|
||||
if (checkOrigin(request, response)) {
|
||||
response.getHeaders().setContentType(new MediaType("application", "json", UTF8_CHARSET));
|
||||
String content = String.format(INFO_CONTENT, random.nextInt(), isSessionCookieNeeded(), isWebSocketEnabled());
|
||||
String content = String.format(INFO_CONTENT, random.nextInt(),
|
||||
isSessionCookieNeeded(), isWebSocketEnabled());
|
||||
response.getBody().write(content.getBytes());
|
||||
}
|
||||
|
||||
|
|
|
@ -17,7 +17,9 @@
|
|||
package org.springframework.web.socket.server.support;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ConcurrentSkipListSet;
|
||||
|
@ -39,31 +41,17 @@ import org.springframework.web.socket.WebSocketHandler;
|
|||
public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void nullAllowedOriginList() {
|
||||
public void invalidInput() {
|
||||
new OriginHandshakeInterceptor(null);
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void invalidAllowedOrigin() {
|
||||
new OriginHandshakeInterceptor(Arrays.asList("domain.com"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void emtpyAllowedOriginList() {
|
||||
new OriginHandshakeInterceptor(Arrays.asList());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void validAllowedOrigins() {
|
||||
new OriginHandshakeInterceptor(Arrays.asList("http://domain.com", "https://domain.com", "*"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void originValueMatch() throws Exception {
|
||||
Map<String, Object> attributes = new HashMap<String, Object>();
|
||||
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
|
||||
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain1.com");
|
||||
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain1.com"));
|
||||
List<String> allowed = Collections.singletonList("http://mydomain1.com");
|
||||
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(allowed);
|
||||
assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes));
|
||||
assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
|
||||
}
|
||||
|
@ -73,7 +61,8 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
|
|||
Map<String, Object> attributes = new HashMap<String, Object>();
|
||||
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
|
||||
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain1.com");
|
||||
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain2.com"));
|
||||
List<String> allowed = Collections.singletonList("http://mydomain2.com");
|
||||
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(allowed);
|
||||
assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes));
|
||||
assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
|
||||
}
|
||||
|
@ -83,7 +72,8 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
|
|||
Map<String, Object> attributes = new HashMap<String, Object>();
|
||||
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
|
||||
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain2.com");
|
||||
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com"));
|
||||
List<String> allowed = Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com");
|
||||
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(allowed);
|
||||
assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes));
|
||||
assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
|
||||
}
|
||||
|
@ -93,7 +83,8 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
|
|||
Map<String, Object> attributes = new HashMap<String, Object>();
|
||||
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
|
||||
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain4.com");
|
||||
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com"));
|
||||
List<String> allowed = Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com");
|
||||
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(allowed);
|
||||
assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes));
|
||||
assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
|
||||
}
|
||||
|
@ -117,7 +108,7 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
|
|||
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
|
||||
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain1.com");
|
||||
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor();
|
||||
interceptor.setAllowedOrigins(Arrays.asList("*"));
|
||||
interceptor.setAllowedOrigins(Collections.singletonList("*"));
|
||||
assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes));
|
||||
assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
|
||||
}
|
||||
|
@ -128,7 +119,7 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
|
|||
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
|
||||
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain2.com");
|
||||
this.servletRequest.setServerName("mydomain2.com");
|
||||
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList());
|
||||
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Collections.emptyList());
|
||||
assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes));
|
||||
assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
|
||||
}
|
||||
|
@ -139,7 +130,7 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests {
|
|||
WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class);
|
||||
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain3.com");
|
||||
this.servletRequest.setServerName("mydomain2.com");
|
||||
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList());
|
||||
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Collections.emptyList());
|
||||
assertFalse(interceptor.beforeHandshake(request, response, wsHandler, attributes));
|
||||
assertEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value());
|
||||
}
|
||||
|
|
|
@ -16,11 +16,14 @@
|
|||
|
||||
package org.springframework.web.socket.sockjs.transport.handler;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
import static org.mockito.BDDMockito.*;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import org.hamcrest.Matchers;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.mockito.Mock;
|
||||
|
@ -41,9 +44,6 @@ import org.springframework.web.socket.sockjs.transport.TransportType;
|
|||
import org.springframework.web.socket.sockjs.transport.session.StubSockJsServiceConfig;
|
||||
import org.springframework.web.socket.sockjs.transport.session.TestSockJsSession;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
import static org.mockito.BDDMockito.*;
|
||||
|
||||
/**
|
||||
* Test fixture for {@link org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService}.
|
||||
*
|
||||
|
@ -125,26 +125,10 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
|
|||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void nullAllowedOriginList() {
|
||||
public void invalidAllowedOrigins() {
|
||||
this.service.setAllowedOrigins(null);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void emptyAllowedOriginList() {
|
||||
this.service.setAllowedOrigins(Arrays.asList());
|
||||
assertThat(this.service.getAllowedOrigins(), Matchers.empty());
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void invalidAllowedOrigin() {
|
||||
this.service.setAllowedOrigins(Arrays.asList("domain.com"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void validAllowedOrigins() {
|
||||
this.service.setAllowedOrigins(Arrays.asList("http://domain.com", "https://domain.com", "*"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void customizedTransportHandlerList() {
|
||||
TransportHandlingSockJsService service = new TransportHandlingSockJsService(
|
||||
|
@ -268,13 +252,13 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
|
|||
assertEquals(404, this.servletResponse.getStatus());
|
||||
|
||||
resetRequestAndResponse();
|
||||
jsonpService.setAllowedOrigins(Arrays.asList("http://mydomain1.com"));
|
||||
jsonpService.setAllowedOrigins(Collections.singletonList("http://mydomain1.com"));
|
||||
setRequest("GET", sockJsPrefix + sockJsPath);
|
||||
jsonpService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
|
||||
assertEquals(404, this.servletResponse.getStatus());
|
||||
|
||||
resetRequestAndResponse();
|
||||
jsonpService.setAllowedOrigins(Arrays.asList("*"));
|
||||
jsonpService.setAllowedOrigins(Collections.singletonList("*"));
|
||||
setRequest("GET", sockJsPrefix + sockJsPath);
|
||||
jsonpService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
|
||||
assertNotEquals(404, this.servletResponse.getStatus());
|
||||
|
@ -289,8 +273,9 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
|
|||
assertNotEquals(403, this.servletResponse.getStatus());
|
||||
|
||||
resetRequestAndResponse();
|
||||
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain1.com"));
|
||||
wsService.setHandshakeInterceptors(Arrays.asList(interceptor));
|
||||
List<String> allowed = Collections.singletonList("http://mydomain1.com");
|
||||
OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(allowed);
|
||||
wsService.setHandshakeInterceptors(Collections.singletonList(interceptor));
|
||||
setRequest("GET", sockJsPrefix + sockJsPath);
|
||||
this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain1.com");
|
||||
wsService.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
|
||||
|
@ -313,14 +298,14 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests {
|
|||
|
||||
resetRequestAndResponse();
|
||||
setRequest("GET", sockJsPrefix + sockJsPath);
|
||||
this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com"));
|
||||
this.service.setAllowedOrigins(Collections.singletonList("http://mydomain1.com"));
|
||||
this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
|
||||
assertEquals(404, this.servletResponse.getStatus());
|
||||
assertNull(this.servletResponse.getHeader("X-Frame-Options"));
|
||||
|
||||
resetRequestAndResponse();
|
||||
setRequest("GET", sockJsPrefix + sockJsPath);
|
||||
this.service.setAllowedOrigins(Arrays.asList("*"));
|
||||
this.service.setAllowedOrigins(Collections.singletonList("*"));
|
||||
this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler);
|
||||
assertNotEquals(404, this.servletResponse.getStatus());
|
||||
assertNull(this.servletResponse.getHeader("X-Frame-Options"));
|
||||
|
|
Loading…
Reference in New Issue