diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RouterFunctions.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RouterFunctions.java index a7e79d6a13d..7c8d814e936 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RouterFunctions.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RouterFunctions.java @@ -77,8 +77,6 @@ public abstract class RouterFunctions { RouterFunctions.class.getName() + ".matchingPattern"; - private static final HandlerFunction NOT_FOUND_HANDLER = - request -> ServerResponse.notFound().build(); /** @@ -253,43 +251,9 @@ public abstract class RouterFunctions { Assert.notNull(routerFunction, "RouterFunction must not be null"); Assert.notNull(strategies, "HandlerStrategies must not be null"); - return exchange -> { - ServerRequest request = new DefaultServerRequest(exchange, strategies.messageReaders()); - addAttributes(exchange, request); - return routerFunction.route(request) - .defaultIfEmpty(notFound()) - .flatMap(handlerFunction -> wrapException(() -> handlerFunction.handle(request))) - .flatMap(response -> wrapException(() -> response.writeTo(exchange, - new HandlerStrategiesResponseContext(strategies)))); - }; + return new RouterFunctionWebHandler(strategies, routerFunction); } - - private static Mono wrapException(Supplier> supplier) { - try { - return supplier.get(); - } - catch (Throwable ex) { - return Mono.error(ex); - } - } - - private static void addAttributes(ServerWebExchange exchange, ServerRequest request) { - Map attributes = exchange.getAttributes(); - attributes.put(REQUEST_ATTRIBUTE, request); - } - - @SuppressWarnings("unchecked") - private static HandlerFunction notFound() { - return (HandlerFunction) NOT_FOUND_HANDLER; - } - - @SuppressWarnings("unchecked") - static HandlerFunction cast(HandlerFunction handlerFunction) { - return (HandlerFunction) handlerFunction; - } - - /** * Represents a discoverable builder for router functions. * Obtained via {@link RouterFunctions#route()}. @@ -846,8 +810,13 @@ public abstract class RouterFunctions { @Override public Mono> route(ServerRequest request) { return this.first.route(request) - .map(RouterFunctions::cast) - .switchIfEmpty(Mono.defer(() -> this.second.route(request).map(RouterFunctions::cast))); + .map(this::cast) + .switchIfEmpty(Mono.defer(() -> this.second.route(request).map(this::cast))); + } + + @SuppressWarnings("unchecked") + private HandlerFunction cast(HandlerFunction handlerFunction) { + return (HandlerFunction) handlerFunction; } @Override @@ -1012,4 +981,51 @@ public abstract class RouterFunctions { } } + + private static class RouterFunctionWebHandler implements WebHandler { + + private static final HandlerFunction NOT_FOUND_HANDLER = + request -> ServerResponse.notFound().build(); + + private final HandlerStrategies strategies; + + private final RouterFunction routerFunction; + + public RouterFunctionWebHandler(HandlerStrategies strategies, RouterFunction routerFunction) { + this.strategies = strategies; + this.routerFunction = routerFunction; + } + + @Override + public Mono handle(ServerWebExchange exchange) { + return Mono.defer(() -> { + ServerRequest request = new DefaultServerRequest(exchange, this.strategies.messageReaders()); + addAttributes(exchange, request); + return this.routerFunction.route(request) + .defaultIfEmpty(notFound()) + .flatMap(handlerFunction -> wrapException(() -> handlerFunction.handle(request))) + .flatMap(response -> wrapException(() -> response.writeTo(exchange, + new HandlerStrategiesResponseContext(this.strategies)))); + }); + } + + private void addAttributes(ServerWebExchange exchange, ServerRequest request) { + Map attributes = exchange.getAttributes(); + attributes.put(REQUEST_ATTRIBUTE, request); + } + + @SuppressWarnings("unchecked") + private static HandlerFunction notFound() { + return (HandlerFunction) NOT_FOUND_HANDLER; + } + + private static Mono wrapException(Supplier> supplier) { + try { + return supplier.get(); + } + catch (Throwable ex) { + return Mono.error(ex); + } + } + } }