Support contextPath override in ForwardedHeaderFilter

Issue: SPR-13614
This commit is contained in:
Rossen Stoyanchev 2016-03-02 18:37:22 -05:00
parent 6fcc869338
commit 36e2dd90a7
2 changed files with 202 additions and 16 deletions

View File

@ -31,9 +31,11 @@ import javax.servlet.http.HttpServletResponse;
import org.springframework.http.HttpRequest;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.web.util.UriComponents;
import org.springframework.web.util.UriComponentsBuilder;
import org.springframework.web.util.UrlPathHelper;
/**
@ -61,6 +63,28 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
}
private ContextPathHelper contextPathHelper;
/**
* Configure a contextPath value that will replace the contextPath of
* proxy-forwarded requests.
*
* <p>This is useful when external clients are not aware of the application
* context path. However a proxy forwards the request to a URL that includes
* a contextPath.
*
* @param contextPath the context path; the given value will be sanitized to
* ensure it starts with a '/' but does not end with one, or if the context
* path is empty (default, root context) it is left as-is.
*/
public void setContextPath(String contextPath) {
Assert.notNull(contextPath, "'contextPath' must not be null");
this.contextPathHelper = new ContextPathHelper(contextPath);
}
@Override
protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException {
Enumeration<String> headerNames = request.getHeaderNames();
@ -87,7 +111,7 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response,
FilterChain filterChain) throws ServletException, IOException {
filterChain.doFilter(new ForwardedHeaderRequestWrapper(request), response);
filterChain.doFilter(new ForwardedHeaderRequestWrapper(request, this.contextPathHelper), response);
}
@ -105,12 +129,16 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
private final int port;
private final String contextPath;
private final String requestUri;
private final StringBuffer requestUrl;
private final Map<String, List<String>> headers;
public ForwardedHeaderRequestWrapper(HttpServletRequest request) {
public ForwardedHeaderRequestWrapper(HttpServletRequest request, ContextPathHelper pathHelper) {
super(request);
HttpRequest httpRequest = new ServletServerHttpRequest(request);
@ -121,7 +149,11 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
this.secure = "https".equals(scheme);
this.host = uriComponents.getHost();
this.port = (port == -1 ? (this.secure ? 443 : 80) : port);
this.requestUrl = initRequestUrl(this.scheme, this.host, port, request.getRequestURI());
this.contextPath = (pathHelper != null ? pathHelper.getContextPath(request) : request.getContextPath());
this.requestUri = (pathHelper != null ? pathHelper.getRequestUri(request) : request.getRequestURI());
this.requestUrl = initRequestUrl(this.scheme, this.host, port, this.requestUri);
this.headers = initHeaders(request);
}
@ -170,6 +202,16 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
return this.secure;
}
@Override
public String getContextPath() {
return this.contextPath;
}
@Override
public String getRequestURI() {
return this.requestUri;
}
@Override
public StringBuffer getRequestURL() {
return this.requestUrl;
@ -195,4 +237,50 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter {
}
}
private static class ContextPathHelper {
private final String contextPath;
private final UrlPathHelper urlPathHelper;
public ContextPathHelper(String contextPath) {
Assert.notNull(contextPath);
this.contextPath = sanitizeContextPath(contextPath);
this.urlPathHelper = new UrlPathHelper();
this.urlPathHelper.setUrlDecode(false);
this.urlPathHelper.setRemoveSemicolonContent(false);
}
private static String sanitizeContextPath(String contextPath) {
contextPath = contextPath.trim();
if (contextPath.isEmpty()) {
return contextPath;
}
if (contextPath.equals("/")) {
return "/";
}
if (contextPath.charAt(0) != '/') {
contextPath = "/" + contextPath;
}
while (contextPath.endsWith("/")) {
contextPath = contextPath.substring(0, contextPath.length() -1);
}
return contextPath;
}
public String getContextPath(HttpServletRequest request) {
return this.contextPath;
}
public String getRequestUri(HttpServletRequest request) {
String pathWithinApplication = this.urlPathHelper.getPathWithinApplication(request);
if (this.contextPath.equals("/") && pathWithinApplication.startsWith("/")) {
return pathWithinApplication;
}
return this.contextPath + pathWithinApplication;
}
}
}

View File

