Support contextPath override in ForwardedHeaderFilter
Issue: SPR-13614
This commit is contained in:
		
							parent
							
								
									6fcc869338
								
							
						
					
					
						commit
						36e2dd90a7
					
				|  | @ -31,9 +31,11 @@ import javax.servlet.http.HttpServletResponse; | |||
| 
 | ||||
| import org.springframework.http.HttpRequest; | ||||
| import org.springframework.http.server.ServletServerHttpRequest; | ||||
| import org.springframework.util.Assert; | ||||
| import org.springframework.util.CollectionUtils; | ||||
| import org.springframework.web.util.UriComponents; | ||||
| import org.springframework.web.util.UriComponentsBuilder; | ||||
| import org.springframework.web.util.UrlPathHelper; | ||||
| 
 | ||||
| 
 | ||||
| /** | ||||
|  | @ -61,6 +63,28 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { | |||
| 	} | ||||
| 
 | ||||
| 
 | ||||
| 	private ContextPathHelper contextPathHelper; | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 	/** | ||||
| 	 * Configure a contextPath value that will replace the contextPath of | ||||
| 	 * proxy-forwarded requests. | ||||
| 	 * | ||||
| 	 * <p>This is useful when external clients are not aware of the application | ||||
| 	 * context path. However a proxy forwards the request to a URL that includes | ||||
| 	 * a contextPath. | ||||
| 	 * | ||||
| 	 * @param contextPath the context path; the given value will be sanitized to | ||||
| 	 * ensure it starts with a '/' but does not end with one, or if the context | ||||
| 	 * path is empty (default, root context) it is left as-is. | ||||
| 	 */ | ||||
| 	public void setContextPath(String contextPath) { | ||||
| 		Assert.notNull(contextPath, "'contextPath' must not be null"); | ||||
| 		this.contextPathHelper = new ContextPathHelper(contextPath); | ||||
| 	} | ||||
| 
 | ||||
| 
 | ||||
