From aa51ed19403129bd9f0b8f5373391ba6bc8baaaa Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Mon, 10 May 2021 14:15:16 +0100 Subject: [PATCH] Fix failing tests This commit ensures that if an Origin is returned as it was provided, possibly with a trailing slash. See gh-26892 --- .../web/cors/CorsConfiguration.java | 16 ++++++++-------- .../web/cors/CorsConfigurationTests.java | 4 ++-- .../mvc/method/annotation/CrossOriginTests.java | 4 ++-- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/spring-web/src/main/java/org/springframework/web/cors/CorsConfiguration.java b/spring-web/src/main/java/org/springframework/web/cors/CorsConfiguration.java index ddbe9d5ba6..1eee79898c 100644 --- a/spring-web/src/main/java/org/springframework/web/cors/CorsConfiguration.java +++ b/spring-web/src/main/java/org/springframework/web/cors/CorsConfiguration.java @@ -549,31 +549,31 @@ public class CorsConfiguration { /** * Check the origin of the request against the configured allowed origins. - * @param requestOrigin the origin to check + * @param origin the origin to check * @return the origin to use for the response, or {@code null} which * means the request origin is not allowed */ @Nullable - public String checkOrigin(@Nullable String requestOrigin) { - if (!StringUtils.hasText(requestOrigin)) { + public String checkOrigin(@Nullable String origin) { + if (!StringUtils.hasText(origin)) { return null; } - requestOrigin = trimTrailingSlash(requestOrigin); + String originToCheck = trimTrailingSlash(origin); if (!ObjectUtils.isEmpty(this.allowedOrigins)) { if (this.allowedOrigins.contains(ALL)) { validateAllowCredentials(); return ALL; } for (String allowedOrigin : this.allowedOrigins) { - if (requestOrigin.equalsIgnoreCase(allowedOrigin)) { - return requestOrigin; + if (originToCheck.equalsIgnoreCase(allowedOrigin)) { + return origin; } } } if (!ObjectUtils.isEmpty(this.allowedOriginPatterns)) { for (OriginPattern p : this.allowedOriginPatterns) { - if (p.getDeclaredPattern().equals(ALL) || p.getPattern().matcher(requestOrigin).matches()) { - return requestOrigin; + if (p.getDeclaredPattern().equals(ALL) || p.getPattern().matcher(originToCheck).matches()) { + return origin; } } } diff --git a/spring-web/src/test/java/org/springframework/web/cors/CorsConfigurationTests.java b/spring-web/src/test/java/org/springframework/web/cors/CorsConfigurationTests.java index 4cfdf1cc35..b920a9f167 100644 --- a/spring-web/src/test/java/org/springframework/web/cors/CorsConfigurationTests.java +++ b/spring-web/src/test/java/org/springframework/web/cors/CorsConfigurationTests.java @@ -294,12 +294,12 @@ public class CorsConfigurationTests { // specific origin matches Origin header with or without trailing "/" config.setAllowedOrigins(Collections.singletonList("https://domain.com")); assertThat(config.checkOrigin("https://domain.com")).isEqualTo("https://domain.com"); - assertThat(config.checkOrigin("https://domain.com/")).isEqualTo("https://domain.com"); + assertThat(config.checkOrigin("https://domain.com/")).isEqualTo("https://domain.com/"); // specific origin with trailing "/" matches Origin header with or without trailing "/" config.setAllowedOrigins(Collections.singletonList("https://domain.com/")); assertThat(config.checkOrigin("https://domain.com")).isEqualTo("https://domain.com"); - assertThat(config.checkOrigin("https://domain.com/")).isEqualTo("https://domain.com"); + assertThat(config.checkOrigin("https://domain.com/")).isEqualTo("https://domain.com/"); config.setAllowCredentials(false); assertThat(config.checkOrigin("https://domain.com")).isEqualTo("https://domain.com"); diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/CrossOriginTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/CrossOriginTests.java index cb9e9f2538..3f1fce6612 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/CrossOriginTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/CrossOriginTests.java @@ -284,7 +284,7 @@ class CrossOriginTests { CorsConfiguration config = getCorsConfiguration(chain, false); assertThat(config).isNotNull(); assertThat(config.getAllowedMethods()).containsExactly("GET"); - assertThat(config.getAllowedOrigins()).containsExactly("http://www.foo.example/"); + assertThat(config.getAllowedOrigins()).containsExactly("http://www.foo.example"); assertThat(config.getAllowCredentials()).isTrue(); } @@ -297,7 +297,7 @@ class CrossOriginTests { CorsConfiguration config = getCorsConfiguration(chain, false); assertThat(config).isNotNull(); assertThat(config.getAllowedMethods()).containsExactly("GET"); - assertThat(config.getAllowedOrigins()).containsExactly("http://www.foo.example/"); + assertThat(config.getAllowedOrigins()).containsExactly("http://www.foo.example"); assertThat(config.getAllowCredentials()).isTrue(); }