@ -15,10 +15,12 @@
*/
package org.springframework.web.filter;
import java.io.IOException;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import org.junit.Before;
import org.junit.Test;
import org.springframework.mock.web.test.MockFilterChain;
@ -38,6 +40,98 @@ public class ForwardedHeaderFilterTests {
private final ForwardedHeaderFilter filter = new ForwardedHeaderFilter();
private MockHttpServletRequest request;
private MockFilterChain filterChain;
@Before
public void setUp() throws Exception {
this.request = new MockHttpServletRequest();
this.request.setScheme("http");
this.request.setServerName("localhost");
this.request.setServerPort(80);
this.filterChain = new MockFilterChain(new HttpServlet() {});
}
@Test(expected = IllegalArgumentException.class)
public void contextPathNull() {
this.filter.setContextPath(null);
}
@Test
public void contextPathEmpty() throws Exception {
this.filter.setContextPath("");
assertEquals("", filterAndGetContextPath());
}
@Test
public void contextPathWithExtraSpaces() throws Exception {
this.filter.setContextPath(" /foo ");
assertEquals("/foo", filterAndGetContextPath());
}
@Test
public void contextPathWithNoLeadingSlash() throws Exception {
this.filter.setContextPath("foo");
assertEquals("/foo", filterAndGetContextPath());
}
@Test
public void contextPathWithTrailingSlash() throws Exception {
this.filter.setContextPath("/foo/bar/");
assertEquals("/foo/bar", filterAndGetContextPath());
}
@Test
public void contextPathWithTrailingSlashes() throws Exception {
this.filter.setContextPath("/foo/bar/baz///");
assertEquals("/foo/bar/baz", filterAndGetContextPath());
}
@Test
public void requestUri() throws Exception {
this.filter.setContextPath("/");
this.request.setContextPath("/app");
this.request.setRequestURI("/app/path");
HttpServletRequest actual = filterAndGetWrappedRequest();
assertEquals("/", actual.getContextPath());
assertEquals("/path", actual.getRequestURI());
}
@Test
public void requestUriWithTrailingSlash() throws Exception {
this.filter.setContextPath("/");
this.request.setContextPath("/app");
this.request.setRequestURI("/app/path/");
HttpServletRequest actual = filterAndGetWrappedRequest();
assertEquals("/", actual.getContextPath());
assertEquals("/path/", actual.getRequestURI());
}
@Test
public void requestUriEqualsContextPath() throws Exception {
this.filter.setContextPath("/");
this.request.setContextPath("/app");
this.request.setRequestURI("/app");
HttpServletRequest actual = filterAndGetWrappedRequest();
assertEquals("/", actual.getContextPath());
assertEquals("/", actual.getRequestURI());
}
@Test
public void requestUriRootUrl() throws Exception {
this.filter.setContextPath("/");
this.request.setContextPath("/app");
this.request.setRequestURI("/app/");
HttpServletRequest actual = filterAndGetWrappedRequest();
assertEquals("/", actual.getContextPath());
assertEquals("/", actual.getRequestURI());
}
@Test
public void shouldFilter() throws Exception {
@ -54,19 +148,14 @@ public class ForwardedHeaderFilterTests {
@Test
public void forwardedRequest() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest();
request.setScheme("http");
request.setServerName("localhost");
request.setServerPort(80);
request.setRequestURI("/mvc-showcase");
request.addHeader("X-Forwarded-Proto", "https");
request.addHeader("X-Forwarded-Host", "84.198.58.199");
request.addHeader("X-Forwarded-Port", "443");
request.addHeader("foo", "bar");
this.request.setRequestURI("/mvc-showcase");
this.request.addHeader("X-Forwarded-Proto", "https");
this.request.addHeader("X-Forwarded-Host", "84.198.58.199");
this.request.addHeader("X-Forwarded-Port", "443");
this.request.addHeader("foo", "bar");
MockFilterChain chain = new MockFilterChain(new HttpServlet() {});
this.filter.doFilter(request, new MockHttpServletResponse(), chain);
HttpServletRequest actual = (HttpServletRequest) chain.getRequest();
this.filter.doFilter(this.request, new MockHttpServletResponse(), this.filterChain);
HttpServletRequest actual = (HttpServletRequest) this.filterChain.getRequest();
assertEquals("https://84.198.58.199/mvc-showcase", actual.getRequestURL().toString());
assertEquals("https", actual.getScheme());
@ -81,11 +170,20 @@ public class ForwardedHeaderFilterTests {
}
private String filterAndGetContextPath() throws ServletException, IOException {
return filterAndGetWrappedRequest().getContextPath();
}
private HttpServletRequest filterAndGetWrappedRequest() throws ServletException, IOException {
MockHttpServletResponse response = new MockHttpServletResponse();
this.filter.doFilterInternal(this.request, response, this.filterChain);
return (HttpServletRequest) this.filterChain.getRequest();
}
private void testShouldFilter(String headerName) throws ServletException {
MockHttpServletRequest request = new MockHttpServletRequest();
request.addHeader(headerName, "1");
assertFalse(this.filter.shouldNotFilter(request));
}
}