diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/DefaultServerHttpRequestBuilder.java b/spring-web/src/main/java/org/springframework/http/server/reactive/DefaultServerHttpRequestBuilder.java index fa39fb1e410..75e247d418a 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/DefaultServerHttpRequestBuilder.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/DefaultServerHttpRequestBuilder.java @@ -16,14 +16,24 @@ package org.springframework.http.server.reactive; +import java.net.InetSocketAddress; import java.net.URI; import java.net.URISyntaxException; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; +import reactor.core.publisher.Flux; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.http.HttpCookie; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; -import org.springframework.http.server.RequestPath; import org.springframework.lang.Nullable; import org.springframework.util.Assert; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; /** * Package-private default implementation of {@link ServerHttpRequest.Builder}. @@ -34,36 +44,66 @@ import org.springframework.util.Assert; */ class DefaultServerHttpRequestBuilder implements ServerHttpRequest.Builder { - private final ServerHttpRequest delegate; + private URI uri; + + private HttpHeaders httpHeaders; + + private String httpMethodValue; + + private final MultiValueMap cookies; @Nullable - private HttpMethod httpMethod; + private final InetSocketAddress remoteAddress; @Nullable - private String path; + private String uriPath; @Nullable private String contextPath; - @Nullable - private HttpHeaders httpHeaders; + private Flux body; + public DefaultServerHttpRequestBuilder(ServerHttpRequest original) { + Assert.notNull(original, "ServerHttpRequest is required"); - public DefaultServerHttpRequestBuilder(ServerHttpRequest delegate) { - Assert.notNull(delegate, "ServerHttpRequest delegate is required"); - this.delegate = delegate; + this.uri = original.getURI(); + this.httpMethodValue = original.getMethodValue(); + this.remoteAddress = original.getRemoteAddress(); + this.body = original.getBody(); + + this.httpHeaders = new HttpHeaders(); + copyMultiValueMap(original.getHeaders(), this.httpHeaders); + + this.cookies = new LinkedMultiValueMap<>(original.getCookies().size()); + copyMultiValueMap(original.getCookies(), this.cookies); + } + + private static void copyMultiValueMap(MultiValueMap source, + MultiValueMap destination) { + + for (Map.Entry> entry : source.entrySet()) { + K key = entry.getKey(); + List values = new LinkedList<>(entry.getValue()); + destination.put(key, values); + } } @Override public ServerHttpRequest.Builder method(HttpMethod httpMethod) { - this.httpMethod = httpMethod; + this.httpMethodValue = httpMethod.name(); + return this; + } + + @Override + public ServerHttpRequest.Builder uri(URI uri) { + this.uri = uri; return this; } @Override public ServerHttpRequest.Builder path(String path) { - this.path = path; + this.uriPath = path; return this; } @@ -75,111 +115,79 @@ class DefaultServerHttpRequestBuilder implements ServerHttpRequest.Builder { @Override public ServerHttpRequest.Builder header(String key, String value) { - if (this.httpHeaders == null) { - this.httpHeaders = new HttpHeaders(); - } this.httpHeaders.add(key, value); return this; } + @Override + public ServerHttpRequest.Builder headers(Consumer headersConsumer) { + Assert.notNull(headersConsumer, "'headersConsumer' must not be null"); + headersConsumer.accept(this.httpHeaders); + return this; + } + @Override public ServerHttpRequest build() { URI uriToUse = getUriToUse(); - RequestPath path = getRequestPathToUse(uriToUse); - HttpHeaders headers = getHeadersToUse(); - return new MutativeDecorator(this.delegate, this.httpMethod, uriToUse, path, headers); + return new DefaultServerHttpRequest(uriToUse, this.contextPath, this.httpHeaders, + this.httpMethodValue, this.cookies, this.remoteAddress, this.body); + } - @Nullable private URI getUriToUse() { - if (this.path == null) { - return null; + if (this.uriPath == null) { + return this.uri; } - URI uri = this.delegate.getURI(); try { - return new URI(uri.getScheme(), uri.getUserInfo(), uri.getHost(), uri.getPort(), - this.path, uri.getQuery(), uri.getFragment()); + return new URI(this.uri.getScheme(), this.uri.getUserInfo(), uri.getHost(), uri.getPort(), + uriPath, uri.getQuery(), uri.getFragment()); } catch (URISyntaxException ex) { - throw new IllegalStateException("Invalid URI path: \"" + this.path + "\""); + throw new IllegalStateException("Invalid URI path: \"" + this.uriPath + "\""); } } - @Nullable - private RequestPath getRequestPathToUse(@Nullable URI uriToUse) { - if (uriToUse == null && this.contextPath == null) { - return null; - } - else if (uriToUse == null) { - return this.delegate.getPath().modifyContextPath(this.contextPath); - } - else { - return RequestPath.parse(uriToUse, this.contextPath); - } - } + private static class DefaultServerHttpRequest extends AbstractServerHttpRequest { - @Nullable - private HttpHeaders getHeadersToUse() { - if (this.httpHeaders != null) { - HttpHeaders headers = new HttpHeaders(); - headers.putAll(this.delegate.getHeaders()); - headers.putAll(this.httpHeaders); - return headers; - } - else { - return null; - } - } + private final String methodValue; - - /** - * An immutable wrapper of a request returning property overrides -- given - * to the constructor -- or original values otherwise. - */ - private static class MutativeDecorator extends ServerHttpRequestDecorator { + private final MultiValueMap cookies; @Nullable - private final HttpMethod httpMethod; + private final InetSocketAddress remoteAddress; - @Nullable - private final URI uri; + private final Flux body; - @Nullable - private final RequestPath requestPath; - - @Nullable - private final HttpHeaders httpHeaders; - - - public MutativeDecorator(ServerHttpRequest delegate, @Nullable HttpMethod method, - @Nullable URI uri, @Nullable RequestPath requestPath, @Nullable HttpHeaders httpHeaders) { - - super(delegate); - this.httpMethod = method; - this.uri = uri; - this.requestPath = requestPath; - this.httpHeaders = httpHeaders; + public DefaultServerHttpRequest(URI uri, @Nullable String contextPath, + HttpHeaders headers, String methodValue, + MultiValueMap cookies, @Nullable InetSocketAddress remoteAddress, + Flux body) { + super(uri, contextPath, headers); + this.methodValue = methodValue; + this.cookies = cookies; + this.remoteAddress = remoteAddress; + this.body = body; } @Override + public String getMethodValue() { + return this.methodValue; + } + + @Override + protected MultiValueMap initCookies() { + return this.cookies; + } + @Nullable - public HttpMethod getMethod() { - return (this.httpMethod != null ? this.httpMethod : super.getMethod()); + @Override + public InetSocketAddress getRemoteAddress() { + return this.remoteAddress; } @Override - public URI getURI() { - return (this.uri != null ? this.uri : super.getURI()); - } - - @Override - public RequestPath getPath() { - return (this.requestPath != null ? this.requestPath : super.getPath()); - } - - @Override - public HttpHeaders getHeaders() { - return (this.httpHeaders != null ? this.httpHeaders : super.getHeaders()); + public Flux getBody() { + return this.body; } } diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ServerHttpRequest.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ServerHttpRequest.java index f10703dfb07..7a5ec8d228e 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/ServerHttpRequest.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ServerHttpRequest.java @@ -17,8 +17,11 @@ package org.springframework.http.server.reactive; import java.net.InetSocketAddress; +import java.net.URI; +import java.util.function.Consumer; import org.springframework.http.HttpCookie; +import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.HttpRequest; import org.springframework.http.ReactiveHttpInputMessage; @@ -79,6 +82,11 @@ public interface ServerHttpRequest extends HttpRequest, ReactiveHttpInputMessage */ Builder method(HttpMethod httpMethod); + /** + * Set the URI to return. + */ + Builder uri(URI uri); + /** * Set the path to use instead of the {@code "rawPath"} of * {@link ServerHttpRequest#getURI()}. @@ -95,6 +103,17 @@ public interface ServerHttpRequest extends HttpRequest, ReactiveHttpInputMessage */ Builder header(String key, String value); + /** + * Manipulate this request's headers with the given consumer. The + * headers provided to the consumer are "live", so that the consumer can be used to + * {@linkplain HttpHeaders#set(String, String) overwrite} existing header values, + * {@linkplain HttpHeaders#remove(Object) remove} values, or use any of the other + * {@link HttpHeaders} methods. + * @param headersConsumer a function that consumes the {@code HttpHeaders} + * @return this builder + */ + Builder headers(Consumer headersConsumer); + /** * Build a {@link ServerHttpRequest} decorator with the mutated properties. */ diff --git a/spring-web/src/main/java/org/springframework/web/filter/reactive/ForwardedHeaderFilter.java b/spring-web/src/main/java/org/springframework/web/filter/reactive/ForwardedHeaderFilter.java new file mode 100644 index 00000000000..b7680934f26 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/web/filter/reactive/ForwardedHeaderFilter.java @@ -0,0 +1,128 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.filter.reactive; + +import java.net.URI; +import java.util.Collections; +import java.util.Locale; +import java.util.Set; +import javax.servlet.http.HttpServletRequest; + +import reactor.core.publisher.Mono; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.lang.Nullable; +import org.springframework.util.LinkedCaseInsensitiveMap; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebFilter; +import org.springframework.web.server.WebFilterChain; +import org.springframework.web.util.UriComponentsBuilder; + +/** + * Extract values from "Forwarded" and "X-Forwarded-*" headers in order to wrap + * and override the following from the request and response: + * {@link HttpServletRequest#getServerName() getServerName()}, + * {@link HttpServletRequest#getServerPort() getServerPort()}, + * {@link HttpServletRequest#getScheme() getScheme()}, + * {@link HttpServletRequest#isSecure() isSecure()}, and + * {@link HttpServletResponse#sendRedirect(String) sendRedirect(String)}. + * In effect the wrapped request and response reflect the client-originated + * protocol and address. + * + *

Note: This filter can also be used in a + * {@link #setRemoveOnly removeOnly} mode where "Forwarded" and "X-Forwarded-*" + * headers are only eliminated without being used. + * @author Arjen Poutsma + * @see https://tools.ietf.org/html/rfc7239 + * @since 5.0 + */ +public class ForwardedHeaderFilter implements WebFilter { + + private static final Set FORWARDED_HEADER_NAMES = + Collections.newSetFromMap(new LinkedCaseInsensitiveMap<>(5, Locale.ENGLISH)); + + static { + FORWARDED_HEADER_NAMES.add("Forwarded"); + FORWARDED_HEADER_NAMES.add("X-Forwarded-Host"); + FORWARDED_HEADER_NAMES.add("X-Forwarded-Port"); + FORWARDED_HEADER_NAMES.add("X-Forwarded-Proto"); + FORWARDED_HEADER_NAMES.add("X-Forwarded-Prefix"); + } + + private boolean removeOnly; + + /** + * Enables mode in which any "Forwarded" or "X-Forwarded-*" headers are + * removed only and the information in them ignored. + * @param removeOnly whether to discard and ignore forwarded headers + */ + public void setRemoveOnly(boolean removeOnly) { + this.removeOnly = removeOnly; + } + + @Override + public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { + + if (shouldNotFilter(exchange.getRequest())) { + return chain.filter(exchange); + } + + if (this.removeOnly) { + ServerWebExchange withoutForwardHeaders = exchange.mutate() + .request(builder -> builder.headers( + headers -> { + for (String headerName : FORWARDED_HEADER_NAMES) { + headers.remove(headerName); + } + })).build(); + return chain.filter(withoutForwardHeaders); + } + else { + URI uri = UriComponentsBuilder.fromHttpRequest(exchange.getRequest()).build().toUri(); + String prefix = getForwardedPrefix(exchange.getRequest().getHeaders()); + + ServerWebExchange withChangedUri = exchange.mutate() + .request(builder -> { + builder.uri(uri); + if (prefix != null) { + builder.path(prefix + uri.getPath()); + builder.contextPath(prefix); + } + }).build(); + return chain.filter(withChangedUri); + } + + } + + private boolean shouldNotFilter(ServerHttpRequest request) { + return request.getHeaders().keySet().stream() + .noneMatch(FORWARDED_HEADER_NAMES::contains); + } + + @Nullable + private static String getForwardedPrefix(HttpHeaders headers) { + String prefix = headers.getFirst("X-Forwarded-Prefix"); + if (prefix != null) { + while (prefix.endsWith("/")) { + prefix = prefix.substring(0, prefix.length() - 1); + } + } + return prefix; + } + +} diff --git a/spring-web/src/main/java/org/springframework/web/filter/reactive/HiddenHttpMethodFilter.java b/spring-web/src/main/java/org/springframework/web/filter/reactive/HiddenHttpMethodFilter.java index 176869e0c8e..bd0021b3101 100644 --- a/spring-web/src/main/java/org/springframework/web/filter/reactive/HiddenHttpMethodFilter.java +++ b/spring-web/src/main/java/org/springframework/web/filter/reactive/HiddenHttpMethodFilter.java @@ -81,7 +81,7 @@ public class HiddenHttpMethodFilter implements WebFilter { String method = formData.getFirst(this.methodParamName); return StringUtils.hasLength(method) ? mapExchange(exchange, method) : exchange; }) - .flatMap((exchange1) -> chain.filter(exchange1)); + .flatMap(chain::filter); } private ServerWebExchange mapExchange(ServerWebExchange exchange, String methodParamValue) { diff --git a/spring-web/src/test/java/org/springframework/web/filter/reactive/ForwardedHeaderFilterTests.java b/spring-web/src/test/java/org/springframework/web/filter/reactive/ForwardedHeaderFilterTests.java new file mode 100644 index 00000000000..04b0f3a8fb6 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/web/filter/reactive/ForwardedHeaderFilterTests.java @@ -0,0 +1,146 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.filter.reactive; + +import java.net.URI; +import java.time.Duration; + +import org.junit.Test; +import reactor.core.publisher.Mono; + +import org.springframework.http.HttpHeaders; +import org.springframework.lang.Nullable; +import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; +import org.springframework.mock.http.server.reactive.test.MockServerWebExchange; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebFilterChain; + +import static org.junit.Assert.*; + +/** + * @author Arjen Poutsma + */ +public class ForwardedHeaderFilterTests { + + private final ForwardedHeaderFilter filter = new ForwardedHeaderFilter(); + + private final TestWebFilterChain filterChain = new TestWebFilterChain(); + + + @Test + public void removeOnly() { + MockServerWebExchange exchange = MockServerHttpRequest.get("/") + .header("Forwarded", "for=192.0.2.60;proto=http;by=203.0.113.43") + .header("X-Forwarded-Host", "example.com") + .header("X-Forwarded-Port", "8080") + .header("X-Forwarded-Proto", "http") + .header("X-Forwarded-Prefix", "prefix") + .toExchange(); + + this.filter.setRemoveOnly(true); + this.filter.filter(exchange, this.filterChain).block(Duration.ZERO); + + HttpHeaders result = this.filterChain.getHeaders(); + assertNotNull(result); + assertFalse(result.containsKey("Forwarded")); + assertFalse(result.containsKey("X-Forwarded-Host")); + assertFalse(result.containsKey("X-Forwarded-Port")); + assertFalse(result.containsKey("X-Forwarded-Proto")); + assertFalse(result.containsKey("X-Forwarded-Prefix")); + } + + @Test + public void xForwardedRequest() throws Exception { + MockServerWebExchange exchange = MockServerHttpRequest.get("http://example.com/path") + .header("X-Forwarded-Host", "84.198.58.199") + .header("X-Forwarded-Port", "443") + .header("X-Forwarded-Proto", "https") + .toExchange(); + + this.filter.filter(exchange, this.filterChain).block(Duration.ZERO); + + URI uri = this.filterChain.uri; + assertEquals(new URI("https://84.198.58.199/path"), uri); + } + + @Test + public void forwardedRequest() throws Exception { + MockServerWebExchange exchange = MockServerHttpRequest.get("http://example.com/path") + .header("Forwarded", "host=84.198.58.199;proto=https") + + .toExchange(); + + this.filter.filter(exchange, this.filterChain).block(Duration.ZERO); + + URI uri = this.filterChain.uri; + assertEquals(new URI("https://84.198.58.199/path"), uri); + } + + @Test + public void requestUriWithForwardedPrefix() throws Exception { + MockServerWebExchange exchange = MockServerHttpRequest.get("http://example.com/path") + .header("X-Forwarded-Prefix", "/prefix") + .toExchange(); + + this.filter.filter(exchange, this.filterChain).block(Duration.ZERO); + + URI uri = this.filterChain.uri; + assertEquals(new URI("http://example.com/prefix/path"), uri); + } + + @Test + public void requestUriWithForwardedPrefixTrailingSlash() throws Exception { + MockServerWebExchange exchange = MockServerHttpRequest.get("http://example.com/path") + .header("X-Forwarded-Prefix", "/prefix/") + .toExchange(); + + this.filter.filter(exchange, this.filterChain).block(Duration.ZERO); + + URI uri = this.filterChain.uri; + assertEquals(new URI("http://example.com/prefix/path"), uri); + } + + + private static class TestWebFilterChain implements WebFilterChain { + + @Nullable + private HttpHeaders httpHeaders; + + @Nullable + private URI uri; + + @Nullable + public HttpHeaders getHeaders() { + return this.httpHeaders; + } + + @Nullable + public URI getUri() { + return this.uri; + } + + @Override + public Mono filter(ServerWebExchange exchange) { + this.httpHeaders = exchange.getRequest().getHeaders(); + this.uri = exchange.getRequest().getURI(); + return Mono.empty(); + } + } + + + +} \ No newline at end of file