Add support for WebFilter and WebExceptionHandler
This commit adds support for configuring `WebFilter` and `WebExceptionHandler` instances in HandlerStrategies. It also drops the "native" support for `ResponseStatusException`s, in favor of the `ResponseStatusExceptionHandler`, which is registered by default. Issue: SPR-15518
This commit is contained in:
parent
ad9cf99420
commit
f4cf55cb2b
|
|
@ -32,6 +32,9 @@ import org.springframework.http.codec.HttpMessageWriter;
|
|||
import org.springframework.http.codec.ServerCodecConfigurer;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.web.reactive.result.view.ViewResolver;
|
||||
import org.springframework.web.server.WebExceptionHandler;
|
||||
import org.springframework.web.server.WebFilter;
|
||||
import org.springframework.web.server.handler.ResponseStatusExceptionHandler;
|
||||
|
||||
/**
|
||||
* Default implementation of {@link HandlerStrategies.Builder}.
|
||||
|
|
@ -53,6 +56,9 @@ class DefaultHandlerStrategiesBuilder implements HandlerStrategies.Builder {
|
|||
|
||||
private Function<ServerRequest, Optional<Locale>> localeResolver;
|
||||
|
||||
private final List<WebFilter> webFilters = new ArrayList<>();
|
||||
|
||||
private final List<WebExceptionHandler> exceptionHandlers = new ArrayList<>();
|
||||
|
||||
|
||||
public DefaultHandlerStrategiesBuilder() {
|
||||
|
|
@ -62,6 +68,7 @@ class DefaultHandlerStrategiesBuilder implements HandlerStrategies.Builder {
|
|||
public void defaultConfiguration() {
|
||||
this.codecConfigurer.registerDefaults(true);
|
||||
localeResolver(DEFAULT_LOCALE_RESOLVER);
|
||||
exceptionHandler(new ResponseStatusExceptionHandler());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
@ -94,10 +101,25 @@ class DefaultHandlerStrategiesBuilder implements HandlerStrategies.Builder {
|
|||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public HandlerStrategies.Builder webFilter(WebFilter filter) {
|
||||
Assert.notNull(filter, "'filter' must not be null");
|
||||
this.webFilters.add(filter);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public HandlerStrategies.Builder exceptionHandler(WebExceptionHandler exceptionHandler) {
|
||||
Assert.notNull(exceptionHandler, "'exceptionHandler' must not be null");
|
||||
this.exceptionHandlers.add(exceptionHandler);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public HandlerStrategies build() {
|
||||
return new DefaultHandlerStrategies(this.codecConfigurer.getReaders(),
|
||||
this.codecConfigurer.getWriters(), this.viewResolvers, this.localeResolver);
|
||||
this.codecConfigurer.getWriters(), this.viewResolvers, this.localeResolver,
|
||||
this.webFilters, this.exceptionHandlers);
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -111,16 +133,24 @@ class DefaultHandlerStrategiesBuilder implements HandlerStrategies.Builder {
|
|||
|
||||
private final Function<ServerRequest, Optional<Locale>> localeResolver;
|
||||
|
||||
private final List<WebFilter> webFilters;
|
||||
|
||||
private final List<WebExceptionHandler> exceptionHandlers;
|
||||
|
||||
public DefaultHandlerStrategies(
|
||||
List<HttpMessageReader<?>> messageReaders,
|
||||
List<HttpMessageWriter<?>> messageWriters,
|
||||
List<ViewResolver> viewResolvers,
|
||||
Function<ServerRequest, Optional<Locale>> localeResolver) {
|
||||
Function<ServerRequest, Optional<Locale>> localeResolver,
|
||||
List<WebFilter> webFilters,
|
||||
List<WebExceptionHandler> exceptionHandlers) {
|
||||
|
||||
this.messageReaders = unmodifiableCopy(messageReaders);
|
||||
this.messageWriters = unmodifiableCopy(messageWriters);
|
||||
this.viewResolvers = unmodifiableCopy(viewResolvers);
|
||||
this.localeResolver = localeResolver;
|
||||
this.webFilters = unmodifiableCopy(webFilters);
|
||||
this.exceptionHandlers = unmodifiableCopy(exceptionHandlers);
|
||||
}
|
||||
|
||||
private static <T> List<T> unmodifiableCopy(List<? extends T> list) {
|
||||
|
|
@ -146,6 +176,16 @@ class DefaultHandlerStrategiesBuilder implements HandlerStrategies.Builder {
|
|||
public Supplier<Function<ServerRequest, Optional<Locale>>> localeResolver() {
|
||||
return () -> this.localeResolver;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Supplier<Stream<WebFilter>> webFilters() {
|
||||
return this.webFilters::stream;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Supplier<Stream<WebExceptionHandler>> exceptionHandlers() {
|
||||
return this.exceptionHandlers::stream;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -65,12 +65,13 @@ class DefaultServerRequest implements ServerRequest {
|
|||
|
||||
private final Headers headers;
|
||||
|
||||
private final HandlerStrategies strategies;
|
||||
private final Supplier<Stream<HttpMessageReader<?>>> messageReaders;
|
||||
|
||||
|
||||
DefaultServerRequest(ServerWebExchange exchange, HandlerStrategies strategies) {
|
||||
DefaultServerRequest(ServerWebExchange exchange,
|
||||
Supplier<Stream<HttpMessageReader<?>>> messageReaders) {
|
||||
this.exchange = exchange;
|
||||
this.strategies = strategies;
|
||||
this.messageReaders = messageReaders;
|
||||
this.headers = new DefaultHeaders();
|
||||
}
|
||||
|
||||
|
|
@ -102,7 +103,7 @@ class DefaultServerRequest implements ServerRequest {
|
|||
new BodyExtractor.Context() {
|
||||
@Override
|
||||
public Supplier<Stream<HttpMessageReader<?>>> messageReaders() {
|
||||
return DefaultServerRequest.this.strategies.messageReaders();
|
||||
return DefaultServerRequest.this.messageReaders;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
|||
|
|
@ -27,6 +27,8 @@ import org.springframework.http.codec.HttpMessageReader;
|
|||
import org.springframework.http.codec.HttpMessageWriter;
|
||||
import org.springframework.http.codec.ServerCodecConfigurer;
|
||||
import org.springframework.web.reactive.result.view.ViewResolver;
|
||||
import org.springframework.web.server.WebExceptionHandler;
|
||||
import org.springframework.web.server.WebFilter;
|
||||
|
||||
/**
|
||||
* Defines the strategies to be used for processing {@link HandlerFunction}s. An instance of
|
||||
|
|
@ -71,6 +73,20 @@ public interface HandlerStrategies {
|
|||
*/
|
||||
Supplier<Function<ServerRequest, Optional<Locale>>> localeResolver();
|
||||
|
||||
/**
|
||||
* Supply a {@linkplain Stream stream} of {@link WebFilter}s to be used for filtering the
|
||||
* request and response.
|
||||
* @return the stream of web filters
|
||||
*/
|
||||
Supplier<Stream<WebFilter>> webFilters();
|
||||
|
||||
/**
|
||||
* Supply a {@linkplain Stream stream} of {@link WebExceptionHandler}s to be used for handling
|
||||
* exceptions.
|
||||
* @return the stream of exception handlers
|
||||
*/
|
||||
Supplier<Stream<WebExceptionHandler>> exceptionHandlers();
|
||||
|
||||
|
||||
// Static methods
|
||||
|
||||
|
|
@ -138,6 +154,20 @@ public interface HandlerStrategies {
|
|||
*/
|
||||
Builder localeResolver(Function<ServerRequest, Optional<Locale>> localeResolver);
|
||||
|
||||
/**
|
||||
* Add the given web filter to this builder.
|
||||
* @param filter the filter to add
|
||||
* @return this builder
|
||||
*/
|
||||
Builder webFilter(WebFilter filter);
|
||||
|
||||
/**
|
||||
* Add the given exception handler to this builder.
|
||||
* @param exceptionHandler the exception handler to add
|
||||
* @return this builder
|
||||
*/
|
||||
Builder exceptionHandler(WebExceptionHandler exceptionHandler);
|
||||
|
||||
/**
|
||||
* Builds the {@link HandlerStrategies}.
|
||||
* @return the built strategies
|
||||
|
|
|
|||
|
|
@ -30,10 +30,8 @@ import org.springframework.util.Assert;
|
|||
import org.springframework.web.reactive.HandlerMapping;
|
||||
import org.springframework.web.reactive.function.server.support.HandlerFunctionAdapter;
|
||||
import org.springframework.web.reactive.function.server.support.ServerResponseResultHandler;
|
||||
import org.springframework.web.server.ResponseStatusException;
|
||||
import org.springframework.web.server.ServerWebExchange;
|
||||
import org.springframework.web.server.WebHandler;
|
||||
import org.springframework.web.server.adapter.HttpWebHandlerAdapter;
|
||||
import org.springframework.web.server.adapter.WebHttpHandlerBuilder;
|
||||
|
||||
/**
|
||||
|
|
@ -197,7 +195,7 @@ public abstract class RouterFunctions {
|
|||
* @param routerFunction the router function to convert
|
||||
* @return an http handler that handles HTTP request using the given router function
|
||||
*/
|
||||
public static HttpWebHandlerAdapter toHttpHandler(RouterFunction<?> routerFunction) {
|
||||
public static HttpHandler toHttpHandler(RouterFunction<?> routerFunction) {
|
||||
return toHttpHandler(routerFunction, HandlerStrategies.withDefaults());
|
||||
}
|
||||
|
||||
|
|
@ -213,32 +211,27 @@ public abstract class RouterFunctions {
|
|||
* <li>Undertow using the
|
||||
* {@link org.springframework.http.server.reactive.UndertowHttpHandlerAdapter}.</li>
|
||||
* </ul>
|
||||
* <p>Note that {@code HttpWebHandlerAdapter} also implements {@link WebHandler}, allowing
|
||||
* for additional filter and exception handler registration through
|
||||
* @param routerFunction the router function to convert
|
||||
* @param strategies the strategies to use
|
||||
* @return an http handler that handles HTTP request using the given router function
|
||||
*/
|
||||
public static HttpWebHandlerAdapter toHttpHandler(RouterFunction<?> routerFunction, HandlerStrategies strategies) {
|
||||
public static HttpHandler toHttpHandler(RouterFunction<?> routerFunction, HandlerStrategies strategies) {
|
||||
Assert.notNull(routerFunction, "RouterFunction must not be null");
|
||||
Assert.notNull(strategies, "HandlerStrategies must not be null");
|
||||
|
||||
return new HttpWebHandlerAdapter(exchange -> {
|
||||
ServerRequest request = new DefaultServerRequest(exchange, strategies);
|
||||
WebHandler webHandler = exchange -> {
|
||||
ServerRequest request = new DefaultServerRequest(exchange, strategies.messageReaders());
|
||||
addAttributes(exchange, request);
|
||||
return routerFunction.route(request)
|
||||
.defaultIfEmpty(notFound())
|
||||
.flatMap(handlerFunction -> wrapException(() -> handlerFunction.handle(request)))
|
||||
.flatMap(response -> wrapException(() -> response.writeTo(exchange, strategies)))
|
||||
.onErrorResume(ResponseStatusException.class,
|
||||
ex -> {
|
||||
exchange.getResponse().setStatusCode(ex.getStatus());
|
||||
if (ex.getMessage() != null) {
|
||||
logger.error(ex.getMessage());
|
||||
}
|
||||
return Mono.empty();
|
||||
});
|
||||
});
|
||||
.flatMap(response -> wrapException(() -> response.writeTo(exchange, strategies)));
|
||||
};
|
||||
|
||||
WebHttpHandlerBuilder handlerBuilder = WebHttpHandlerBuilder.webHandler(webHandler);
|
||||
strategies.webFilters().get().forEach(handlerBuilder::filter);
|
||||
strategies.exceptionHandlers().get().forEach(handlerBuilder::exceptionHandler);
|
||||
return handlerBuilder.build();
|
||||
}
|
||||
|
||||
private static <T> Mono<T> wrapException(Supplier<Mono<T>> supplier) {
|
||||
|
|
@ -280,7 +273,7 @@ public abstract class RouterFunctions {
|
|||
Assert.notNull(strategies, "HandlerStrategies must not be null");
|
||||
|
||||
return exchange -> {
|
||||
ServerRequest request = new DefaultServerRequest(exchange, strategies);
|
||||
ServerRequest request = new DefaultServerRequest(exchange, strategies.messageReaders());
|
||||
addAttributes(exchange, request);
|
||||
return routerFunction.route(request).map(handlerFunction -> (Object)handlerFunction);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -26,7 +26,8 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.OptionalLong;
|
||||
import java.util.Set;
|
||||
import java.util.function.Supplier;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
|
@ -52,9 +53,8 @@ import org.springframework.web.server.ServerWebExchange;
|
|||
import org.springframework.web.server.UnsupportedMediaTypeStatusException;
|
||||
import org.springframework.web.server.WebSession;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
import static org.junit.Assert.*;
|
||||
import static org.mockito.Mockito.*;
|
||||
import static org.springframework.web.reactive.function.BodyExtractors.toMono;
|
||||
|
||||
/**
|
||||
|
|
@ -64,11 +64,9 @@ public class DefaultServerRequestTests {
|
|||
|
||||
private ServerHttpRequest mockRequest;
|
||||
|
||||
private ServerHttpResponse mockResponse;
|
||||
|
||||
private ServerWebExchange mockExchange;
|
||||
|
||||
private HandlerStrategies mockHandlerStrategies;
|
||||
Supplier<Stream<HttpMessageReader<?>>> messageReaders;
|
||||
|
||||
private DefaultServerRequest defaultRequest;
|
||||
|
||||
|
|
@ -76,14 +74,15 @@ public class DefaultServerRequestTests {
|
|||
@Before
|
||||
public void createMocks() {
|
||||
mockRequest = mock(ServerHttpRequest.class);
|
||||
mockResponse = mock(ServerHttpResponse.class);
|
||||
ServerHttpResponse mockResponse = mock(ServerHttpResponse.class);
|
||||
|
||||
mockExchange = mock(ServerWebExchange.class);
|
||||
when(mockExchange.getRequest()).thenReturn(mockRequest);
|
||||
when(mockExchange.getResponse()).thenReturn(mockResponse);
|
||||
mockHandlerStrategies = mock(HandlerStrategies.class);
|
||||
|
||||
defaultRequest = new DefaultServerRequest(mockExchange, mockHandlerStrategies);
|
||||
this.messageReaders = Collections.<HttpMessageReader<?>>singleton(new DecoderHttpMessageReader<>(StringDecoder.allMimeTypes(true)))::stream;
|
||||
|
||||
defaultRequest = new DefaultServerRequest(mockExchange, messageReaders);
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -190,10 +189,6 @@ public class DefaultServerRequestTests {
|
|||
when(mockRequest.getHeaders()).thenReturn(httpHeaders);
|
||||
when(mockRequest.getBody()).thenReturn(body);
|
||||
|
||||
Set<HttpMessageReader<?>> messageReaders = Collections
|
||||
.singleton(new DecoderHttpMessageReader<>(StringDecoder.allMimeTypes(true)));
|
||||
when(mockHandlerStrategies.messageReaders()).thenReturn(messageReaders::stream);
|
||||
|
||||
Mono<String> resultMono = defaultRequest.body(toMono(String.class));
|
||||
assertEquals("foo", resultMono.block());
|
||||
}
|
||||
|
|
@ -210,10 +205,6 @@ public class DefaultServerRequestTests {
|
|||
when(mockRequest.getHeaders()).thenReturn(httpHeaders);
|
||||
when(mockRequest.getBody()).thenReturn(body);
|
||||
|
||||
Set<HttpMessageReader<?>> messageReaders = Collections
|
||||
.singleton(new DecoderHttpMessageReader<>(StringDecoder.allMimeTypes(true)));
|
||||
when(mockHandlerStrategies.messageReaders()).thenReturn(messageReaders::stream);
|
||||
|
||||
Mono<String> resultMono = defaultRequest.bodyToMono(String.class);
|
||||
assertEquals("foo", resultMono.block());
|
||||
}
|
||||
|
|
@ -230,10 +221,6 @@ public class DefaultServerRequestTests {
|
|||
when(mockRequest.getHeaders()).thenReturn(httpHeaders);
|
||||
when(mockRequest.getBody()).thenReturn(body);
|
||||
|
||||
Set<HttpMessageReader<?>> messageReaders = Collections
|
||||
.singleton(new DecoderHttpMessageReader<>(StringDecoder.allMimeTypes(true)));
|
||||
when(mockHandlerStrategies.messageReaders()).thenReturn(messageReaders::stream);
|
||||
|
||||
Flux<String> resultFlux = defaultRequest.bodyToFlux(String.class);
|
||||
Mono<List<String>> result = resultFlux.collectList();
|
||||
assertEquals(Collections.singletonList("foo"), result.block());
|
||||
|
|
@ -251,8 +238,8 @@ public class DefaultServerRequestTests {
|
|||
when(mockRequest.getHeaders()).thenReturn(httpHeaders);
|
||||
when(mockRequest.getBody()).thenReturn(body);
|
||||
|
||||
Set<HttpMessageReader<?>> messageReaders = Collections.emptySet();
|
||||
when(mockHandlerStrategies.messageReaders()).thenReturn(messageReaders::stream);
|
||||
this.messageReaders = Collections.<HttpMessageReader<?>>emptySet()::stream;
|
||||
this.defaultRequest = new DefaultServerRequest(mockExchange, messageReaders);
|
||||
|
||||
Flux<String> resultFlux = defaultRequest.bodyToFlux(String.class);
|
||||
StepVerifier.create(resultFlux)
|
||||
|
|
|
|||
|
|
@ -33,8 +33,7 @@ import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest;
|
|||
import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse;
|
||||
import org.springframework.mock.http.server.reactive.test.MockServerWebExchange;
|
||||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
/**
|
||||
* @author Arjen Poutsma
|
||||
|
|
@ -51,7 +50,7 @@ public class ResourceHandlerFunctionTests {
|
|||
MockServerWebExchange exchange = MockServerHttpRequest.get("http://localhost").toExchange();
|
||||
MockServerHttpResponse mockResponse = exchange.getResponse();
|
||||
|
||||
ServerRequest request = new DefaultServerRequest(exchange, HandlerStrategies.withDefaults());
|
||||
ServerRequest request = new DefaultServerRequest(exchange, HandlerStrategies.withDefaults().messageReaders());
|
||||
|
||||
Mono<ServerResponse> responseMono = this.handlerFunction.handle(request);
|
||||
|
||||
|
|
@ -86,7 +85,7 @@ public class ResourceHandlerFunctionTests {
|
|||
MockServerWebExchange exchange = MockServerHttpRequest.head("http://localhost").toExchange();
|
||||
MockServerHttpResponse mockResponse = exchange.getResponse();
|
||||
|
||||
ServerRequest request = new DefaultServerRequest(exchange, HandlerStrategies.withDefaults());
|
||||
ServerRequest request = new DefaultServerRequest(exchange, HandlerStrategies.withDefaults().messageReaders());
|
||||
|
||||
Mono<ServerResponse> responseMono = this.handlerFunction.handle(request);
|
||||
|
||||
|
|
@ -110,7 +109,7 @@ public class ResourceHandlerFunctionTests {
|
|||
MockServerWebExchange exchange = MockServerHttpRequest.options("http://localhost").toExchange();
|
||||
MockServerHttpResponse mockResponse = exchange.getResponse();
|
||||
|
||||
ServerRequest request = new DefaultServerRequest(exchange, HandlerStrategies.withDefaults());
|
||||
ServerRequest request = new DefaultServerRequest(exchange, HandlerStrategies.withDefaults().messageReaders());
|
||||
|
||||
Mono<ServerResponse> responseMono = this.handlerFunction.handle(request);
|
||||
Mono<Void> result = responseMono.flatMap(response -> {
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@
|
|||
package org.springframework.web.reactive.function.server;
|
||||
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
|
||||
import org.junit.Test;
|
||||
import reactor.core.publisher.Mono;
|
||||
|
|
@ -29,6 +30,8 @@ import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest;
|
|||
import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse;
|
||||
import org.springframework.web.server.ResponseStatusException;
|
||||
import org.springframework.web.server.ServerWebExchange;
|
||||
import org.springframework.web.server.WebFilter;
|
||||
import org.springframework.web.server.WebFilterChain;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
import static org.mockito.Mockito.*;
|
||||
|
|
@ -245,4 +248,35 @@ public class RouterFunctionsTests {
|
|||
assertEquals(HttpStatus.NOT_FOUND, httpResponse.getStatusCode());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void toHttpHandlerWebFilter() throws Exception {
|
||||
AtomicBoolean filterInvoked = new AtomicBoolean();
|
||||
|
||||
WebFilter webFilter = new WebFilter() {
|
||||
@Override
|
||||
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
|
||||
filterInvoked.set(true);
|
||||
return chain.filter(exchange);
|
||||
}
|
||||
};
|
||||
|
||||
HandlerFunction<ServerResponse> handlerFunction = request -> ServerResponse.accepted().build();
|
||||
RouterFunction<ServerResponse> routerFunction =
|
||||
RouterFunctions.route(RequestPredicates.all(), handlerFunction);
|
||||
|
||||
HandlerStrategies handlerStrategies = HandlerStrategies.builder()
|
||||
.webFilter(webFilter).build();
|
||||
|
||||
HttpHandler result = RouterFunctions.toHttpHandler(routerFunction, handlerStrategies);
|
||||
assertNotNull(result);
|
||||
|
||||
MockServerHttpRequest httpRequest = MockServerHttpRequest.get("http://localhost").build();
|
||||
MockServerHttpResponse httpResponse = new MockServerHttpResponse();
|
||||
result.handle(httpRequest, httpResponse).block();
|
||||
assertEquals(HttpStatus.ACCEPTED, httpResponse.getStatusCode());
|
||||
|
||||
assertTrue(filterInvoked.get());
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue