Revert errorhandler order in RouterFunctionBuilder

Prior to this commit, error handlers in the WebMvc.fn and WebFlux.fn
router function builders had to be registered in an unintuitive, reverse
order, due to the filter chain composition model used.
This commit reverses the error handler order, so that more specific
error handlers can come before generic ones.

Closes gh-25541
This commit is contained in:
Arjen Poutsma 2020-09-15 15:43:25 +02:00
parent 200b33b26a
commit 392895e256
4 changed files with 82 additions and 37 deletions

View File

@ -23,6 +23,7 @@ import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.stream.Stream;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
@ -43,6 +44,8 @@ class RouterFunctionBuilder implements RouterFunctions.Builder {
private final List<HandlerFilterFunction<ServerResponse, ServerResponse>> filterFunctions = new ArrayList<>();
private final List<HandlerFilterFunction<ServerResponse, ServerResponse>> errorHandlers = new ArrayList<>();
@Override
public RouterFunctions.Builder add(RouterFunction<ServerResponse> routerFunction) {
@ -310,8 +313,9 @@ class RouterFunctionBuilder implements RouterFunctions.Builder {
Assert.notNull(predicate, "Predicate must not be null");
Assert.notNull(responseProvider, "ResponseProvider must not be null");
return filter((request, next) -> next.handle(request)
this.errorHandlers.add(0, (request, next) -> next.handle(request)
.onErrorResume(predicate, t -> responseProvider.apply(t, request)));
return this;
}
@Override
@ -321,8 +325,9 @@ class RouterFunctionBuilder implements RouterFunctions.Builder {
Assert.notNull(exceptionType, "ExceptionType must not be null");
Assert.notNull(responseProvider, "ResponseProvider must not be null");
return filter((request, next) -> next.handle(request)
this.errorHandlers.add(0, (request, next) -> next.handle(request)
.onErrorResume(exceptionType, t -> responseProvider.apply(t, request)));
return this;
}
@Override
@ -332,12 +337,12 @@ class RouterFunctionBuilder implements RouterFunctions.Builder {
}
RouterFunction<ServerResponse> result = new BuiltRouterFunction(this.routerFunctions);
if (this.filterFunctions.isEmpty()) {
if (this.filterFunctions.isEmpty() && this.errorHandlers.isEmpty()) {
return result;
}
else {
HandlerFilterFunction<ServerResponse, ServerResponse> filter =
this.filterFunctions.stream()
Stream.concat(this.filterFunctions.stream(), this.errorHandlers.stream())
.reduce(HandlerFilterFunction::andThen)
.orElseThrow(IllegalStateException::new);

View File

@ -16,6 +16,7 @@
package org.springframework.web.reactive.function.server;
import java.io.IOException;
import java.util.Collections;
import java.util.concurrent.atomic.AtomicInteger;
@ -210,4 +211,26 @@ public class RouterFunctionBuilderTests {
.verifyComplete();
}
@Test
public void multipleOnErrors() {
RouterFunction<ServerResponse> route = RouterFunctions.route()
.GET("/error", request -> Mono.error(new IOException()))
.onError(IOException.class, (t, r) -> ServerResponse.status(200).build())
.onError(Exception.class, (t, r) -> ServerResponse.status(201).build())
.build();
MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com/error").build();
ServerRequest serverRequest = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
Mono<HttpStatus> responseStatus = route.route(serverRequest)
.flatMap(handlerFunction -> handlerFunction.handle(serverRequest))
.map(ServerResponse::statusCode);
StepVerifier.create(responseStatus)
.assertNext(status -> assertThat(status).isEqualTo(HttpStatus.OK))
.verifyComplete();
}
}

View File

@ -24,6 +24,7 @@ import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.stream.Stream;
import org.springframework.core.io.Resource;
import org.springframework.http.HttpMethod;
@ -41,6 +42,8 @@ class RouterFunctionBuilder implements RouterFunctions.Builder {
private final List<HandlerFilterFunction<ServerResponse, ServerResponse>> filterFunctions = new ArrayList<>();
private final List<HandlerFilterFunction<ServerResponse, ServerResponse>> errorHandlers = new ArrayList<>();
@Override
public RouterFunctions.Builder add(RouterFunction<ServerResponse> routerFunction) {
@ -307,7 +310,8 @@ class RouterFunctionBuilder implements RouterFunctions.Builder {
Assert.notNull(predicate, "Predicate must not be null");
Assert.notNull(responseProvider, "ResponseProvider must not be null");
return filter(HandlerFilterFunction.ofErrorHandler(predicate, responseProvider));
this.errorHandlers.add(0, HandlerFilterFunction.ofErrorHandler(predicate, responseProvider));
return this;
}
@Override
@ -316,8 +320,7 @@ class RouterFunctionBuilder implements RouterFunctions.Builder {
Assert.notNull(exceptionType, "ExceptionType must not be null");
Assert.notNull(responseProvider, "ResponseProvider must not be null");
return filter(HandlerFilterFunction.ofErrorHandler(exceptionType::isInstance,
responseProvider));
return onError(exceptionType::isInstance, responseProvider);
}
@Override
@ -327,12 +330,12 @@ class RouterFunctionBuilder implements RouterFunctions.Builder {
}
RouterFunction<ServerResponse> result = new BuiltRouterFunction(this.routerFunctions);
if (this.filterFunctions.isEmpty()) {
if (this.filterFunctions.isEmpty() && this.errorHandlers.isEmpty()) {
return result;
}
else {
HandlerFilterFunction<ServerResponse, ServerResponse> filter =
this.filterFunctions.stream()
Stream.concat(this.filterFunctions.stream(), this.errorHandlers.stream())
.reduce(HandlerFilterFunction::andThen)
.orElseThrow(IllegalStateException::new);

View File

@ -16,6 +16,7 @@
package org.springframework.web.servlet.function;
import java.io.IOException;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
@ -52,36 +53,32 @@ class RouterFunctionBuilderTests {
ServerRequest getFooRequest = initRequest("GET", "/foo");
Optional<Integer> responseStatus = route.route(getFooRequest)
Optional<HttpStatus> responseStatus = route.route(getFooRequest)
.map(handlerFunction -> handle(handlerFunction, getFooRequest))
.map(ServerResponse::statusCode)
.map(HttpStatus::value);
assertThat(responseStatus.get().intValue()).isEqualTo(200);
.map(ServerResponse::statusCode);
assertThat(responseStatus).contains(HttpStatus.OK);
ServerRequest headFooRequest = initRequest("HEAD", "/foo");
responseStatus = route.route(headFooRequest)
.map(handlerFunction -> handle(handlerFunction, getFooRequest))
.map(ServerResponse::statusCode)
.map(HttpStatus::value);
assertThat(responseStatus.get().intValue()).isEqualTo(202);
.map(ServerResponse::statusCode);
assertThat(responseStatus).contains(HttpStatus.ACCEPTED);
ServerRequest barRequest = initRequest("POST", "/", req -> req.setContentType("text/plain"));
responseStatus = route.route(barRequest)
.map(handlerFunction -> handle(handlerFunction, barRequest))
.map(ServerResponse::statusCode)
.map(HttpStatus::value);
assertThat(responseStatus.get().intValue()).isEqualTo(204);
.map(ServerResponse::statusCode);
assertThat(responseStatus).contains(HttpStatus.NO_CONTENT);
ServerRequest invalidRequest = initRequest("POST", "/");
responseStatus = route.route(invalidRequest)
.map(handlerFunction -> handle(handlerFunction, invalidRequest))
.map(ServerResponse::statusCode)
.map(HttpStatus::value);
.map(ServerResponse::statusCode);
assertThat(responseStatus.isPresent()).isFalse();
assertThat(responseStatus).isEmpty();
}
private static ServerResponse handle(HandlerFunction<ServerResponse> handlerFunction,
@ -105,19 +102,17 @@ class RouterFunctionBuilderTests {
ServerRequest resourceRequest = initRequest("GET", "/resources/response.txt");
Optional<Integer> responseStatus = route.route(resourceRequest)
Optional<HttpStatus> responseStatus = route.route(resourceRequest)
.map(handlerFunction -> handle(handlerFunction, resourceRequest))
.map(ServerResponse::statusCode)
.map(HttpStatus::value);
assertThat(responseStatus.get().intValue()).isEqualTo(200);
.map(ServerResponse::statusCode);
assertThat(responseStatus).contains(HttpStatus.OK);
ServerRequest invalidRequest = initRequest("POST", "/resources/foo.txt");
responseStatus = route.route(invalidRequest)
.map(handlerFunction -> handle(handlerFunction, invalidRequest))
.map(ServerResponse::statusCode)
.map(HttpStatus::value);
assertThat(responseStatus.isPresent()).isFalse();
.map(ServerResponse::statusCode);
assertThat(responseStatus).isEmpty();
}
@Test
@ -132,11 +127,10 @@ class RouterFunctionBuilderTests {
ServerRequest fooRequest = initRequest("GET", "/foo/bar/baz");
Optional<Integer> responseStatus = route.route(fooRequest)
Optional<HttpStatus> responseStatus = route.route(fooRequest)
.map(handlerFunction -> handle(handlerFunction, fooRequest))
.map(ServerResponse::statusCode)
.map(HttpStatus::value);
assertThat(responseStatus.get().intValue()).isEqualTo(200);
.map(ServerResponse::statusCode);
assertThat(responseStatus).contains(HttpStatus.OK);
}
@Test
@ -181,13 +175,33 @@ class RouterFunctionBuilderTests {
ServerRequest barRequest = initRequest("GET", "/bar");
Optional<Integer> responseStatus = route.route(barRequest)
Optional<HttpStatus> responseStatus = route.route(barRequest)
.map(handlerFunction -> handle(handlerFunction, barRequest))
.map(ServerResponse::statusCode)
.map(HttpStatus::value);
assertThat(responseStatus.get().intValue()).isEqualTo(500);
.map(ServerResponse::statusCode);
assertThat(responseStatus).contains(HttpStatus.INTERNAL_SERVER_ERROR);
}
@Test
public void multipleOnErrors() {
RouterFunction<ServerResponse> route = RouterFunctions.route()
.GET("/error", request -> {
throw new IOException();
})
.onError(IOException.class, (t, r) -> ServerResponse.status(200).build())
.onError(Exception.class, (t, r) -> ServerResponse.status(201).build())
.build();
MockHttpServletRequest servletRequest = new MockHttpServletRequest("GET", "/error");
ServerRequest serverRequest = new DefaultServerRequest(servletRequest, emptyList());
Optional<HttpStatus> responseStatus = route.route(serverRequest)
.map(handlerFunction -> handle(handlerFunction, serverRequest))
.map(ServerResponse::statusCode);
assertThat(responseStatus).contains(HttpStatus.OK);
}
private ServerRequest initRequest(String httpMethod, String requestUri) {
return initRequest(httpMethod, requestUri, null);