From d27b5d0ab6e8b91a77e272ad57ae83c7d81d810b Mon Sep 17 00:00:00 2001 From: Sebastien Deleuze Date: Mon, 1 Apr 2019 14:36:38 +0200 Subject: [PATCH] Improve CORS handling This commit improves CORS support by: - Using CORS processing only for CORS-enabled endpoints - Skipping CORS processing for same-origin requests - Adding Vary headers for non-CORS requests It introduces an AbstractHandlerMapping#hasCorsConfigurationSource method in order to be able to check CORS endpoints efficiently. Closes gh-22273 Closes gh-22496 --- .../springframework/web/cors/CorsUtils.java | 38 ++++++++++++-- .../web/cors/DefaultCorsProcessor.java | 38 ++++---------- .../web/cors/reactive/CorsUtils.java | 15 ++++-- .../web/cors/reactive/CorsWebFilter.java | 14 ++--- .../cors/reactive/DefaultCorsProcessor.java | 23 +++------ .../web/filter/CorsFilter.java | 15 ++---- .../web/cors/CorsUtilsTests.java | 7 +-- .../web/cors/DefaultCorsProcessorTests.java | 24 ++++++++- .../web/cors/reactive/CorsUtilsTests.java | 7 +-- .../web/cors/reactive/CorsWebFilterTests.java | 51 ++++++++++++++++--- .../reactive/DefaultCorsProcessorTests.java | 31 +++++++++++ .../web/filter/CorsFilterTests.java | 32 +++++++++++- .../handler/AbstractHandlerMapping.java | 35 +++++++++---- .../method/AbstractHandlerMethodMapping.java | 8 +++ .../handler/CorsUrlHandlerMappingTests.java | 6 +-- .../handler/AbstractHandlerMapping.java | 34 +++++++++---- .../handler/AbstractHandlerMethodMapping.java | 8 +++ ...MvcConfigurationSupportExtensionTests.java | 6 +-- .../CorsAbstractHandlerMappingTests.java | 5 +- .../sockjs/support/AbstractSockJsService.java | 4 +- 20 files changed, 278 insertions(+), 123 deletions(-) diff --git a/spring-web/src/main/java/org/springframework/web/cors/CorsUtils.java b/spring-web/src/main/java/org/springframework/web/cors/CorsUtils.java index 2e31588101..d24594cef0 100644 --- a/spring-web/src/main/java/org/springframework/web/cors/CorsUtils.java +++ b/spring-web/src/main/java/org/springframework/web/cors/CorsUtils.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2015 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,10 @@ import javax.servlet.http.HttpServletRequest; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; +import org.springframework.lang.Nullable; +import org.springframework.util.ObjectUtils; +import org.springframework.web.util.UriComponents; +import org.springframework.web.util.UriComponentsBuilder; /** * Utility class for CORS request handling based on the @@ -31,17 +35,43 @@ import org.springframework.http.HttpMethod; public abstract class CorsUtils { /** - * Returns {@code true} if the request is a valid CORS one. + * Returns {@code true} if the request is a valid CORS one by checking {@code Origin} + * header presence and ensuring that origins are different. */ public static boolean isCorsRequest(HttpServletRequest request) { - return (request.getHeader(HttpHeaders.ORIGIN) != null); + String origin = request.getHeader(HttpHeaders.ORIGIN); + if (origin == null) { + return false; + } + UriComponents originUrl = UriComponentsBuilder.fromOriginHeader(origin).build(); + String scheme = request.getScheme(); + String host = request.getServerName(); + int port = request.getServerPort(); + return !(ObjectUtils.nullSafeEquals(scheme, originUrl.getScheme()) && + ObjectUtils.nullSafeEquals(host, originUrl.getHost()) && + getPort(scheme, port) == getPort(originUrl.getScheme(), originUrl.getPort())); + + } + + private static int getPort(@Nullable String scheme, int port) { + if (port == -1) { + if ("http".equals(scheme) || "ws".equals(scheme)) { + port = 80; + } + else if ("https".equals(scheme) || "wss".equals(scheme)) { + port = 443; + } + } + return port; } /** * Returns {@code true} if the request is a valid CORS pre-flight one. + * To be used in combination with {@link #isCorsRequest(HttpServletRequest)} since + * regular CORS checks are not invoked here for performance reasons. */ public static boolean isPreFlightRequest(HttpServletRequest request) { - return (isCorsRequest(request) && HttpMethod.OPTIONS.matches(request.getMethod()) && + return (HttpMethod.OPTIONS.matches(request.getMethod()) && request.getHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD) != null); } diff --git a/spring-web/src/main/java/org/springframework/web/cors/DefaultCorsProcessor.java b/spring-web/src/main/java/org/springframework/web/cors/DefaultCorsProcessor.java index 3063488343..f446307750 100644 --- a/spring-web/src/main/java/org/springframework/web/cors/DefaultCorsProcessor.java +++ b/spring-web/src/main/java/org/springframework/web/cors/DefaultCorsProcessor.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,7 +19,6 @@ package org.springframework.web.cors; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -36,7 +35,6 @@ import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.lang.Nullable; import org.springframework.util.CollectionUtils; -import org.springframework.web.util.WebUtils; /** * The default implementation of {@link CorsProcessor}, as defined by the @@ -45,8 +43,7 @@ import org.springframework.web.util.WebUtils; *

