Add support for WebFilter and WebExceptionHandler

This commit adds support for configuring `WebFilter` and
`WebExceptionHandler` instances in HandlerStrategies. It also drops the
"native" support for `ResponseStatusException`s, in favor of the
`ResponseStatusExceptionHandler`, which is registered by default.

Issue: SPR-15518
This commit is contained in:
Arjen Poutsma 2017-05-10 14:35:32 +02:00
parent ad9cf99420
commit f4cf55cb2b
7 changed files with 138 additions and 54 deletions

View File

@ -32,6 +32,9 @@ import org.springframework.http.codec.HttpMessageWriter;
import org.springframework.http.codec.ServerCodecConfigurer;
import org.springframework.util.Assert;
import org.springframework.web.reactive.result.view.ViewResolver;
import org.springframework.web.server.WebExceptionHandler;
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.handler.ResponseStatusExceptionHandler;
/**
* Default implementation of {@link HandlerStrategies.Builder}.
@ -53,6 +56,9 @@ class DefaultHandlerStrategiesBuilder implements HandlerStrategies.Builder {
private Function<ServerRequest, Optional<Locale>> localeResolver;
private final List<WebFilter> webFilters = new ArrayList<>();
private final List<WebExceptionHandler> exceptionHandlers = new ArrayList<>();
public DefaultHandlerStrategiesBuilder() {
@ -62,6 +68,7 @@ class DefaultHandlerStrategiesBuilder implements HandlerStrategies.Builder {
public void defaultConfiguration() {
this.codecConfigurer.registerDefaults(true);
localeResolver(DEFAULT_LOCALE_RESOLVER);
exceptionHandler(new ResponseStatusExceptionHandler());
}
@Override
@ -94,10 +101,25 @@ class DefaultHandlerStrategiesBuilder implements HandlerStrategies.Builder {
return this;
}
@Override
public HandlerStrategies.Builder webFilter(WebFilter filter) {
Assert.notNull(filter, "'filter' must not be null");
this.webFilters.add(filter);
return this;
}
@Override
public HandlerStrategies.Builder exceptionHandler(WebExceptionHandler exceptionHandler) {
Assert.notNull(exceptionHandler, "'exceptionHandler' must not be null");
this.exceptionHandlers.add(exceptionHandler);
return this;
}
@Override
public HandlerStrategies build() {
return new DefaultHandlerStrategies(this.codecConfigurer.getReaders(),
this.codecConfigurer.getWriters(), this.viewResolvers, this.localeResolver);
this.codecConfigurer.getWriters(), this.viewResolvers, this.localeResolver,
this.webFilters, this.exceptionHandlers);
}
@ -111,16 +133,24 @@ class DefaultHandlerStrategiesBuilder implements HandlerStrategies.Builder {
private final Function<ServerRequest, Optional<Locale>> localeResolver;
private final List<WebFilter> webFilters;
private final List<WebExceptionHandler> exceptionHandlers;
public DefaultHandlerStrategies(
List<HttpMessageReader<?>> messageReaders,
List<HttpMessageWriter<?>> messageWriters,
List<ViewResolver> viewResolvers,
Function<ServerRequest, Optional<Locale>> localeResolver) {
Function<ServerRequest, Optional<Locale>> localeResolver,
List<WebFilter> webFilters,
List<WebExceptionHandler> exceptionHandlers) {
this.messageReaders = unmodifiableCopy(messageReaders);
this.messageWriters = unmodifiableCopy(messageWriters);
this.viewResolvers = unmodifiableCopy(viewResolvers);
this.localeResolver = localeResolver;
this.webFilters = unmodifiableCopy(webFilters);
this.exceptionHandlers = unmodifiableCopy(exceptionHandlers);
}
private static <T> List<T> unmodifiableCopy(List<? extends T> list) {
@ -146,6 +176,16 @@ class DefaultHandlerStrategiesBuilder implements HandlerStrategies.Builder {
public Supplier<Function<ServerRequest, Optional<Locale>>> localeResolver() {
return () -> this.localeResolver;
}
@Override
public Supplier<Stream<WebFilter>> webFilters() {
return this.webFilters::stream;
}
@Override
public Supplier<Stream<WebExceptionHandler>> exceptionHandlers() {
return this.exceptionHandlers::stream;
}
}
}

View File

@ -65,12 +65,13 @@ class DefaultServerRequest implements ServerRequest {
private final Headers headers;
private final HandlerStrategies strategies;
private final Supplier<Stream<HttpMessageReader<?>>> messageReaders;
DefaultServerRequest(ServerWebExchange exchange, HandlerStrategies strategies) {
DefaultServerRequest(ServerWebExchange exchange,
Supplier<Stream<HttpMessageReader<?>>> messageReaders) {
this.exchange = exchange;
this.strategies = strategies;
this.messageReaders = messageReaders;
this.headers = new DefaultHeaders();
}
@ -102,7 +103,7 @@ class DefaultServerRequest implements ServerRequest {
new BodyExtractor.Context() {
@Override
public Supplier<Stream<HttpMessageReader<?>>> messageReaders() {
return DefaultServerRequest.this.strategies.messageReaders();
return DefaultServerRequest.this.messageReaders;
}
@Override

View File

@ -27,6 +27,8 @@ import org.springframework.http.codec.HttpMessageReader;
import org.springframework.http.codec.HttpMessageWriter;
import org.springframework.http.codec.ServerCodecConfigurer;
import org.springframework.web.reactive.result.view.ViewResolver;
import org.springframework.web.server.WebExceptionHandler;
import org.springframework.web.server.WebFilter;
/**
* Defines the strategies to be used for processing {@link HandlerFunction}s. An instance of
@ -71,6 +73,20 @@ public interface HandlerStrategies {
*/
Supplier<Function<ServerRequest, Optional<Locale>>> localeResolver();
/**
* Supply a {@linkplain Stream stream} of {@link WebFilter}s to be used for filtering the
* request and response.
* @return the stream of web filters
*/
Supplier<Stream<WebFilter>> webFilters();
/**
* Supply a {@linkplain Stream stream} of {@link WebExceptionHandler}s to be used for handling
* exceptions.
* @return the stream of exception handlers
*/
Supplier<Stream<WebExceptionHandler>> exceptionHandlers();
// Static methods
@ -138,6 +154,20 @@ public interface HandlerStrategies {
*/
Builder localeResolver(Function<ServerRequest, Optional<Locale>> localeResolver);
/**
* Add the given web filter to this builder.
* @param filter the filter to add
* @return this builder
*/
Builder webFilter(WebFilter filter);
/**
* Add the given exception handler to this builder.
* @param exceptionHandler the exception handler to add
* @return this builder
*/
Builder exceptionHandler(WebExceptionHandler exceptionHandler);
/**
* Builds the {@link HandlerStrategies}.
* @return the built strategies

View File

@ -30,10 +30,8 @@ import org.springframework.util.Assert;
import org.springframework.web.reactive.HandlerMapping;
import org.springframework.web.reactive.function.server.support.HandlerFunctionAdapter;
import org.springframework.web.reactive.function.server.support.ServerResponseResultHandler;
import org.springframework.web.server.ResponseStatusException;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebHandler;
import org.springframework.web.server.adapter.HttpWebHandlerAdapter;
import org.springframework.web.server.adapter.WebHttpHandlerBuilder;
/**
@ -197,7 +195,7 @@ public abstract class RouterFunctions {
* @param routerFunction the router function to convert
* @return an http handler that handles HTTP request using the given router function
*/
public static HttpWebHandlerAdapter toHttpHandler(RouterFunction<?> routerFunction) {
public static HttpHandler toHttpHandler(RouterFunction<?> routerFunction) {
return toHttpHandler(routerFunction, HandlerStrategies.withDefaults());
}
@ -213,32 +211,27 @@ public abstract class RouterFunctions {
* <li>Undertow using the
* {@link org.springframework.http.server.reactive.UndertowHttpHandlerAdapter}.</li>
* </ul>
* <p>Note that {@code HttpWebHandlerAdapter} also implements {@link WebHandler}, allowing
* for additional filter and exception handler registration through
* @param routerFunction the router function to convert
* @param strategies the strategies to use
* @return an http handler that handles HTTP request using the given router function
*/
public static HttpWebHandlerAdapter toHttpHandler(RouterFunction<?> routerFunction, HandlerStrategies strategies) {
public static HttpHandler toHttpHandler(RouterFunction<?> routerFunction, HandlerStrategies strategies) {
Assert.notNull(routerFunction, "RouterFunction must not be null");
Assert.notNull(strategies, "HandlerStrategies must not be null");
return new HttpWebHandlerAdapter(exchange -> {
ServerRequest request = new DefaultServerRequest(exchange, strategies);
WebHandler webHandler = exchange -> {
ServerRequest request = new DefaultServerRequest(exchange, strategies.messageReaders());
addAttributes(exchange, request);
return routerFunction.route(request)
.defaultIfEmpty(notFound())
.flatMap(handlerFunction -> wrapException(() -> handlerFunction.handle(request)))
.flatMap(response -> wrapException(() -> response.writeTo(exchange, strategies)))
.onErrorResume(ResponseStatusException.class,
ex -> {
exchange.getResponse().setStatusCode(ex.getStatus());
if (ex.getMessage() != null) {
logger.error(ex.getMessage());
}
return Mono.empty();
});
});
.flatMap(response -> wrapException(() -> response.writeTo(exchange, strategies)));
};
WebHttpHandlerBuilder handlerBuilder = WebHttpHandlerBuilder.webHandler(webHandler);
strategies.webFilters().get().forEach(handlerBuilder::filter);
strategies.exceptionHandlers().get().forEach(handlerBuilder::exceptionHandler);
return handlerBuilder.build();
}
private static <T> Mono<T> wrapException(Supplier<Mono<T>> supplier) {
@ -280,7 +273,7 @@ public abstract class RouterFunctions {
Assert.notNull(strategies, "HandlerStrategies must not be null");
return exchange -> {
ServerRequest request = new DefaultServerRequest(exchange, strategies);
ServerRequest request = new DefaultServerRequest(exchange, strategies.messageReaders());
addAttributes(exchange, request);
return routerFunction.route(request).map(handlerFunction -> (Object)handlerFunction);
};

View File

@ -26,7 +26,8 @@ import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalLong;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Stream;
import org.junit.Before;
import org.junit.Test;
@ -52,9 +53,8 @@ import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.UnsupportedMediaTypeStatusException;
import org.springframework.web.server.WebSession;
import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.junit.Assert.*;
import static org.mockito.Mockito.*;
import static org.springframework.web.reactive.function.BodyExtractors.toMono;
/**
@ -64,11 +64,9 @@ public class DefaultServerRequestTests {
private ServerHttpRequest mockRequest;
private ServerHttpResponse mockResponse;
private ServerWebExchange mockExchange;
private HandlerStrategies mockHandlerStrategies;
Supplier<Stream<HttpMessageReader<?>>> messageReaders;
private DefaultServerRequest defaultRequest;
@ -76,14 +74,15 @@ public class DefaultServerRequestTests {
@Before
public void createMocks() {
mockRequest = mock(ServerHttpRequest.class);
mockResponse = mock(ServerHttpResponse.class);
ServerHttpResponse mockResponse = mock(ServerHttpResponse.class);
mockExchange = mock(ServerWebExchange.class);
when(mockExchange.getRequest()).thenReturn(mockRequest);
when(mockExchange.getResponse()).thenReturn(mockResponse);
mockHandlerStrategies = mock(HandlerStrategies.class);
defaultRequest = new DefaultServerRequest(mockExchange, mockHandlerStrategies);
this.messageReaders = Collections.<HttpMessageReader<?>>singleton(new DecoderHttpMessageReader<>(StringDecoder.allMimeTypes(true)))::stream;
defaultRequest = new DefaultServerRequest(mockExchange, messageReaders);
}
@ -190,10 +189,6 @@ public class DefaultServerRequestTests {
when(mockRequest.getHeaders()).thenReturn(httpHeaders);
when(mockRequest.getBody()).thenReturn(body);
Set<HttpMessageReader<?>> messageReaders = Collections
.singleton(new DecoderHttpMessageReader<>(StringDecoder.allMimeTypes(true)));
when(mockHandlerStrategies.messageReaders()).thenReturn(messageReaders::stream);
Mono<String> resultMono = defaultRequest.body(toMono(String.class));
assertEquals("foo", resultMono.block());
}
@ -210,10 +205,6 @@ public class DefaultServerRequestTests {
when(mockRequest.getHeaders()).thenReturn(httpHeaders);
when(mockRequest.getBody()).thenReturn(body);
Set<HttpMessageReader<?>> messageReaders = Collections
.singleton(new DecoderHttpMessageReader<>(StringDecoder.allMimeTypes(true)));
when(mockHandlerStrategies.messageReaders()).thenReturn(messageReaders::stream);
Mono<String> resultMono = defaultRequest.bodyToMono(String.class);
assertEquals("foo", resultMono.block());
}
@ -230,10 +221,6 @@ public class DefaultServerRequestTests {
when(mockRequest.getHeaders()).thenReturn(httpHeaders);
when(mockRequest.getBody()).thenReturn(body);
Set<HttpMessageReader<?>> messageReaders = Collections
.singleton(new DecoderHttpMessageReader<>(StringDecoder.allMimeTypes(true)));
when(mockHandlerStrategies.messageReaders()).thenReturn(messageReaders::stream);
Flux<String> resultFlux = defaultRequest.bodyToFlux(String.class);
Mono<List<String>> result = resultFlux.collectList();
assertEquals(Collections.singletonList("foo"), result.block());
@ -251,8 +238,8 @@ public class DefaultServerRequestTests {
when(mockRequest.getHeaders()).thenReturn(httpHeaders);
when(mockRequest.getBody()).thenReturn(body);
Set<HttpMessageReader<?>> messageReaders = Collections.emptySet();
when(mockHandlerStrategies.messageReaders()).thenReturn(messageReaders::stream);
this.messageReaders = Collections.<HttpMessageReader<?>>emptySet()::stream;
this.defaultRequest = new DefaultServerRequest(mockExchange, messageReaders);
Flux<String> resultFlux = defaultRequest.bodyToFlux(String.class);
StepVerifier.create(resultFlux)

View File

@ -33,8 +33,7 @@ import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest;
import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse;
import org.springframework.mock.http.server.reactive.test.MockServerWebExchange;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.*;
/**
* @author Arjen Poutsma
@ -51,7 +50,7 @@ public class ResourceHandlerFunctionTests {
MockServerWebExchange exchange = MockServerHttpRequest.get("http://localhost").toExchange();
MockServerHttpResponse mockResponse = exchange.getResponse();
ServerRequest request = new DefaultServerRequest(exchange, HandlerStrategies.withDefaults());
ServerRequest request = new DefaultServerRequest(exchange, HandlerStrategies.withDefaults().messageReaders());
Mono<ServerResponse> responseMono = this.handlerFunction.handle(request);
@ -86,7 +85,7 @@ public class ResourceHandlerFunctionTests {
MockServerWebExchange exchange = MockServerHttpRequest.head("http://localhost").toExchange();
MockServerHttpResponse mockResponse = exchange.getResponse();
ServerRequest request = new DefaultServerRequest(exchange, HandlerStrategies.withDefaults());
ServerRequest request = new DefaultServerRequest(exchange, HandlerStrategies.withDefaults().messageReaders());
Mono<ServerResponse> responseMono = this.handlerFunction.handle(request);
@ -110,7 +109,7 @@ public class ResourceHandlerFunctionTests {
MockServerWebExchange exchange = MockServerHttpRequest.options("http://localhost").toExchange();
MockServerHttpResponse mockResponse = exchange.getResponse();
ServerRequest request = new DefaultServerRequest(exchange, HandlerStrategies.withDefaults());
ServerRequest request = new DefaultServerRequest(exchange, HandlerStrategies.withDefaults().messageReaders());
Mono<ServerResponse> responseMono = this.handlerFunction.handle(request);
Mono<Void> result = responseMono.flatMap(response -> {

View File

@ -17,6 +17,7 @@
package org.springframework.web.reactive.function.server;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;
import org.junit.Test;
import reactor.core.publisher.Mono;
@ -29,6 +30,8 @@ import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest;
import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse;
import org.springframework.web.server.ResponseStatusException;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain;
import static org.junit.Assert.*;
import static org.mockito.Mockito.*;
@ -245,4 +248,35 @@ public class RouterFunctionsTests {
assertEquals(HttpStatus.NOT_FOUND, httpResponse.getStatusCode());
}
@Test
public void toHttpHandlerWebFilter() throws Exception {
AtomicBoolean filterInvoked = new AtomicBoolean();
WebFilter webFilter = new WebFilter() {
@Override
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
filterInvoked.set(true);
return chain.filter(exchange);
}
};
HandlerFunction<ServerResponse> handlerFunction = request -> ServerResponse.accepted().build();
RouterFunction<ServerResponse> routerFunction =
RouterFunctions.route(RequestPredicates.all(), handlerFunction);
HandlerStrategies handlerStrategies = HandlerStrategies.builder()
.webFilter(webFilter).build();
HttpHandler result = RouterFunctions.toHttpHandler(routerFunction, handlerStrategies);
assertNotNull(result);
MockServerHttpRequest httpRequest = MockServerHttpRequest.get("http://localhost").build();
MockServerHttpResponse httpResponse = new MockServerHttpResponse();
result.handle(httpRequest, httpResponse).block();
assertEquals(HttpStatus.ACCEPTED, httpResponse.getStatusCode());
assertTrue(filterInvoked.get());
}
}