diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RouterFunctions.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RouterFunctions.java index 882f3efea58..2fee42abd1f 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RouterFunctions.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RouterFunctions.java @@ -18,6 +18,7 @@ package org.springframework.web.reactive.function.server; import java.util.Map; import java.util.function.Function; +import java.util.function.Supplier; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -27,8 +28,7 @@ import org.springframework.core.io.Resource; import org.springframework.http.server.reactive.HttpHandler; import org.springframework.util.Assert; import org.springframework.web.reactive.HandlerMapping; -import org.springframework.web.reactive.function.server.support.HandlerFunctionAdapter; -import org.springframework.web.reactive.function.server.support.ServerResponseResultHandler; +import org.springframework.web.reactive.function.server.support.*; import org.springframework.web.server.ResponseStatusException; import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.WebHandler; @@ -233,23 +233,25 @@ public abstract class RouterFunctions { addAttributes(exchange, request); return routerFunction.route(request) .defaultIfEmpty(notFound()) - .then(handlerFunction -> invokeHandler(handlerFunction, request)) - .otherwise(ResponseStatusException.class, RouterFunctions::responseStatusFallback) - .then(response -> response.writeTo(exchange, strategies)); + .then(handlerFunction -> wrapException(() -> handlerFunction.handle(request))) + .then(response -> wrapException(() -> response.writeTo(exchange, strategies))) + .otherwise(ResponseStatusException.class, + ex -> { + exchange.getResponse().setStatusCode(ex.getStatus()); + return Mono.empty(); + }); }); } - private static Mono invokeHandler(HandlerFunction handlerFunction, - ServerRequest request) { + private static Mono wrapException(Supplier> supplier) { try { - return handlerFunction.handle(request); + return supplier.get(); } catch (Throwable t) { return Mono.error(t); } } - /** * Convert the given {@code RouterFunction} into a {@code HandlerMapping}. * This conversion uses {@linkplain HandlerStrategies#builder() default strategies}. @@ -286,7 +288,6 @@ public abstract class RouterFunctions { }; } - private static void addAttributes(ServerWebExchange exchange, ServerRequest request) { Map attributes = exchange.getAttributes(); attributes.put(REQUEST_ATTRIBUTE, request); @@ -297,11 +298,6 @@ public abstract class RouterFunctions { return (HandlerFunction) NOT_FOUND_HANDLER; } - @SuppressWarnings("unchecked") - private static Mono responseStatusFallback(ResponseStatusException ex) { - return (Mono) ServerResponse.status(ex.getStatus()).build(); - } - @SuppressWarnings("unchecked") static HandlerFunction cast(HandlerFunction handlerFunction) { return (HandlerFunction) handlerFunction; diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/PublisherHandlerFunctionIntegrationTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/PublisherHandlerFunctionIntegrationTests.java index b3b3b14a210..84f95d6698e 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/PublisherHandlerFunctionIntegrationTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/PublisherHandlerFunctionIntegrationTests.java @@ -16,11 +16,9 @@ package org.springframework.web.reactive.function.server; -import java.io.IOException; import java.net.URI; import java.util.List; -import org.junit.Before; import org.junit.Test; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -30,41 +28,20 @@ import org.springframework.http.HttpMethod; import org.springframework.http.HttpStatus; import org.springframework.http.RequestEntity; import org.springframework.http.ResponseEntity; -import org.springframework.http.client.ClientHttpResponse; -import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestTemplate; -import org.springframework.web.server.ResponseStatusException; -import static org.junit.Assert.assertEquals; -import static org.springframework.web.reactive.function.BodyExtractors.toMono; -import static org.springframework.web.reactive.function.BodyInserters.fromPublisher; -import static org.springframework.web.reactive.function.server.RequestPredicates.GET; -import static org.springframework.web.reactive.function.server.RequestPredicates.POST; -import static org.springframework.web.reactive.function.server.RouterFunctions.route; +import static org.junit.Assert.*; +import static org.springframework.web.reactive.function.BodyExtractors.*; +import static org.springframework.web.reactive.function.BodyInserters.*; +import static org.springframework.web.reactive.function.server.RequestPredicates.*; +import static org.springframework.web.reactive.function.server.RouterFunctions.*; /** * @author Arjen Poutsma */ public class PublisherHandlerFunctionIntegrationTests extends AbstractRouterFunctionIntegrationTests { - private RestTemplate restTemplate; - - @Before - public void createRestTemplate() { - restTemplate = new RestTemplate(); - restTemplate.setErrorHandler(new ResponseErrorHandler() { - @Override - public boolean hasError(ClientHttpResponse response) throws IOException { - return false; - } - - @Override - public void handleError(ClientHttpResponse response) throws IOException { - - } - }); - - } + private final RestTemplate restTemplate = new RestTemplate(); @Override @@ -72,9 +49,7 @@ public class PublisherHandlerFunctionIntegrationTests extends AbstractRouterFunc PersonHandler personHandler = new PersonHandler(); return route(GET("/mono"), personHandler::mono) .and(route(POST("/mono"), personHandler::postMono)) - .and(route(GET("/flux"), personHandler::flux)) - .and(route(GET("/throwRSE"), personHandler::throwResponseStatusException)) - .and(route(GET("/returnRSE"), personHandler::returnResponseStatusException)); + .and(route(GET("/flux"), personHandler::flux)); } @@ -111,19 +86,6 @@ public class PublisherHandlerFunctionIntegrationTests extends AbstractRouterFunc assertEquals("Jack", result.getBody().getName()); } - @Test - public void responseStatusException() { - ResponseEntity result = - restTemplate.getForEntity("http://localhost:" + port + "/throwRSE", String.class); - - assertEquals(HttpStatus.BAD_REQUEST, result.getStatusCode()); - - result = restTemplate.getForEntity("http://localhost:" + port + "/returnRSE", String.class); - - assertEquals(HttpStatus.BAD_REQUEST, result.getStatusCode()); - } - - private static class PersonHandler { @@ -143,14 +105,6 @@ public class PublisherHandlerFunctionIntegrationTests extends AbstractRouterFunc return ServerResponse.ok().body( fromPublisher(Flux.just(person1, person2), Person.class)); } - - public Mono throwResponseStatusException(ServerRequest request) { - throw new ResponseStatusException(HttpStatus.BAD_REQUEST, "Bad Request"); - } - - public Mono returnResponseStatusException(ServerRequest request) { - return Mono.error(new ResponseStatusException(HttpStatus.BAD_REQUEST, "Bad Request")); - } } diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RouterFunctionsTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RouterFunctionsTests.java index 146105d890d..e2f62ec14cb 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RouterFunctionsTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RouterFunctionsTests.java @@ -16,18 +16,16 @@ package org.springframework.web.reactive.function.server; -import java.util.stream.Stream; - import org.junit.Test; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; -import org.springframework.http.codec.HttpMessageReader; -import org.springframework.http.codec.HttpMessageWriter; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; import org.springframework.http.server.reactive.HttpHandler; import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse; -import org.springframework.web.reactive.result.view.ViewResolver; +import org.springframework.web.server.ResponseStatusException; import org.springframework.web.server.ServerWebExchange; import static org.junit.Assert.*; @@ -114,34 +112,132 @@ public class RouterFunctionsTests { } @Test - public void toHttpHandler() throws Exception { - HandlerStrategies strategies = mock(HandlerStrategies.class); - when(strategies.messageReaders()).thenReturn( - Stream::>empty); - when(strategies.messageWriters()).thenReturn( - Stream::>empty); - when(strategies.viewResolvers()).thenReturn( - Stream::empty); + public void toHttpHandlerNormal() throws Exception { + HandlerFunction handlerFunction = request -> ServerResponse.accepted().build(); + RouterFunction routerFunction = + RouterFunctions.route(RequestPredicates.all(), handlerFunction); - ServerRequest request = mock(ServerRequest.class); - ServerResponse response = mock(ServerResponse.class); - when(response.writeTo(any(ServerWebExchange.class), eq(strategies))).thenReturn(Mono.empty()); - - HandlerFunction handlerFunction = mock(HandlerFunction.class); - when(handlerFunction.handle(any(ServerRequest.class))).thenReturn(Mono.just(response)); - - RouterFunction routerFunction = mock(RouterFunction.class); - when(routerFunction.route(any(ServerRequest.class))).thenReturn(Mono.just(handlerFunction)); - - RequestPredicate requestPredicate = mock(RequestPredicate.class); - when(requestPredicate.test(request)).thenReturn(false); - - HttpHandler result = RouterFunctions.toHttpHandler(routerFunction, strategies); + HttpHandler result = RouterFunctions.toHttpHandler(routerFunction); assertNotNull(result); MockServerHttpRequest httpRequest = MockServerHttpRequest.get("http://localhost").build(); - MockServerHttpResponse serverHttpResponse = new MockServerHttpResponse(); - result.handle(httpRequest, serverHttpResponse); + MockServerHttpResponse httpResponse = new MockServerHttpResponse(); + result.handle(httpRequest, httpResponse).block(); + assertEquals(HttpStatus.ACCEPTED, httpResponse.getStatusCode()); + } + + @Test + public void toHttpHandlerHandlerThrowsException() throws Exception { + HandlerFunction handlerFunction = + request -> { + throw new IllegalStateException(); + }; + RouterFunction routerFunction = + RouterFunctions.route(RequestPredicates.all(), handlerFunction); + + HttpHandler result = RouterFunctions.toHttpHandler(routerFunction); + assertNotNull(result); + + MockServerHttpRequest httpRequest = MockServerHttpRequest.get("http://localhost").build(); + MockServerHttpResponse httpResponse = new MockServerHttpResponse(); + result.handle(httpRequest, httpResponse).block(); + assertEquals(HttpStatus.INTERNAL_SERVER_ERROR, httpResponse.getStatusCode()); + } + + @Test + public void toHttpHandlerHandlerReturnsException() throws Exception { + HandlerFunction handlerFunction = + request -> Mono.error(new IllegalStateException()); + RouterFunction routerFunction = + RouterFunctions.route(RequestPredicates.all(), handlerFunction); + + HttpHandler result = RouterFunctions.toHttpHandler(routerFunction); + assertNotNull(result); + + MockServerHttpRequest httpRequest = MockServerHttpRequest.get("http://localhost").build(); + MockServerHttpResponse httpResponse = new MockServerHttpResponse(); + result.handle(httpRequest, httpResponse).block(); + assertEquals(HttpStatus.INTERNAL_SERVER_ERROR, httpResponse.getStatusCode()); + } + + @Test + public void toHttpHandlerHandlerResponseStatusException() throws Exception { + HandlerFunction handlerFunction = + request -> Mono.error(new ResponseStatusException(HttpStatus.NOT_FOUND, "Not found")); + RouterFunction routerFunction = + RouterFunctions.route(RequestPredicates.all(), handlerFunction); + + HttpHandler result = RouterFunctions.toHttpHandler(routerFunction); + assertNotNull(result); + + MockServerHttpRequest httpRequest = MockServerHttpRequest.get("http://localhost").build(); + MockServerHttpResponse httpResponse = new MockServerHttpResponse(); + result.handle(httpRequest, httpResponse).block(); + assertEquals(HttpStatus.NOT_FOUND, httpResponse.getStatusCode()); + } + + @Test + public void toHttpHandlerHandlerReturnResponseStatusExceptionInResponseWriteTo() throws Exception { + HandlerFunction handlerFunction = + request -> Mono.just(new ServerResponse() { + @Override + public HttpStatus statusCode() { + return HttpStatus.OK; + } + + @Override + public HttpHeaders headers() { + return new HttpHeaders(); + } + + @Override + public Mono writeTo(ServerWebExchange exchange, + HandlerStrategies strategies) { + return Mono.error(new ResponseStatusException(HttpStatus.NOT_FOUND, "Not found")); + } + }); + RouterFunction routerFunction = + RouterFunctions.route(RequestPredicates.all(), handlerFunction); + + HttpHandler result = RouterFunctions.toHttpHandler(routerFunction); + assertNotNull(result); + + MockServerHttpRequest httpRequest = MockServerHttpRequest.get("http://localhost").build(); + MockServerHttpResponse httpResponse = new MockServerHttpResponse(); + result.handle(httpRequest, httpResponse).block(); + assertEquals(HttpStatus.NOT_FOUND, httpResponse.getStatusCode()); + } + + @Test + public void toHttpHandlerHandlerThrowResponseStatusExceptionInResponseWriteTo() throws Exception { + HandlerFunction handlerFunction = + request -> Mono.just(new ServerResponse() { + @Override + public HttpStatus statusCode() { + return HttpStatus.OK; + } + + @Override + public HttpHeaders headers() { + return new HttpHeaders(); + } + + @Override + public Mono writeTo(ServerWebExchange exchange, + HandlerStrategies strategies) { + throw new ResponseStatusException(HttpStatus.NOT_FOUND, "Not found"); + } + }); + RouterFunction routerFunction = + RouterFunctions.route(RequestPredicates.all(), handlerFunction); + + HttpHandler result = RouterFunctions.toHttpHandler(routerFunction); + assertNotNull(result); + + MockServerHttpRequest httpRequest = MockServerHttpRequest.get("http://localhost").build(); + MockServerHttpResponse httpResponse = new MockServerHttpResponse(); + result.handle(httpRequest, httpResponse).block(); + assertEquals(HttpStatus.NOT_FOUND, httpResponse.getStatusCode()); } }