diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptor.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptor.java index 60b3c34f78..0008222f14 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptor.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptor.java @@ -93,7 +93,7 @@ public class OriginHandshakeInterceptor implements HandshakeInterceptor { @Override public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map attributes) throws Exception { - if (!WebUtils.isValidOrigin(request, this.allowedOrigins)) { + if (!WebUtils.isSameOrigin(request) && !WebUtils.isValidOrigin(request, this.allowedOrigins)) { response.setStatusCode(HttpStatus.FORBIDDEN); if (logger.isDebugEnabled()) { logger.debug("Handshake request rejected, Origin header value " diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java index 62c85f014b..bb3bc9e5ac 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java @@ -448,13 +448,12 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig protected boolean checkOrigin(ServerHttpRequest request, ServerHttpResponse response, HttpMethod... httpMethods) throws IOException { - String origin = request.getHeaders().getOrigin(); - - if (origin == null) { + if (WebUtils.isSameOrigin(request)) { return true; } if (!WebUtils.isValidOrigin(request, this.allowedOrigins)) { + String origin = request.getHeaders().getOrigin(); logger.debug("Request rejected, Origin header value " + origin + " not allowed"); response.setStatusCode(HttpStatus.FORBIDDEN); return false; diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptorTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptorTests.java index ec87b68154..03177e288f 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptorTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptorTests.java @@ -114,7 +114,7 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests { } @Test - public void sameOriginMatch() throws Exception { + public void sameOriginMatchWithEmptyAllowedOrigins() throws Exception { Map attributes = new HashMap(); WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain2.com"); @@ -124,6 +124,17 @@ public class OriginHandshakeInterceptorTests extends AbstractHttpRequestTests { assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value()); } + @Test + public void sameOriginMatchWithAllowedOrigins() throws Exception { + Map attributes = new HashMap(); + WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); + this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain2.com"); + this.servletRequest.setServerName("mydomain2.com"); + OriginHandshakeInterceptor interceptor = new OriginHandshakeInterceptor(Arrays.asList("http://mydomain1.com")); + assertTrue(interceptor.beforeHandshake(request, response, wsHandler, attributes)); + assertNotEquals(servletResponse.getStatus(), HttpStatus.FORBIDDEN.value()); + } + @Test public void sameOriginNoMatch() throws Exception { Map attributes = new HashMap(); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/support/SockJsServiceTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/support/SockJsServiceTests.java index 0209f62366..14813c34c5 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/support/SockJsServiceTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/support/SockJsServiceTests.java @@ -121,13 +121,17 @@ public class SockJsServiceTests extends AbstractHttpRequestTests { assertEquals(",\"origins\":[\"*:*\"],\"cookie_needed\":true,\"websocket\":true}", body.substring(body.indexOf(','))); this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com")); - resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.FORBIDDEN); + resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.OK); this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com")); resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.OK); this.service.setAllowedOrigins(Arrays.asList("*")); resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.OK); + + this.servletRequest.setServerName("mydomain3.com"); + this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com")); + resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.FORBIDDEN); } @Test // SPR-11443 @@ -176,7 +180,8 @@ public class SockJsServiceTests extends AbstractHttpRequestTests { assertNotNull(this.service.getCorsConfiguration(this.servletRequest)); this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com")); - resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.FORBIDDEN); + resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.NO_CONTENT); + assertNotNull(this.service.getCorsConfiguration(this.servletRequest)); this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com", "http://mydomain2.com", "http://mydomain3.com")); resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.NO_CONTENT); @@ -185,6 +190,10 @@ public class SockJsServiceTests extends AbstractHttpRequestTests { this.service.setAllowedOrigins(Arrays.asList("*")); resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.NO_CONTENT); assertNotNull(this.service.getCorsConfiguration(this.servletRequest)); + + this.servletRequest.setServerName("mydomain3.com"); + this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com")); + resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.FORBIDDEN); } @Test // SPR-12283 diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsServiceTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsServiceTests.java index 74abf35539..6b07be8241 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsServiceTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/transport/handler/DefaultSockJsServiceTests.java @@ -174,6 +174,18 @@ public class DefaultSockJsServiceTests extends AbstractHttpRequestTests { assertEquals(403, this.servletResponse.getStatus()); } + @Test // SPR-13464 + public void handleTransportRequestXhrSameOrigin() throws Exception { + String sockJsPath = sessionUrlPrefix + "xhr"; + setRequest("POST", sockJsPrefix + sockJsPath); + this.service.setAllowedOrigins(Arrays.asList("http://mydomain1.com")); + this.servletRequest.addHeader(HttpHeaders.ORIGIN, "http://mydomain2.com"); + this.servletRequest.setServerName("mydomain2.com"); + this.service.handleRequest(this.request, this.response, sockJsPath, this.wsHandler); + + assertEquals(200, this.servletResponse.getStatus()); + } + @Test public void handleTransportRequestXhrOptions() throws Exception { String sockJsPath = sessionUrlPrefix + "xhr";