Introduce ForwardedHeaderFilter for WebFlux

This commit introduces a ForwardedHeaderFilter for WebFlux, similar to
the existing Servlet version. As part of this the
DefaultServerHttpRequestBuilder had to be changed to no longer use
delegation, but instead use a deep copy at the point of mutate().
Otherwise, headers could not be removed.

Issue: SPR-15954
This commit is contained in:
Arjen Poutsma 2017-09-14 16:26:35 +02:00
parent 69af698ceb
commit e70210a1da
5 changed files with 388 additions and 87 deletions

View File

@ -16,14 +16,24 @@
package org.springframework.http.server.reactive;
import java.net.InetSocketAddress;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import reactor.core.publisher.Flux;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.http.HttpCookie;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.server.RequestPath;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
/**
* Package-private default implementation of {@link ServerHttpRequest.Builder}.
@ -34,36 +44,66 @@ import org.springframework.util.Assert;
*/
class DefaultServerHttpRequestBuilder implements ServerHttpRequest.Builder {
private final ServerHttpRequest delegate;
private URI uri;
private HttpHeaders httpHeaders;
private String httpMethodValue;
private final MultiValueMap<String, HttpCookie> cookies;
@Nullable
private HttpMethod httpMethod;
private final InetSocketAddress remoteAddress;
@Nullable
private String path;
private String uriPath;
@Nullable
private String contextPath;
@Nullable
private HttpHeaders httpHeaders;
private Flux<DataBuffer> body;
public DefaultServerHttpRequestBuilder(ServerHttpRequest original) {
Assert.notNull(original, "ServerHttpRequest is required");
public DefaultServerHttpRequestBuilder(ServerHttpRequest delegate) {
Assert.notNull(delegate, "ServerHttpRequest delegate is required");
this.delegate = delegate;
this.uri = original.getURI();
this.httpMethodValue = original.getMethodValue();
this.remoteAddress = original.getRemoteAddress();
this.body = original.getBody();
this.httpHeaders = new HttpHeaders();
copyMultiValueMap(original.getHeaders(), this.httpHeaders);
this.cookies = new LinkedMultiValueMap<>(original.getCookies().size());
copyMultiValueMap(original.getCookies(), this.cookies);
}
private static <K, V> void copyMultiValueMap(MultiValueMap<K,V> source,
MultiValueMap<K,V> destination) {
for (Map.Entry<K, List<V>> entry : source.entrySet()) {
K key = entry.getKey();
List<V> values = new LinkedList<>(entry.getValue());
destination.put(key, values);
}
}
@Override
public ServerHttpRequest.Builder method(HttpMethod httpMethod) {
this.httpMethod = httpMethod;
this.httpMethodValue = httpMethod.name();
return this;
}
@Override
public ServerHttpRequest.Builder uri(URI uri) {
this.uri = uri;
return this;
}
@Override
public ServerHttpRequest.Builder path(String path) {
this.path = path;
this.uriPath = path;
return this;
}
@ -75,111 +115,79 @@ class DefaultServerHttpRequestBuilder implements ServerHttpRequest.Builder {
@Override
public ServerHttpRequest.Builder header(String key, String value) {
if (this.httpHeaders == null) {
this.httpHeaders = new HttpHeaders();
}
this.httpHeaders.add(key, value);
return this;
}
@Override
public ServerHttpRequest.Builder headers(Consumer<HttpHeaders> headersConsumer) {
Assert.notNull(headersConsumer, "'headersConsumer' must not be null");
headersConsumer.accept(this.httpHeaders);
return this;
}
@Override
public ServerHttpRequest build() {
URI uriToUse = getUriToUse();
RequestPath path = getRequestPathToUse(uriToUse);
HttpHeaders headers = getHeadersToUse();
return new MutativeDecorator(this.delegate, this.httpMethod, uriToUse, path, headers);
return new DefaultServerHttpRequest(uriToUse, this.contextPath, this.httpHeaders,
this.httpMethodValue, this.cookies, this.remoteAddress, this.body);
}
@Nullable
private URI getUriToUse() {
if (this.path == null) {
return null;
if (this.uriPath == null) {
return this.uri;
}
URI uri = this.delegate.getURI();
try {
return new URI(uri.getScheme(), uri.getUserInfo(), uri.getHost(), uri.getPort(),
this.path, uri.getQuery(), uri.getFragment());
return new URI(this.uri.getScheme(), this.uri.getUserInfo(), uri.getHost(), uri.getPort(),
uriPath, uri.getQuery(), uri.getFragment());
}
catch (URISyntaxException ex) {
throw new IllegalStateException("Invalid URI path: \"" + this.path + "\"");
throw new IllegalStateException("Invalid URI path: \"" + this.uriPath + "\"");
}
}
@Nullable
private RequestPath getRequestPathToUse(@Nullable URI uriToUse) {
if (uriToUse == null && this.contextPath == null) {
return null;
}
else if (uriToUse == null) {
return this.delegate.getPath().modifyContextPath(this.contextPath);
}
else {
return RequestPath.parse(uriToUse, this.contextPath);
}
}
private static class DefaultServerHttpRequest extends AbstractServerHttpRequest {
@Nullable
private HttpHeaders getHeadersToUse() {
if (this.httpHeaders != null) {
HttpHeaders headers = new HttpHeaders();
headers.putAll(this.delegate.getHeaders());
headers.putAll(this.httpHeaders);
return headers;
}
else {
return null;
}
}
private final String methodValue;
/**
* An immutable wrapper of a request returning property overrides -- given
* to the constructor -- or original values otherwise.
*/
private static class MutativeDecorator extends ServerHttpRequestDecorator {
private final MultiValueMap<String, HttpCookie> cookies;
@Nullable
private final HttpMethod httpMethod;
private final InetSocketAddress remoteAddress;
@Nullable
private final URI uri;
private final Flux<DataBuffer> body;
@Nullable
private final RequestPath requestPath;
@Nullable
private final HttpHeaders httpHeaders;
public MutativeDecorator(ServerHttpRequest delegate, @Nullable HttpMethod method,
@Nullable URI uri, @Nullable RequestPath requestPath, @Nullable HttpHeaders httpHeaders) {
super(delegate);
this.httpMethod = method;
this.uri = uri;
this.requestPath = requestPath;
this.httpHeaders = httpHeaders;
public DefaultServerHttpRequest(URI uri, @Nullable String contextPath,
HttpHeaders headers, String methodValue,
MultiValueMap<String, HttpCookie> cookies, @Nullable InetSocketAddress remoteAddress,
Flux<DataBuffer> body) {
super(uri, contextPath, headers);
this.methodValue = methodValue;
this.cookies = cookies;
this.remoteAddress = remoteAddress;
this.body = body;
}
@Override
public String getMethodValue() {
return this.methodValue;
}
@Override
protected MultiValueMap<String, HttpCookie> initCookies() {
return this.cookies;
}
@Nullable
public HttpMethod getMethod() {
return (this.httpMethod != null ? this.httpMethod : super.getMethod());
@Override
public InetSocketAddress getRemoteAddress() {
return this.remoteAddress;
}
@Override
public URI getURI() {
return (this.uri != null ? this.uri : super.getURI());
}
@Override
public RequestPath getPath() {
return (this.requestPath != null ? this.requestPath : super.getPath());
}
@Override
public HttpHeaders getHeaders() {
return (this.httpHeaders != null ? this.httpHeaders : super.getHeaders());
public Flux<DataBuffer> getBody() {
return this.body;
}
}

View File

@ -17,8 +17,11 @@
package org.springframework.http.server.reactive;
import java.net.InetSocketAddress;
import java.net.URI;
import java.util.function.Consumer;
import org.springframework.http.HttpCookie;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpRequest;
import org.springframework.http.ReactiveHttpInputMessage;
@ -79,6 +82,11 @@ public interface ServerHttpRequest extends HttpRequest, ReactiveHttpInputMessage
*/
Builder method(HttpMethod httpMethod);
/**
* Set the URI to return.
*/
Builder uri(URI uri);
/**
* Set the path to use instead of the {@code "rawPath"} of
* {@link ServerHttpRequest#getURI()}.
@ -95,6 +103,17 @@ public interface ServerHttpRequest extends HttpRequest, ReactiveHttpInputMessage
*/
Builder header(String key, String value);
/**
* Manipulate this request's headers with the given consumer. The
* headers provided to the consumer are "live", so that the consumer can be used to
* {@linkplain HttpHeaders#set(String, String) overwrite} existing header values,
* {@linkplain HttpHeaders#remove(Object) remove} values, or use any of the other
* {@link HttpHeaders} methods.
* @param headersConsumer a function that consumes the {@code HttpHeaders}
* @return this builder
*/
Builder headers(Consumer<HttpHeaders> headersConsumer);
/**
* Build a {@link ServerHttpRequest} decorator with the mutated properties.
*/

View File

@ -0,0 +1,128 @@
/*
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.filter.reactive;
import java.net.URI;
import java.util.Collections;
import java.util.Locale;
import java.util.Set;
import javax.servlet.http.HttpServletRequest;
import reactor.core.publisher.Mono;
import org.springframework.http.HttpHeaders;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.lang.Nullable;
import org.springframework.util.LinkedCaseInsensitiveMap;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain;
import org.springframework.web.util.UriComponentsBuilder;
/**
* Extract values from "Forwarded" and "X-Forwarded-*" headers in order to wrap
* and override the following from the request and response:
* {@link HttpServletRequest#getServerName() getServerName()},
* {@link HttpServletRequest#getServerPort() getServerPort()},
* {@link HttpServletRequest#getScheme() getScheme()},
* {@link HttpServletRequest#isSecure() isSecure()}, and
* {@link HttpServletResponse#sendRedirect(String) sendRedirect(String)}.
* In effect the wrapped request and response reflect the client-originated
* protocol and address.
*
* <p><strong>Note:</strong> This filter can also be used in a
* {@link #setRemoveOnly removeOnly} mode where "Forwarded" and "X-Forwarded-*"
* headers are only eliminated without being used.
* @author Arjen Poutsma
* @see <a href="https://tools.ietf.org/html/rfc7239">https://tools.ietf.org/html/rfc7239</a>
* @since 5.0
*/
public class ForwardedHeaderFilter implements WebFilter {
private static final Set<String> FORWARDED_HEADER_NAMES =
Collections.newSetFromMap(new LinkedCaseInsensitiveMap<>(5, Locale.ENGLISH));
static {
FORWARDED_HEADER_NAMES.add("Forwarded");
FORWARDED_HEADER_NAMES.add("X-Forwarded-Host");
FORWARDED_HEADER_NAMES.add("X-Forwarded-Port");
FORWARDED_HEADER_NAMES.add("X-Forwarded-Proto");
FORWARDED_HEADER_NAMES.add("X-Forwarded-Prefix");
}
private boolean removeOnly;
/**
* Enables mode in which any "Forwarded" or "X-Forwarded-*" headers are
* removed only and the information in them ignored.
* @param removeOnly whether to discard and ignore forwarded headers
*/
public void setRemoveOnly(boolean removeOnly) {
this.removeOnly = removeOnly;
}
@Override
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
if (shouldNotFilter(exchange.getRequest())) {
return chain.filter(exchange);
}
if (this.removeOnly) {
ServerWebExchange withoutForwardHeaders = exchange.mutate()
.request(builder -> builder.headers(
headers -> {
for (String headerName : FORWARDED_HEADER_NAMES) {
headers.remove(headerName);
}
})).build();
return chain.filter(withoutForwardHeaders);
}
else {
URI uri = UriComponentsBuilder.fromHttpRequest(exchange.getRequest()).build().toUri();
String prefix = getForwardedPrefix(exchange.getRequest().getHeaders());
ServerWebExchange withChangedUri = exchange.mutate()
.request(builder -> {
builder.uri(uri);
if (prefix != null) {
builder.path(prefix + uri.getPath());
builder.contextPath(prefix);
}
}).build();
return chain.filter(withChangedUri);
}
}
private boolean shouldNotFilter(ServerHttpRequest request) {
return request.getHeaders().keySet().stream()
.noneMatch(FORWARDED_HEADER_NAMES::contains);
}
@Nullable
private static String getForwardedPrefix(HttpHeaders headers) {
String prefix = headers.getFirst("X-Forwarded-Prefix");
if (prefix != null) {
while (prefix.endsWith("/")) {
prefix = prefix.substring(0, prefix.length() - 1);
}
}
return prefix;
}
}

View File

@ -81,7 +81,7 @@ public class HiddenHttpMethodFilter implements WebFilter {
String method = formData.getFirst(this.methodParamName);
return StringUtils.hasLength(method) ? mapExchange(exchange, method) : exchange;
})
.flatMap((exchange1) -> chain.filter(exchange1));
.flatMap(chain::filter);
}
private ServerWebExchange mapExchange(ServerWebExchange exchange, String methodParamValue) {

View File

@ -0,0 +1,146 @@
/*
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.filter.reactive;
import java.net.URI;
import java.time.Duration;
import org.junit.Test;
import reactor.core.publisher.Mono;
import org.springframework.http.HttpHeaders;
import org.springframework.lang.Nullable;
import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest;
import org.springframework.mock.http.server.reactive.test.MockServerWebExchange;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilterChain;
import static org.junit.Assert.*;
/**
* @author Arjen Poutsma
*/
public class ForwardedHeaderFilterTests {
private final ForwardedHeaderFilter filter = new ForwardedHeaderFilter();
private final TestWebFilterChain filterChain = new TestWebFilterChain();
@Test
public void removeOnly() {
MockServerWebExchange exchange = MockServerHttpRequest.get("/")
.header("Forwarded", "for=192.0.2.60;proto=http;by=203.0.113.43")
.header("X-Forwarded-Host", "example.com")
.header("X-Forwarded-Port", "8080")
.header("X-Forwarded-Proto", "http")
.header("X-Forwarded-Prefix", "prefix")
.toExchange();
this.filter.setRemoveOnly(true);
this.filter.filter(exchange, this.filterChain).block(Duration.ZERO);
HttpHeaders result = this.filterChain.getHeaders();
assertNotNull(result);
assertFalse(result.containsKey("Forwarded"));
assertFalse(result.containsKey("X-Forwarded-Host"));
assertFalse(result.containsKey("X-Forwarded-Port"));
assertFalse(result.containsKey("X-Forwarded-Proto"));
assertFalse(result.containsKey("X-Forwarded-Prefix"));
}
@Test
public void xForwardedRequest() throws Exception {
MockServerWebExchange exchange = MockServerHttpRequest.get("http://example.com/path")
.header("X-Forwarded-Host", "84.198.58.199")
.header("X-Forwarded-Port", "443")
.header("X-Forwarded-Proto", "https")
.toExchange();
this.filter.filter(exchange, this.filterChain).block(Duration.ZERO);
URI uri = this.filterChain.uri;
assertEquals(new URI("https://84.198.58.199/path"), uri);
}
@Test
public void forwardedRequest() throws Exception {
MockServerWebExchange exchange = MockServerHttpRequest.get("http://example.com/path")
.header("Forwarded", "host=84.198.58.199;proto=https")
.toExchange();
this.filter.filter(exchange, this.filterChain).block(Duration.ZERO);
URI uri = this.filterChain.uri;
assertEquals(new URI("https://84.198.58.199/path"), uri);
}
@Test
public void requestUriWithForwardedPrefix() throws Exception {
MockServerWebExchange exchange = MockServerHttpRequest.get("http://example.com/path")
.header("X-Forwarded-Prefix", "/prefix")
.toExchange();
this.filter.filter(exchange, this.filterChain).block(Duration.ZERO);
URI uri = this.filterChain.uri;
assertEquals(new URI("http://example.com/prefix/path"), uri);
}
@Test
public void requestUriWithForwardedPrefixTrailingSlash() throws Exception {
MockServerWebExchange exchange = MockServerHttpRequest.get("http://example.com/path")
.header("X-Forwarded-Prefix", "/prefix/")
.toExchange();
this.filter.filter(exchange, this.filterChain).block(Duration.ZERO);
URI uri = this.filterChain.uri;
assertEquals(new URI("http://example.com/prefix/path"), uri);
}
private static class TestWebFilterChain implements WebFilterChain {
@Nullable
private HttpHeaders httpHeaders;
@Nullable
private URI uri;
@Nullable
public HttpHeaders getHeaders() {
return this.httpHeaders;
}
@Nullable
public URI getUri() {
return this.uri;
}
@Override
public Mono<Void> filter(ServerWebExchange exchange) {
this.httpHeaders = exchange.getRequest().getHeaders();
this.uri = exchange.getRequest().getURI();
return Mono.empty();
}
}
}