Optimize header removal in ForwardedHeaderFilter
The current implementation suggests that the request's headers are not expected to change. Hence, it's not necessary to copy them. Furthermore, it might be costly to do so if there are many headers. Instead, cache only the request's header names for method getHeaderNames. Methods getHeader and getHeaders delegate to the respective methods of request if the header name is not in FORWARDED_HEADER_NAMES. Otherwise, they return null or an empty Enumeration respectively. See gh-27466
This commit is contained in:
parent
fab9abd7fe
commit
6605953eb5
|
@ -20,9 +20,7 @@ import java.io.IOException;
|
|||
import java.net.InetSocketAddress;
|
||||
import java.util.Collections;
|
||||
import java.util.Enumeration;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.function.Supplier;
|
||||
|
||||
|
@ -37,7 +35,6 @@ import org.springframework.http.HttpStatus;
|
|||
import org.springframework.http.server.ServerHttpRequest;
|
||||
import org.springframework.http.server.ServletServerHttpRequest;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import org.springframework.util.LinkedCaseInsensitiveMap;
|
||||
import org.springframework.util.StringUtils;
|
||||
import org.springframework.web.util.UriComponents;
|
||||
|
@ -169,23 +166,26 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
|
|||
*/
|
||||
private static class ForwardedHeaderRemovingRequest extends HttpServletRequestWrapper {
|
||||
|
||||
private final Map<String, List<String>> headers;
|
||||
private final Set<String> headerNames;
|
||||
|
||||
public ForwardedHeaderRemovingRequest(HttpServletRequest request) {
|
||||
super(request);
|
||||
this.headers = initHeaders(request);
|
||||
|
||||
this.headerNames = headerNames(request);
|
||||
}
|
||||
|
||||
private static Map<String, List<String>> initHeaders(HttpServletRequest request) {
|
||||
Map<String, List<String>> headers = new LinkedCaseInsensitiveMap<>(Locale.ENGLISH);
|
||||
Enumeration<String> names = request.getHeaderNames();
|
||||
private static Set<String> headerNames(HttpServletRequest request) {
|
||||
final var headerNames = Collections.newSetFromMap(new LinkedCaseInsensitiveMap<>(Locale.ENGLISH));
|
||||
final var names = request.getHeaderNames();
|
||||
|
||||
while (names.hasMoreElements()) {
|
||||
String name = names.nextElement();
|
||||
if (!FORWARDED_HEADER_NAMES.contains(name)) {
|
||||
headers.put(name, Collections.list(request.getHeaders(name)));
|
||||
}
|
||||
final var name = names.nextElement();
|
||||
headerNames.add(name);
|
||||
}
|
||||
return headers;
|
||||
|
||||
headerNames.removeAll(FORWARDED_HEADER_NAMES);
|
||||
|
||||
return Collections.unmodifiableSet(headerNames);
|
||||
}
|
||||
|
||||
// Override header accessors to not expose forwarded headers
|
||||
|
@ -193,19 +193,25 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
|
|||
@Override
|
||||
@Nullable
|
||||
public String getHeader(String name) {
|
||||
List<String> value = this.headers.get(name);
|
||||
return (CollectionUtils.isEmpty(value) ? null : value.get(0));
|
||||
if (FORWARDED_HEADER_NAMES.contains(name)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return super.getHeader(name);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Enumeration<String> getHeaders(String name) {
|
||||
List<String> value = this.headers.get(name);
|
||||
return (Collections.enumeration(value != null ? value : Collections.emptySet()));
|
||||
if (FORWARDED_HEADER_NAMES.contains(name)) {
|
||||
return Collections.emptyEnumeration();
|
||||
}
|
||||
|
||||
return super.getHeaders(name);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Enumeration<String> getHeaderNames() {
|
||||
return Collections.enumeration(this.headers.keySet());
|
||||
return Collections.enumeration(this.headerNames);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue