diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpRequest.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpRequest.java index d73e44b2df..51fb5b7239 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpRequest.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpRequest.java @@ -20,6 +20,7 @@ import java.net.InetSocketAddress; import java.net.URI; import java.net.URISyntaxException; +import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.cookie.Cookie; import reactor.core.publisher.Flux; import reactor.ipc.netty.http.server.HttpServerRequest; @@ -55,16 +56,39 @@ public class ReactorServerHttpRequest extends AbstractServerHttpRequest { this.bufferFactory = bufferFactory; } - private static URI initUri(HttpServerRequest channel) throws URISyntaxException { - Assert.notNull(channel, "'channel' must not be null"); - InetSocketAddress address = channel.remoteAddress(); - String requestUri = channel.uri(); - return (address != null ? createUrl(address, requestUri) : new URI(requestUri)); + private static URI initUri(HttpServerRequest request) throws URISyntaxException { + Assert.notNull(request, "'request' must not be null"); + URI baseUri = resolveBaseUrl(request); + String requestUri = request.uri(); + return (baseUri != null ? new URI(baseUri.toString() + requestUri) : new URI(requestUri)); } - private static URI createUrl(InetSocketAddress address, String requestUri) throws URISyntaxException { - URI baseUrl = new URI(null, null, address.getHostString(), address.getPort(), null, null, null); - return new URI(baseUrl.toString() + requestUri); + private static URI resolveBaseUrl(HttpServerRequest request) throws URISyntaxException { + String header = request.requestHeaders().get(HttpHeaderNames.HOST); + if (header != null) { + final int portIndex; + if (header.startsWith("[")) { + portIndex = header.indexOf(':', header.indexOf(']')); + } else { + portIndex = header.indexOf(':'); + } + if (portIndex != -1) { + try { + return new URI(null, null, header.substring(0, portIndex), + Integer.parseInt(header.substring(portIndex + 1)), null, null, null); + } catch (NumberFormatException ignore) { + throw new URISyntaxException(header, "Unable to parse port", portIndex); + } + } + else { + return new URI(null, header, null, null); + } + } + else { + InetSocketAddress localAddress = (InetSocketAddress) request.context().channel().localAddress(); + return new URI(null, null, localAddress.getHostString(), + localAddress.getPort(), null, null, null); + } } private static HttpHeaders initHeaders(HttpServerRequest channel) {