Polish WebFlux ForwardedHeaderFilter and tests

Preparation for SPR-17072
This commit is contained in:
Rossen Stoyanchev 2018-07-24 14:37:55 -04:00
parent 02403f6a34
commit 41aa4218af
2 changed files with 89 additions and 66 deletions

View File

@ -44,7 +44,7 @@ import org.springframework.web.util.UriComponentsBuilder;
*/ */
public class ForwardedHeaderFilter implements WebFilter { public class ForwardedHeaderFilter implements WebFilter {
private static final Set<String> FORWARDED_HEADER_NAMES = new LinkedHashSet<>(5); static final Set<String> FORWARDED_HEADER_NAMES = new LinkedHashSet<>(5);
static { static {
FORWARDED_HEADER_NAMES.add("Forwarded"); FORWARDED_HEADER_NAMES.add("Forwarded");
@ -72,54 +72,58 @@ public class ForwardedHeaderFilter implements WebFilter {
@Override @Override
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) { public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
if (shouldNotFilter(exchange.getRequest())) { ServerHttpRequest request = exchange.getRequest();
if (!hasForwardedHeaders(request)) {
return chain.filter(exchange); return chain.filter(exchange);
} }
ServerWebExchange mutatedExchange; ServerWebExchange mutatedExchange;
if (this.removeOnly) { if (this.removeOnly) {
mutatedExchange = exchange.mutate().request(builder -> mutatedExchange = exchange.mutate().request(this::removeForwardedHeaders).build();
builder.headers(headers -> {
FORWARDED_HEADER_NAMES.forEach(headers::remove);
}))
.build();
} }
else { else {
URI uri = UriComponentsBuilder.fromHttpRequest(exchange.getRequest()).build().toUri(); mutatedExchange = exchange.mutate()
String prefix = getForwardedPrefix(exchange.getRequest().getHeaders()); .request(builder -> {
URI uri = UriComponentsBuilder.fromHttpRequest(request).build().toUri();
mutatedExchange = exchange.mutate().request(builder -> {
builder.uri(uri); builder.uri(uri);
String prefix = getForwardedPrefix(request);
if (prefix != null) { if (prefix != null) {
builder.path(prefix + uri.getPath()); builder.path(prefix + uri.getPath());
builder.contextPath(prefix); builder.contextPath(prefix);
} }
}).build(); })
.build();
} }
return chain.filter(mutatedExchange); return chain.filter(mutatedExchange);
} }
private boolean shouldNotFilter(ServerHttpRequest request) { private boolean hasForwardedHeaders(ServerHttpRequest request) {
HttpHeaders headers = request.getHeaders(); HttpHeaders headers = request.getHeaders();
for (String headerName : FORWARDED_HEADER_NAMES) { for (String headerName : FORWARDED_HEADER_NAMES) {
if (headers.containsKey(headerName)) { if (headers.containsKey(headerName)) {
return false;
}
}
return true; return true;
} }
}
return false;
}
@Nullable @Nullable
private static String getForwardedPrefix(HttpHeaders headers) { private static String getForwardedPrefix(ServerHttpRequest request) {
HttpHeaders headers = request.getHeaders();
String prefix = headers.getFirst("X-Forwarded-Prefix"); String prefix = headers.getFirst("X-Forwarded-Prefix");
if (prefix != null) { if (prefix != null) {
while (prefix.endsWith("/")) { int endIndex = prefix.length();
prefix = prefix.substring(0, prefix.length() - 1); while (endIndex > 1 && prefix.charAt(endIndex - 1) == '/') {
} endIndex--;
};
prefix = endIndex != prefix.length() ? prefix.substring(0, endIndex) : prefix;
} }
return prefix; return prefix;
} }
private ServerHttpRequest.Builder removeForwardedHeaders(ServerHttpRequest.Builder builder) {
return builder.headers(map -> FORWARDED_HEADER_NAMES.forEach(map::remove));
}
} }

View File

@ -23,16 +23,19 @@ import org.junit.Test;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest;
import org.springframework.mock.web.test.server.MockServerWebExchange; import org.springframework.mock.web.test.server.MockServerWebExchange;
import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilterChain; import org.springframework.web.server.WebFilterChain;
import static org.junit.Assert.*; import static org.junit.Assert.*;
import static org.springframework.mock.http.server.reactive.test.MockServerHttpRequest.*;
/** /**
* Unit tests for {@link ForwardedHeaderFilter}.
* @author Arjen Poutsma * @author Arjen Poutsma
* @author Rossen Stoyanchev
*/ */
public class ForwardedHeaderFilterTests { public class ForwardedHeaderFilterTests {
@ -46,65 +49,65 @@ public class ForwardedHeaderFilterTests {
@Test @Test
public void removeOnly() { public void removeOnly() {
ServerWebExchange exchange = MockServerWebExchange.from(get(BASE_URL)
.header("Forwarded", "for=192.0.2.60;proto=http;by=203.0.113.43")
.header("X-Forwarded-Host", "example.com")
.header("X-Forwarded-Port", "8080")
.header("X-Forwarded-Proto", "http")
.header("X-Forwarded-Prefix", "prefix")
.header("X-Forwarded-Ssl", "on"));
this.filter.setRemoveOnly(true); this.filter.setRemoveOnly(true);
this.filter.filter(exchange, this.filterChain).block(Duration.ZERO);
HttpHeaders result = this.filterChain.getHeaders(); HttpHeaders headers = new HttpHeaders();
assertNotNull(result); headers.add("Forwarded", "for=192.0.2.60;proto=http;by=203.0.113.43");
assertFalse(result.containsKey("Forwarded")); headers.add("X-Forwarded-Host", "example.com");
assertFalse(result.containsKey("X-Forwarded-Host")); headers.add("X-Forwarded-Port", "8080");
assertFalse(result.containsKey("X-Forwarded-Port")); headers.add("X-Forwarded-Proto", "http");
assertFalse(result.containsKey("X-Forwarded-Proto")); headers.add("X-Forwarded-Prefix", "prefix");
assertFalse(result.containsKey("X-Forwarded-Prefix")); headers.add("X-Forwarded-Ssl", "on");
assertFalse(result.containsKey("X-Forwarded-Ssl")); this.filter.filter(getExchange(headers), this.filterChain).block(Duration.ZERO);
this.filterChain.assertForwardedHeadersRemoved();
} }
@Test @Test
public void xForwardedRequest() throws Exception { public void xForwardedHeaders() throws Exception {
ServerWebExchange exchange = MockServerWebExchange.from(get(BASE_URL) HttpHeaders headers = new HttpHeaders();
.header("X-Forwarded-Host", "84.198.58.199") headers.add("X-Forwarded-Host", "84.198.58.199");
.header("X-Forwarded-Port", "443") headers.add("X-Forwarded-Port", "443");
.header("X-Forwarded-Proto", "https")); headers.add("X-Forwarded-Proto", "https");
headers.add("foo", "bar");
this.filter.filter(getExchange(headers), this.filterChain).block(Duration.ZERO);
assertEquals(new URI("https://84.198.58.199/path"), filterAndGetUri(exchange)); assertEquals(new URI("https://84.198.58.199/path"), this.filterChain.uri);
} }
@Test @Test
public void forwardedRequest() throws Exception { public void forwardedHeader() throws Exception {
ServerWebExchange exchange = MockServerWebExchange.from(get(BASE_URL) HttpHeaders headers = new HttpHeaders();
.header("Forwarded", "host=84.198.58.199;proto=https")); headers.add("Forwarded", "host=84.198.58.199;proto=https");
this.filter.filter(getExchange(headers), this.filterChain).block(Duration.ZERO);
assertEquals(new URI("https://84.198.58.199/path"), filterAndGetUri(exchange)); assertEquals(new URI("https://84.198.58.199/path"), this.filterChain.uri);
} }
@Test @Test
public void requestUriWithForwardedPrefix() throws Exception { public void xForwardedPrefix() throws Exception {
ServerWebExchange exchange = MockServerWebExchange.from(get(BASE_URL) HttpHeaders headers = new HttpHeaders();
.header("X-Forwarded-Prefix", "/prefix")); headers.add("X-Forwarded-Prefix", "/prefix");
this.filter.filter(getExchange(headers), this.filterChain).block(Duration.ZERO);
assertEquals(new URI("http://example.com/prefix/path"), filterAndGetUri(exchange)); assertEquals(new URI("http://example.com/prefix/path"), this.filterChain.uri);
assertEquals("/prefix/path", this.filterChain.requestPathValue);
} }
@Test @Test
public void requestUriWithForwardedPrefixTrailingSlash() throws Exception { public void xForwardedPrefixTrailingSlash() throws Exception {
ServerWebExchange exchange = MockServerWebExchange.from(get(BASE_URL) HttpHeaders headers = new HttpHeaders();
.header("X-Forwarded-Prefix", "/prefix/")); headers.add("X-Forwarded-Prefix", "/prefix////");
this.filter.filter(getExchange(headers), this.filterChain).block(Duration.ZERO);
assertEquals(new URI("http://example.com/prefix/path"), filterAndGetUri(exchange)); assertEquals(new URI("http://example.com/prefix/path"), this.filterChain.uri);
assertEquals("/prefix/path", this.filterChain.requestPathValue);
} }
@Nullable private MockServerWebExchange getExchange(HttpHeaders headers) {
private URI filterAndGetUri(ServerWebExchange exchange) { MockServerHttpRequest request = MockServerHttpRequest.get(BASE_URL).headers(headers).build();
this.filter.filter(exchange, this.filterChain).block(Duration.ZERO); return MockServerWebExchange.from(request);
return this.filterChain.uri;
} }
@ -116,12 +119,26 @@ public class ForwardedHeaderFilterTests {
@Nullable @Nullable
private URI uri; private URI uri;
@Nullable String requestPathValue;
@Nullable @Nullable
public HttpHeaders getHeaders() { public HttpHeaders getHeaders() {
return this.headers; return this.headers;
} }
@Nullable
public String getHeader(String name) {
assertNotNull(this.headers);
return this.headers.getFirst(name);
}
public void assertForwardedHeadersRemoved() {
assertNotNull(this.headers);
ForwardedHeaderFilter.FORWARDED_HEADER_NAMES
.forEach(name -> assertFalse(this.headers.containsKey(name)));
}
@Nullable @Nullable
public URI getUri() { public URI getUri() {
return this.uri; return this.uri;
@ -129,8 +146,10 @@ public class ForwardedHeaderFilterTests {
@Override @Override
public Mono<Void> filter(ServerWebExchange exchange) { public Mono<Void> filter(ServerWebExchange exchange) {
this.headers = exchange.getRequest().getHeaders(); ServerHttpRequest request = exchange.getRequest();
this.uri = exchange.getRequest().getURI(); this.headers = request.getHeaders();
this.uri = request.getURI();
this.requestPathValue = request.getPath().value();
return Mono.empty(); return Mono.empty();
} }
} }