Polishing
Optimize same origin check when the request is an instance of ServletServerHttpRequest and when there is no forwarded headers. This commit also optimizes the getPort methods and ForwardedHeaderFilter forwarded headers checks. Issue: SPR-16262
This commit is contained in:
parent
c326e44488
commit
9c7de232b8
|
|
@ -64,20 +64,19 @@ public abstract class CorsUtils {
|
||||||
UriComponentsBuilder urlBuilder = UriComponentsBuilder.fromHttpRequest(request);
|
UriComponentsBuilder urlBuilder = UriComponentsBuilder.fromHttpRequest(request);
|
||||||
UriComponents actualUrl = urlBuilder.build();
|
UriComponents actualUrl = urlBuilder.build();
|
||||||
String actualHost = actualUrl.getHost();
|
String actualHost = actualUrl.getHost();
|
||||||
int actualPort = getPort(actualUrl);
|
int actualPort = getPort(actualUrl.getScheme(), actualUrl.getPort());
|
||||||
Assert.notNull(actualHost, "Actual request host must not be null");
|
Assert.notNull(actualHost, "Actual request host must not be null");
|
||||||
Assert.isTrue(actualPort != -1, "Actual request port must not be undefined");
|
Assert.isTrue(actualPort != -1, "Actual request port must not be undefined");
|
||||||
UriComponents originUrl = UriComponentsBuilder.fromOriginHeader(origin).build();
|
UriComponents originUrl = UriComponentsBuilder.fromOriginHeader(origin).build();
|
||||||
return (actualHost.equals(originUrl.getHost()) && actualPort == getPort(originUrl));
|
return (actualHost.equals(originUrl.getHost()) && actualPort == getPort(originUrl.getScheme(), originUrl.getPort()));
|
||||||
}
|
}
|
||||||
|
|
||||||
private static int getPort(UriComponents uri) {
|
private static int getPort(String scheme, int port) {
|
||||||
int port = uri.getPort();
|
|
||||||
if (port == -1) {
|
if (port == -1) {
|
||||||
if ("http".equals(uri.getScheme()) || "ws".equals(uri.getScheme())) {
|
if ("http".equals(scheme) || "ws".equals(scheme)) {
|
||||||
port = 80;
|
port = 80;
|
||||||
}
|
}
|
||||||
else if ("https".equals(uri.getScheme()) || "wss".equals(uri.getScheme())) {
|
else if ("https".equals(scheme) || "wss".equals(scheme)) {
|
||||||
port = 443;
|
port = 443;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -118,10 +118,8 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException {
|
protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException {
|
||||||
Enumeration<String> names = request.getHeaderNames();
|
for (String headerName : FORWARDED_HEADER_NAMES) {
|
||||||
while (names.hasMoreElements()) {
|
if (request.getHeader(headerName) != null) {
|
||||||
String name = names.nextElement();
|
|
||||||
if (FORWARDED_HEADER_NAMES.contains(name)) {
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -17,8 +17,7 @@
|
||||||
package org.springframework.web.filter.reactive;
|
package org.springframework.web.filter.reactive;
|
||||||
|
|
||||||
import java.net.URI;
|
import java.net.URI;
|
||||||
import java.util.Collections;
|
import java.util.LinkedHashSet;
|
||||||
import java.util.Locale;
|
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|
||||||
import reactor.core.publisher.Mono;
|
import reactor.core.publisher.Mono;
|
||||||
|
|
@ -26,7 +25,6 @@ import reactor.core.publisher.Mono;
|
||||||
import org.springframework.http.HttpHeaders;
|
import org.springframework.http.HttpHeaders;
|
||||||
import org.springframework.http.server.reactive.ServerHttpRequest;
|
import org.springframework.http.server.reactive.ServerHttpRequest;
|
||||||
import org.springframework.lang.Nullable;
|
import org.springframework.lang.Nullable;
|
||||||
import org.springframework.util.LinkedCaseInsensitiveMap;
|
|
||||||
import org.springframework.web.server.ServerWebExchange;
|
import org.springframework.web.server.ServerWebExchange;
|
||||||
import org.springframework.web.server.WebFilter;
|
import org.springframework.web.server.WebFilter;
|
||||||
import org.springframework.web.server.WebFilterChain;
|
import org.springframework.web.server.WebFilterChain;
|
||||||
|
|
@ -47,8 +45,7 @@ import org.springframework.web.util.UriComponentsBuilder;
|
||||||
*/
|
*/
|
||||||
public class ForwardedHeaderFilter implements WebFilter {
|
public class ForwardedHeaderFilter implements WebFilter {
|
||||||
|
|
||||||
private static final Set<String> FORWARDED_HEADER_NAMES =
|
private static final Set<String> FORWARDED_HEADER_NAMES = new LinkedHashSet<>(5);
|
||||||
Collections.newSetFromMap(new LinkedCaseInsensitiveMap<>(5, Locale.ENGLISH));
|
|
||||||
|
|
||||||
static {
|
static {
|
||||||
FORWARDED_HEADER_NAMES.add("Forwarded");
|
FORWARDED_HEADER_NAMES.add("Forwarded");
|
||||||
|
|
@ -104,8 +101,13 @@ public class ForwardedHeaderFilter implements WebFilter {
|
||||||
}
|
}
|
||||||
|
|
||||||
private boolean shouldNotFilter(ServerHttpRequest request) {
|
private boolean shouldNotFilter(ServerHttpRequest request) {
|
||||||
return request.getHeaders().keySet().stream()
|
HttpHeaders headers = request.getHeaders();
|
||||||
.noneMatch(FORWARDED_HEADER_NAMES::contains);
|
for (String headerName : FORWARDED_HEADER_NAMES) {
|
||||||
|
if (headers.containsKey(headerName)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Nullable
|
@Nullable
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,9 @@ import java.io.File;
|
||||||
import java.io.FileNotFoundException;
|
import java.io.FileNotFoundException;
|
||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
import java.util.Enumeration;
|
import java.util.Enumeration;
|
||||||
|
import java.util.LinkedHashSet;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.Set;
|
||||||
import java.util.StringTokenizer;
|
import java.util.StringTokenizer;
|
||||||
import java.util.TreeMap;
|
import java.util.TreeMap;
|
||||||
import javax.servlet.ServletContext;
|
import javax.servlet.ServletContext;
|
||||||
|
|
@ -33,6 +35,7 @@ import javax.servlet.http.HttpServletRequest;
|
||||||
import javax.servlet.http.HttpServletResponse;
|
import javax.servlet.http.HttpServletResponse;
|
||||||
import javax.servlet.http.HttpSession;
|
import javax.servlet.http.HttpSession;
|
||||||
|
|
||||||
|
import org.springframework.http.HttpHeaders;
|
||||||
import org.springframework.http.HttpRequest;
|
import org.springframework.http.HttpRequest;
|
||||||
import org.springframework.http.server.ServletServerHttpRequest;
|
import org.springframework.http.server.ServletServerHttpRequest;
|
||||||
import org.springframework.lang.Nullable;
|
import org.springframework.lang.Nullable;
|
||||||
|
|
@ -135,6 +138,16 @@ public abstract class WebUtils {
|
||||||
/** Key for the mutex session attribute */
|
/** Key for the mutex session attribute */
|
||||||
public static final String SESSION_MUTEX_ATTRIBUTE = WebUtils.class.getName() + ".MUTEX";
|
public static final String SESSION_MUTEX_ATTRIBUTE = WebUtils.class.getName() + ".MUTEX";
|
||||||
|
|
||||||
|
private static final Set<String> FORWARDED_HEADER_NAMES = new LinkedHashSet<>(5);
|
||||||
|
|
||||||
|
static {
|
||||||
|
FORWARDED_HEADER_NAMES.add("Forwarded");
|
||||||
|
FORWARDED_HEADER_NAMES.add("X-Forwarded-Host");
|
||||||
|
FORWARDED_HEADER_NAMES.add("X-Forwarded-Port");
|
||||||
|
FORWARDED_HEADER_NAMES.add("X-Forwarded-Proto");
|
||||||
|
FORWARDED_HEADER_NAMES.add("X-Forwarded-Prefix");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set a system property to the web application root directory.
|
* Set a system property to the web application root directory.
|
||||||
|
|
@ -693,36 +706,60 @@ public abstract class WebUtils {
|
||||||
* @since 4.2
|
* @since 4.2
|
||||||
*/
|
*/
|
||||||
public static boolean isSameOrigin(HttpRequest request) {
|
public static boolean isSameOrigin(HttpRequest request) {
|
||||||
String origin = request.getHeaders().getOrigin();
|
HttpHeaders headers = request.getHeaders();
|
||||||
|
String origin = headers.getOrigin();
|
||||||
if (origin == null) {
|
if (origin == null) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
UriComponentsBuilder urlBuilder;
|
String scheme;
|
||||||
|
String host;
|
||||||
|
int port;
|
||||||
if (request instanceof ServletServerHttpRequest) {
|
if (request instanceof ServletServerHttpRequest) {
|
||||||
// Build more efficiently if we can: we only need scheme, host, port for origin comparison
|
// Build more efficiently if we can: we only need scheme, host, port for origin comparison
|
||||||
HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest();
|
HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest();
|
||||||
urlBuilder = new UriComponentsBuilder().
|
scheme = servletRequest.getScheme();
|
||||||
scheme(servletRequest.getScheme()).
|
host = servletRequest.getServerName();
|
||||||
host(servletRequest.getServerName()).
|
port = servletRequest.getServerPort();
|
||||||
port(servletRequest.getServerPort()).
|
|
||||||
adaptFromForwardedHeaders(request.getHeaders());
|
if(containsForwardedHeaders(servletRequest)) {
|
||||||
|
UriComponents actualUrl = new UriComponentsBuilder()
|
||||||
|
.scheme(scheme)
|
||||||
|
.host(host)
|
||||||
|
.port(port)
|
||||||
|
.adaptFromForwardedHeaders(headers)
|
||||||
|
.build();
|
||||||
|
scheme = actualUrl.getScheme();
|
||||||
|
host = actualUrl.getHost();
|
||||||
|
port = actualUrl.getPort();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
urlBuilder = UriComponentsBuilder.fromHttpRequest(request);
|
UriComponents actualUrl = UriComponentsBuilder.fromHttpRequest(request).build();
|
||||||
|
scheme = actualUrl.getScheme();
|
||||||
|
host = actualUrl.getHost();
|
||||||
|
port = actualUrl.getPort();
|
||||||
}
|
}
|
||||||
UriComponents actualUrl = urlBuilder.build();
|
|
||||||
UriComponents originUrl = UriComponentsBuilder.fromOriginHeader(origin).build();
|
UriComponents originUrl = UriComponentsBuilder.fromOriginHeader(origin).build();
|
||||||
return (ObjectUtils.nullSafeEquals(actualUrl.getHost(), originUrl.getHost()) &&
|
return (ObjectUtils.nullSafeEquals(host, originUrl.getHost()) &&
|
||||||
getPort(actualUrl) == getPort(originUrl));
|
getPort(scheme, port) == getPort(originUrl.getScheme(), originUrl.getPort()));
|
||||||
}
|
}
|
||||||
|
|
||||||
private static int getPort(UriComponents uri) {
|
private static boolean containsForwardedHeaders(HttpServletRequest request) {
|
||||||
int port = uri.getPort();
|
for (String headerName : FORWARDED_HEADER_NAMES) {
|
||||||
|
if (request.getHeader(headerName) != null) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static int getPort(String scheme, int port) {
|
||||||
if (port == -1) {
|
if (port == -1) {
|
||||||
if ("http".equals(uri.getScheme()) || "ws".equals(uri.getScheme())) {
|
if ("http".equals(scheme) || "ws".equals(scheme)) {
|
||||||
port = 80;
|
port = 80;
|
||||||
}
|
}
|
||||||
else if ("https".equals(uri.getScheme()) || "wss".equals(uri.getScheme())) {
|
else if ("https".equals(scheme) || "wss".equals(scheme)) {
|
||||||
port = 443;
|
port = 443;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -168,7 +168,7 @@ public class WebUtilsTests {
|
||||||
if (port != -1) {
|
if (port != -1) {
|
||||||
servletRequest.setServerPort(port);
|
servletRequest.setServerPort(port);
|
||||||
}
|
}
|
||||||
request.getHeaders().set(HttpHeaders.ORIGIN, originHeader);
|
servletRequest.addHeader(HttpHeaders.ORIGIN, originHeader);
|
||||||
return WebUtils.isValidOrigin(request, allowed);
|
return WebUtils.isValidOrigin(request, allowed);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -179,7 +179,7 @@ public class WebUtilsTests {
|
||||||
if (port != -1) {
|
if (port != -1) {
|
||||||
servletRequest.setServerPort(port);
|
servletRequest.setServerPort(port);
|
||||||
}
|
}
|
||||||
request.getHeaders().set(HttpHeaders.ORIGIN, originHeader);
|
servletRequest.addHeader(HttpHeaders.ORIGIN, originHeader);
|
||||||
return WebUtils.isSameOrigin(request);
|
return WebUtils.isSameOrigin(request);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -191,15 +191,15 @@ public class WebUtilsTests {
|
||||||
servletRequest.setServerPort(port);
|
servletRequest.setServerPort(port);
|
||||||
}
|
}
|
||||||
if (forwardedProto != null) {
|
if (forwardedProto != null) {
|
||||||
request.getHeaders().set("X-Forwarded-Proto", forwardedProto);
|
servletRequest.addHeader("X-Forwarded-Proto", forwardedProto);
|
||||||
}
|
}
|
||||||
if (forwardedHost != null) {
|
if (forwardedHost != null) {
|
||||||
request.getHeaders().set("X-Forwarded-Host", forwardedHost);
|
servletRequest.addHeader("X-Forwarded-Host", forwardedHost);
|
||||||
}
|
}
|
||||||
if (forwardedPort != -1) {
|
if (forwardedPort != -1) {
|
||||||
request.getHeaders().set("X-Forwarded-Port", String.valueOf(forwardedPort));
|
servletRequest.addHeader("X-Forwarded-Port", String.valueOf(forwardedPort));
|
||||||
}
|
}
|
||||||
request.getHeaders().set(HttpHeaders.ORIGIN, originHeader);
|
servletRequest.addHeader(HttpHeaders.ORIGIN, originHeader);
|
||||||
return WebUtils.isSameOrigin(request);
|
return WebUtils.isSameOrigin(request);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -210,8 +210,8 @@ public class WebUtilsTests {
|
||||||
if (port != -1) {
|
if (port != -1) {
|
||||||
servletRequest.setServerPort(port);
|
servletRequest.setServerPort(port);
|
||||||
}
|
}
|
||||||
request.getHeaders().set("Forwarded", forwardedHeader);
|
servletRequest.addHeader("Forwarded", forwardedHeader);
|
||||||
request.getHeaders().set(HttpHeaders.ORIGIN, originHeader);
|
servletRequest.addHeader(HttpHeaders.ORIGIN, originHeader);
|
||||||
return WebUtils.isSameOrigin(request);
|
return WebUtils.isSameOrigin(request);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue