From 7d02ba0694b7b20f3af0ecb739851a008aff5b95 Mon Sep 17 00:00:00 2001 From: Brian Clozel Date: Tue, 29 Oct 2019 15:01:18 +0100 Subject: [PATCH] Add missing CORS headers defined in SockJS CORS config Prior to this commit and following changes done in d27b5d0, the CORS response headers would not be added for SockJS-related requests, even though a CORS configuration had been applied to SockJS/WebSocket. This was due to a missing case in our implementation: calling `AbstractHandlerMapping#getHandlerInternal` can return a Handler directly, but also a `HandlerExecutionChain` in some cases, as explained in the Javadoc. This commit ensures that, when checking for existing CORS configuration, the `AbstractHandlerMapping` class also considers the `HandlerExecutionChain` case and unwraps it to get the CORS configuration from the actual Handler. Fixes gh-23843 --- .../handler/AbstractHandlerMapping.java | 3 ++ .../CorsAbstractHandlerMappingTests.java | 37 ++++++++++++++----- 2 files changed, 30 insertions(+), 10 deletions(-) diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/AbstractHandlerMapping.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/AbstractHandlerMapping.java index 42566bdaec0..28f059926f9 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/AbstractHandlerMapping.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/AbstractHandlerMapping.java @@ -485,6 +485,9 @@ public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport * @since 5.2 */ protected boolean hasCorsConfigurationSource(Object handler) { + if (handler instanceof HandlerExecutionChain) { + handler = ((HandlerExecutionChain) handler).getHandler(); + } return (handler instanceof CorsConfigurationSource || this.corsConfigurationSource != null); } diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/CorsAbstractHandlerMappingTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/CorsAbstractHandlerMappingTests.java index 9cb12f794b2..b8213381b02 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/CorsAbstractHandlerMappingTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/CorsAbstractHandlerMappingTests.java @@ -48,7 +48,7 @@ import static org.mockito.Mockito.mock; * @author Sebastien Deleuze * @author Rossen Stoyanchev */ -public class CorsAbstractHandlerMappingTests { +class CorsAbstractHandlerMappingTests { private MockHttpServletRequest request; @@ -56,7 +56,7 @@ public class CorsAbstractHandlerMappingTests { @BeforeEach - public void setup() { + void setup() { StaticWebApplicationContext context = new StaticWebApplicationContext(); this.handlerMapping = new TestHandlerMapping(); this.handlerMapping.setInterceptors(mock(HandlerInterceptor.class)); @@ -66,7 +66,7 @@ public class CorsAbstractHandlerMappingTests { } @Test - public void actualRequestWithoutCorsConfigurationProvider() throws Exception { + void actualRequestWithoutCorsConfigurationProvider() throws Exception { this.request.setMethod(RequestMethod.GET.name()); this.request.setRequestURI("/foo"); this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com"); @@ -79,7 +79,7 @@ public class CorsAbstractHandlerMappingTests { } @Test - public void preflightRequestWithoutCorsConfigurationProvider() throws Exception { + void preflightRequestWithoutCorsConfigurationProvider() throws Exception { this.request.setMethod(RequestMethod.OPTIONS.name()); this.request.setRequestURI("/foo"); this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com"); @@ -92,7 +92,7 @@ public class CorsAbstractHandlerMappingTests { } @Test - public void actualRequestWithCorsConfigurationProvider() throws Exception { + void actualRequestWithCorsConfigurationProvider() throws Exception { this.request.setMethod(RequestMethod.GET.name()); this.request.setRequestURI("/cors"); this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com"); @@ -105,8 +105,22 @@ public class CorsAbstractHandlerMappingTests { assertThat(getRequiredCorsConfiguration(chain, false).getAllowedOrigins()).isEqualTo(Collections.singletonList("*")); } + @Test // see gh-23843 + void actualRequestWithCorsConfigurationProviderForHandlerChain() throws Exception { + this.request.setMethod(RequestMethod.GET.name()); + this.request.setRequestURI("/chain"); + this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com"); + this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); + HandlerExecutionChain chain = handlerMapping.getHandler(this.request); + + assertThat(chain).isNotNull(); + boolean condition = chain.getHandler() instanceof CorsAwareHandler; + assertThat(condition).isTrue(); + assertThat(getRequiredCorsConfiguration(chain, false).getAllowedOrigins()).isEqualTo(Collections.singletonList("*")); + } + @Test - public void preflightRequestWithCorsConfigurationProvider() throws Exception { + void preflightRequestWithCorsConfigurationProvider() throws Exception { this.request.setMethod(RequestMethod.OPTIONS.name()); this.request.setRequestURI("/cors"); this.request.addHeader(HttpHeaders.ORIGIN, "https://domain2.com"); @@ -120,7 +134,7 @@ public class CorsAbstractHandlerMappingTests { } @Test - public void actualRequestWithMappedCorsConfiguration() throws Exception { + void actualRequestWithMappedCorsConfiguration() throws Exception { CorsConfiguration config = new CorsConfiguration(); config.addAllowedOrigin("*"); this.handlerMapping.setCorsConfigurations(Collections.singletonMap("/foo", config)); @@ -137,7 +151,7 @@ public class CorsAbstractHandlerMappingTests { } @Test - public void preflightRequestWithMappedCorsConfiguration() throws Exception { + void preflightRequestWithMappedCorsConfiguration() throws Exception { CorsConfiguration config = new CorsConfiguration(); config.addAllowedOrigin("*"); this.handlerMapping.setCorsConfigurations(Collections.singletonMap("/foo", config)); @@ -154,7 +168,7 @@ public class CorsAbstractHandlerMappingTests { } @Test - public void actualRequestWithCorsConfigurationSource() throws Exception { + void actualRequestWithCorsConfigurationSource() throws Exception { this.handlerMapping.setCorsConfigurationSource(new CustomCorsConfigurationSource()); this.request.setMethod(RequestMethod.GET.name()); this.request.setRequestURI("/foo"); @@ -172,7 +186,7 @@ public class CorsAbstractHandlerMappingTests { } @Test - public void preflightRequestWithCorsConfigurationSource() throws Exception { + void preflightRequestWithCorsConfigurationSource() throws Exception { this.handlerMapping.setCorsConfigurationSource(new CustomCorsConfigurationSource()); this.request.setMethod(RequestMethod.OPTIONS.name()); this.request.setRequestURI("/foo"); @@ -217,6 +231,9 @@ public class CorsAbstractHandlerMappingTests { if (request.getRequestURI().equals("/cors")) { return new CorsAwareHandler(); } + else if (request.getRequestURI().equals("/chain")) { + return new HandlerExecutionChain(new CorsAwareHandler()); + } return new SimpleHandler(); } }