Polishing
Optimize same origin check when the request is an instance of ServletServerHttpRequest and when there is no forwarded headers. This commit also optimizes the getPort methods and ForwardedHeaderFilter forwarded headers checks. Issue: SPR-16262
This commit is contained in:
parent
c326e44488
commit
9c7de232b8
|
|
@ -64,20 +64,19 @@ public abstract class CorsUtils {
|
|||
UriComponentsBuilder urlBuilder = UriComponentsBuilder.fromHttpRequest(request);
|
||||
UriComponents actualUrl = urlBuilder.build();
|
||||
String actualHost = actualUrl.getHost();
|
||||
int actualPort = getPort(actualUrl);
|
||||
int actualPort = getPort(actualUrl.getScheme(), actualUrl.getPort());
|
||||
Assert.notNull(actualHost, "Actual request host must not be null");
|
||||
Assert.isTrue(actualPort != -1, "Actual request port must not be undefined");
|
||||
UriComponents originUrl = UriComponentsBuilder.fromOriginHeader(origin).build();
|
||||
return (actualHost.equals(originUrl.getHost()) && actualPort == getPort(originUrl));
|
||||
return (actualHost.equals(originUrl.getHost()) && actualPort == getPort(originUrl.getScheme(), originUrl.getPort()));
|
||||
}
|
||||
|
||||
private static int getPort(UriComponents uri) {
|
||||
int port = uri.getPort();
|
||||
private static int getPort(String scheme, int port) {
|
||||
if (port == -1) {
|
||||
if ("http".equals(uri.getScheme()) || "ws".equals(uri.getScheme())) {
|
||||
if ("http".equals(scheme) || "ws".equals(scheme)) {
|
||||
port = 80;
|
||||
}
|
||||
else if ("https".equals(uri.getScheme()) || "wss".equals(uri.getScheme())) {
|
||||
else if ("https".equals(scheme) || "wss".equals(scheme)) {
|
||||
port = 443;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -118,10 +118,8 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
|
|||
|
||||
@Override
|
||||
protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException {
|
||||
Enumeration<String> names = request.getHeaderNames();
|
||||
while (names.hasMoreElements()) {
|
||||
String name = names.nextElement();
|
||||
if (FORWARDED_HEADER_NAMES.contains(name)) {
|
||||
for (String headerName : FORWARDED_HEADER_NAMES) {
|
||||
if (request.getHeader(headerName) != null) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,8 +17,7 @@
|
|||
package org.springframework.web.filter.reactive;
|
||||
|
||||
import java.net.URI;
|
||||
import java.util.Collections;
|
||||
import java.util.Locale;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.Set;
|
||||
|
||||
import reactor.core.publisher.Mono;
|
||||
|
|
@ -26,7 +25,6 @@ import reactor.core.publisher.Mono;
|
|||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.server.reactive.ServerHttpRequest;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.util.LinkedCaseInsensitiveMap;
|
||||
import org.springframework.web.server.ServerWebExchange;
|
||||
import org.springframework.web.server.WebFilter;
|
||||
import org.springframework.web.server.WebFilterChain;
|
||||
|
|
@ -47,8 +45,7 @@ import org.springframework.web.util.UriComponentsBuilder;
|
|||
*/
|
||||
public class ForwardedHeaderFilter implements WebFilter {
|
||||
|
||||
private static final Set<String> FORWARDED_HEADER_NAMES =
|
||||
Collections.newSetFromMap(new LinkedCaseInsensitiveMap<>(5, Locale.ENGLISH));
|
||||
private static final Set<String> FORWARDED_HEADER_NAMES = new LinkedHashSet<>(5);
|
||||
|
||||
static {
|
||||
FORWARDED_HEADER_NAMES.add("Forwarded");
|
||||
|
|
@ -104,8 +101,13 @@ public class ForwardedHeaderFilter implements WebFilter {
|
|||
}
|
||||
|
||||
private boolean shouldNotFilter(ServerHttpRequest request) {
|
||||
return request.getHeaders().keySet().stream()
|
||||
.noneMatch(FORWARDED_HEADER_NAMES::contains);
|
||||
HttpHeaders headers = request.getHeaders();
|
||||
for (String headerName : FORWARDED_HEADER_NAMES) {
|
||||
if (headers.containsKey(headerName)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
|
|
|
|||
|
|
@ -20,7 +20,9 @@ import java.io.File;
|
|||
import java.io.FileNotFoundException;
|
||||
import java.util.Collection;
|
||||
import java.util.Enumeration;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.StringTokenizer;
|
||||
import java.util.TreeMap;
|
||||
import javax.servlet.ServletContext;
|
||||
|
|
@ -33,6 +35,7 @@ import javax.servlet.http.HttpServletRequest;
|
|||
import javax.servlet.http.HttpServletResponse;
|
||||
import javax.servlet.http.HttpSession;
|
||||
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpRequest;
|
||||
import org.springframework.http.server.ServletServerHttpRequest;
|
||||
import org.springframework.lang.Nullable;
|
||||
|
|
@ -135,6 +138,16 @@ public abstract class WebUtils {
|
|||
/** Key for the mutex session attribute */
|
||||
public static final String SESSION_MUTEX_ATTRIBUTE = WebUtils.class.getName() + ".MUTEX";
|
||||
|
||||
private static final Set<String> FORWARDED_HEADER_NAMES = new LinkedHashSet<>(5);
|
||||
|
||||
static {
|
||||
FORWARDED_HEADER_NAMES.add("Forwarded");
|
||||
FORWARDED_HEADER_NAMES.add("X-Forwarded-Host");
|
||||
FORWARDED_HEADER_NAMES.add("X-Forwarded-Port");
|
||||
FORWARDED_HEADER_NAMES.add("X-Forwarded-Proto");
|
||||
FORWARDED_HEADER_NAMES.add("X-Forwarded-Prefix");
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Set a system property to the web application root directory.
|
||||
|
|
@ -693,36 +706,60 @@ public abstract class WebUtils {
|
|||
* @since 4.2
|
||||
*/
|
||||
public static boolean isSameOrigin(HttpRequest request) {
|
||||
String origin = request.getHeaders().getOrigin();
|
||||
HttpHeaders headers = request.getHeaders();
|
||||
String origin = headers.getOrigin();
|
||||
if (origin == null) {
|
||||
return true;
|
||||
}
|
||||
UriComponentsBuilder urlBuilder;
|
||||
String scheme;
|
||||
String host;
|
||||
int port;
|
||||
if (request instanceof ServletServerHttpRequest) {
|
||||
// Build more efficiently if we can: we only need scheme, host, port for origin comparison
|
||||
HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest();
|
||||
urlBuilder = new UriComponentsBuilder().
|
||||
scheme(servletRequest.getScheme()).
|
||||
host(servletRequest.getServerName()).
|
||||
port(servletRequest.getServerPort()).
|
||||
adaptFromForwardedHeaders(request.getHeaders());
|
||||
scheme = servletRequest.getScheme();
|
||||
host = servletRequest.getServerName();
|
||||
port = servletRequest.getServerPort();
|
||||
|
||||
if(containsForwardedHeaders(servletRequest)) {
|
||||
UriComponents actualUrl = new UriComponentsBuilder()
|
||||
.scheme(scheme)
|
||||
.host(host)
|
||||
.port(port)
|
||||
.adaptFromForwardedHeaders(headers)
|
||||
.build();
|
||||
scheme = actualUrl.getScheme();
|
||||
host = actualUrl.getHost();
|
||||
port = actualUrl.getPort();
|
||||
}
|
||||
}
|
||||
else {
|
||||
urlBuilder = UriComponentsBuilder.fromHttpRequest(request);
|
||||
}
|
||||
UriComponents actualUrl = urlBuilder.build();
|
||||
UriComponents originUrl = UriComponentsBuilder.fromOriginHeader(origin).build();
|
||||
return (ObjectUtils.nullSafeEquals(actualUrl.getHost(), originUrl.getHost()) &&
|
||||
getPort(actualUrl) == getPort(originUrl));
|
||||
UriComponents actualUrl = UriComponentsBuilder.fromHttpRequest(request).build();
|
||||
scheme = actualUrl.getScheme();
|
||||
host = actualUrl.getHost();
|
||||
port = actualUrl.getPort();
|
||||
}
|
||||
|
||||
private static int getPort(UriComponents uri) {
|
||||
int port = uri.getPort();
|
||||
UriComponents originUrl = UriComponentsBuilder.fromOriginHeader(origin).build();
|
||||
return (ObjectUtils.nullSafeEquals(host, originUrl.getHost()) &&
|
||||
getPort(scheme, port) == getPort(originUrl.getScheme(), originUrl.getPort()));
|
||||
}
|
||||
|
||||
private static boolean containsForwardedHeaders(HttpServletRequest request) {
|
||||
for (String headerName : FORWARDED_HEADER_NAMES) {
|
||||
if (request.getHeader(headerName) != null) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private static int getPort(String scheme, int port) {
|
||||
if (port == -1) {
|
||||
if ("http".equals(uri.getScheme()) || "ws".equals(uri.getScheme())) {
|
||||
if ("http".equals(scheme) || "ws".equals(scheme)) {
|
||||
port = 80;
|
||||
}
|
||||
else if ("https".equals(uri.getScheme()) || "wss".equals(uri.getScheme())) {
|
||||
else if ("https".equals(scheme) || "wss".equals(scheme)) {
|
||||
port = 443;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -168,7 +168,7 @@ public class WebUtilsTests {
|
|||
if (port != -1) {
|
||||
servletRequest.setServerPort(port);
|
||||
}
|
||||
request.getHeaders().set(HttpHeaders.ORIGIN, originHeader);
|
||||
servletRequest.addHeader(HttpHeaders.ORIGIN, originHeader);
|
||||
return WebUtils.isValidOrigin(request, allowed);
|
||||
}
|
||||
|
||||
|
|
@ -179,7 +179,7 @@ public class WebUtilsTests {
|
|||
if (port != -1) {
|
||||
servletRequest.setServerPort(port);
|
||||
}
|
||||
request.getHeaders().set(HttpHeaders.ORIGIN, originHeader);
|
||||
servletRequest.addHeader(HttpHeaders.ORIGIN, originHeader);
|
||||
return WebUtils.isSameOrigin(request);
|
||||
}
|
||||
|
||||
|
|
@ -191,15 +191,15 @@ public class WebUtilsTests {
|
|||
servletRequest.setServerPort(port);
|
||||
}
|
||||
if (forwardedProto != null) {
|
||||
request.getHeaders().set("X-Forwarded-Proto", forwardedProto);
|
||||
servletRequest.addHeader("X-Forwarded-Proto", forwardedProto);
|
||||
}
|
||||
if (forwardedHost != null) {
|
||||
request.getHeaders().set("X-Forwarded-Host", forwardedHost);
|
||||
servletRequest.addHeader("X-Forwarded-Host", forwardedHost);
|
||||
}
|
||||
if (forwardedPort != -1) {
|
||||
request.getHeaders().set("X-Forwarded-Port", String.valueOf(forwardedPort));
|
||||
servletRequest.addHeader("X-Forwarded-Port", String.valueOf(forwardedPort));
|
||||
}
|
||||
request.getHeaders().set(HttpHeaders.ORIGIN, originHeader);
|
||||
servletRequest.addHeader(HttpHeaders.ORIGIN, originHeader);
|
||||
return WebUtils.isSameOrigin(request);
|
||||
}
|
||||
|
||||
|
|
@ -210,8 +210,8 @@ public class WebUtilsTests {
|
|||
if (port != -1) {
|
||||
servletRequest.setServerPort(port);
|
||||
}
|
||||
request.getHeaders().set("Forwarded", forwardedHeader);
|
||||
request.getHeaders().set(HttpHeaders.ORIGIN, originHeader);
|
||||
servletRequest.addHeader("Forwarded", forwardedHeader);
|
||||
servletRequest.addHeader(HttpHeaders.ORIGIN, originHeader);
|
||||
return WebUtils.isSameOrigin(request);
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue