From 6605953eb5919775b16f666bf1865a0cb43caaea Mon Sep 17 00:00:00 2001 From: Daniel Le Date: Sat, 25 Sep 2021 20:18:54 +0800 Subject: [PATCH] Optimize header removal in ForwardedHeaderFilter The current implementation suggests that the request's headers are not expected to change. Hence, it's not necessary to copy them. Furthermore, it might be costly to do so if there are many headers. Instead, cache only the request's header names for method getHeaderNames. Methods getHeader and getHeaders delegate to the respective methods of request if the header name is not in FORWARDED_HEADER_NAMES. Otherwise, they return null or an empty Enumeration respectively. See gh-27466 --- .../web/filter/ForwardedHeaderFilter.java | 42 +++++++++++-------- 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java b/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java index 27db5087a2..9d6bb2eed6 100644 --- a/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java +++ b/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java @@ -20,9 +20,7 @@ import java.io.IOException; import java.net.InetSocketAddress; import java.util.Collections; import java.util.Enumeration; -import java.util.List; import java.util.Locale; -import java.util.Map; import java.util.Set; import java.util.function.Supplier; @@ -37,7 +35,6 @@ import org.springframework.http.HttpStatus; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.lang.Nullable; -import org.springframework.util.CollectionUtils; import org.springframework.util.LinkedCaseInsensitiveMap; import org.springframework.util.StringUtils; import org.springframework.web.util.UriComponents; @@ -169,23 +166,26 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { */ private static class ForwardedHeaderRemovingRequest extends HttpServletRequestWrapper { - private final Map> headers; + private final Set headerNames; public ForwardedHeaderRemovingRequest(HttpServletRequest request) { super(request); - this.headers = initHeaders(request); + + this.headerNames = headerNames(request); } - private static Map> initHeaders(HttpServletRequest request) { - Map> headers = new LinkedCaseInsensitiveMap<>(Locale.ENGLISH); - Enumeration names = request.getHeaderNames(); + private static Set headerNames(HttpServletRequest request) { + final var headerNames = Collections.newSetFromMap(new LinkedCaseInsensitiveMap<>(Locale.ENGLISH)); + final var names = request.getHeaderNames(); + while (names.hasMoreElements()) { - String name = names.nextElement(); - if (!FORWARDED_HEADER_NAMES.contains(name)) { - headers.put(name, Collections.list(request.getHeaders(name))); - } + final var name = names.nextElement(); + headerNames.add(name); } - return headers; + + headerNames.removeAll(FORWARDED_HEADER_NAMES); + + return Collections.unmodifiableSet(headerNames); } // Override header accessors to not expose forwarded headers @@ -193,19 +193,25 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { @Override @Nullable public String getHeader(String name) { - List value = this.headers.get(name); - return (CollectionUtils.isEmpty(value) ? null : value.get(0)); + if (FORWARDED_HEADER_NAMES.contains(name)) { + return null; + } + + return super.getHeader(name); } @Override public Enumeration getHeaders(String name) { - List value = this.headers.get(name); - return (Collections.enumeration(value != null ? value : Collections.emptySet())); + if (FORWARDED_HEADER_NAMES.contains(name)) { + return Collections.emptyEnumeration(); + } + + return super.getHeaders(name); } @Override public Enumeration getHeaderNames() { - return Collections.enumeration(this.headers.keySet()); + return Collections.enumeration(this.headerNames); } }