| 	@Override | ||||
| 	protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException { | ||||
| 		Enumeration<String> headerNames = request.getHeaderNames(); | ||||
|  | @ -87,7 +111,7 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { | |||
| 	protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, | ||||
| 			FilterChain filterChain) throws ServletException, IOException { | ||||
| 
 | ||||
| 		filterChain.doFilter(new ForwardedHeaderRequestWrapper(request), response); | ||||
| 		filterChain.doFilter(new ForwardedHeaderRequestWrapper(request, this.contextPathHelper), response); | ||||
| 	} | ||||
| 
 | ||||
| 
 | ||||
|  | @ -105,12 +129,16 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { | |||
| 
 | ||||
| 		private final int port; | ||||
| 
 | ||||
| 		private final String contextPath; | ||||
| 
 | ||||
| 		private final String requestUri; | ||||
| 
 | ||||
| 		private final StringBuffer requestUrl; | ||||
| 
 | ||||
| 		private final Map<String, List<String>> headers; | ||||
| 
 | ||||
| 
 | ||||
| 		public ForwardedHeaderRequestWrapper(HttpServletRequest request) { | ||||
| 		public ForwardedHeaderRequestWrapper(HttpServletRequest request, ContextPathHelper pathHelper) { | ||||
| 			super(request); | ||||
| 
 | ||||
| 			HttpRequest httpRequest = new ServletServerHttpRequest(request); | ||||
|  | @ -121,7 +149,11 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { | |||
| 			this.secure = "https".equals(scheme); | ||||
| 			this.host = uriComponents.getHost(); | ||||
| 			this.port = (port == -1 ? (this.secure ? 443 : 80) : port); | ||||
| 			this.requestUrl = initRequestUrl(this.scheme, this.host, port, request.getRequestURI()); | ||||
| 
 | ||||
| 			this.contextPath = (pathHelper != null ? pathHelper.getContextPath(request) : request.getContextPath()); | ||||
| 			this.requestUri = (pathHelper != null ? pathHelper.getRequestUri(request) : request.getRequestURI()); | ||||
| 			this.requestUrl = initRequestUrl(this.scheme, this.host, port, this.requestUri); | ||||
| 
 | ||||
| 			this.headers = initHeaders(request); | ||||
| 		} | ||||
| 
 | ||||
|  | @ -170,6 +202,16 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { | |||
| 			return this.secure; | ||||
| 		} | ||||
| 
 | ||||
| 		@Override | ||||
| 		public String getContextPath() { | ||||
| 			return this.contextPath; | ||||
| 		} | ||||
| 
 | ||||
| 		@Override | ||||
| 		public String getRequestURI() { | ||||
| 			return this.requestUri; | ||||
| 		} | ||||
| 
 | ||||
| 		@Override | ||||
| 		public StringBuffer getRequestURL() { | ||||
| 			return this.requestUrl; | ||||
|  | @ -195,4 +237,50 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { | |||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 
 | ||||
| 	private static class ContextPathHelper { | ||||
| 
 | ||||
| 		private final String contextPath; | ||||
| 
 | ||||
| 		private final UrlPathHelper urlPathHelper; | ||||
| 
 | ||||
| 
 | ||||
| 		public ContextPathHelper(String contextPath) { | ||||
| 			Assert.notNull(contextPath); | ||||
| 			this.contextPath = sanitizeContextPath(contextPath); | ||||
| 			this.urlPathHelper = new UrlPathHelper(); | ||||
| 			this.urlPathHelper.setUrlDecode(false); | ||||
| 			this.urlPathHelper.setRemoveSemicolonContent(false); | ||||
| 		} | ||||
| 
 | ||||
| 		private static String sanitizeContextPath(String contextPath) { | ||||
| 			contextPath = contextPath.trim(); | ||||
| 			if (contextPath.isEmpty()) { | ||||
| 				return contextPath; | ||||
| 			} | ||||
| 			if (contextPath.equals("/")) { | ||||
| 				return "/"; | ||||
| 			} | ||||
| 			if (contextPath.charAt(0) != '/') { | ||||
| 				contextPath = "/"  + contextPath; | ||||
| 			} | ||||
| 			while (contextPath.endsWith("/")) { | ||||
| 				contextPath = contextPath.substring(0, contextPath.length() -1); | ||||
| 			} | ||||
| 			return contextPath; | ||||
| 		} | ||||
| 
 | ||||
| 		public String getContextPath(HttpServletRequest request) { | ||||
| 			return this.contextPath; | ||||
| 		} | ||||
| 
 | ||||
| 		public String getRequestUri(HttpServletRequest request) { | ||||
| 			String pathWithinApplication = this.urlPathHelper.getPathWithinApplication(request); | ||||
| 			if (this.contextPath.equals("/") && pathWithinApplication.startsWith("/")) { | ||||
| 				return pathWithinApplication; | ||||
| 			} | ||||
| 			return this.contextPath + pathWithinApplication; | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| } | ||||
|  |  | |||
|  | @ -15,10 +15,12 @@ | |||
|  */ | ||||
| package org.springframework.web.filter; | ||||
| 
 | ||||
| import java.io.IOException; | ||||
| import javax.servlet.ServletException; | ||||
| import javax.servlet.http.HttpServlet; | ||||
| import javax.servlet.http.HttpServletRequest; | ||||
| 
 | ||||
| import org.junit.Before; | ||||
| import org.junit.Test; | ||||
| 
 | ||||
| import org.springframework.mock.web.test.MockFilterChain; | ||||
|  | @ -38,6 +40,98 @@ public class ForwardedHeaderFilterTests { | |||
| 
 | ||||
| 	private final ForwardedHeaderFilter filter = new ForwardedHeaderFilter(); | ||||
| 
 | ||||
| 	private MockHttpServletRequest request; | ||||
| 
 | ||||
| 	private MockFilterChain filterChain; | ||||
| 
 | ||||
| 
 | ||||
| 	@Before | ||||
| 	public void setUp() throws Exception { | ||||
| 		this.request = new MockHttpServletRequest(); | ||||
| 		this.request.setScheme("http"); | ||||
| 		this.request.setServerName("localhost"); | ||||
| 		this.request.setServerPort(80); | ||||
| 		this.filterChain = new MockFilterChain(new HttpServlet() {}); | ||||
| 	} | ||||
| 
 | ||||
| 
 | ||||
| 	@Test(expected = IllegalArgumentException.class) | ||||
| 	public void contextPathNull() { | ||||
| 		this.filter.setContextPath(null); | ||||
| 	} | ||||
| 
 | ||||
| 	@Test | ||||
| 	public void contextPathEmpty() throws Exception { | ||||
| 		this.filter.setContextPath(""); | ||||
| 		assertEquals("", filterAndGetContextPath()); | ||||
| 	} | ||||
| 
 | ||||
| 	@Test | ||||
| 	public void contextPathWithExtraSpaces() throws Exception { | ||||
| 		this.filter.setContextPath("  /foo  "); | ||||
| 		assertEquals("/foo", filterAndGetContextPath()); | ||||
| 	} | ||||
| 
 | ||||
| 	@Test | ||||
| 	public void contextPathWithNoLeadingSlash() throws Exception { | ||||
| 		this.filter.setContextPath("foo"); | ||||
| 		assertEquals("/foo", filterAndGetContextPath()); | ||||
| 	} | ||||
| 
 | ||||
| 	@Test | ||||
| 	public void contextPathWithTrailingSlash() throws Exception { | ||||
| 		this.filter.setContextPath("/foo/bar/"); | ||||
| 		assertEquals("/foo/bar", filterAndGetContextPath()); | ||||
| 	} | ||||
| 
 | ||||
| 	@Test | ||||
| 	public void contextPathWithTrailingSlashes() throws Exception { | ||||
| 		this.filter.setContextPath("/foo/bar/baz///"); | ||||
| 		assertEquals("/foo/bar/baz", filterAndGetContextPath()); | ||||
| 	} | ||||
| 
 | ||||
| 	@Test | ||||
| 	public void requestUri() throws Exception { | ||||
| 		this.filter.setContextPath("/"); | ||||
| 		this.request.setContextPath("/app"); | ||||
| 		this.request.setRequestURI("/app/path"); | ||||
| 		HttpServletRequest actual = filterAndGetWrappedRequest(); | ||||
| 
 | ||||
| 		assertEquals("/", actual.getContextPath()); | ||||
| 		assertEquals("/path", actual.getRequestURI()); | ||||
| 	} | ||||
| 
 | ||||
| 	@Test | ||||
| 	public void requestUriWithTrailingSlash() throws Exception { | ||||
| 		this.filter.setContextPath("/"); | ||||
| 		this.request.setContextPath("/app"); | ||||
| 		this.request.setRequestURI("/app/path/"); | ||||
| 		HttpServletRequest actual = filterAndGetWrappedRequest(); | ||||
| 
 | ||||
| 		assertEquals("/", actual.getContextPath()); | ||||
| 		assertEquals("/path/", actual.getRequestURI()); | ||||
| 	} | ||||
| 	@Test | ||||
| 	public void requestUriEqualsContextPath() throws Exception { | ||||
| 		this.filter.setContextPath("/"); | ||||
| 		this.request.setContextPath("/app"); | ||||
| 		this.request.setRequestURI("/app"); | ||||
| 		HttpServletRequest actual = filterAndGetWrappedRequest(); | ||||
| 
 | ||||
| 		assertEquals("/", actual.getContextPath()); | ||||
| 		assertEquals("/", actual.getRequestURI()); | ||||
| 	} | ||||
| 
 | ||||
| 	@Test | ||||
| 	public void requestUriRootUrl() throws Exception { | ||||
| 		this.filter.setContextPath("/"); | ||||
| 		this.request.setContextPath("/app"); | ||||
| 		this.request.setRequestURI("/app/"); | ||||
| 		HttpServletRequest actual = filterAndGetWrappedRequest(); | ||||
| 
 | ||||
| 		assertEquals("/", actual.getContextPath()); | ||||
| 		assertEquals("/", actual.getRequestURI()); | ||||
| 	} | ||||
| 
 | ||||
| 	@Test | ||||
| 	public void shouldFilter() throws Exception { | ||||
|  | @ -54,19 +148,14 @@ public class ForwardedHeaderFilterTests { | |||
| 
 | ||||
| 	@Test | ||||
| 	public void forwardedRequest() throws Exception { | ||||
| 		MockHttpServletRequest request = new MockHttpServletRequest(); | ||||
| 		request.setScheme("http"); | ||||
| 		request.setServerName("localhost"); | ||||
| 		request.setServerPort(80); | ||||
| 		request.setRequestURI("/mvc-showcase"); | ||||
| 		request.addHeader("X-Forwarded-Proto", "https"); | ||||
| 		request.addHeader("X-Forwarded-Host", "84.198.58.199"); | ||||
| 		request.addHeader("X-Forwarded-Port", "443"); | ||||
| 		request.addHeader("foo", "bar"); | ||||
| 		this.request.setRequestURI("/mvc-showcase"); | ||||
| 		this.request.addHeader("X-Forwarded-Proto", "https"); | ||||
| 		this.request.addHeader("X-Forwarded-Host", "84.198.58.199"); | ||||
| 		this.request.addHeader("X-Forwarded-Port", "443"); | ||||
| 		this.request.addHeader("foo", "bar"); | ||||
| 
 | ||||
| 		MockFilterChain chain = new MockFilterChain(new HttpServlet() {}); | ||||
| 		this.filter.doFilter(request, new MockHttpServletResponse(), chain); | ||||
| 		HttpServletRequest actual = (HttpServletRequest) chain.getRequest(); | ||||
| 		this.filter.doFilter(this.request, new MockHttpServletResponse(), this.filterChain); | ||||
| 		HttpServletRequest actual = (HttpServletRequest) this.filterChain.getRequest(); | ||||
| 
 | ||||
| 		assertEquals("https://84.198.58.199/mvc-showcase", actual.getRequestURL().toString()); | ||||
| 		assertEquals("https", actual.getScheme()); | ||||
|  | @ -81,11 +170,20 @@ public class ForwardedHeaderFilterTests { | |||
| 	} | ||||
| 
 | ||||
| 
 | ||||
| 	private String filterAndGetContextPath() throws ServletException, IOException { | ||||
| 		return filterAndGetWrappedRequest().getContextPath(); | ||||
| 	} | ||||
| 
 | ||||
| 	private HttpServletRequest filterAndGetWrappedRequest() throws ServletException, IOException { | ||||
| 		MockHttpServletResponse response = new MockHttpServletResponse(); | ||||
| 		this.filter.doFilterInternal(this.request, response, this.filterChain); | ||||
| 		return (HttpServletRequest) this.filterChain.getRequest(); | ||||
| 	} | ||||
| 
 | ||||
| 	private void testShouldFilter(String headerName) throws ServletException { | ||||
| 		MockHttpServletRequest request = new MockHttpServletRequest(); | ||||
| 		request.addHeader(headerName, "1"); | ||||
| 		assertFalse(this.filter.shouldNotFilter(request)); | ||||
| 	} | ||||
| 
 | ||||
| 
 | ||||
| } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue