Add removeOnly mode to ForwardedHeaderFilter
Issue: SPR-15610
This commit is contained in:
parent
cf1bc81199
commit
895fa2ea7b
|
|
@ -40,20 +40,25 @@ import org.springframework.web.util.UriComponentsBuilder;
|
|||
import org.springframework.web.util.UrlPathHelper;
|
||||
|
||||
/**
|
||||
* Filter that wraps the request and response in order to override its
|
||||
* Extract values from "Forwarded" and "X-Forwarded-*" headers in order to wrap
|
||||
* and override the following from the request and response:
|
||||
* {@link HttpServletRequest#getServerName() getServerName()},
|
||||
* {@link HttpServletRequest#getServerPort() getServerPort()},
|
||||
* {@link HttpServletRequest#getScheme() getScheme()},
|
||||
* {@link HttpServletRequest#isSecure() isSecure()},
|
||||
* {@link HttpServletResponse#sendRedirect(String) sendRedirect(String)},
|
||||
* methods with values derived from "Forwarded" or "X-Forwarded-*"
|
||||
* headers. In effect the wrapped request and response reflects the
|
||||
* client-originated protocol and address.
|
||||
* {@link HttpServletRequest#isSecure() isSecure()}, and
|
||||
* {@link HttpServletResponse#sendRedirect(String) sendRedirect(String)}.
|
||||
* In effect the wrapped request and response reflect the client-originated
|
||||
* protocol and address.
|
||||
*
|
||||
* <p><strong>Note:</strong> This filter can also be used in a
|
||||
* {@link #setRemoveOnly removeOnly} mode where "Forwarded" and "X-Forwarded-*"
|
||||
* headers are only eliminated without being used.
|
||||
*
|
||||
* @author Rossen Stoyanchev
|
||||
* @author Eddú Meléndez
|
||||
* @author Rob Winch
|
||||
* @since 4.3
|
||||
* @see <a href="https://tools.ietf.org/html/rfc7239">https://tools.ietf.org/html/rfc7239</a>
|
||||
*/
|
||||
public class ForwardedHeaderFilter extends OncePerRequestFilter {
|
||||
|
||||
|
|
@ -71,6 +76,8 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
|
|||
|
||||
private final UrlPathHelper pathHelper;
|
||||
|
||||
private boolean removeOnly;
|
||||
|
||||
|
||||
public ForwardedHeaderFilter() {
|
||||
this.pathHelper = new UrlPathHelper();
|
||||
|
|
@ -79,6 +86,17 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
|
|||
}
|
||||
|
||||
|
||||
/**
|
||||
* Enables mode in which any "Forwarded" or "X-Forwarded-*" headers are
|
||||
* removed only and the information in them ignored.
|
||||
* @param removeOnly whether to discard and ingore forwarded headers
|
||||
* @since 4.3.9
|
||||
*/
|
||||
public void setRemoveOnly(boolean removeOnly) {
|
||||
this.removeOnly = removeOnly;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException {
|
||||
Enumeration<String> names = request.getHeaderNames();
|
||||
|
|
@ -105,13 +123,67 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
|
|||
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response,
|
||||
FilterChain filterChain) throws ServletException, IOException {
|
||||
|
||||
ForwardedHeaderRequestWrapper wrappedRequest = new ForwardedHeaderRequestWrapper(request, this.pathHelper);
|
||||
ForwardedHeaderResponseWrapper wrappedResponse = new ForwardedHeaderResponseWrapper(response, wrappedRequest);
|
||||
filterChain.doFilter(wrappedRequest, wrappedResponse);
|
||||
if (this.removeOnly) {
|
||||
ForwardedHeaderRemovingRequest theRequest = new ForwardedHeaderRemovingRequest(request);
|
||||
filterChain.doFilter(theRequest, response);
|
||||
}
|
||||
else {
|
||||
HttpServletRequest theRequest = new ForwardedHeaderExtractingRequest(request, this.pathHelper);
|
||||
HttpServletResponse theResponse = new ForwardedHeaderExtractingResponse(response, theRequest);
|
||||
filterChain.doFilter(theRequest, theResponse);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private static class ForwardedHeaderRequestWrapper extends HttpServletRequestWrapper {
|
||||
/**
|
||||
* Hide "Forwarded" or "X-Forwarded-*" headers.
|
||||
*/
|
||||
private static class ForwardedHeaderRemovingRequest extends HttpServletRequestWrapper {
|
||||
|
||||
private final Map<String, List<String>> headers;
|
||||
|
||||
|
||||
public ForwardedHeaderRemovingRequest(HttpServletRequest request) {
|
||||
super(request);
|
||||
this.headers = initHeaders(request);
|
||||
}
|
||||
|
||||
private static Map<String, List<String>> initHeaders(HttpServletRequest request) {
|
||||
Map<String, List<String>> headers = new LinkedCaseInsensitiveMap<>(Locale.ENGLISH);
|
||||
Enumeration<String> names = request.getHeaderNames();
|
||||
while (names.hasMoreElements()) {
|
||||
String name = names.nextElement();
|
||||
if (!FORWARDED_HEADER_NAMES.contains(name)) {
|
||||
headers.put(name, Collections.list(request.getHeaders(name)));
|
||||
}
|
||||
}
|
||||
return headers;
|
||||
}
|
||||
|
||||
// Override header accessors to not expose forwarded headers
|
||||
|
||||
@Override
|
||||
public String getHeader(String name) {
|
||||
List<String> value = this.headers.get(name);
|
||||
return (CollectionUtils.isEmpty(value) ? null : value.get(0));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Enumeration<String> getHeaders(String name) {
|
||||
List<String> value = this.headers.get(name);
|
||||
return (Collections.enumeration(value != null ? value : Collections.emptySet()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Enumeration<String> getHeaderNames() {
|
||||
return Collections.enumeration(this.headers.keySet());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract and use "Forwarded" or "X-Forwarded-*" headers.
|
||||
*/
|
||||
private static class ForwardedHeaderExtractingRequest extends ForwardedHeaderRemovingRequest {
|
||||
|
||||
private final String scheme;
|
||||
|
||||
|
|
@ -127,9 +199,8 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
|
|||
|
||||
private final String requestUrl;
|
||||
|
||||
private final Map<String, List<String>> headers;
|
||||
|
||||
public ForwardedHeaderRequestWrapper(HttpServletRequest request, UrlPathHelper pathHelper) {
|
||||
public ForwardedHeaderExtractingRequest(HttpServletRequest request, UrlPathHelper pathHelper) {
|
||||
super(request);
|
||||
|
||||
HttpRequest httpRequest = new ServletServerHttpRequest(request);
|
||||
|
|
@ -145,7 +216,6 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
|
|||
this.contextPath = (prefix != null ? prefix : request.getContextPath());
|
||||
this.requestUri = this.contextPath + pathHelper.getPathWithinApplication(request);
|
||||
this.requestUrl = this.scheme + "://" + this.host + (port == -1 ? "" : ":" + port) + this.requestUri;
|
||||
this.headers = initHeaders(request);
|
||||
}
|
||||
|
||||
private static String getForwardedPrefix(HttpServletRequest request) {
|
||||
|
|
@ -165,21 +235,6 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
|
|||
return prefix;
|
||||
}
|
||||
|
||||
/**
|
||||
* Copy the headers excluding any {@link #FORWARDED_HEADER_NAMES}.
|
||||
*/
|
||||
private static Map<String, List<String>> initHeaders(HttpServletRequest request) {
|
||||
Map<String, List<String>> headers = new LinkedCaseInsensitiveMap<>(Locale.ENGLISH);
|
||||
Enumeration<String> names = request.getHeaderNames();
|
||||
while (names.hasMoreElements()) {
|
||||
String name = names.nextElement();
|
||||
if (!FORWARDED_HEADER_NAMES.contains(name)) {
|
||||
headers.put(name, Collections.list(request.getHeaders(name)));
|
||||
}
|
||||
}
|
||||
return headers;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getScheme() {
|
||||
return this.scheme;
|
||||
|
|
@ -214,35 +269,18 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
|
|||
public StringBuffer getRequestURL() {
|
||||
return new StringBuffer(this.requestUrl);
|
||||
}
|
||||
|
||||
// Override header accessors to not expose forwarded headers
|
||||
|
||||
@Override
|
||||
public String getHeader(String name) {
|
||||
List<String> value = this.headers.get(name);
|
||||
return (CollectionUtils.isEmpty(value) ? null : value.get(0));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Enumeration<String> getHeaders(String name) {
|
||||
List<String> value = this.headers.get(name);
|
||||
return (Collections.enumeration(value != null ? value : Collections.emptySet()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Enumeration<String> getHeaderNames() {
|
||||
return Collections.enumeration(this.headers.keySet());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private static class ForwardedHeaderResponseWrapper extends HttpServletResponseWrapper {
|
||||
private static class ForwardedHeaderExtractingResponse extends HttpServletResponseWrapper {
|
||||
|
||||
private static final String FOLDER_SEPARATOR = "/";
|
||||
|
||||
|
||||
private final HttpServletRequest request;
|
||||
|
||||
public ForwardedHeaderResponseWrapper(HttpServletResponse response, HttpServletRequest request) {
|
||||
|
||||
public ForwardedHeaderExtractingResponse(HttpServletResponse response, HttpServletRequest request) {
|
||||
super(response);
|
||||
this.request = request;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -32,7 +32,10 @@ import org.springframework.mock.web.test.MockFilterChain;
|
|||
import org.springframework.mock.web.test.MockHttpServletRequest;
|
||||
import org.springframework.mock.web.test.MockHttpServletResponse;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertNull;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
/**
|
||||
* Unit tests for {@link ForwardedHeaderFilter}.
|
||||
|
|
@ -239,6 +242,30 @@ public class ForwardedHeaderFilterTests {
|
|||
assertEquals("bar", actual.getHeader("foo"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void forwardedRequestInRemoveOnlyMode() throws Exception {
|
||||
this.request.setRequestURI("/mvc-showcase");
|
||||
this.request.addHeader(X_FORWARDED_PROTO, "https");
|
||||
this.request.addHeader(X_FORWARDED_HOST, "84.198.58.199");
|
||||
this.request.addHeader(X_FORWARDED_PORT, "443");
|
||||
this.request.addHeader("foo", "bar");
|
||||
|
||||
this.filter.setRemoveOnly(true);
|
||||
this.filter.doFilter(this.request, new MockHttpServletResponse(), this.filterChain);
|
||||
HttpServletRequest actual = (HttpServletRequest) this.filterChain.getRequest();
|
||||
|
||||
assertEquals("http://localhost/mvc-showcase", actual.getRequestURL().toString());
|
||||
assertEquals("http", actual.getScheme());
|
||||
assertEquals("localhost", actual.getServerName());
|
||||
assertEquals(80, actual.getServerPort());
|
||||
assertFalse(actual.isSecure());
|
||||
|
||||
assertNull(actual.getHeader(X_FORWARDED_PROTO));
|
||||
assertNull(actual.getHeader(X_FORWARDED_HOST));
|
||||
assertNull(actual.getHeader(X_FORWARDED_PORT));
|
||||
assertEquals("bar", actual.getHeader("foo"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void requestUriWithForwardedPrefix() throws Exception {
|
||||
this.request.addHeader(X_FORWARDED_PREFIX, "/prefix");
|
||||
|
|
|
|||
Loading…
Reference in New Issue