diff --git a/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java b/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java index e7e627bb96..8a7cc349ec 100644 --- a/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java +++ b/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java @@ -42,19 +42,19 @@ import org.springframework.web.util.UriComponentsBuilder; import org.springframework.web.util.UrlPathHelper; /** - * Extract values from "Forwarded" and "X-Forwarded-*" headers in order to wrap - * and override the following from the request and response: - * {@link HttpServletRequest#getServerName() getServerName()}, - * {@link HttpServletRequest#getServerPort() getServerPort()}, - * {@link HttpServletRequest#getScheme() getScheme()}, - * {@link HttpServletRequest#isSecure() isSecure()}, and - * {@link HttpServletResponse#sendRedirect(String) sendRedirect(String)}. - * In effect the wrapped request and response reflect the client-originated - * protocol and address. + * Extract values from "Forwarded" and "X-Forwarded-*" headers, wrap the request + * and response, and make they reflect the client-originated protocol and + * address in the following methods: + * * - *

Note: This filter can also be used in a - * {@link #setRemoveOnly removeOnly} mode where "Forwarded" and "X-Forwarded-*" - * headers are only eliminated without being used. + *

This filter can also be used in a {@link #setRemoveOnly removeOnly} mode + * where "Forwarded" and "X-Forwarded-*" headers are eliminated, and not used. * * @author Rossen Stoyanchev * @author EddĂș MelĂ©ndez @@ -117,7 +117,7 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { @Override - protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException { + protected boolean shouldNotFilter(HttpServletRequest request) { for (String headerName : FORWARDED_HEADER_NAMES) { if (request.getHeader(headerName) != null) { return false; @@ -141,15 +141,18 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { FilterChain filterChain) throws ServletException, IOException { if (this.removeOnly) { - ForwardedHeaderRemovingRequest theRequest = new ForwardedHeaderRemovingRequest(request); - filterChain.doFilter(theRequest, response); + ForwardedHeaderRemovingRequest wrappedRequest = new ForwardedHeaderRemovingRequest(request); + filterChain.doFilter(wrappedRequest, response); } else { - HttpServletRequest theRequest = new ForwardedHeaderExtractingRequest(request, this.pathHelper); - HttpServletResponse theResponse = (this.relativeRedirects ? + HttpServletRequest wrappedRequest = + new ForwardedHeaderExtractingRequest(request, this.pathHelper); + + HttpServletResponse wrappedResponse = this.relativeRedirects ? RelativeRedirectResponseWrapper.wrapIfNecessary(response, HttpStatus.SEE_OTHER) : - new ForwardedHeaderExtractingResponse(response, theRequest)); - filterChain.doFilter(theRequest, theResponse); + new ForwardedHeaderExtractingResponse(response, wrappedRequest); + + filterChain.doFilter(wrappedRequest, wrappedResponse); } } @@ -221,7 +224,8 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { private final String requestUrl; - public ForwardedHeaderExtractingRequest(HttpServletRequest request, UrlPathHelper pathHelper) { + + ForwardedHeaderExtractingRequest(HttpServletRequest request, UrlPathHelper pathHelper) { super(request); HttpRequest httpRequest = new ServletServerHttpRequest(request); @@ -257,6 +261,7 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { return prefix; } + @Override @Nullable public String getScheme() { @@ -302,11 +307,13 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { private final HttpServletRequest request; - public ForwardedHeaderExtractingResponse(HttpServletResponse response, HttpServletRequest request) { + + ForwardedHeaderExtractingResponse(HttpServletResponse response, HttpServletRequest request) { super(response); this.request = request; } + @Override public void sendRedirect(String location) throws IOException { diff --git a/spring-web/src/main/java/org/springframework/web/filter/reactive/ForwardedHeaderFilter.java b/spring-web/src/main/java/org/springframework/web/filter/reactive/ForwardedHeaderFilter.java index d48484a1aa..5a1d5c49c2 100644 --- a/spring-web/src/main/java/org/springframework/web/filter/reactive/ForwardedHeaderFilter.java +++ b/spring-web/src/main/java/org/springframework/web/filter/reactive/ForwardedHeaderFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,17 +31,16 @@ import org.springframework.web.server.WebFilterChain; import org.springframework.web.util.UriComponentsBuilder; /** - * Extract values from "Forwarded" and "X-Forwarded-*" headers in order to change - * and override {@link ServerHttpRequest#getURI()}. - * In effect the request URI will reflect the client-originated + * Extract values from "Forwarded" and "X-Forwarded-*" headers, and use them to + * override {@link ServerHttpRequest#getURI()} to reflect the client-originated * protocol and address. * - *

Note: This filter can also be used in a - * {@link #setRemoveOnly removeOnly} mode where "Forwarded" and "X-Forwarded-*" - * headers are only eliminated without being used. + *

This filter can also be used in a {@link #setRemoveOnly removeOnly} mode + * where "Forwarded" and "X-Forwarded-*" headers are eliminated, and not used. + * * @author Arjen Poutsma - * @see https://tools.ietf.org/html/rfc7239 * @since 5.0 + * @see https://tools.ietf.org/html/rfc7239 */ public class ForwardedHeaderFilter implements WebFilter { @@ -55,8 +54,10 @@ public class ForwardedHeaderFilter implements WebFilter { FORWARDED_HEADER_NAMES.add("X-Forwarded-Prefix"); } + private boolean removeOnly; + /** * Enables mode in which any "Forwarded" or "X-Forwarded-*" headers are * removed only and the information in them ignored. @@ -66,6 +67,7 @@ public class ForwardedHeaderFilter implements WebFilter { this.removeOnly = removeOnly; } + @Override public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { @@ -73,31 +75,29 @@ public class ForwardedHeaderFilter implements WebFilter { return chain.filter(exchange); } + ServerWebExchange mutatedExchange; + if (this.removeOnly) { - ServerWebExchange withoutForwardHeaders = exchange.mutate() - .request(builder -> builder.headers( - headers -> { - for (String headerName : FORWARDED_HEADER_NAMES) { - headers.remove(headerName); - } - })).build(); - return chain.filter(withoutForwardHeaders); + mutatedExchange = exchange.mutate().request(builder -> + builder.headers(headers -> { + FORWARDED_HEADER_NAMES.forEach(headers::remove); + })) + .build(); } else { URI uri = UriComponentsBuilder.fromHttpRequest(exchange.getRequest()).build().toUri(); String prefix = getForwardedPrefix(exchange.getRequest().getHeaders()); - ServerWebExchange withChangedUri = exchange.mutate() - .request(builder -> { - builder.uri(uri); - if (prefix != null) { - builder.path(prefix + uri.getPath()); - builder.contextPath(prefix); - } - }).build(); - return chain.filter(withChangedUri); + mutatedExchange = exchange.mutate().request(builder -> { + builder.uri(uri); + if (prefix != null) { + builder.path(prefix + uri.getPath()); + builder.contextPath(prefix); + } + }).build(); } + return chain.filter(mutatedExchange); } private boolean shouldNotFilter(ServerHttpRequest request) { diff --git a/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java b/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java index ef14dbefbb..a5fadd84c6 100644 --- a/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java +++ b/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java @@ -103,6 +103,17 @@ public class ForwardedHeaderFilterTests { assertEquals("/prefix", actual); } + private String filterAndGetContextPath() throws ServletException, IOException { + return filterAndGetWrappedRequest().getContextPath(); + } + + private HttpServletRequest filterAndGetWrappedRequest() throws ServletException, IOException { + MockHttpServletResponse response = new MockHttpServletResponse(); + this.filter.doFilterInternal(this.request, response, this.filterChain); + return (HttpServletRequest) this.filterChain.getRequest(); + } + + @Test public void contextPathPreserveEncoding() throws Exception { this.request.setContextPath("/app%20"); @@ -183,8 +194,8 @@ public class ForwardedHeaderFilterTests { @Test public void caseInsensitiveForwardedPrefix() throws Exception { this.request = new MockHttpServletRequest() { - // Make it case-sensitive (SPR-14372) - @Override + + @Override // SPR-14372: make it case-sensitive public String getHeader(String header) { Enumeration names = getHeaderNames(); while (names.hasMoreElements()) { @@ -204,15 +215,21 @@ public class ForwardedHeaderFilterTests { } @Test - public void shouldFilter() throws Exception { + public void shouldFilter() { testShouldFilter("Forwarded"); testShouldFilter(X_FORWARDED_HOST); testShouldFilter(X_FORWARDED_PORT); testShouldFilter(X_FORWARDED_PROTO); } + private void testShouldFilter(String headerName) { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.addHeader(headerName, "1"); + assertFalse(this.filter.shouldNotFilter(request)); + } + @Test - public void shouldNotFilter() throws Exception { + public void shouldNotFilter() { assertTrue(this.filter.shouldNotFilter(new MockHttpServletRequest())); } @@ -417,7 +434,6 @@ public class ForwardedHeaderFilterTests { this.request.addHeader(X_FORWARDED_HOST, "example.com"); this.request.addHeader(X_FORWARDED_PORT, "443"); this.filter.setRelativeRedirects(true); - String location = sendRedirect("/a"); assertEquals("/a", location); @@ -426,7 +442,6 @@ public class ForwardedHeaderFilterTests { @Test public void sendRedirectWhenRequestOnlyAndNoXForwardedThenUsesRelativeRedirects() throws Exception { this.filter.setRelativeRedirects(true); - String location = sendRedirect("/a"); assertEquals("/a", location); @@ -441,34 +456,12 @@ public class ForwardedHeaderFilterTests { res.sendRedirect(location); } }; - MockHttpServletResponse response = doWithFiltersAndGetResponse(this.filter, filter); + + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = new MockFilterChain(new HttpServlet() {}, this.filter, filter); + filterChain.doFilter(request, response); + return response.getRedirectedUrl(); } - @SuppressWarnings("serial") - private MockHttpServletResponse doWithFiltersAndGetResponse(Filter... filters) - throws ServletException, IOException { - - MockHttpServletResponse response = new MockHttpServletResponse(); - FilterChain filterChain = new MockFilterChain(new HttpServlet() {}, filters); - filterChain.doFilter(request, response); - return response; - } - - private String filterAndGetContextPath() throws ServletException, IOException { - return filterAndGetWrappedRequest().getContextPath(); - } - - private HttpServletRequest filterAndGetWrappedRequest() throws ServletException, IOException { - MockHttpServletResponse response = new MockHttpServletResponse(); - this.filter.doFilterInternal(this.request, response, this.filterChain); - return (HttpServletRequest) this.filterChain.getRequest(); - } - - private void testShouldFilter(String headerName) throws ServletException { - MockHttpServletRequest request = new MockHttpServletRequest(); - request.addHeader(headerName, "1"); - assertFalse(this.filter.shouldNotFilter(request)); - } - } diff --git a/spring-web/src/test/java/org/springframework/web/filter/reactive/ForwardedHeaderFilterTests.java b/spring-web/src/test/java/org/springframework/web/filter/reactive/ForwardedHeaderFilterTests.java index 3ec757d185..dbff65137f 100644 --- a/spring-web/src/test/java/org/springframework/web/filter/reactive/ForwardedHeaderFilterTests.java +++ b/spring-web/src/test/java/org/springframework/web/filter/reactive/ForwardedHeaderFilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,18 +24,21 @@ import reactor.core.publisher.Mono; import org.springframework.http.HttpHeaders; import org.springframework.lang.Nullable; -import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; import org.springframework.mock.web.test.server.MockServerWebExchange; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebFilterChain; import static org.junit.Assert.*; +import static org.springframework.mock.http.server.reactive.test.MockServerHttpRequest.*; /** * @author Arjen Poutsma */ public class ForwardedHeaderFilterTests { + private static final String BASE_URL = "http://example.com/path"; + + private final ForwardedHeaderFilter filter = new ForwardedHeaderFilter(); private final TestWebFilterChain filterChain = new TestWebFilterChain(); @@ -43,8 +46,7 @@ public class ForwardedHeaderFilterTests { @Test public void removeOnly() { - MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest - .get("/") + ServerWebExchange exchange = MockServerWebExchange.from(get(BASE_URL) .header("Forwarded", "for=192.0.2.60;proto=http;by=203.0.113.43") .header("X-Forwarded-Host", "example.com") .header("X-Forwarded-Port", "8080") @@ -65,66 +67,57 @@ public class ForwardedHeaderFilterTests { @Test public void xForwardedRequest() throws Exception { - MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest - .get("http://example.com/path") + ServerWebExchange exchange = MockServerWebExchange.from(get(BASE_URL) .header("X-Forwarded-Host", "84.198.58.199") .header("X-Forwarded-Port", "443") .header("X-Forwarded-Proto", "https")); - this.filter.filter(exchange, this.filterChain).block(Duration.ZERO); - - URI uri = this.filterChain.uri; - assertEquals(new URI("https://84.198.58.199/path"), uri); + assertEquals(new URI("https://84.198.58.199/path"), filterAndGetUri(exchange)); } @Test public void forwardedRequest() throws Exception { - MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest - .get("http://example.com/path") + ServerWebExchange exchange = MockServerWebExchange.from(get(BASE_URL) .header("Forwarded", "host=84.198.58.199;proto=https")); - this.filter.filter(exchange, this.filterChain).block(Duration.ZERO); - - URI uri = this.filterChain.uri; - assertEquals(new URI("https://84.198.58.199/path"), uri); + assertEquals(new URI("https://84.198.58.199/path"), filterAndGetUri(exchange)); } @Test public void requestUriWithForwardedPrefix() throws Exception { - MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest - .get("http://example.com/path") + ServerWebExchange exchange = MockServerWebExchange.from(get(BASE_URL) .header("X-Forwarded-Prefix", "/prefix")); - this.filter.filter(exchange, this.filterChain).block(Duration.ZERO); - - URI uri = this.filterChain.uri; - assertEquals(new URI("http://example.com/prefix/path"), uri); + assertEquals(new URI("http://example.com/prefix/path"), filterAndGetUri(exchange)); } @Test public void requestUriWithForwardedPrefixTrailingSlash() throws Exception { - MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest - .get("http://example.com/path") + ServerWebExchange exchange = MockServerWebExchange.from(get(BASE_URL) .header("X-Forwarded-Prefix", "/prefix/")); - this.filter.filter(exchange, this.filterChain).block(Duration.ZERO); + assertEquals(new URI("http://example.com/prefix/path"), filterAndGetUri(exchange)); + } - URI uri = this.filterChain.uri; - assertEquals(new URI("http://example.com/prefix/path"), uri); + @Nullable + private URI filterAndGetUri(ServerWebExchange exchange) { + this.filter.filter(exchange, this.filterChain).block(Duration.ZERO); + return this.filterChain.uri; } private static class TestWebFilterChain implements WebFilterChain { @Nullable - private HttpHeaders httpHeaders; + private HttpHeaders headers; @Nullable private URI uri; + @Nullable public HttpHeaders getHeaders() { - return this.httpHeaders; + return this.headers; } @Nullable @@ -134,12 +127,10 @@ public class ForwardedHeaderFilterTests { @Override public Mono filter(ServerWebExchange exchange) { - this.httpHeaders = exchange.getRequest().getHeaders(); + this.headers = exchange.getRequest().getHeaders(); this.uri = exchange.getRequest().getURI(); return Mono.empty(); } } - - } \ No newline at end of file