Note that when input {@link CorsConfiguration} is {@code null}, this * implementation does not reject simple or actual requests outright but simply * avoid adding CORS headers to the response. CORS processing is also skipped - * if the response already contains CORS headers, or if the request is detected - * as a same-origin one. + * if the response already contains CORS headers. * * @author Sebastien Deleuze * @author Rossen Stoyanchev @@ -62,26 +59,23 @@ public class DefaultCorsProcessor implements CorsProcessor { public boolean processRequest(@Nullable CorsConfiguration config, HttpServletRequest request, HttpServletResponse response) throws IOException { + response.addHeader(HttpHeaders.VARY, HttpHeaders.ORIGIN); + response.addHeader(HttpHeaders.VARY, HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD); + response.addHeader(HttpHeaders.VARY, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS); + if (!CorsUtils.isCorsRequest(request)) { return true; } - ServletServerHttpResponse serverResponse = new ServletServerHttpResponse(response); - if (responseHasCors(serverResponse)) { + if (response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN) != null) { logger.trace("Skip: response already contains \"Access-Control-Allow-Origin\""); return true; } - ServletServerHttpRequest serverRequest = new ServletServerHttpRequest(request); - if (WebUtils.isSameOrigin(serverRequest)) { - logger.trace("Skip: request is from same origin"); - return true; - } - boolean preFlightRequest = CorsUtils.isPreFlightRequest(request); if (config == null) { if (preFlightRequest) { - rejectRequest(serverResponse); + rejectRequest(new ServletServerHttpResponse(response)); return false; } else { @@ -89,17 +83,7 @@ public class DefaultCorsProcessor implements CorsProcessor { } } - return handleInternal(serverRequest, serverResponse, config, preFlightRequest); - } - - private boolean responseHasCors(ServerHttpResponse response) { - try { - return (response.getHeaders().getAccessControlAllowOrigin() != null); - } - catch (NullPointerException npe) { - // SPR-11919 and https://issues.jboss.org/browse/WFLY-3474 - return false; - } + return handleInternal(new ServletServerHttpRequest(request), new ServletServerHttpResponse(response), config, preFlightRequest); } /** @@ -110,6 +94,7 @@ public class DefaultCorsProcessor implements CorsProcessor { protected void rejectRequest(ServerHttpResponse response) throws IOException { response.setStatusCode(HttpStatus.FORBIDDEN); response.getBody().write("Invalid CORS request".getBytes(StandardCharsets.UTF_8)); + response.flush(); } /** @@ -122,9 +107,6 @@ public class DefaultCorsProcessor implements CorsProcessor { String allowOrigin = checkOrigin(config, requestOrigin); HttpHeaders responseHeaders = response.getHeaders(); - responseHeaders.addAll(HttpHeaders.VARY, Arrays.asList(HttpHeaders.ORIGIN, - HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); - if (allowOrigin == null) { logger.debug("Reject: '" + requestOrigin + "' origin is not allowed"); rejectRequest(response); diff --git a/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsUtils.java b/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsUtils.java index 73107bb1dc..006f32f684 100644 --- a/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsUtils.java +++ b/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsUtils.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -36,18 +36,21 @@ import org.springframework.web.util.UriComponentsBuilder; public abstract class CorsUtils { /** - * Returns {@code true} if the request is a valid CORS one. + * Returns {@code true} if the request is a valid CORS one by checking {@code Origin} + * header presence and ensuring that origins are different via {@link #isSameOrigin}. */ + @SuppressWarnings("deprecation") public static boolean isCorsRequest(ServerHttpRequest request) { - return (request.getHeaders().get(HttpHeaders.ORIGIN) != null); + return request.getHeaders().containsKey(HttpHeaders.ORIGIN) && !isSameOrigin(request); } /** * Returns {@code true} if the request is a valid CORS pre-flight one. + * To be used in combination with {@link #isCorsRequest(ServerHttpRequest)} since + * regular CORS checks are not invoked here for performance reasons. */ public static boolean isPreFlightRequest(ServerHttpRequest request) { - return (request.getMethod() == HttpMethod.OPTIONS && isCorsRequest(request) && - request.getHeaders().get(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD) != null); + return (request.getMethod() == HttpMethod.OPTIONS && request.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD)); } /** @@ -61,7 +64,9 @@ public abstract class CorsUtils { * * @return {@code true} if the request is a same-origin one, {@code false} in case * of a cross-origin request + * @deprecated as of 5.2, same-origin checks are performed directly by {@link #isCorsRequest} */ + @Deprecated public static boolean isSameOrigin(ServerHttpRequest request) { String origin = request.getHeaders().getOrigin(); if (origin == null) { diff --git a/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsWebFilter.java b/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsWebFilter.java index 4938d7842e..f3f0ea11e1 100644 --- a/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsWebFilter.java +++ b/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsWebFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -75,14 +75,10 @@ public class CorsWebFilter implements WebFilter { @Override public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { ServerHttpRequest request = exchange.getRequest(); - if (CorsUtils.isCorsRequest(request)) { - CorsConfiguration corsConfiguration = this.configSource.getCorsConfiguration(exchange); - if (corsConfiguration != null) { - boolean isValid = this.processor.process(corsConfiguration, exchange); - if (!isValid || CorsUtils.isPreFlightRequest(request)) { - return Mono.empty(); - } - } + CorsConfiguration corsConfiguration = this.configSource.getCorsConfiguration(exchange); + boolean isValid = this.processor.process(corsConfiguration, exchange); + if (!isValid || CorsUtils.isPreFlightRequest(request)) { + return Mono.empty(); } return chain.filter(exchange); } diff --git a/spring-web/src/main/java/org/springframework/web/cors/reactive/DefaultCorsProcessor.java b/spring-web/src/main/java/org/springframework/web/cors/reactive/DefaultCorsProcessor.java index f52d9be190..840eddb80a 100644 --- a/spring-web/src/main/java/org/springframework/web/cors/reactive/DefaultCorsProcessor.java +++ b/spring-web/src/main/java/org/springframework/web/cors/reactive/DefaultCorsProcessor.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -40,8 +40,7 @@ import org.springframework.web.server.ServerWebExchange; *

Note that when input {@link CorsConfiguration} is {@code null}, this * implementation does not reject simple or actual requests outright but simply * avoid adding CORS headers to the response. CORS processing is also skipped - * if the response already contains CORS headers, or if the request is detected - * as a same-origin one. + * if the response already contains CORS headers. * * @author Sebastien Deleuze * @author Rossen Stoyanchev @@ -51,27 +50,26 @@ public class DefaultCorsProcessor implements CorsProcessor { private static final Log logger = LogFactory.getLog(DefaultCorsProcessor.class); + private static final List VARY_HEADERS = Arrays.asList( + HttpHeaders.ORIGIN, HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS); + @Override public boolean process(@Nullable CorsConfiguration config, ServerWebExchange exchange) { ServerHttpRequest request = exchange.getRequest(); ServerHttpResponse response = exchange.getResponse(); + response.getHeaders().addAll(HttpHeaders.VARY, VARY_HEADERS); if (!CorsUtils.isCorsRequest(request)) { return true; } - if (responseHasCors(response)) { + if (response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN) != null) { logger.trace("Skip: response already contains \"Access-Control-Allow-Origin\""); return true; } - if (CorsUtils.isSameOrigin(request)) { - logger.trace("Skip: request is from same origin"); - return true; - } - boolean preFlightRequest = CorsUtils.isPreFlightRequest(request); if (config == null) { if (preFlightRequest) { @@ -86,10 +84,6 @@ public class DefaultCorsProcessor implements CorsProcessor { return handleInternal(exchange, config, preFlightRequest); } - private boolean responseHasCors(ServerHttpResponse response) { - return response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN) != null; - } - /** * Invoked when one of the CORS checks failed. */ @@ -107,9 +101,6 @@ public class DefaultCorsProcessor implements CorsProcessor { ServerHttpResponse response = exchange.getResponse(); HttpHeaders responseHeaders = response.getHeaders(); - response.getHeaders().addAll(HttpHeaders.VARY, Arrays.asList(HttpHeaders.ORIGIN, - HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); - String requestOrigin = request.getHeaders().getOrigin(); String allowOrigin = checkOrigin(config, requestOrigin); if (allowOrigin == null) { diff --git a/spring-web/src/main/java/org/springframework/web/filter/CorsFilter.java b/spring-web/src/main/java/org/springframework/web/filter/CorsFilter.java index d85fce5505..3a4bf501c5 100644 --- a/spring-web/src/main/java/org/springframework/web/filter/CorsFilter.java +++ b/spring-web/src/main/java/org/springframework/web/filter/CorsFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2016 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -83,16 +83,11 @@ public class CorsFilter extends OncePerRequestFilter { protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { - if (CorsUtils.isCorsRequest(request)) { - CorsConfiguration corsConfiguration = this.configSource.getCorsConfiguration(request); - if (corsConfiguration != null) { - boolean isValid = this.processor.processRequest(corsConfiguration, request, response); - if (!isValid || CorsUtils.isPreFlightRequest(request)) { - return; - } - } + CorsConfiguration corsConfiguration = this.configSource.getCorsConfiguration(request); + boolean isValid = this.processor.processRequest(corsConfiguration, request, response); + if (!isValid || CorsUtils.isPreFlightRequest(request)) { + return; } - filterChain.doFilter(request, response); } diff --git a/spring-web/src/test/java/org/springframework/web/cors/CorsUtilsTests.java b/spring-web/src/test/java/org/springframework/web/cors/CorsUtilsTests.java index 0a923d67ff..61f0b330fc 100644 --- a/spring-web/src/test/java/org/springframework/web/cors/CorsUtilsTests.java +++ b/spring-web/src/test/java/org/springframework/web/cors/CorsUtilsTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2015 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -62,11 +62,6 @@ public class CorsUtilsTests { request.setMethod(HttpMethod.OPTIONS.name()); request.addHeader(HttpHeaders.ORIGIN, "https://domain.com"); assertFalse(CorsUtils.isPreFlightRequest(request)); - - request = new MockHttpServletRequest(); - request.setMethod(HttpMethod.OPTIONS.name()); - request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); - assertFalse(CorsUtils.isPreFlightRequest(request)); } } diff --git a/spring-web/src/test/java/org/springframework/web/cors/DefaultCorsProcessorTests.java b/spring-web/src/test/java/org/springframework/web/cors/DefaultCorsProcessorTests.java index 0f49e8b5d0..57a519350e 100644 --- a/spring-web/src/test/java/org/springframework/web/cors/DefaultCorsProcessorTests.java +++ b/spring-web/src/test/java/org/springframework/web/cors/DefaultCorsProcessorTests.java @@ -51,13 +51,35 @@ public class DefaultCorsProcessorTests { public void setup() { this.request = new MockHttpServletRequest(); this.request.setRequestURI("/test.html"); - this.request.setRemoteHost("domain1.com"); + this.request.setServerName("domain1.com"); this.conf = new CorsConfiguration(); this.response = new MockHttpServletResponse(); this.response.setStatus(HttpServletResponse.SC_OK); this.processor = new DefaultCorsProcessor(); } + @Test + public void requestWithoutOriginHeader() throws Exception { + this.request.setMethod(HttpMethod.GET.name()); + + this.processor.processRequest(this.conf, this.request, this.response); + assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN, + HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); + assertEquals(HttpServletResponse.SC_OK, this.response.getStatus()); + } + + @Test + public void sameOriginRequest() throws Exception { + this.request.setMethod(HttpMethod.GET.name()); + this.request.addHeader(HttpHeaders.ORIGIN, "http://domain1.com"); + + this.processor.processRequest(this.conf, this.request, this.response); + assertFalse(this.response.containsHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertThat(this.response.getHeaders(HttpHeaders.VARY), contains(HttpHeaders.ORIGIN, + HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS)); + assertEquals(HttpServletResponse.SC_OK, this.response.getStatus()); + } @Test public void actualRequestWithOriginHeader() throws Exception { diff --git a/spring-web/src/test/java/org/springframework/web/cors/reactive/CorsUtilsTests.java b/spring-web/src/test/java/org/springframework/web/cors/reactive/CorsUtilsTests.java index f7be0f8944..fbe406f098 100644 --- a/spring-web/src/test/java/org/springframework/web/cors/reactive/CorsUtilsTests.java +++ b/spring-web/src/test/java/org/springframework/web/cors/reactive/CorsUtilsTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -39,7 +39,7 @@ public class CorsUtilsTests { @Test public void isCorsRequest() { - ServerHttpRequest request = get("/").header(HttpHeaders.ORIGIN, "https://domain.com").build(); + ServerHttpRequest request = get("http://domain.com/").header(HttpHeaders.ORIGIN, "https://domain.com").build(); assertTrue(CorsUtils.isCorsRequest(request)); } @@ -65,9 +65,6 @@ public class CorsUtilsTests { request = options("/").header(HttpHeaders.ORIGIN, "https://domain.com").build(); assertFalse(CorsUtils.isPreFlightRequest(request)); - - request = options("/").header(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET").build(); - assertFalse(CorsUtils.isPreFlightRequest(request)); } @Test // SPR-16262 diff --git a/spring-web/src/test/java/org/springframework/web/cors/reactive/CorsWebFilterTests.java b/spring-web/src/test/java/org/springframework/web/cors/reactive/CorsWebFilterTests.java index a351fc3ea0..d212e05d9b 100644 --- a/spring-web/src/test/java/org/springframework/web/cors/reactive/CorsWebFilterTests.java +++ b/spring-web/src/test/java/org/springframework/web/cors/reactive/CorsWebFilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -63,6 +63,46 @@ public class CorsWebFilterTests { filter = new CorsWebFilter(r -> config); } + @Test + public void nonCorsRequest() { + WebFilterChain filterChain = (filterExchange) -> { + try { + HttpHeaders headers = filterExchange.getResponse().getHeaders(); + assertNull(headers.getFirst(ACCESS_CONTROL_ALLOW_ORIGIN)); + assertNull(headers.getFirst(ACCESS_CONTROL_EXPOSE_HEADERS)); + } catch (AssertionError ex) { + return Mono.error(ex); + } + return Mono.empty(); + + }; + MockServerWebExchange exchange = MockServerWebExchange.from( + MockServerHttpRequest + .get("https://domain1.com/test.html") + .header(HOST, "domain1.com")); + this.filter.filter(exchange, filterChain).block(); + } + + @Test + public void sameOriginRequest() { + WebFilterChain filterChain = (filterExchange) -> { + try { + HttpHeaders headers = filterExchange.getResponse().getHeaders(); + assertNull(headers.getFirst(ACCESS_CONTROL_ALLOW_ORIGIN)); + assertNull(headers.getFirst(ACCESS_CONTROL_EXPOSE_HEADERS)); + } catch (AssertionError ex) { + return Mono.error(ex); + } + return Mono.empty(); + + }; + MockServerWebExchange exchange = MockServerWebExchange.from( + MockServerHttpRequest + .get("https://domain1.com/test.html") + .header(ORIGIN, "https://domain1.com")); + this.filter.filter(exchange, filterChain).block(); + } + @Test public void validActualRequest() { WebFilterChain filterChain = (filterExchange) -> { @@ -82,7 +122,7 @@ public class CorsWebFilterTests { .header(HOST, "domain1.com") .header(ORIGIN, "https://domain2.com") .header("header2", "foo")); - this.filter.filter(exchange, filterChain); + this.filter.filter(exchange, filterChain).block(); } @Test @@ -96,8 +136,7 @@ public class CorsWebFilterTests { WebFilterChain filterChain = (filterExchange) -> Mono.error( new AssertionError("Invalid requests must not be forwarded to the filter chain")); - filter.filter(exchange, filterChain); - + filter.filter(exchange, filterChain).block(); assertNull(exchange.getResponse().getHeaders().getFirst(ACCESS_CONTROL_ALLOW_ORIGIN)); } @@ -115,7 +154,7 @@ public class CorsWebFilterTests { WebFilterChain filterChain = (filterExchange) -> Mono.error( new AssertionError("Preflight requests must not be forwarded to the filter chain")); - filter.filter(exchange, filterChain); + filter.filter(exchange, filterChain).block(); HttpHeaders headers = exchange.getResponse().getHeaders(); assertEquals("https://domain2.com", headers.getFirst(ACCESS_CONTROL_ALLOW_ORIGIN)); @@ -138,7 +177,7 @@ public class CorsWebFilterTests { WebFilterChain filterChain = (filterExchange) -> Mono.error( new AssertionError("Preflight requests must not be forwarded to the filter chain")); - filter.filter(exchange, filterChain); + filter.filter(exchange, filterChain).block(); assertNull(exchange.getResponse().getHeaders().getFirst(ACCESS_CONTROL_ALLOW_ORIGIN)); } diff --git a/spring-web/src/test/java/org/springframework/web/cors/reactive/DefaultCorsProcessorTests.java b/spring-web/src/test/java/org/springframework/web/cors/reactive/DefaultCorsProcessorTests.java index 675f23e12c..3058442066 100644 --- a/spring-web/src/test/java/org/springframework/web/cors/reactive/DefaultCorsProcessorTests.java +++ b/spring-web/src/test/java/org/springframework/web/cors/reactive/DefaultCorsProcessorTests.java @@ -60,6 +60,37 @@ public class DefaultCorsProcessorTests { } + @Test + public void requestWithoutOriginHeader() throws Exception { + MockServerHttpRequest request = MockServerHttpRequest + .method(HttpMethod.GET, "http://domain1.com/test.html") + .build(); + ServerWebExchange exchange = MockServerWebExchange.from(request); + this.processor.process(this.conf, exchange); + + ServerHttpResponse response = exchange.getResponse(); + assertFalse(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_ORIGIN)); + assertThat(response.getHeaders().get(VARY), contains(ORIGIN, + ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS)); + assertNull(response.getStatusCode()); + } + + @Test + public void sameOriginRequest() throws Exception { + MockServerHttpRequest request = MockServerHttpRequest + .method(HttpMethod.GET, "http://domain1.com/test.html") + .header(HttpHeaders.ORIGIN, "http://domain1.com") + .build(); + ServerWebExchange exchange = MockServerWebExchange.from(request); + this.processor.process(this.conf, exchange); + + ServerHttpResponse response = exchange.getResponse(); + assertFalse(response.getHeaders().containsKey(ACCESS_CONTROL_ALLOW_ORIGIN)); + assertThat(response.getHeaders().get(VARY), contains(ORIGIN, + ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS)); + assertNull(response.getStatusCode()); + } + @Test public void actualRequestWithOriginHeader() throws Exception { ServerWebExchange exchange = actualRequest(); diff --git a/spring-web/src/test/java/org/springframework/web/filter/CorsFilterTests.java b/spring-web/src/test/java/org/springframework/web/filter/CorsFilterTests.java index 77458b5ada..d2bc3f471b 100644 --- a/spring-web/src/test/java/org/springframework/web/filter/CorsFilterTests.java +++ b/spring-web/src/test/java/org/springframework/web/filter/CorsFilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2015 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -52,6 +52,36 @@ public class CorsFilterTests { filter = new CorsFilter(r -> config); } + @Test + public void nonCorsRequest() throws ServletException, IOException { + + MockHttpServletRequest request = new MockHttpServletRequest(HttpMethod.GET.name(), "/test.html"); + MockHttpServletResponse response = new MockHttpServletResponse(); + + FilterChain filterChain = (filterRequest, filterResponse) -> { + assertNull(response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertNull(response.getHeader(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS)); + }; + filter.doFilter(request, response, filterChain); + } + + @Test + public void sameOriginRequest() throws ServletException, IOException { + + MockHttpServletRequest request = new MockHttpServletRequest(HttpMethod.GET.name(), "https://domain1.com/test.html"); + request.addHeader(HttpHeaders.ORIGIN, "https://domain1.com"); + request.setScheme("https"); + request.setServerName("domain1.com"); + request.setServerPort(443); + MockHttpServletResponse response = new MockHttpServletResponse(); + + FilterChain filterChain = (filterRequest, filterResponse) -> { + assertNull(response.getHeader(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertNull(response.getHeader(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS)); + }; + filter.doFilter(request, response, filterChain); + } + @Test public void validActualRequest() throws ServletException, IOException { diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/handler/AbstractHandlerMapping.java b/spring-webflux/src/main/java/org/springframework/web/reactive/handler/AbstractHandlerMapping.java index 3a736d50ff..a323217cd8 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/handler/AbstractHandlerMapping.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/handler/AbstractHandlerMapping.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,6 +23,7 @@ import reactor.core.publisher.Mono; import org.springframework.beans.factory.BeanNameAware; import org.springframework.context.support.ApplicationObjectSupport; import org.springframework.core.Ordered; +import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.web.cors.CorsConfiguration; @@ -53,6 +54,7 @@ public abstract class AbstractHandlerMapping extends ApplicationObjectSupport private final PathPatternParser patternParser; + @Nullable private CorsConfigurationSource corsConfigurationSource; private CorsProcessor corsProcessor = new DefaultCorsProcessor(); @@ -65,7 +67,6 @@ public abstract class AbstractHandlerMapping extends ApplicationObjectSupport public AbstractHandlerMapping() { this.patternParser = new PathPatternParser(); - this.corsConfigurationSource = new UrlBasedCorsConfigurationSource(this.patternParser); } @@ -113,8 +114,14 @@ public abstract class AbstractHandlerMapping extends ApplicationObjectSupport */ public void setCorsConfigurations(Map corsConfigurations) { Assert.notNull(corsConfigurations, "corsConfigurations must not be null"); - this.corsConfigurationSource = new UrlBasedCorsConfigurationSource(this.patternParser); - ((UrlBasedCorsConfigurationSource) this.corsConfigurationSource).setCorsConfigurations(corsConfigurations); + if (!corsConfigurations.isEmpty()) { + UrlBasedCorsConfigurationSource source = new UrlBasedCorsConfigurationSource(this.patternParser); + source.setCorsConfigurations(corsConfigurations); + this.corsConfigurationSource = source; + } + else { + this.corsConfigurationSource = null; + } } /** @@ -175,12 +182,12 @@ public abstract class AbstractHandlerMapping extends ApplicationObjectSupport if (logger.isDebugEnabled()) { logger.debug(exchange.getLogPrefix() + "Mapped to " + handler); } - if (CorsUtils.isCorsRequest(exchange.getRequest())) { - CorsConfiguration configA = this.corsConfigurationSource.getCorsConfiguration(exchange); - CorsConfiguration configB = getCorsConfiguration(handler, exchange); - CorsConfiguration config = (configA != null ? configA.combine(configB) : configB); - if (!getCorsProcessor().process(config, exchange) || - CorsUtils.isPreFlightRequest(exchange.getRequest())) { + if (hasCorsConfigurationSource(handler)) { + ServerHttpRequest request = exchange.getRequest(); + CorsConfiguration config = (this.corsConfigurationSource != null ? this.corsConfigurationSource.getCorsConfiguration(exchange) : null); + CorsConfiguration handlerConfig = getCorsConfiguration(handler, exchange); + config = (config != null ? config.combine(handlerConfig) : handlerConfig); + if (!this.corsProcessor.process(config, exchange) || CorsUtils.isPreFlightRequest(request)) { return REQUEST_HANDLED_HANDLER; } } @@ -200,6 +207,14 @@ public abstract class AbstractHandlerMapping extends ApplicationObjectSupport */ protected abstract Mono getHandlerInternal(ServerWebExchange exchange); + /** + * Return {@code true} if there is a {@link CorsConfigurationSource} for this handler. + * @since 5.2 + */ + protected boolean hasCorsConfigurationSource(Object handler) { + return handler instanceof CorsConfigurationSource || this.corsConfigurationSource != null; + } + /** * Retrieve the CORS configuration for the given handler. * @param handler the handler to check (never {@code null}) diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/AbstractHandlerMethodMapping.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/AbstractHandlerMethodMapping.java index d884796a4b..a4c02f75f0 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/AbstractHandlerMethodMapping.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/AbstractHandlerMethodMapping.java @@ -370,6 +370,13 @@ public abstract class AbstractHandlerMethodMapping extends AbstractHandlerMap return null; } + @Override + protected boolean hasCorsConfigurationSource(Object handler) { + return super.hasCorsConfigurationSource(handler) || + (handler instanceof HandlerMethod && this.mappingRegistry.getCorsConfiguration((HandlerMethod) handler) != null) || + handler.equals(PREFLIGHT_AMBIGUOUS_MATCH); + } + @Override protected CorsConfiguration getCorsConfiguration(Object handler, ServerWebExchange exchange) { CorsConfiguration corsConfig = super.getCorsConfiguration(handler, exchange); @@ -451,6 +458,7 @@ public abstract class AbstractHandlerMethodMapping extends AbstractHandlerMap /** * Return CORS configuration. Thread-safe for concurrent use. */ + @Nullable public CorsConfiguration getCorsConfiguration(HandlerMethod handlerMethod) { HandlerMethod original = handlerMethod.getResolvedFromHandlerMethod(); return this.corsLookup.get(original != null ? original : handlerMethod); diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/handler/CorsUrlHandlerMappingTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/handler/CorsUrlHandlerMappingTests.java index 80fc7c4062..ca03dd9244 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/handler/CorsUrlHandlerMappingTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/handler/CorsUrlHandlerMappingTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,7 +31,6 @@ import org.springframework.web.server.ServerWebExchange; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; /** @@ -74,8 +73,7 @@ public class CorsUrlHandlerMappingTests { Object actual = this.handlerMapping.getHandler(exchange).block(); assertNotNull(actual); - assertNotSame(this.welcomeController, actual); - assertNull(exchange.getResponse().getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN)); + assertSame(this.welcomeController, actual); } @Test diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/AbstractHandlerMapping.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/AbstractHandlerMapping.java index cbf6a1a317..d84e9eb19e 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/AbstractHandlerMapping.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/AbstractHandlerMapping.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -81,7 +81,8 @@ public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport private final List adaptedInterceptors = new ArrayList<>(); - private CorsConfigurationSource corsConfigurationSource = new UrlBasedCorsConfigurationSource(); + @Nullable + private CorsConfigurationSource corsConfigurationSource; private CorsProcessor corsProcessor = new DefaultCorsProcessor(); @@ -206,11 +207,16 @@ public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport */ public void setCorsConfigurations(Map corsConfigurations) { Assert.notNull(corsConfigurations, "corsConfigurations must not be null"); - UrlBasedCorsConfigurationSource source = new UrlBasedCorsConfigurationSource(); - source.setCorsConfigurations(corsConfigurations); - source.setPathMatcher(this.pathMatcher); - source.setUrlPathHelper(this.urlPathHelper); - this.corsConfigurationSource = source; + if (!corsConfigurations.isEmpty()) { + UrlBasedCorsConfigurationSource source = new UrlBasedCorsConfigurationSource(); + source.setCorsConfigurations(corsConfigurations); + source.setPathMatcher(this.pathMatcher); + source.setUrlPathHelper(this.urlPathHelper); + this.corsConfigurationSource = source; + } + else { + this.corsConfigurationSource = null; + } } /** @@ -420,10 +426,10 @@ public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport logger.debug("Mapped to " + executionChain.getHandler()); } - if (CorsUtils.isCorsRequest(request)) { - CorsConfiguration globalConfig = this.corsConfigurationSource.getCorsConfiguration(request); + if (hasCorsConfigurationSource(handler)) { + CorsConfiguration config = (this.corsConfigurationSource != null ? this.corsConfigurationSource.getCorsConfiguration(request) : null); CorsConfiguration handlerConfig = getCorsConfiguration(handler, request); - CorsConfiguration config = (globalConfig != null ? globalConfig.combine(handlerConfig) : handlerConfig); + config = (config != null ? config.combine(handlerConfig) : handlerConfig); executionChain = getCorsHandlerExecutionChain(request, executionChain, config); } @@ -488,6 +494,14 @@ public abstract class AbstractHandlerMapping extends WebApplicationObjectSupport return chain; } + /** + * Return {@code true} if there is a {@link CorsConfigurationSource} for this handler. + * @since 5.2 + */ + protected boolean hasCorsConfigurationSource(Object handler) { + return handler instanceof CorsConfigurationSource || this.corsConfigurationSource != null; + } + /** * Retrieve the CORS configuration for the given handler. * @param handler the handler to check (never {@code null}). diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/AbstractHandlerMethodMapping.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/AbstractHandlerMethodMapping.java index d97ec37383..cf4f430ca6 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/AbstractHandlerMethodMapping.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/AbstractHandlerMethodMapping.java @@ -448,6 +448,13 @@ public abstract class AbstractHandlerMethodMapping extends AbstractHandlerMap return null; } + @Override + protected boolean hasCorsConfigurationSource(Object handler) { + return super.hasCorsConfigurationSource(handler) || + (handler instanceof HandlerMethod && this.mappingRegistry.getCorsConfiguration((HandlerMethod) handler) != null) || + handler.equals(PREFLIGHT_AMBIGUOUS_MATCH); + } + @Override protected CorsConfiguration getCorsConfiguration(Object handler, HttpServletRequest request) { CorsConfiguration corsConfig = super.getCorsConfiguration(handler, request); @@ -555,6 +562,7 @@ public abstract class AbstractHandlerMethodMapping extends AbstractHandlerMap /** * Return CORS configuration. Thread-safe for concurrent use. */ + @Nullable public CorsConfiguration getCorsConfiguration(HandlerMethod handlerMethod) { HandlerMethod original = handlerMethod.getResolvedFromHandlerMethod(); return this.corsLookup.get(original != null ? original : handlerMethod); diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/config/annotation/WebMvcConfigurationSupportExtensionTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/config/annotation/WebMvcConfigurationSupportExtensionTests.java index 6a485cbc89..71b6617efc 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/config/annotation/WebMvcConfigurationSupportExtensionTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/config/annotation/WebMvcConfigurationSupportExtensionTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -138,7 +138,7 @@ public class WebMvcConfigurationSupportExtensionTests { HandlerExecutionChain chain = rmHandlerMapping.getHandler(new MockHttpServletRequest("GET", "/")); assertNotNull(chain); assertNotNull(chain.getInterceptors()); - assertEquals(3, chain.getInterceptors().length); + assertEquals(4, chain.getInterceptors().length); assertEquals(LocaleChangeInterceptor.class, chain.getInterceptors()[0].getClass()); assertEquals(ConversionServiceExposingInterceptor.class, chain.getInterceptors()[1].getClass()); assertEquals(ResourceUrlProviderExposingInterceptor.class, chain.getInterceptors()[2].getClass()); @@ -177,7 +177,7 @@ public class WebMvcConfigurationSupportExtensionTests { chain = handlerMapping.getHandler(new MockHttpServletRequest("GET", "/resources/foo.gif")); assertNotNull(chain); assertNotNull(chain.getHandler()); - assertEquals(Arrays.toString(chain.getInterceptors()), 4, chain.getInterceptors().length); + assertEquals(Arrays.toString(chain.getInterceptors()), 5, chain.getInterceptors().length); // PathExposingHandlerInterceptor at chain.getInterceptors()[0] assertEquals(LocaleChangeInterceptor.class, chain.getInterceptors()[1].getClass()); assertEquals(ConversionServiceExposingInterceptor.class, chain.getInterceptors()[2].getClass()); diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/CorsAbstractHandlerMappingTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/CorsAbstractHandlerMappingTests.java index 160db8f4c5..d400c4ca3e 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/CorsAbstractHandlerMappingTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/CorsAbstractHandlerMappingTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -81,8 +81,7 @@ public class CorsAbstractHandlerMappingTests { this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET"); HandlerExecutionChain chain = handlerMapping.getHandler(this.request); assertNotNull(chain); - assertNotNull(chain.getHandler()); - assertTrue(chain.getHandler().getClass().getSimpleName().equals("PreFlightHandler")); + assertTrue(chain.getHandler() instanceof SimpleHandler); } @Test diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java index bb003bdebe..1d399b8049 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java @@ -33,6 +33,7 @@ import javax.servlet.http.HttpServletRequest; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.HttpStatus; import org.springframework.http.InvalidMediaTypeException; @@ -48,7 +49,6 @@ import org.springframework.util.ObjectUtils; import org.springframework.util.StringUtils; import org.springframework.web.cors.CorsConfiguration; import org.springframework.web.cors.CorsConfigurationSource; -import org.springframework.web.cors.CorsUtils; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.sockjs.SockJsException; import org.springframework.web.socket.sockjs.SockJsService; @@ -495,7 +495,7 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig @Override @Nullable public CorsConfiguration getCorsConfiguration(HttpServletRequest request) { - if (!this.suppressCors && CorsUtils.isCorsRequest(request)) { + if (!this.suppressCors && (request.getHeader(HttpHeaders.ORIGIN) != null)) { CorsConfiguration config = new CorsConfiguration(); config.setAllowedOrigins(new ArrayList<>(this.allowedOrigins)); config.addAllowedMethod("*");