Remove individual detection of forwarded headers

This commit removes all places where forwarded headers are checked
implicitly, on an ad-hoc basis.

ForwardedHeaderFilter is expected to be used instead providing
centralized control over using or discarding such headers.

Issue: SPR-16668
This commit is contained in:
Rossen Stoyanchev 2018-05-11 09:31:39 -04:00
parent 82a8e42ff9
commit 4da43de7e1
14 changed files with 251 additions and 299 deletions

View File

@ -38,7 +38,6 @@ import org.springframework.http.HttpCookie;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod; import org.springframework.http.HttpMethod;
import org.springframework.http.HttpRange; import org.springframework.http.HttpRange;
import org.springframework.http.HttpRequest;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
import org.springframework.http.codec.HttpMessageReader; import org.springframework.http.codec.HttpMessageReader;
import org.springframework.http.codec.multipart.Part; import org.springframework.http.codec.multipart.Part;
@ -141,7 +140,7 @@ public class MockServerRequest implements ServerRequest {
@Override @Override
public UriBuilder uriBuilder() { public UriBuilder uriBuilder() {
return UriComponentsBuilder.fromHttpRequest(new ServerRequestAdapter()); return UriComponentsBuilder.fromUri(this.uri);
} }
@Override @Override
@ -571,22 +570,4 @@ public class MockServerRequest implements ServerRequest {
} }
private final class ServerRequestAdapter implements HttpRequest {
@Override
public String getMethodValue() {
return methodName();
}
@Override
public URI getURI() {
return MockServerRequest.this.uri;
}
@Override
public HttpHeaders getHeaders() {
return MockServerRequest.this.headers.headers;
}
}
} }

View File

@ -16,6 +16,8 @@
package org.springframework.web.cors.reactive; package org.springframework.web.cors.reactive;
import java.net.URI;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod; import org.springframework.http.HttpMethod;
import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.http.server.reactive.ServerHttpRequest;
@ -49,18 +51,16 @@ public abstract class CorsUtils {
} }
/** /**
* Check if the request is a same-origin one, based on {@code Origin}, {@code Host}, * Check if the request is a same-origin one, based on {@code Origin}, and
* {@code Forwarded}, {@code X-Forwarded-Proto}, {@code X-Forwarded-Host} and * {@code Host} headers.
* @code X-Forwarded-Port} headers. *
* <p><strong>Note:</strong> as of 5.1 this method ignores
* {@code "Forwarded"} and {@code "X-Forwarded-*"} headers that specify the
* client-originated address. Consider using the {@code ForwardedHeaderFilter}
* to extract and use, or to discard such headers.
*
* @return {@code true} if the request is a same-origin one, {@code false} in case * @return {@code true} if the request is a same-origin one, {@code false} in case
* of a cross-origin request * of a cross-origin request
* <p><strong>Note:</strong> this method uses values from "Forwarded"
* (<a href="http://tools.ietf.org/html/rfc7239">RFC 7239</a>),
* "X-Forwarded-Host", "X-Forwarded-Port", and "X-Forwarded-Proto" headers,
* if present, in order to reflect the client-originated address.
* Consider using the {@code ForwardedHeaderFilter} in order to choose from a
* central place whether to extract and use, or to discard such headers.
* See the Spring Framework reference for more on this filter.
*/ */
public static boolean isSameOrigin(ServerHttpRequest request) { public static boolean isSameOrigin(ServerHttpRequest request) {
String origin = request.getHeaders().getOrigin(); String origin = request.getHeaders().getOrigin();
@ -68,9 +68,9 @@ public abstract class CorsUtils {
return true; return true;
} }
UriComponents actualUrl = UriComponentsBuilder.fromHttpRequest(request).build(); URI uri = request.getURI();
String actualHost = actualUrl.getHost(); String actualHost = uri.getHost();
int actualPort = getPort(actualUrl.getScheme(), actualUrl.getPort()); int actualPort = getPort(uri.getScheme(), uri.getPort());
Assert.notNull(actualHost, "Actual request host must not be null"); Assert.notNull(actualHost, "Actual request host must not be null");
Assert.isTrue(actualPort != -1, "Actual request port must not be undefined"); Assert.isTrue(actualPort != -1, "Actual request port must not be undefined");

View File

@ -18,11 +18,10 @@ package org.springframework.web.util;
import java.io.File; import java.io.File;
import java.io.FileNotFoundException; import java.io.FileNotFoundException;
import java.net.URI;
import java.util.Collection; import java.util.Collection;
import java.util.Enumeration; import java.util.Enumeration;
import java.util.LinkedHashSet;
import java.util.Map; import java.util.Map;
import java.util.Set;
import java.util.StringTokenizer; import java.util.StringTokenizer;
import java.util.TreeMap; import java.util.TreeMap;
import javax.servlet.ServletContext; import javax.servlet.ServletContext;
@ -138,16 +137,6 @@ public abstract class WebUtils {
/** Key for the mutex session attribute */ /** Key for the mutex session attribute */
public static final String SESSION_MUTEX_ATTRIBUTE = WebUtils.class.getName() + ".MUTEX"; public static final String SESSION_MUTEX_ATTRIBUTE = WebUtils.class.getName() + ".MUTEX";
private static final Set<String> FORWARDED_HEADER_NAMES = new LinkedHashSet<>(5);
static {
FORWARDED_HEADER_NAMES.add("Forwarded");
FORWARDED_HEADER_NAMES.add("X-Forwarded-Host");
FORWARDED_HEADER_NAMES.add("X-Forwarded-Port");
FORWARDED_HEADER_NAMES.add("X-Forwarded-Proto");
FORWARDED_HEADER_NAMES.add("X-Forwarded-Prefix");
}
/** /**
* Set a system property to the web application root directory. * Set a system property to the web application root directory.
@ -677,13 +666,12 @@ public abstract class WebUtils {
* Check the given request origin against a list of allowed origins. * Check the given request origin against a list of allowed origins.
* A list containing "*" means that all origins are allowed. * A list containing "*" means that all origins are allowed.
* An empty list means only same origin is allowed. * An empty list means only same origin is allowed.
* <p><strong>Note:</strong> this method may use values from "Forwarded" *
* (<a href="http://tools.ietf.org/html/rfc7239">RFC 7239</a>), * <p><strong>Note:</strong> as of 5.1 this method ignores
* "X-Forwarded-Host", "X-Forwarded-Port", and "X-Forwarded-Proto" headers, * {@code "Forwarded"} and {@code "X-Forwarded-*"} headers that specify the
* if present, in order to reflect the client-originated address. * client-originated address. Consider using the {@code ForwardedHeaderFilter}
* Consider using the {@code ForwardedHeaderFilter} in order to choose from a * to extract and use, or to discard such headers.
* central place whether to extract and use, or to discard such headers. *
* See the Spring Framework reference for more on this filter.
* @return {@code true} if the request origin is valid, {@code false} otherwise * @return {@code true} if the request origin is valid, {@code false} otherwise
* @since 4.1.5 * @since 4.1.5
* @see <a href="https://tools.ietf.org/html/rfc6454">RFC 6454: The Web Origin Concept</a> * @see <a href="https://tools.ietf.org/html/rfc6454">RFC 6454: The Web Origin Concept</a>
@ -708,13 +696,12 @@ public abstract class WebUtils {
* Check if the request is a same-origin one, based on {@code Origin}, {@code Host}, * Check if the request is a same-origin one, based on {@code Origin}, {@code Host},
* {@code Forwarded}, {@code X-Forwarded-Proto}, {@code X-Forwarded-Host} and * {@code Forwarded}, {@code X-Forwarded-Proto}, {@code X-Forwarded-Host} and
* @code X-Forwarded-Port} headers. * @code X-Forwarded-Port} headers.
* <p><strong>Note:</strong> this method uses values from "Forwarded" *
* (<a href="http://tools.ietf.org/html/rfc7239">RFC 7239</a>), * <p><strong>Note:</strong> as of 5.1 this method ignores
* "X-Forwarded-Host", "X-Forwarded-Port", and "X-Forwarded-Proto" headers, * {@code "Forwarded"} and {@code "X-Forwarded-*"} headers that specify the
* if present, in order to reflect the client-originated address. * client-originated address. Consider using the {@code ForwardedHeaderFilter}
* Consider using the {@code ForwardedHeaderFilter} in order to choose from a * to extract and use, or to discard such headers.
* central place whether to extract and use, or to discard such headers.
* See the Spring Framework reference for more on this filter.
* @return {@code true} if the request is a same-origin one, {@code false} in case * @return {@code true} if the request is a same-origin one, {@code false} in case
* of cross-origin request * of cross-origin request
* @since 4.2 * @since 4.2
@ -735,21 +722,12 @@ public abstract class WebUtils {
scheme = servletRequest.getScheme(); scheme = servletRequest.getScheme();
host = servletRequest.getServerName(); host = servletRequest.getServerName();
port = servletRequest.getServerPort(); port = servletRequest.getServerPort();
if (containsForwardedHeaders(servletRequest)) {
UriComponents actualUrl = new UriComponentsBuilder()
.scheme(scheme).host(host).port(port)
.adaptFromForwardedHeaders(headers)
.build();
scheme = actualUrl.getScheme();
host = actualUrl.getHost();
port = actualUrl.getPort();
}
} }
else { else {
UriComponents actualUrl = UriComponentsBuilder.fromHttpRequest(request).build(); URI uri = request.getURI();
scheme = actualUrl.getScheme(); scheme = uri.getScheme();
host = actualUrl.getHost(); host = uri.getHost();
port = actualUrl.getPort(); port = uri.getPort();
} }
UriComponents originUrl = UriComponentsBuilder.fromOriginHeader(origin).build(); UriComponents originUrl = UriComponentsBuilder.fromOriginHeader(origin).build();
@ -757,15 +735,6 @@ public abstract class WebUtils {
getPort(scheme, port) == getPort(originUrl.getScheme(), originUrl.getPort())); getPort(scheme, port) == getPort(originUrl.getScheme(), originUrl.getPort()));
} }
private static boolean containsForwardedHeaders(HttpServletRequest request) {
for (String headerName : FORWARDED_HEADER_NAMES) {
if (request.getHeader(headerName) != null) {
return true;
}
}
return false;
}
private static int getPort(@Nullable String scheme, int port) { private static int getPort(@Nullable String scheme, int port) {
if (port == -1) { if (port == -1) {
if ("http".equals(scheme) || "ws".equals(scheme)) { if ("http".equals(scheme) || "ws".equals(scheme)) {

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2015 the original author or authors. * Copyright 2002-2018 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,15 +16,19 @@
package org.springframework.web.cors.reactive; package org.springframework.web.cors.reactive;
import java.util.concurrent.atomic.AtomicReference;
import org.junit.Test; import org.junit.Test;
import reactor.core.publisher.Mono;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest;
import org.springframework.mock.web.test.server.MockServerWebExchange;
import org.springframework.web.filter.reactive.ForwardedHeaderFilter;
import static org.junit.Assert.assertFalse; import static org.junit.Assert.*;
import static org.junit.Assert.assertTrue; import static org.springframework.mock.http.server.reactive.test.MockServerHttpRequest.*;
import static org.springframework.mock.http.server.reactive.test.MockServerHttpRequest.get;
import static org.springframework.mock.http.server.reactive.test.MockServerHttpRequest.options;
/** /**
* Test case for reactive {@link CorsUtils}. * Test case for reactive {@link CorsUtils}.
@ -35,19 +39,19 @@ public class CorsUtilsTests {
@Test @Test
public void isCorsRequest() { public void isCorsRequest() {
MockServerHttpRequest request = get("/").header(HttpHeaders.ORIGIN, "http://domain.com").build(); ServerHttpRequest request = get("/").header(HttpHeaders.ORIGIN, "http://domain.com").build();
assertTrue(CorsUtils.isCorsRequest(request)); assertTrue(CorsUtils.isCorsRequest(request));
} }
@Test @Test
public void isNotCorsRequest() { public void isNotCorsRequest() {
MockServerHttpRequest request = get("/").build(); ServerHttpRequest request = get("/").build();
assertFalse(CorsUtils.isCorsRequest(request)); assertFalse(CorsUtils.isCorsRequest(request));
} }
@Test @Test
public void isPreFlightRequest() { public void isPreFlightRequest() {
MockServerHttpRequest request = options("/") ServerHttpRequest request = options("/")
.header(HttpHeaders.ORIGIN, "http://domain.com") .header(HttpHeaders.ORIGIN, "http://domain.com")
.header(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET") .header(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET")
.build(); .build();
@ -56,7 +60,7 @@ public class CorsUtilsTests {
@Test @Test
public void isNotPreFlightRequest() { public void isNotPreFlightRequest() {
MockServerHttpRequest request = get("/").build(); ServerHttpRequest request = get("/").build();
assertFalse(CorsUtils.isPreFlightRequest(request)); assertFalse(CorsUtils.isPreFlightRequest(request));
request = options("/").header(HttpHeaders.ORIGIN, "http://domain.com").build(); request = options("/").header(HttpHeaders.ORIGIN, "http://domain.com").build();
@ -68,31 +72,35 @@ public class CorsUtilsTests {
@Test // SPR-16262 @Test // SPR-16262
public void isSameOriginWithXForwardedHeaders() { public void isSameOriginWithXForwardedHeaders() {
assertTrue(checkSameOriginWithXForwardedHeaders("mydomain1.com", -1, "https", null, -1, "https://mydomain1.com")); String server = "mydomain1.com";
assertTrue(checkSameOriginWithXForwardedHeaders("mydomain1.com", 123, "https", null, -1, "https://mydomain1.com")); testWithXForwardedHeaders(server, -1, "https", null, -1, "https://mydomain1.com");
assertTrue(checkSameOriginWithXForwardedHeaders("mydomain1.com", -1, "https", "mydomain2.com", -1, "https://mydomain2.com")); testWithXForwardedHeaders(server, 123, "https", null, -1, "https://mydomain1.com");
assertTrue(checkSameOriginWithXForwardedHeaders("mydomain1.com", 123, "https", "mydomain2.com", -1, "https://mydomain2.com")); testWithXForwardedHeaders(server, -1, "https", "mydomain2.com", -1, "https://mydomain2.com");
assertTrue(checkSameOriginWithXForwardedHeaders("mydomain1.com", -1, "https", "mydomain2.com", 456, "https://mydomain2.com:456")); testWithXForwardedHeaders(server, 123, "https", "mydomain2.com", -1, "https://mydomain2.com");
assertTrue(checkSameOriginWithXForwardedHeaders("mydomain1.com", 123, "https", "mydomain2.com", 456, "https://mydomain2.com:456")); testWithXForwardedHeaders(server, -1, "https", "mydomain2.com", 456, "https://mydomain2.com:456");
testWithXForwardedHeaders(server, 123, "https", "mydomain2.com", 456, "https://mydomain2.com:456");
} }
@Test // SPR-16262 @Test // SPR-16262
public void isSameOriginWithForwardedHeader() { public void isSameOriginWithForwardedHeader() {
assertTrue(checkSameOriginWithForwardedHeader("mydomain1.com", -1, "proto=https", "https://mydomain1.com")); String server = "mydomain1.com";
assertTrue(checkSameOriginWithForwardedHeader("mydomain1.com", 123, "proto=https", "https://mydomain1.com")); testWithForwardedHeader(server, -1, "proto=https", "https://mydomain1.com");
assertTrue(checkSameOriginWithForwardedHeader("mydomain1.com", -1, "proto=https; host=mydomain2.com", "https://mydomain2.com")); testWithForwardedHeader(server, 123, "proto=https", "https://mydomain1.com");
assertTrue(checkSameOriginWithForwardedHeader("mydomain1.com", 123, "proto=https; host=mydomain2.com", "https://mydomain2.com")); testWithForwardedHeader(server, -1, "proto=https; host=mydomain2.com", "https://mydomain2.com");
assertTrue(checkSameOriginWithForwardedHeader("mydomain1.com", -1, "proto=https; host=mydomain2.com:456", "https://mydomain2.com:456")); testWithForwardedHeader(server, 123, "proto=https; host=mydomain2.com", "https://mydomain2.com");
assertTrue(checkSameOriginWithForwardedHeader("mydomain1.com", 123, "proto=https; host=mydomain2.com:456", "https://mydomain2.com:456")); testWithForwardedHeader(server, -1, "proto=https; host=mydomain2.com:456", "https://mydomain2.com:456");
testWithForwardedHeader(server, 123, "proto=https; host=mydomain2.com:456", "https://mydomain2.com:456");
} }
private boolean checkSameOriginWithXForwardedHeaders(String serverName, int port, String forwardedProto, String forwardedHost, int forwardedPort, String originHeader) { private void testWithXForwardedHeaders(String serverName, int port,
String forwardedProto, String forwardedHost, int forwardedPort, String originHeader) {
String url = "http://" + serverName; String url = "http://" + serverName;
if (port != -1) { if (port != -1) {
url = url + ":" + port; url = url + ":" + port;
} }
MockServerHttpRequest.BaseBuilder<?> builder = get(url)
.header(HttpHeaders.ORIGIN, originHeader); MockServerHttpRequest.BaseBuilder<?> builder = get(url).header(HttpHeaders.ORIGIN, originHeader);
if (forwardedProto != null) { if (forwardedProto != null) {
builder.header("X-Forwarded-Proto", forwardedProto); builder.header("X-Forwarded-Proto", forwardedProto);
} }
@ -102,18 +110,36 @@ public class CorsUtilsTests {
if (forwardedPort != -1) { if (forwardedPort != -1) {
builder.header("X-Forwarded-Port", String.valueOf(forwardedPort)); builder.header("X-Forwarded-Port", String.valueOf(forwardedPort));
} }
return CorsUtils.isSameOrigin(builder.build());
ServerHttpRequest request = adaptFromForwardedHeaders(builder);
assertTrue(CorsUtils.isSameOrigin(request));
} }
private boolean checkSameOriginWithForwardedHeader(String serverName, int port, String forwardedHeader, String originHeader) { private void testWithForwardedHeader(String serverName, int port,
String forwardedHeader, String originHeader) {
String url = "http://" + serverName; String url = "http://" + serverName;
if (port != -1) { if (port != -1) {
url = url + ":" + port; url = url + ":" + port;
} }
MockServerHttpRequest.BaseBuilder<?> builder = get(url) MockServerHttpRequest.BaseBuilder<?> builder = get(url)
.header("Forwarded", forwardedHeader) .header("Forwarded", forwardedHeader)
.header(HttpHeaders.ORIGIN, originHeader); .header(HttpHeaders.ORIGIN, originHeader);
return CorsUtils.isSameOrigin(builder.build());
ServerHttpRequest request = adaptFromForwardedHeaders(builder);
assertTrue(CorsUtils.isSameOrigin(request));
}
// SPR-16668
private ServerHttpRequest adaptFromForwardedHeaders(MockServerHttpRequest.BaseBuilder<?> builder) {
AtomicReference<ServerHttpRequest> requestRef = new AtomicReference<>();
MockServerWebExchange exchange = MockServerWebExchange.from(builder);
new ForwardedHeaderFilter().filter(exchange, exchange2 -> {
requestRef.set(exchange2.getRequest());
return Mono.empty();
}).block();
return requestRef.get();
} }
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2016 the original author or authors. * Copyright 2002-2018 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -21,19 +21,20 @@ import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import javax.servlet.http.HttpServletRequest;
import org.junit.Test; import org.junit.Test;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.mock.web.test.MockFilterChain;
import org.springframework.mock.web.test.MockHttpServletRequest; import org.springframework.mock.web.test.MockHttpServletRequest;
import org.springframework.mock.web.test.MockHttpServletResponse;
import org.springframework.util.MultiValueMap; import org.springframework.util.MultiValueMap;
import org.springframework.web.filter.ForwardedHeaderFilter;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.*;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
/** /**
* @author Juergen Hoeller * @author Juergen Hoeller
@ -141,23 +142,25 @@ public class WebUtilsTests {
} }
@Test // SPR-16262 @Test // SPR-16262
public void isSameOriginWithXForwardedHeaders() { public void isSameOriginWithXForwardedHeaders() throws Exception {
assertTrue(checkSameOriginWithXForwardedHeaders("mydomain1.com", -1, "https", null, -1, "https://mydomain1.com")); String server = "mydomain1.com";
assertTrue(checkSameOriginWithXForwardedHeaders("mydomain1.com", 123, "https", null, -1, "https://mydomain1.com")); testWithXForwardedHeaders(server, -1, "https", null, -1, "https://mydomain1.com");
assertTrue(checkSameOriginWithXForwardedHeaders("mydomain1.com", -1, "https", "mydomain2.com", -1, "https://mydomain2.com")); testWithXForwardedHeaders(server, 123, "https", null, -1, "https://mydomain1.com");
assertTrue(checkSameOriginWithXForwardedHeaders("mydomain1.com", 123, "https", "mydomain2.com", -1, "https://mydomain2.com")); testWithXForwardedHeaders(server, -1, "https", "mydomain2.com", -1, "https://mydomain2.com");
assertTrue(checkSameOriginWithXForwardedHeaders("mydomain1.com", -1, "https", "mydomain2.com", 456, "https://mydomain2.com:456")); testWithXForwardedHeaders(server, 123, "https", "mydomain2.com", -1, "https://mydomain2.com");
assertTrue(checkSameOriginWithXForwardedHeaders("mydomain1.com", 123, "https", "mydomain2.com", 456, "https://mydomain2.com:456")); testWithXForwardedHeaders(server, -1, "https", "mydomain2.com", 456, "https://mydomain2.com:456");
testWithXForwardedHeaders(server, 123, "https", "mydomain2.com", 456, "https://mydomain2.com:456");
} }
@Test // SPR-16262 @Test // SPR-16262
public void isSameOriginWithForwardedHeader() { public void isSameOriginWithForwardedHeader() throws Exception {
assertTrue(checkSameOriginWithForwardedHeader("mydomain1.com", -1, "proto=https", "https://mydomain1.com")); String server = "mydomain1.com";
assertTrue(checkSameOriginWithForwardedHeader("mydomain1.com", 123, "proto=https", "https://mydomain1.com")); testWithForwardedHeader(server, -1, "proto=https", "https://mydomain1.com");
assertTrue(checkSameOriginWithForwardedHeader("mydomain1.com", -1, "proto=https; host=mydomain2.com", "https://mydomain2.com")); testWithForwardedHeader(server, 123, "proto=https", "https://mydomain1.com");
assertTrue(checkSameOriginWithForwardedHeader("mydomain1.com", 123, "proto=https; host=mydomain2.com", "https://mydomain2.com")); testWithForwardedHeader(server, -1, "proto=https; host=mydomain2.com", "https://mydomain2.com");
assertTrue(checkSameOriginWithForwardedHeader("mydomain1.com", -1, "proto=https; host=mydomain2.com:456", "https://mydomain2.com:456")); testWithForwardedHeader(server, 123, "proto=https; host=mydomain2.com", "https://mydomain2.com");
assertTrue(checkSameOriginWithForwardedHeader("mydomain1.com", 123, "proto=https; host=mydomain2.com:456", "https://mydomain2.com:456")); testWithForwardedHeader(server, -1, "proto=https; host=mydomain2.com:456", "https://mydomain2.com:456");
testWithForwardedHeader(server, 123, "proto=https; host=mydomain2.com:456", "https://mydomain2.com:456");
} }
@ -183,36 +186,53 @@ public class WebUtilsTests {
return WebUtils.isSameOrigin(request); return WebUtils.isSameOrigin(request);
} }
private boolean checkSameOriginWithXForwardedHeaders(String serverName, int port, String forwardedProto, String forwardedHost, int forwardedPort, String originHeader) { private void testWithXForwardedHeaders(String serverName, int port, String forwardedProto,
MockHttpServletRequest servletRequest = new MockHttpServletRequest(); String forwardedHost, int forwardedPort, String originHeader) throws Exception {
ServerHttpRequest request = new ServletServerHttpRequest(servletRequest);
servletRequest.setServerName(serverName); MockHttpServletRequest request = new MockHttpServletRequest();
request.setServerName(serverName);
if (port != -1) { if (port != -1) {
servletRequest.setServerPort(port); request.setServerPort(port);
} }
if (forwardedProto != null) { if (forwardedProto != null) {
servletRequest.addHeader("X-Forwarded-Proto", forwardedProto); request.addHeader("X-Forwarded-Proto", forwardedProto);
} }
if (forwardedHost != null) { if (forwardedHost != null) {
servletRequest.addHeader("X-Forwarded-Host", forwardedHost); request.addHeader("X-Forwarded-Host", forwardedHost);
} }
if (forwardedPort != -1) { if (forwardedPort != -1) {
servletRequest.addHeader("X-Forwarded-Port", String.valueOf(forwardedPort)); request.addHeader("X-Forwarded-Port", String.valueOf(forwardedPort));
} }
servletRequest.addHeader(HttpHeaders.ORIGIN, originHeader); request.addHeader(HttpHeaders.ORIGIN, originHeader);
return WebUtils.isSameOrigin(request);
HttpServletRequest requestToUse = adaptFromForwardedHeaders(request);
ServerHttpRequest httpRequest = new ServletServerHttpRequest(requestToUse);
assertTrue(WebUtils.isSameOrigin(httpRequest));
} }
private boolean checkSameOriginWithForwardedHeader(String serverName, int port, String forwardedHeader, String originHeader) { private void testWithForwardedHeader(String serverName, int port, String forwardedHeader,
MockHttpServletRequest servletRequest = new MockHttpServletRequest(); String originHeader) throws Exception {
ServerHttpRequest request = new ServletServerHttpRequest(servletRequest);
servletRequest.setServerName(serverName); MockHttpServletRequest request = new MockHttpServletRequest();
request.setServerName(serverName);
if (port != -1) { if (port != -1) {
servletRequest.setServerPort(port); request.setServerPort(port);
} }
servletRequest.addHeader("Forwarded", forwardedHeader); request.addHeader("Forwarded", forwardedHeader);
servletRequest.addHeader(HttpHeaders.ORIGIN, originHeader); request.addHeader(HttpHeaders.ORIGIN, originHeader);
return WebUtils.isSameOrigin(request);
HttpServletRequest requestToUse = adaptFromForwardedHeaders(request);
ServerHttpRequest httpRequest = new ServletServerHttpRequest(requestToUse);
assertTrue(WebUtils.isSameOrigin(httpRequest));
}
// SPR-16668
private HttpServletRequest adaptFromForwardedHeaders(HttpServletRequest request) throws Exception {
MockFilterChain chain = new MockFilterChain();
new ForwardedHeaderFilter().doFilter(request, new MockHttpServletResponse(), chain);
return (HttpServletRequest) chain.getRequest();
} }
} }

View File

@ -36,7 +36,6 @@ import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.HttpCookie; import org.springframework.http.HttpCookie;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpRange; import org.springframework.http.HttpRange;
import org.springframework.http.HttpRequest;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
import org.springframework.http.codec.HttpMessageReader; import org.springframework.http.codec.HttpMessageReader;
import org.springframework.http.codec.multipart.Part; import org.springframework.http.codec.multipart.Part;
@ -93,7 +92,7 @@ class DefaultServerRequest implements ServerRequest {
@Override @Override
public UriBuilder uriBuilder() { public UriBuilder uriBuilder() {
return UriComponentsBuilder.fromHttpRequest(new ServerRequestAdapter()); return UriComponentsBuilder.fromUri(uri());
} }
@Override @Override
@ -279,23 +278,4 @@ class DefaultServerRequest implements ServerRequest {
} }
} }
private final class ServerRequestAdapter implements HttpRequest {
@Override
public String getMethodValue() {
return methodName();
}
@Override
public URI getURI() {
return uri();
}
@Override
public HttpHeaders getHeaders() {
return request().getHeaders();
}
}
} }

View File

@ -84,10 +84,13 @@ public interface ServerRequest {
/** /**
* Return a {@code UriBuilderComponents} from the URI associated with this * Return a {@code UriBuilderComponents} from the URI associated with this
* {@code ServerRequest}, while also overlaying with values from the headers * {@code ServerRequest}.
* "Forwarded" (<a href="http://tools.ietf.org/html/rfc7239">RFC 7239</a>), *
* or "X-Forwarded-Host", "X-Forwarded-Port", and "X-Forwarded-Proto" if * <p><strong>Note:</strong> as of 5.1 this method ignores
* "Forwarded" is not found. * {@code "Forwarded"} and {@code "X-Forwarded-*"} headers that specify the
* client-originated address. Consider using the {@code ForwardedHeaderFilter}
* to extract and use, or to discard such headers.
*
* @return a URI builder * @return a URI builder
*/ */
UriBuilder uriBuilder(); UriBuilder uriBuilder();

View File

@ -109,7 +109,7 @@ public class ServerWebExchangeArgumentResolver extends HandlerMethodArgumentReso
return timeZone != null ? timeZone.toZoneId() : ZoneId.systemDefault(); return timeZone != null ? timeZone.toZoneId() : ZoneId.systemDefault();
} }
else if (UriBuilder.class == paramType || UriComponentsBuilder.class == paramType) { else if (UriBuilder.class == paramType || UriComponentsBuilder.class == paramType) {
return UriComponentsBuilder.fromHttpRequest(exchange.getRequest()); return UriComponentsBuilder.fromUri(exchange.getRequest().getURI());
} }
else { else {
// should never happen... // should never happen...

View File

@ -38,7 +38,6 @@ import org.springframework.http.HttpCookie;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod; import org.springframework.http.HttpMethod;
import org.springframework.http.HttpRange; import org.springframework.http.HttpRange;
import org.springframework.http.HttpRequest;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
import org.springframework.http.codec.HttpMessageReader; import org.springframework.http.codec.HttpMessageReader;
import org.springframework.http.codec.multipart.Part; import org.springframework.http.codec.multipart.Part;
@ -139,7 +138,7 @@ public class MockServerRequest implements ServerRequest {
@Override @Override
public UriBuilder uriBuilder() { public UriBuilder uriBuilder() {
return UriComponentsBuilder.fromHttpRequest(new ServerRequestAdapter()); return UriComponentsBuilder.fromUri(this.uri);
} }
@Override @Override
@ -569,23 +568,4 @@ public class MockServerRequest implements ServerRequest {
} }
private final class ServerRequestAdapter implements HttpRequest {
@Override
public String getMethodValue() {
return methodName();
}
@Override
public URI getURI() {
return MockServerRequest.this.uri;
}
@Override
public HttpHeaders getHeaders() {
return MockServerRequest.this.headers.headers;
}
}
} }

View File

@ -16,18 +16,14 @@
package org.springframework.web.servlet.support; package org.springframework.web.servlet.support;
import java.util.Enumeration;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import org.springframework.http.HttpRequest;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import org.springframework.web.context.request.RequestAttributes; import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes; import org.springframework.web.context.request.ServletRequestAttributes;
import org.springframework.web.util.UriComponents;
import org.springframework.web.util.UriComponentsBuilder; import org.springframework.web.util.UriComponentsBuilder;
import org.springframework.web.util.UriUtils; import org.springframework.web.util.UriUtils;
import org.springframework.web.util.UrlPathHelper; import org.springframework.web.util.UrlPathHelper;
@ -81,17 +77,14 @@ public class ServletUriComponentsBuilder extends UriComponentsBuilder {
* Prepare a builder from the host, port, scheme, and context path of the * Prepare a builder from the host, port, scheme, and context path of the
* given HttpServletRequest. * given HttpServletRequest.
* *
* <p><strong>Note:</strong> This method extracts values from "Forwarded" * <p><strong>Note:</strong> as of 5.1 this method ignores
* and "X-Forwarded-*" headers if found. See class-level docs. * {@code "Forwarded"} and {@code "X-Forwarded-*"} headers that specify the
* * client-originated address. Consider using the {@code ForwardedHeaderFilter}
* <p>As of 4.3.15, this method replaces the contextPath with the value * to extract and use, or to discard such headers.
* of "X-Forwarded-Prefix" rather than prepending, thus aligning with
* {@code ForwardedHeaderFiller}.
*/ */
public static ServletUriComponentsBuilder fromContextPath(HttpServletRequest request) { public static ServletUriComponentsBuilder fromContextPath(HttpServletRequest request) {
ServletUriComponentsBuilder builder = initFromRequest(request); ServletUriComponentsBuilder builder = initFromRequest(request);
String forwardedPrefix = getForwardedPrefix(request); builder.replacePath(request.getContextPath());
builder.replacePath(forwardedPrefix != null ? forwardedPrefix : request.getContextPath());
return builder; return builder;
} }
@ -103,12 +96,10 @@ public class ServletUriComponentsBuilder extends UriComponentsBuilder {
* {@code "/"} or {@code "*.do"}, the result will be the same as * {@code "/"} or {@code "*.do"}, the result will be the same as
* if calling {@link #fromContextPath(HttpServletRequest)}. * if calling {@link #fromContextPath(HttpServletRequest)}.
* *
* <p><strong>Note:</strong> This method extracts values from "Forwarded" * <p><strong>Note:</strong> as of 5.1 this method ignores
* and "X-Forwarded-*" headers if found. See class-level docs. * {@code "Forwarded"} and {@code "X-Forwarded-*"} headers that specify the
* * client-originated address. Consider using the {@code ForwardedHeaderFilter}
* <p>As of 4.3.15, this method replaces the contextPath with the value * to extract and use, or to discard such headers.
* of "X-Forwarded-Prefix" rather than prepending, thus aligning with
* {@code ForwardedHeaderFiller}.
*/ */
public static ServletUriComponentsBuilder fromServletMapping(HttpServletRequest request) { public static ServletUriComponentsBuilder fromServletMapping(HttpServletRequest request) {
ServletUriComponentsBuilder builder = fromContextPath(request); ServletUriComponentsBuilder builder = fromContextPath(request);
@ -122,16 +113,14 @@ public class ServletUriComponentsBuilder extends UriComponentsBuilder {
* Prepare a builder from the host, port, scheme, and path (but not the query) * Prepare a builder from the host, port, scheme, and path (but not the query)
* of the HttpServletRequest. * of the HttpServletRequest.
* *
* <p><strong>Note:</strong> This method extracts values from "Forwarded" * <p><strong>Note:</strong> as of 5.1 this method ignores
* and "X-Forwarded-*" headers if found. See class-level docs. * {@code "Forwarded"} and {@code "X-Forwarded-*"} headers that specify the
* * client-originated address. Consider using the {@code ForwardedHeaderFilter}
* <p>As of 4.3.15, this method replaces the contextPath with the value * to extract and use, or to discard such headers.
* of "X-Forwarded-Prefix" rather than prepending, thus aligning with
* {@code ForwardedHeaderFiller}.
*/ */
public static ServletUriComponentsBuilder fromRequestUri(HttpServletRequest request) { public static ServletUriComponentsBuilder fromRequestUri(HttpServletRequest request) {
ServletUriComponentsBuilder builder = initFromRequest(request); ServletUriComponentsBuilder builder = initFromRequest(request);
builder.initPath(getRequestUriWithForwardedPrefix(request)); builder.initPath(request.getRequestURI());
return builder; return builder;
} }
@ -139,16 +128,14 @@ public class ServletUriComponentsBuilder extends UriComponentsBuilder {
* Prepare a builder by copying the scheme, host, port, path, and * Prepare a builder by copying the scheme, host, port, path, and
* query string of an HttpServletRequest. * query string of an HttpServletRequest.
* *
* <p><strong>Note:</strong> This method extracts values from "Forwarded" * <p><strong>Note:</strong> as of 5.1 this method ignores
* and "X-Forwarded-*" headers if found. See class-level docs. * {@code "Forwarded"} and {@code "X-Forwarded-*"} headers that specify the
* * client-originated address. Consider using the {@code ForwardedHeaderFilter}
* <p>As of 4.3.15, this method replaces the contextPath with the value * to extract and use, or to discard such headers.
* of "X-Forwarded-Prefix" rather than prepending, thus aligning with
* {@code ForwardedHeaderFiller}.
*/ */
public static ServletUriComponentsBuilder fromRequest(HttpServletRequest request) { public static ServletUriComponentsBuilder fromRequest(HttpServletRequest request) {
ServletUriComponentsBuilder builder = initFromRequest(request); ServletUriComponentsBuilder builder = initFromRequest(request);
builder.initPath(getRequestUriWithForwardedPrefix(request)); builder.initPath(request.getRequestURI());
builder.query(request.getQueryString()); builder.query(request.getQueryString());
return builder; return builder;
} }
@ -157,11 +144,9 @@ public class ServletUriComponentsBuilder extends UriComponentsBuilder {
* Initialize a builder with a scheme, host,and port (but not path and query). * Initialize a builder with a scheme, host,and port (but not path and query).
*/ */
private static ServletUriComponentsBuilder initFromRequest(HttpServletRequest request) { private static ServletUriComponentsBuilder initFromRequest(HttpServletRequest request) {
HttpRequest httpRequest = new ServletServerHttpRequest(request); String scheme = request.getScheme();
UriComponents uriComponents = UriComponentsBuilder.fromHttpRequest(httpRequest).build(); String host = request.getServerName();
String scheme = uriComponents.getScheme(); int port = request.getServerPort();
String host = uriComponents.getHost();
int port = uriComponents.getPort();
ServletUriComponentsBuilder builder = new ServletUriComponentsBuilder(); ServletUriComponentsBuilder builder = new ServletUriComponentsBuilder();
builder.scheme(scheme); builder.scheme(scheme);
@ -172,37 +157,6 @@ public class ServletUriComponentsBuilder extends UriComponentsBuilder {
return builder; return builder;
} }
@Nullable
private static String getForwardedPrefix(HttpServletRequest request) {
String prefix = null;
Enumeration<String> names = request.getHeaderNames();
while (names.hasMoreElements()) {
String name = names.nextElement();
if ("X-Forwarded-Prefix".equalsIgnoreCase(name)) {
prefix = request.getHeader(name);
}
}
if (prefix != null) {
while (prefix.endsWith("/")) {
prefix = prefix.substring(0, prefix.length() - 1);
}
}
return prefix;
}
private static String getRequestUriWithForwardedPrefix(HttpServletRequest request) {
String path = request.getRequestURI();
String forwardedPrefix = getForwardedPrefix(request);
if (forwardedPrefix != null) {
String contextPath = request.getContextPath();
if (!StringUtils.isEmpty(contextPath) && !contextPath.equals("/") && path.startsWith(contextPath)) {
path = path.substring(contextPath.length());
}
path = forwardedPrefix + path;
}
return path;
}
// Alternative methods relying on RequestContextHolder to find the request // Alternative methods relying on RequestContextHolder to find the request

View File

@ -23,6 +23,7 @@ import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target; import java.lang.annotation.Target;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import javax.servlet.http.HttpServletRequest;
import org.hamcrest.Matchers; import org.hamcrest.Matchers;
import org.joda.time.DateTime; import org.joda.time.DateTime;
@ -36,7 +37,9 @@ import org.springframework.format.annotation.DateTimeFormat;
import org.springframework.format.annotation.DateTimeFormat.ISO; import org.springframework.format.annotation.DateTimeFormat.ISO;
import org.springframework.http.HttpEntity; import org.springframework.http.HttpEntity;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
import org.springframework.mock.web.test.MockFilterChain;
import org.springframework.mock.web.test.MockHttpServletRequest; import org.springframework.mock.web.test.MockHttpServletRequest;
import org.springframework.mock.web.test.MockHttpServletResponse;
import org.springframework.mock.web.test.MockServletContext; import org.springframework.mock.web.test.MockServletContext;
import org.springframework.stereotype.Controller; import org.springframework.stereotype.Controller;
import org.springframework.util.MultiValueMap; import org.springframework.util.MultiValueMap;
@ -49,6 +52,7 @@ import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes; import org.springframework.web.context.request.ServletRequestAttributes;
import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
import org.springframework.web.filter.ForwardedHeaderFilter;
import org.springframework.web.servlet.DispatcherServlet; import org.springframework.web.servlet.DispatcherServlet;
import org.springframework.web.servlet.ModelAndView; import org.springframework.web.servlet.ModelAndView;
import org.springframework.web.servlet.config.annotation.EnableWebMvc; import org.springframework.web.servlet.config.annotation.EnableWebMvc;
@ -136,29 +140,40 @@ public class MvcUriComponentsBuilderTests {
} }
@Test @Test
public void usesForwardedHostAsHostIfHeaderIsSet() { public void usesForwardedHostAsHostIfHeaderIsSet() throws Exception {
this.request.addHeader("X-Forwarded-Host", "somethingDifferent"); this.request.addHeader("X-Forwarded-Host", "somethingDifferent");
adaptRequestFromForwardedHeaders();
UriComponents uriComponents = fromController(PersonControllerImpl.class).build(); UriComponents uriComponents = fromController(PersonControllerImpl.class).build();
assertThat(uriComponents.toUriString(), startsWith("http://somethingDifferent")); assertThat(uriComponents.toUriString(), startsWith("http://somethingDifferent"));
} }
@Test @Test
public void usesForwardedHostAndPortFromHeader() { public void usesForwardedHostAndPortFromHeader() throws Exception {
request.addHeader("X-Forwarded-Host", "foobar:8088"); request.addHeader("X-Forwarded-Host", "foobar:8088");
adaptRequestFromForwardedHeaders();
UriComponents uriComponents = fromController(PersonControllerImpl.class).build(); UriComponents uriComponents = fromController(PersonControllerImpl.class).build();
assertThat(uriComponents.toUriString(), startsWith("http://foobar:8088")); assertThat(uriComponents.toUriString(), startsWith("http://foobar:8088"));
} }
@Test @Test
public void usesFirstHostOfXForwardedHost() { public void usesFirstHostOfXForwardedHost() throws Exception {
request.addHeader("X-Forwarded-Host", "barfoo:8888, localhost:8088"); this.request.addHeader("X-Forwarded-Host", "barfoo:8888, localhost:8088");
adaptRequestFromForwardedHeaders();
UriComponents uriComponents = fromController(PersonControllerImpl.class).build(); UriComponents uriComponents = fromController(PersonControllerImpl.class).build();
assertThat(uriComponents.toUriString(), startsWith("http://barfoo:8888")); assertThat(uriComponents.toUriString(), startsWith("http://barfoo:8888"));
} }
// SPR-16668
private void adaptRequestFromForwardedHeaders() throws Exception {
MockFilterChain chain = new MockFilterChain();
new ForwardedHeaderFilter().doFilter(this.request, new MockHttpServletResponse(), chain);
HttpServletRequest adaptedRequest = (HttpServletRequest) chain.getRequest();
RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(adaptedRequest));
}
@Test @Test
public void fromMethodNamePathVariable() { public void fromMethodNamePathVariable() {
UriComponents uriComponents = fromMethodName(ControllerWithMethods.class, UriComponents uriComponents = fromMethodName(ControllerWithMethods.class,

View File

@ -16,12 +16,17 @@
package org.springframework.web.servlet.support; package org.springframework.web.servlet.support;
import javax.servlet.http.HttpServletRequest;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.springframework.mock.web.test.MockFilterChain;
import org.springframework.mock.web.test.MockHttpServletRequest; import org.springframework.mock.web.test.MockHttpServletRequest;
import org.springframework.mock.web.test.MockHttpServletResponse;
import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes; import org.springframework.web.context.request.ServletRequestAttributes;
import org.springframework.web.filter.ForwardedHeaderFilter;
import org.springframework.web.util.UriComponents; import org.springframework.web.util.UriComponents;
import static org.junit.Assert.*; import static org.junit.Assert.*;
@ -78,10 +83,10 @@ public class ServletUriComponentsBuilderTests {
assertEquals("https://localhost:9043/mvc-showcase", result); assertEquals("https://localhost:9043/mvc-showcase", result);
} }
// Most X-Forwarded-* tests in UriComponentsBuilderTests // Some X-Forwarded-* tests in addition to the ones in UriComponentsBuilderTests
@Test @Test
public void fromRequestWithForwardedHostAndPort() { public void fromRequestWithForwardedHostAndPort() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = new MockHttpServletRequest();
request.setScheme("http"); request.setScheme("http");
request.setServerName("localhost"); request.setServerName("localhost");
@ -90,7 +95,10 @@ public class ServletUriComponentsBuilderTests {
request.addHeader("X-Forwarded-Proto", "https"); request.addHeader("X-Forwarded-Proto", "https");
request.addHeader("X-Forwarded-Host", "84.198.58.199"); request.addHeader("X-Forwarded-Host", "84.198.58.199");
request.addHeader("X-Forwarded-Port", "443"); request.addHeader("X-Forwarded-Port", "443");
UriComponents result = ServletUriComponentsBuilder.fromRequest(request).build();
HttpServletRequest requestToUse = adaptFromForwardedHeaders(request);
UriComponents result = ServletUriComponentsBuilder.fromRequest(requestToUse).build();
assertEquals("https://84.198.58.199/mvc-showcase", result.toString()); assertEquals("https://84.198.58.199/mvc-showcase", result.toString());
} }
@ -103,29 +111,38 @@ public class ServletUriComponentsBuilderTests {
} }
@Test // SPR-16650 @Test // SPR-16650
public void fromRequestWithForwardedPrefix() { public void fromRequestWithForwardedPrefix() throws Exception {
this.request.addHeader("X-Forwarded-Prefix", "/prefix"); this.request.addHeader("X-Forwarded-Prefix", "/prefix");
this.request.setContextPath("/mvc-showcase"); this.request.setContextPath("/mvc-showcase");
this.request.setRequestURI("/mvc-showcase/bar"); this.request.setRequestURI("/mvc-showcase/bar");
UriComponents result = ServletUriComponentsBuilder.fromRequest(this.request).build();
HttpServletRequest requestToUse = adaptFromForwardedHeaders(this.request);
UriComponents result = ServletUriComponentsBuilder.fromRequest(requestToUse).build();
assertEquals("http://localhost/prefix/bar", result.toUriString()); assertEquals("http://localhost/prefix/bar", result.toUriString());
} }
@Test // SPR-16650 @Test // SPR-16650
public void fromRequestWithForwardedPrefixTrailingSlash() { public void fromRequestWithForwardedPrefixTrailingSlash() throws Exception {
this.request.addHeader("X-Forwarded-Prefix", "/foo/"); this.request.addHeader("X-Forwarded-Prefix", "/foo/");
this.request.setContextPath("/spring-mvc-showcase"); this.request.setContextPath("/spring-mvc-showcase");
this.request.setRequestURI("/spring-mvc-showcase/bar"); this.request.setRequestURI("/spring-mvc-showcase/bar");
UriComponents result = ServletUriComponentsBuilder.fromRequest(this.request).build();
HttpServletRequest requestToUse = adaptFromForwardedHeaders(this.request);
UriComponents result = ServletUriComponentsBuilder.fromRequest(requestToUse).build();
assertEquals("http://localhost/foo/bar", result.toUriString()); assertEquals("http://localhost/foo/bar", result.toUriString());
} }
@Test // SPR-16650 @Test // SPR-16650
public void fromRequestWithForwardedPrefixRoot() { public void fromRequestWithForwardedPrefixRoot() throws Exception {
this.request.addHeader("X-Forwarded-Prefix", "/"); this.request.addHeader("X-Forwarded-Prefix", "/");
this.request.setContextPath("/mvc-showcase"); this.request.setContextPath("/mvc-showcase");
this.request.setRequestURI("/mvc-showcase/bar"); this.request.setRequestURI("/mvc-showcase/bar");
UriComponents result = ServletUriComponentsBuilder.fromRequest(this.request).build();
HttpServletRequest requestToUse = adaptFromForwardedHeaders(this.request);
UriComponents result = ServletUriComponentsBuilder.fromRequest(requestToUse).build();
assertEquals("http://localhost/bar", result.toUriString()); assertEquals("http://localhost/bar", result.toUriString());
} }
@ -138,11 +155,14 @@ public class ServletUriComponentsBuilderTests {
} }
@Test // SPR-16650 @Test // SPR-16650
public void fromContextPathWithForwardedPrefix() { public void fromContextPathWithForwardedPrefix() throws Exception {
this.request.addHeader("X-Forwarded-Prefix", "/prefix"); this.request.addHeader("X-Forwarded-Prefix", "/prefix");
this.request.setContextPath("/mvc-showcase"); this.request.setContextPath("/mvc-showcase");
this.request.setRequestURI("/mvc-showcase/simple"); this.request.setRequestURI("/mvc-showcase/simple");
String result = ServletUriComponentsBuilder.fromContextPath(this.request).build().toUriString();
HttpServletRequest requestToUse = adaptFromForwardedHeaders(this.request);
String result = ServletUriComponentsBuilder.fromContextPath(requestToUse).build().toUriString();
assertEquals("http://localhost/prefix", result); assertEquals("http://localhost/prefix", result);
} }
@ -156,12 +176,15 @@ public class ServletUriComponentsBuilderTests {
} }
@Test // SPR-16650 @Test // SPR-16650
public void fromServletMappingWithForwardedPrefix() { public void fromServletMappingWithForwardedPrefix() throws Exception {
this.request.addHeader("X-Forwarded-Prefix", "/prefix"); this.request.addHeader("X-Forwarded-Prefix", "/prefix");
this.request.setContextPath("/mvc-showcase"); this.request.setContextPath("/mvc-showcase");
this.request.setServletPath("/app"); this.request.setServletPath("/app");
this.request.setRequestURI("/mvc-showcase/app/simple"); this.request.setRequestURI("/mvc-showcase/app/simple");
String result = ServletUriComponentsBuilder.fromServletMapping(this.request).build().toUriString();
HttpServletRequest requestToUse = adaptFromForwardedHeaders(this.request);
String result = ServletUriComponentsBuilder.fromServletMapping(requestToUse).build().toUriString();
assertEquals("http://localhost/prefix/app", result); assertEquals("http://localhost/prefix/app", result);
} }
@ -194,4 +217,12 @@ public class ServletUriComponentsBuilderTests {
ServletUriComponentsBuilder builder = ServletUriComponentsBuilder.fromRequestUri(this.request); ServletUriComponentsBuilder builder = ServletUriComponentsBuilder.fromRequestUri(this.request);
assertNull(builder.removePathExtension()); assertNull(builder.removePathExtension());
} }
// SPR-16668
private HttpServletRequest adaptFromForwardedHeaders(HttpServletRequest request) throws Exception {
MockFilterChain chain = new MockFilterChain();
new ForwardedHeaderFilter().doFilter(request, new MockHttpServletResponse(), chain);
return (HttpServletRequest) chain.getRequest();
}
} }

View File

@ -1417,10 +1417,8 @@ etc, and is equivalent to `required=false`.
See <<webflux-ann-sessionattributes>> for more details. See <<webflux-ann-sessionattributes>> for more details.
| `UriComponentsBuilder` | `UriComponentsBuilder`
| For preparing a URL relative to the current request's host, port, scheme, context path, and | For preparing a URL relative to the current request's host, port, scheme, and path.
the literal part of the servlet mapping also taking into account `Forwarded` and See <<webflux-uri-building>>.
`X-Forwarded-*` headers.
// TODO: See <<webflux-uri-building>>.
| `@SessionAttribute` | `@SessionAttribute`
| For access to any session attribute; in contrast to model attributes stored in the session | For access to any session attribute; in contrast to model attributes stored in the session
@ -2499,7 +2497,7 @@ Javadoc for more details.
[[mvc-uri-building]] [[webflux-uri-building]]
== URI Links == URI Links
[.small]#<<web.adoc#mvc-uri-building,Same in Spring MVC>># [.small]#<<web.adoc#mvc-uri-building,Same in Spring MVC>>#

View File

@ -1689,8 +1689,7 @@ etc, and is equivalent to `required=false`.
| `UriComponentsBuilder` | `UriComponentsBuilder`
| For preparing a URL relative to the current request's host, port, scheme, context path, and | For preparing a URL relative to the current request's host, port, scheme, context path, and
the literal part of the servlet mapping also taking into account `Forwarded` and the literal part of the servlet mapping. See <<mvc-uri-building>>.
`X-Forwarded-*` headers. See <<mvc-uri-building>>.
| `@SessionAttribute` | `@SessionAttribute`
| For access to any session attribute; in contrast to model attributes stored in the session | For access to any session attribute; in contrast to model attributes stored in the session
@ -3098,7 +3097,7 @@ Javadoc for more details.
[[mvc-uri-building]] [[mvc-uri-building]]
== URI Links == URI Links
[.small]#<<web-reactive.adoc#mvc-uri-building,Same in Spring WebFlux>># [.small]#<<web-reactive.adoc#webflux-uri-building,Same in Spring WebFlux>>#
This section describes various options available in the Spring Framework to prepare URIs. This section describes various options available in the Spring Framework to prepare URIs.
@ -3148,14 +3147,12 @@ You can create URIs relative to a Servlet (e.g. `/main/{asterisk}`):
.path("/accounts").build() .path("/accounts").build()
---- ----
[CAUTION] [NOTE]
==== ====
`ServletUriComponentsBuilder` detects and uses information from the "Forwarded", As of 5.1 `ServletUriComponentsBuilder` ignores information from the "Forwarded",
"X-Forwarded-Host", "X-Forwarded-Port", and "X-Forwarded-Proto" headers, so the resulting "X-Forwarded-*" headers, that specify the client-originated address. Consider using the
links reflect the original request. You need to ensure that your application is behind <<filters-forwarded-headers,ForwardedHeaderFilter>> to extract and use, or to discard
a trusted proxy which filters out such headers coming from outside. Also consider using such headers.
the <<filters-forwarded-headers,ForwardedHeaderFilter>> which processes such headers once
per request, and also provides an option to remove and ignore such headers.
==== ====
@ -3243,14 +3240,12 @@ with a base URL and then use the instance-based "withXxx" methods. For example:
URI uri = uriComponents.encode().toUri(); URI uri = uriComponents.encode().toUri();
---- ----
[CAUTION] [NOTE]
==== ====
`MvcUriComponentsBuilder` detects and uses information from the "Forwarded", As of 5.1 `MvcUriComponentsBuilder` ignores information from the "Forwarded",
"X-Forwarded-Host", "X-Forwarded-Port", and "X-Forwarded-Proto" headers, so the resulting "X-Forwarded-*" headers, that specify the client-originated address. Consider using the
links reflect the original request. You need to ensure that your application is behind <<filters-forwarded-headers,ForwardedHeaderFilter>> to extract and use, or to discard
a trusted proxy which filters out such headers coming from outside. Also consider using such headers.
the <<filters-forwarded-headers,ForwardedHeaderFilter>> which processes such headers once
per request, and also provides an option to remove and ignore such headers.
==== ====