diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java index 3cd84070472..66d2522acd6 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java @@ -25,7 +25,6 @@ import java.util.Collections; import java.util.LinkedHashSet; import java.util.List; import java.util.Random; -import java.util.Set; import java.util.concurrent.TimeUnit; import javax.servlet.http.HttpServletRequest; @@ -97,9 +96,7 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig private boolean suppressCors = false; - protected final Set allowedOrigins = new LinkedHashSet<>(); - - protected final Set allowedOriginPatterns = new LinkedHashSet<>(); + protected final CorsConfiguration corsConfiguration; private final SockJsRequestHandler infoHandler = new InfoHandler(); @@ -109,6 +106,18 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig public AbstractSockJsService(TaskScheduler scheduler) { Assert.notNull(scheduler, "TaskScheduler must not be null"); this.taskScheduler = scheduler; + this.corsConfiguration = initCorsConfiguration(); + } + + private static CorsConfiguration initCorsConfiguration() { + CorsConfiguration config = new CorsConfiguration(); + config.addAllowedMethod("*"); + config.setAllowedOrigins(Collections.emptyList()); + config.setAllowedOriginPatterns(Collections.emptyList()); + config.setAllowCredentials(true); + config.setMaxAge(ONE_YEAR); + config.addAllowedHeader("*"); + return config; } @@ -317,10 +326,18 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig */ public void setAllowedOrigins(Collection allowedOrigins) { Assert.notNull(allowedOrigins, "Allowed origins Collection must not be null"); - this.allowedOrigins.clear(); - this.allowedOrigins.addAll(allowedOrigins); + this.corsConfiguration.setAllowedOrigins(new ArrayList<>(allowedOrigins)); } + /** + * Return configure allowed {@code Origin} header values. + * @since 4.1.2 + * @see #setAllowedOrigins + */ + @SuppressWarnings("ConstantConditions") + public Collection getAllowedOrigins() { + return this.corsConfiguration.getAllowedOrigins(); + } /** * A variant of {@link #setAllowedOrigins(Collection)} that accepts flexible * domain patterns, e.g. {@code "https://*.domain1.com"}. Furthermore it @@ -331,26 +348,17 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig */ public void setAllowedOriginPatterns(Collection allowedOriginPatterns) { Assert.notNull(allowedOriginPatterns, "Allowed origin patterns Collection must not be null"); - this.allowedOriginPatterns.clear(); - this.allowedOriginPatterns.addAll(allowedOriginPatterns); + this.corsConfiguration.setAllowedOriginPatterns(new ArrayList<>(allowedOriginPatterns)); } /** - * Return configure allowed {@code Origin} header values. - * @since 4.1.2 - * @see #setAllowedOrigins - */ - public Collection getAllowedOrigins() { - return Collections.unmodifiableSet(this.allowedOrigins); - } - - /** - * Return configure allowed {@code Origin} pattern header values. + * Return {@link #setAllowedOriginPatterns(Collection) configured} origin patterns. * @since 5.3.2 * @see #setAllowedOriginPatterns */ + @SuppressWarnings("ConstantConditions") public Collection getAllowedOriginPatterns() { - return Collections.unmodifiableSet(this.allowedOriginPatterns); + return this.corsConfiguration.getAllowedOriginPatterns(); } @@ -396,7 +404,8 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig } else if (sockJsPath.matches("/iframe[0-9-.a-z_]*.html")) { - if (!this.allowedOrigins.isEmpty() && !this.allowedOrigins.contains("*")) { + if (!getAllowedOrigins().isEmpty() && !getAllowedOrigins().contains("*") || + !getAllowedOriginPatterns().isEmpty()) { if (requestInfo != null) { logger.debug("Iframe support is disabled when an origin check is required. " + "Ignoring transport request: " + requestInfo); @@ -404,7 +413,7 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig response.setStatusCode(HttpStatus.NOT_FOUND); return; } - if (this.allowedOrigins.isEmpty()) { + if (getAllowedOrigins().isEmpty()) { response.getHeaders().add(XFRAME_OPTIONS_HEADER, "SAMEORIGIN"); } if (requestInfo != null) { @@ -506,7 +515,7 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig return true; } - if (!WebUtils.isValidOrigin(request, this.allowedOrigins)) { + if (this.corsConfiguration.checkOrigin(request.getHeaders().getOrigin()) == null) { if (logger.isWarnEnabled()) { logger.warn("Origin header value '" + request.getHeaders().getOrigin() + "' not allowed."); } @@ -521,14 +530,7 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig @Nullable public CorsConfiguration getCorsConfiguration(HttpServletRequest request) { if (!this.suppressCors && (request.getHeader(HttpHeaders.ORIGIN) != null)) { - CorsConfiguration config = new CorsConfiguration(); - config.setAllowedOrigins(new ArrayList<>(this.allowedOrigins)); - config.setAllowedOriginPatterns(new ArrayList<>(this.allowedOriginPatterns)); - config.addAllowedMethod("*"); - config.setAllowCredentials(true); - config.setMaxAge(ONE_YEAR); - config.addAllowedHeader("*"); - return config; + return this.corsConfiguration; } return null; } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java index 272700b49a4..8968980ef6a 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java @@ -344,7 +344,8 @@ public class TransportHandlingSockJsService extends AbstractSockJsService implem return false; } - if (!this.allowedOrigins.contains("*")) { + if (!getAllowedOrigins().isEmpty() && !getAllowedOrigins().contains("*") || + !getAllowedOriginPatterns().isEmpty()) { TransportType transportType = TransportType.fromValue(transport); if (transportType == null || !transportType.supportsOrigin()) { if (logger.isWarnEnabled()) { diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/support/SockJsServiceTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/support/SockJsServiceTests.java index f34c2f8e920..16106bdccc8 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/support/SockJsServiceTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/support/SockJsServiceTests.java @@ -215,7 +215,7 @@ public class SockJsServiceTests extends AbstractHttpRequestTests { @Test // SPR-12283 public void handleInfoOptionsWithOriginAndCorsHeadersDisabled() { this.servletRequest.addHeader(HttpHeaders.ORIGIN, "https://mydomain2.example"); - this.service.setAllowedOrigins(Collections.singletonList("*")); + this.service.setAllowedOriginPatterns(Collections.singletonList("*")); this.service.setSuppressCors(true); this.servletRequest.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS, "Last-Modified"); @@ -223,10 +223,12 @@ public class SockJsServiceTests extends AbstractHttpRequestTests { assertThat(this.service.getCorsConfiguration(this.servletRequest)).isNull(); this.service.setAllowedOrigins(Collections.singletonList("https://mydomain1.example")); + this.service.setAllowedOriginPatterns(Collections.emptyList()); resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.FORBIDDEN); assertThat(this.service.getCorsConfiguration(this.servletRequest)).isNull(); this.service.setAllowedOrigins(Arrays.asList("https://mydomain1.example", "https://mydomain2.example", "http://mydomain3.example")); + this.service.setAllowedOriginPatterns(Collections.emptyList()); resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.NO_CONTENT); assertThat(this.service.getCorsConfiguration(this.servletRequest)).isNull(); }