Polishing

Optimize same origin check when the request is an instance of
ServletServerHttpRequest and when there is no forwarded headers.

This commit also optimizes the getPort methods and ForwardedHeaderFilter
forwarded headers checks.

Issue: SPR-16262
This commit is contained in:
sdeleuze 2018-01-09 12:40:34 +01:00
parent c326e44488
commit 9c7de232b8
5 changed files with 76 additions and 40 deletions

View File

@ -64,20 +64,19 @@ public abstract class CorsUtils {
UriComponentsBuilder urlBuilder = UriComponentsBuilder.fromHttpRequest(request); UriComponentsBuilder urlBuilder = UriComponentsBuilder.fromHttpRequest(request);
UriComponents actualUrl = urlBuilder.build(); UriComponents actualUrl = urlBuilder.build();
String actualHost = actualUrl.getHost(); String actualHost = actualUrl.getHost();
int actualPort = getPort(actualUrl); int actualPort = getPort(actualUrl.getScheme(), actualUrl.getPort());
Assert.notNull(actualHost, "Actual request host must not be null"); Assert.notNull(actualHost, "Actual request host must not be null");
Assert.isTrue(actualPort != -1, "Actual request port must not be undefined"); Assert.isTrue(actualPort != -1, "Actual request port must not be undefined");
UriComponents originUrl = UriComponentsBuilder.fromOriginHeader(origin).build(); UriComponents originUrl = UriComponentsBuilder.fromOriginHeader(origin).build();
return (actualHost.equals(originUrl.getHost()) && actualPort == getPort(originUrl)); return (actualHost.equals(originUrl.getHost()) && actualPort == getPort(originUrl.getScheme(), originUrl.getPort()));
} }
private static int getPort(UriComponents uri) { private static int getPort(String scheme, int port) {
int port = uri.getPort();
if (port == -1) { if (port == -1) {
if ("http".equals(uri.getScheme()) || "ws".equals(uri.getScheme())) { if ("http".equals(scheme) || "ws".equals(scheme)) {
port = 80; port = 80;
} }
else if ("https".equals(uri.getScheme()) || "wss".equals(uri.getScheme())) { else if ("https".equals(scheme) || "wss".equals(scheme)) {
port = 443; port = 443;
} }
} }

View File

@ -118,10 +118,8 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
@Override @Override
protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException { protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException {
Enumeration<String> names = request.getHeaderNames(); for (String headerName : FORWARDED_HEADER_NAMES) {
while (names.hasMoreElements()) { if (request.getHeader(headerName) != null) {
String name = names.nextElement();
if (FORWARDED_HEADER_NAMES.contains(name)) {
return false; return false;
} }
} }

View File

@ -17,8 +17,7 @@
package org.springframework.web.filter.reactive; package org.springframework.web.filter.reactive;
import java.net.URI; import java.net.URI;
import java.util.Collections; import java.util.LinkedHashSet;
import java.util.Locale;
import java.util.Set; import java.util.Set;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
@ -26,7 +25,6 @@ import reactor.core.publisher.Mono;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
import org.springframework.util.LinkedCaseInsensitiveMap;
import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter; import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain; import org.springframework.web.server.WebFilterChain;
@ -47,8 +45,7 @@ import org.springframework.web.util.UriComponentsBuilder;
*/ */
public class ForwardedHeaderFilter implements WebFilter { public class ForwardedHeaderFilter implements WebFilter {
private static final Set<String> FORWARDED_HEADER_NAMES = private static final Set<String> FORWARDED_HEADER_NAMES = new LinkedHashSet<>(5);
Collections.newSetFromMap(new LinkedCaseInsensitiveMap<>(5, Locale.ENGLISH));
static { static {
FORWARDED_HEADER_NAMES.add("Forwarded"); FORWARDED_HEADER_NAMES.add("Forwarded");
@ -104,8 +101,13 @@ public class ForwardedHeaderFilter implements WebFilter {
} }
private boolean shouldNotFilter(ServerHttpRequest request) { private boolean shouldNotFilter(ServerHttpRequest request) {
return request.getHeaders().keySet().stream() HttpHeaders headers = request.getHeaders();
.noneMatch(FORWARDED_HEADER_NAMES::contains); for (String headerName : FORWARDED_HEADER_NAMES) {
if (headers.containsKey(headerName)) {
return false;
}
}
return true;
} }
@Nullable @Nullable

View File

@ -20,7 +20,9 @@ import java.io.File;
import java.io.FileNotFoundException; import java.io.FileNotFoundException;
import java.util.Collection; import java.util.Collection;
import java.util.Enumeration; import java.util.Enumeration;
import java.util.LinkedHashSet;
import java.util.Map; import java.util.Map;
import java.util.Set;
import java.util.StringTokenizer; import java.util.StringTokenizer;
import java.util.TreeMap; import java.util.TreeMap;
import javax.servlet.ServletContext; import javax.servlet.ServletContext;
@ -33,6 +35,7 @@ import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession; import javax.servlet.http.HttpSession;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpRequest; import org.springframework.http.HttpRequest;
import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
@ -135,6 +138,16 @@ public abstract class WebUtils {
/** Key for the mutex session attribute */ /** Key for the mutex session attribute */
public static final String SESSION_MUTEX_ATTRIBUTE = WebUtils.class.getName() + ".MUTEX"; public static final String SESSION_MUTEX_ATTRIBUTE = WebUtils.class.getName() + ".MUTEX";
private static final Set<String> FORWARDED_HEADER_NAMES = new LinkedHashSet<>(5);
static {
FORWARDED_HEADER_NAMES.add("Forwarded");
FORWARDED_HEADER_NAMES.add("X-Forwarded-Host");
FORWARDED_HEADER_NAMES.add("X-Forwarded-Port");
FORWARDED_HEADER_NAMES.add("X-Forwarded-Proto");
FORWARDED_HEADER_NAMES.add("X-Forwarded-Prefix");
}
/** /**
* Set a system property to the web application root directory. * Set a system property to the web application root directory.
@ -693,36 +706,60 @@ public abstract class WebUtils {
* @since 4.2 * @since 4.2
*/ */
public static boolean isSameOrigin(HttpRequest request) { public static boolean isSameOrigin(HttpRequest request) {
String origin = request.getHeaders().getOrigin(); HttpHeaders headers = request.getHeaders();
String origin = headers.getOrigin();
if (origin == null) { if (origin == null) {
return true; return true;
} }
UriComponentsBuilder urlBuilder; String scheme;
String host;
int port;
if (request instanceof ServletServerHttpRequest) { if (request instanceof ServletServerHttpRequest) {
// Build more efficiently if we can: we only need scheme, host, port for origin comparison // Build more efficiently if we can: we only need scheme, host, port for origin comparison
HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest(); HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest();
urlBuilder = new UriComponentsBuilder(). scheme = servletRequest.getScheme();
scheme(servletRequest.getScheme()). host = servletRequest.getServerName();
host(servletRequest.getServerName()). port = servletRequest.getServerPort();
port(servletRequest.getServerPort()).
adaptFromForwardedHeaders(request.getHeaders()); if(containsForwardedHeaders(servletRequest)) {
UriComponents actualUrl = new UriComponentsBuilder()
.scheme(scheme)
.host(host)
.port(port)
.adaptFromForwardedHeaders(headers)
.build();
scheme = actualUrl.getScheme();
host = actualUrl.getHost();
port = actualUrl.getPort();
}
} }
else { else {
urlBuilder = UriComponentsBuilder.fromHttpRequest(request); UriComponents actualUrl = UriComponentsBuilder.fromHttpRequest(request).build();
scheme = actualUrl.getScheme();
host = actualUrl.getHost();
port = actualUrl.getPort();
} }
UriComponents actualUrl = urlBuilder.build();
UriComponents originUrl = UriComponentsBuilder.fromOriginHeader(origin).build(); UriComponents originUrl = UriComponentsBuilder.fromOriginHeader(origin).build();
return (ObjectUtils.nullSafeEquals(actualUrl.getHost(), originUrl.getHost()) && return (ObjectUtils.nullSafeEquals(host, originUrl.getHost()) &&
getPort(actualUrl) == getPort(originUrl)); getPort(scheme, port) == getPort(originUrl.getScheme(), originUrl.getPort()));
} }
private static int getPort(UriComponents uri) { private static boolean containsForwardedHeaders(HttpServletRequest request) {
int port = uri.getPort(); for (String headerName : FORWARDED_HEADER_NAMES) {
if (request.getHeader(headerName) != null) {
return true;
}
}
return false;
}
private static int getPort(String scheme, int port) {
if (port == -1) { if (port == -1) {
if ("http".equals(uri.getScheme()) || "ws".equals(uri.getScheme())) { if ("http".equals(scheme) || "ws".equals(scheme)) {
port = 80; port = 80;
} }
else if ("https".equals(uri.getScheme()) || "wss".equals(uri.getScheme())) { else if ("https".equals(scheme) || "wss".equals(scheme)) {
port = 443; port = 443;
} }
} }

View File

@ -168,7 +168,7 @@ public class WebUtilsTests {
if (port != -1) { if (port != -1) {
servletRequest.setServerPort(port); servletRequest.setServerPort(port);
} }
request.getHeaders().set(HttpHeaders.ORIGIN, originHeader); servletRequest.addHeader(HttpHeaders.ORIGIN, originHeader);
return WebUtils.isValidOrigin(request, allowed); return WebUtils.isValidOrigin(request, allowed);
} }
@ -179,7 +179,7 @@ public class WebUtilsTests {
if (port != -1) { if (port != -1) {
servletRequest.setServerPort(port); servletRequest.setServerPort(port);
} }
request.getHeaders().set(HttpHeaders.ORIGIN, originHeader); servletRequest.addHeader(HttpHeaders.ORIGIN, originHeader);
return WebUtils.isSameOrigin(request); return WebUtils.isSameOrigin(request);
} }
@ -191,15 +191,15 @@ public class WebUtilsTests {
servletRequest.setServerPort(port); servletRequest.setServerPort(port);
} }
if (forwardedProto != null) { if (forwardedProto != null) {
request.getHeaders().set("X-Forwarded-Proto", forwardedProto); servletRequest.addHeader("X-Forwarded-Proto", forwardedProto);
} }
if (forwardedHost != null) { if (forwardedHost != null) {
request.getHeaders().set("X-Forwarded-Host", forwardedHost); servletRequest.addHeader("X-Forwarded-Host", forwardedHost);
} }
if (forwardedPort != -1) { if (forwardedPort != -1) {
request.getHeaders().set("X-Forwarded-Port", String.valueOf(forwardedPort)); servletRequest.addHeader("X-Forwarded-Port", String.valueOf(forwardedPort));
} }
request.getHeaders().set(HttpHeaders.ORIGIN, originHeader); servletRequest.addHeader(HttpHeaders.ORIGIN, originHeader);
return WebUtils.isSameOrigin(request); return WebUtils.isSameOrigin(request);
} }
@ -210,8 +210,8 @@ public class WebUtilsTests {
if (port != -1) { if (port != -1) {
servletRequest.setServerPort(port); servletRequest.setServerPort(port);
} }
request.getHeaders().set("Forwarded", forwardedHeader); servletRequest.addHeader("Forwarded", forwardedHeader);
request.getHeaders().set(HttpHeaders.ORIGIN, originHeader); servletRequest.addHeader(HttpHeaders.ORIGIN, originHeader);
return WebUtils.isSameOrigin(request); return WebUtils.isSameOrigin(request);
} }