Use X-Forwarded-Prefix in ServletUriComponentsBuilder

Issue: SPR-12500
This commit is contained in:
Rossen Stoyanchev 2014-12-03 12:07:20 -05:00
parent af2782aa79
commit d322bcfbf4
2 changed files with 39 additions and 16 deletions

View File

@ -35,7 +35,7 @@ import org.springframework.web.util.WebUtils;
*/ */
public class ServletUriComponentsBuilder extends UriComponentsBuilder { public class ServletUriComponentsBuilder extends UriComponentsBuilder {
private String servletRequestURI; private String originalPath;
/** /**
@ -86,7 +86,6 @@ public class ServletUriComponentsBuilder extends UriComponentsBuilder {
*/ */
public static ServletUriComponentsBuilder fromRequestUri(HttpServletRequest request) { public static ServletUriComponentsBuilder fromRequestUri(HttpServletRequest request) {
ServletUriComponentsBuilder builder = fromRequest(request); ServletUriComponentsBuilder builder = fromRequest(request);
builder.pathFromRequest(request);
builder.replaceQuery(null); builder.replaceQuery(null);
return builder; return builder;
} }
@ -99,6 +98,7 @@ public class ServletUriComponentsBuilder extends UriComponentsBuilder {
String scheme = request.getScheme(); String scheme = request.getScheme();
String host = request.getServerName(); String host = request.getServerName();
int port = request.getServerPort(); int port = request.getServerPort();
String path = request.getRequestURI();
String hostHeader = request.getHeader("X-Forwarded-Host"); String hostHeader = request.getHeader("X-Forwarded-Host");
if (StringUtils.hasText(hostHeader)) { if (StringUtils.hasText(hostHeader)) {
@ -125,13 +125,18 @@ public class ServletUriComponentsBuilder extends UriComponentsBuilder {
scheme = protocolHeader; scheme = protocolHeader;
} }
String prefix = request.getHeader("X-Forwarded-Prefix");
if (StringUtils.hasText(prefix)) {
path = prefix + path;
}
ServletUriComponentsBuilder builder = new ServletUriComponentsBuilder(); ServletUriComponentsBuilder builder = new ServletUriComponentsBuilder();
builder.scheme(scheme); builder.scheme(scheme);
builder.host(host); builder.host(host);
if (scheme.equals("http") && port != 80 || scheme.equals("https") && port != 443) { if (scheme.equals("http") && port != 80 || scheme.equals("https") && port != 443) {
builder.port(port); builder.port(port);
} }
builder.pathFromRequest(request); builder.initPath(path);
builder.query(request.getQueryString()); builder.query(request.getQueryString());
return builder; return builder;
} }
@ -180,9 +185,9 @@ public class ServletUriComponentsBuilder extends UriComponentsBuilder {
return servletRequest; return servletRequest;
} }
private void pathFromRequest(HttpServletRequest request) { private void initPath(String path) {
this.servletRequestURI = request.getRequestURI(); this.originalPath = path;
replacePath(request.getRequestURI()); replacePath(path);
} }
/** /**
@ -190,27 +195,27 @@ public class ServletUriComponentsBuilder extends UriComponentsBuilder {
* requestURI}. This method must be invoked before any calls to {@link #path(String)} * requestURI}. This method must be invoked before any calls to {@link #path(String)}
* or {@link #pathSegment(String...)}. * or {@link #pathSegment(String...)}.
* <pre> * <pre>
* // GET http://foo.com/rest/books/6.json
* *
* ServletUriComponentsBuilder builder = ServletUriComponentsBuilder.fromRequestUri(this.request); * GET http://foo.com/rest/books/6.json
* String ext = builder.removePathExtension();
* String uri = builder.path("/pages/1.{ext}").buildAndExpand(ext).toUriString();
* *
* assertEquals("http://foo.com/rest/books/6/pages/1.json", result); * ServletUriComponentsBuilder builder = ServletUriComponentsBuilder.fromRequestUri(this.request);
* String ext = builder.removePathExtension();
* String uri = builder.path("/pages/1.{ext}").buildAndExpand(ext).toUriString();
* assertEquals("http://foo.com/rest/books/6/pages/1.json", result);
* </pre> * </pre>
* @return the removed path extension for possible re-use, or {@code null} * @return the removed path extension for possible re-use, or {@code null}
* @since 4.0 * @since 4.0
*/ */
public String removePathExtension() { public String removePathExtension() {
String extension = null; String extension = null;
if (this.servletRequestURI != null) { if (this.originalPath != null) {
String filename = WebUtils.extractFullFilenameFromUrlPath(this.servletRequestURI); String filename = WebUtils.extractFullFilenameFromUrlPath(this.originalPath);
extension = StringUtils.getFilenameExtension(filename); extension = StringUtils.getFilenameExtension(filename);
if (!StringUtils.isEmpty(extension)) { if (!StringUtils.isEmpty(extension)) {
int end = this.servletRequestURI.length() - (extension.length() + 1); int end = this.originalPath.length() - (extension.length() + 1);
replacePath(this.servletRequestURI.substring(0, end)); replacePath(this.originalPath.substring(0, end));
} }
this.servletRequestURI = null; this.originalPath = null;
} }
return extension; return extension;
} }

View File

@ -150,6 +150,24 @@ public class ServletUriComponentsBuilderTests {
assertEquals("should have used the default port of the forwarded request", -1, result.getPort()); assertEquals("should have used the default port of the forwarded request", -1, result.getPort());
} }
@Test
public void fromRequestWithForwardedPrefix() {
this.request.setRequestURI("/bar");
this.request.addHeader("X-Forwarded-Prefix", "/foo");
UriComponents result = ServletUriComponentsBuilder.fromRequest(request).build();
assertEquals("http://localhost/foo/bar", result.toUriString());
}
@Test
public void fromRequestWithForwardedPrefixTrailingSlash() {
this.request.setRequestURI("/bar");
this.request.addHeader("X-Forwarded-Prefix", "/foo/");
UriComponents result = ServletUriComponentsBuilder.fromRequest(request).build();
assertEquals("http://localhost/foo/bar", result.toUriString());
}
@Test @Test
public void fromContextPath() { public void fromContextPath() {
request.setRequestURI("/mvc-showcase/data/param"); request.setRequestURI("/mvc-showcase/data/param");