diff --git a/spring-web/src/main/java/org/springframework/web/util/UriComponentsBuilder.java b/spring-web/src/main/java/org/springframework/web/util/UriComponentsBuilder.java index 09eef5bf178..60eb76c3a35 100644 --- a/spring-web/src/main/java/org/springframework/web/util/UriComponentsBuilder.java +++ b/spring-web/src/main/java/org/springframework/web/util/UriComponentsBuilder.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,6 +24,7 @@ import java.util.Map; import java.util.regex.Matcher; import java.util.regex.Pattern; +import org.springframework.http.HttpRequest; import org.springframework.util.Assert; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; @@ -263,6 +264,55 @@ public class UriComponentsBuilder implements Cloneable { } } + /** + * Create a new {@code UriComponents} object from the URI associated with + * the given HttpRequest while also overlaying with values from the headers + * "X-Forwarded-Host", "X-Forwarded-Port", and "X-Forwarded-Proto" if present. + * + * @param request the source request + * @return the URI components of the UR + */ + public static UriComponentsBuilder fromHttpRequest(HttpRequest request) { + URI uri = request.getURI(); + UriComponentsBuilder builder = UriComponentsBuilder.fromUri(uri); + + String scheme = uri.getScheme(); + String host = uri.getHost(); + int port = uri.getPort(); + + String hostHeader = request.getHeaders().getFirst("X-Forwarded-Host"); + if (StringUtils.hasText(hostHeader)) { + String[] hosts = StringUtils.commaDelimitedListToStringArray(hostHeader); + String hostToUse = hosts[0]; + if (hostToUse.contains(":")) { + String[] hostAndPort = StringUtils.split(hostToUse, ":"); + host = hostAndPort[0]; + port = Integer.parseInt(hostAndPort[1]); + } + else { + host = hostToUse; + port = -1; + } + } + + String portHeader = request.getHeaders().getFirst("X-Forwarded-Port"); + if (StringUtils.hasText(portHeader)) { + port = Integer.parseInt(portHeader); + } + + String protocolHeader = request.getHeaders().getFirst("X-Forwarded-Proto"); + if (StringUtils.hasText(protocolHeader)) { + scheme = protocolHeader; + } + + builder.scheme(scheme); + builder.host(host); + if (scheme.equals("http") && port != 80 || scheme.equals("https") && port != 443) { + builder.port(port); + } + return builder; + } + // build methods diff --git a/spring-web/src/test/java/org/springframework/web/util/UriComponentsBuilderTests.java b/spring-web/src/test/java/org/springframework/web/util/UriComponentsBuilderTests.java index 85a65434359..6689f465dc4 100644 --- a/spring-web/src/test/java/org/springframework/web/util/UriComponentsBuilderTests.java +++ b/spring-web/src/test/java/org/springframework/web/util/UriComponentsBuilderTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,11 +31,16 @@ import java.util.HashMap; import java.util.Map; import org.junit.Test; + +import org.springframework.http.server.ServletServerHttpRequest; +import org.springframework.mock.web.test.MockHttpServletRequest; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; /** + * Unit tests for {@link org.springframework.web.util.UriComponentsBuilder}. + * * @author Arjen Poutsma * @author Phillip Webb * @author Oliver Gierke @@ -232,6 +237,25 @@ public class UriComponentsBuilderTests { assertEquals("bar@baz", result.getQueryParams().getFirst("foo")); } + // Also see X-Forwarded-* related tests in ServletUriComponentsBuilderTests + + @Test + public void fromHttpRequest() throws URISyntaxException { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setScheme("http"); + request.setServerName("localhost"); + request.setServerPort(-1); + request.setRequestURI("/path"); + request.setQueryString("a=1"); + + UriComponents result = UriComponentsBuilder.fromHttpRequest(new ServletServerHttpRequest(request)).build(); + assertEquals("http", result.getScheme()); + assertEquals("localhost", result.getHost()); + assertEquals(-1, result.getPort()); + assertEquals("/path", result.getPath()); + assertEquals("a=1", result.getQuery()); + } + @Test public void path() throws URISyntaxException { UriComponentsBuilder builder = UriComponentsBuilder.fromPath("/foo/bar"); diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/support/ServletUriComponentsBuilder.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/support/ServletUriComponentsBuilder.java index 1ddca18ec97..e49eff6e367 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/support/ServletUriComponentsBuilder.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/support/ServletUriComponentsBuilder.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,11 +18,14 @@ package org.springframework.web.servlet.support; import javax.servlet.http.HttpServletRequest; +import org.springframework.http.HttpRequest; +import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.util.Assert; import org.springframework.util.StringUtils; import org.springframework.web.context.request.RequestAttributes; import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.ServletRequestAttributes; +import org.springframework.web.util.UriComponents; import org.springframework.web.util.UriComponentsBuilder; import org.springframework.web.util.UrlPathHelper; import org.springframework.web.util.WebUtils; @@ -112,34 +115,11 @@ public class ServletUriComponentsBuilder extends UriComponentsBuilder { * Initialize a builder with a scheme, host,and port (but not path and query). */ private static ServletUriComponentsBuilder initFromRequest(HttpServletRequest request) { - String scheme = request.getScheme(); - String host = request.getServerName(); - int port = request.getServerPort(); - - String hostHeader = request.getHeader("X-Forwarded-Host"); - if (StringUtils.hasText(hostHeader)) { - String[] hosts = StringUtils.commaDelimitedListToStringArray(hostHeader); - String hostToUse = hosts[0]; - if (hostToUse.contains(":")) { - String[] hostAndPort = StringUtils.split(hostToUse, ":"); - host = hostAndPort[0]; - port = Integer.parseInt(hostAndPort[1]); - } - else { - host = hostToUse; - port = -1; - } - } - - String portHeader = request.getHeader("X-Forwarded-Port"); - if (StringUtils.hasText(portHeader)) { - port = Integer.parseInt(portHeader); - } - - String protocolHeader = request.getHeader("X-Forwarded-Proto"); - if (StringUtils.hasText(protocolHeader)) { - scheme = protocolHeader; - } + HttpRequest httpRequest = new ServletServerHttpRequest(request); + UriComponents uriComponents = UriComponentsBuilder.fromHttpRequest(httpRequest).build(); + String scheme = uriComponents.getScheme(); + String host = uriComponents.getHost(); + int port = uriComponents.getPort(); ServletUriComponentsBuilder builder = new ServletUriComponentsBuilder(); builder.scheme(scheme); diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/support/ServletUriComponentsBuilderTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/support/ServletUriComponentsBuilderTests.java index a3879dc22f3..c9a87739472 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/support/ServletUriComponentsBuilderTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/support/ServletUriComponentsBuilderTests.java @@ -41,7 +41,7 @@ public class ServletUriComponentsBuilderTests { this.request = new MockHttpServletRequest(); this.request.setScheme("http"); this.request.setServerName("localhost"); - this.request.setServerPort(80); + this.request.setServerPort(-1); this.request.setContextPath("/mvc-showcase"); }