diff --git a/spring-test/src/main/java/org/springframework/test/web/reactive/server/AbstractMockServerSpec.java b/spring-test/src/main/java/org/springframework/test/web/reactive/server/AbstractMockServerSpec.java index 59c03049e6d..e2084f7a21d 100644 --- a/spring-test/src/main/java/org/springframework/test/web/reactive/server/AbstractMockServerSpec.java +++ b/spring-test/src/main/java/org/springframework/test/web/reactive/server/AbstractMockServerSpec.java @@ -16,10 +16,14 @@ package org.springframework.test.web.reactive.server; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; import java.util.function.UnaryOperator; -import org.springframework.http.server.reactive.HttpHandler; import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebFilter; import org.springframework.web.server.adapter.WebHttpHandlerBuilder; /** @@ -33,6 +37,13 @@ abstract class AbstractMockServerSpec> private final ExchangeMutatorWebFilter exchangeMutatorFilter = new ExchangeMutatorWebFilter(); + private final List filters = new ArrayList<>(4); + + + AbstractMockServerSpec() { + this.filters.add(this.exchangeMutatorFilter); + } + @Override public T exchangeMutator(UnaryOperator mutator) { @@ -40,6 +51,12 @@ abstract class AbstractMockServerSpec> return self(); } + @Override + public T webFilter(WebFilter... filter) { + this.filters.addAll(Arrays.asList(filter)); + return self(); + } + @SuppressWarnings("unchecked") private T self() { return (T) this; @@ -48,12 +65,25 @@ abstract class AbstractMockServerSpec> @Override public WebTestClient.Builder configureClient() { - HttpHandler handler = initHttpHandlerBuilder().prependFilter(this.exchangeMutatorFilter).build(); - return new DefaultWebTestClientBuilder(handler, this.exchangeMutatorFilter); + WebHttpHandlerBuilder builder = initHttpHandlerBuilder(); + filtersInReverse().forEach(builder::prependFilter); + return new DefaultWebTestClientBuilder(builder.build(), this.exchangeMutatorFilter); } + /** + * Sub-classes to create the {@code WebHttpHandlerBuilder} to use. + */ protected abstract WebHttpHandlerBuilder initHttpHandlerBuilder(); + /** + * Return the filters in reverse order for pre-pending. + */ + private List filtersInReverse() { + List result = new ArrayList<>(this.filters); + Collections.reverse(result); + return result; + } + @Override public WebTestClient build() { return configureClient().build(); diff --git a/spring-test/src/main/java/org/springframework/test/web/reactive/server/WebTestClient.java b/spring-test/src/main/java/org/springframework/test/web/reactive/server/WebTestClient.java index a4af65b1d63..2fedf08077f 100644 --- a/spring-test/src/main/java/org/springframework/test/web/reactive/server/WebTestClient.java +++ b/spring-test/src/main/java/org/springframework/test/web/reactive/server/WebTestClient.java @@ -50,6 +50,7 @@ import org.springframework.web.reactive.function.client.ExchangeStrategies; import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.reactive.function.server.RouterFunction; import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebFilter; import org.springframework.web.util.UriBuilder; import org.springframework.web.util.UriBuilderFactory; @@ -193,6 +194,12 @@ public interface WebTestClient { */ T exchangeMutator(UnaryOperator mutator); + /** + * Configure {@link WebFilter}'s for server request processing. + * @param filter one or more filters + */ + T webFilter(WebFilter... filter); + /** * Proceed to configure and build the test client. */ diff --git a/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/bind/ApplicationContextTests.java b/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/bind/ApplicationContextTests.java index bad8a5e3c00..fca4e4406c6 100644 --- a/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/bind/ApplicationContextTests.java +++ b/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/bind/ApplicationContextTests.java @@ -32,9 +32,7 @@ import org.springframework.web.bind.annotation.RequestAttribute; import org.springframework.web.bind.annotation.RestController; import org.springframework.web.reactive.config.EnableWebFlux; import org.springframework.web.server.ServerWebExchange; - -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; +import org.springframework.web.server.WebFilter; /** * Binding to server infrastructure declared in a Spring ApplicationContext. @@ -56,24 +54,16 @@ public class ApplicationContextTests { this.client = WebTestClient.bindToApplicationContext(context) .exchangeMutator(principal("Pablo")) + .webFilter(prefixFilter("Mr.")) .build(); } - private UnaryOperator principal(String userName) { - return exchange -> { - Principal user = mock(Principal.class); - when(user.getName()).thenReturn(userName); - return exchange.mutate().principal(Mono.just(user)).build(); - }; - } - - @Test public void basic() throws Exception { this.client.get().uri("/principal") .exchange() .expectStatus().isOk() - .expectBody(String.class).value().isEqualTo("Hello Pablo!"); + .expectBody(String.class).value().isEqualTo("Hello Mr. Pablo!"); } @Test @@ -82,7 +72,7 @@ public class ApplicationContextTests { .get().uri("/principal") .exchange() .expectStatus().isOk() - .expectBody(String.class).value().isEqualTo("Hello Giovanni!"); + .expectBody(String.class).value().isEqualTo("Hello Mr. Giovanni!"); } @Test @@ -96,6 +86,18 @@ public class ApplicationContextTests { .expectBody(String.class).value().isEqualTo("foo+bar"); } + + private UnaryOperator principal(String userName) { + return exchange -> exchange.mutate().principal(Mono.just(new TestUser(userName))).build(); + } + + private WebFilter prefixFilter(String prefix) { + return (exchange, chain) -> { + Mono user = exchange.getPrincipal().map(p -> new TestUser(prefix + " " + p.getName())); + return chain.filter(exchange.mutate().principal(user).build()); + }; + } + private UnaryOperator attribute(String attrName, String attrValue) { return exchange -> { exchange.getAttributes().put(attrName, attrValue); @@ -129,4 +131,18 @@ public class ApplicationContextTests { } } + private static class TestUser implements Principal { + + private final String name; + + TestUser(String name) { + this.name = name; + } + + @Override + public String getName() { + return this.name; + } + } + } diff --git a/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/bind/ControllerTests.java b/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/bind/ControllerTests.java index 8b4b63cd108..6e6ab8dea87 100644 --- a/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/bind/ControllerTests.java +++ b/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/bind/ControllerTests.java @@ -27,9 +27,7 @@ import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RequestAttribute; import org.springframework.web.bind.annotation.RestController; import org.springframework.web.server.ServerWebExchange; - -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; +import org.springframework.web.server.WebFilter; /** * Bind to annotated controllers. @@ -39,27 +37,18 @@ import static org.mockito.Mockito.when; */ public class ControllerTests { - private final WebTestClient client = WebTestClient - .bindToController(new TestController()) + private final WebTestClient client = WebTestClient.bindToController(new TestController()) .exchangeMutator(principal("Pablo")) + .webFilter(prefixFilter("Mr.")) .build(); - private UnaryOperator principal(String userName) { - return exchange -> { - Principal user = mock(Principal.class); - when(user.getName()).thenReturn(userName); - return exchange.mutate().principal(Mono.just(user)).build(); - }; - } - - @Test public void basic() throws Exception { this.client.get().uri("/principal") .exchange() .expectStatus().isOk() - .expectBody(String.class).value().isEqualTo("Hello Pablo!"); + .expectBody(String.class).value().isEqualTo("Hello Mr. Pablo!"); } @Test @@ -68,7 +57,7 @@ public class ControllerTests { .get().uri("/principal") .exchange() .expectStatus().isOk() - .expectBody(String.class).value().isEqualTo("Hello Giovanni!"); + .expectBody(String.class).value().isEqualTo("Hello Mr. Giovanni!"); } @Test @@ -82,6 +71,18 @@ public class ControllerTests { .expectBody(String.class).value().isEqualTo("foo+bar"); } + + private UnaryOperator principal(String userName) { + return exchange -> exchange.mutate().principal(Mono.just(new TestUser(userName))).build(); + } + + private WebFilter prefixFilter(String prefix) { + return (exchange, chain) -> { + Mono user = exchange.getPrincipal().map(p -> new TestUser(prefix + " " + p.getName())); + return chain.filter(exchange.mutate().principal(user).build()); + }; + } + private UnaryOperator attribute(String attrName, String attrValue) { return exchange -> { exchange.getAttributes().put(attrName, attrValue); @@ -104,4 +105,18 @@ public class ControllerTests { } } + private static class TestUser implements Principal { + + private final String name; + + TestUser(String name) { + this.name = name; + } + + @Override + public String getName() { + return this.name; + } + } + }