From 6f4fb08bf8c1bf6f783e380f2c604f415fc0dfe6 Mon Sep 17 00:00:00 2001 From: Arjen Poutsma Date: Fri, 30 Apr 2021 11:38:37 +0200 Subject: [PATCH] Invoke WebMvc.fn error handlers for async errors This commit makes sure that any error handlers registered on the route are also applied when an error occurs asynchronously. This commit applies to asynchronous bodies with both CompletableFuture and Reactive Streams, as well as completely asynchronous responses. Closes gh-26831 --- .../function/DefaultAsyncServerResponse.java | 14 +++++++--- .../DefaultEntityResponseBuilder.java | 21 ++++++++++---- .../function/ErrorHandlingServerResponse.java | 28 ++++++++++++++----- 3 files changed, 47 insertions(+), 16 deletions(-) diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultAsyncServerResponse.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultAsyncServerResponse.java index 0fd28344543..e720174b37e 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultAsyncServerResponse.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultAsyncServerResponse.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -118,7 +118,7 @@ final class DefaultAsyncServerResponse extends ErrorHandlingServerResponse imple public ModelAndView writeTo(HttpServletRequest request, HttpServletResponse response, Context context) throws ServletException, IOException { - writeAsync(request, response, createDeferredResult()); + writeAsync(request, response, createDeferredResult(request)); return null; } @@ -140,7 +140,7 @@ final class DefaultAsyncServerResponse extends ErrorHandlingServerResponse imple } - private DeferredResult createDeferredResult() { + private DeferredResult createDeferredResult(HttpServletRequest request) { DeferredResult result; if (this.timeout != null) { result = new DeferredResult<>(this.timeout.toMillis()); @@ -153,7 +153,13 @@ final class DefaultAsyncServerResponse extends ErrorHandlingServerResponse imple if (ex instanceof CompletionException && ex.getCause() != null) { ex = ex.getCause(); } - result.setErrorResult(ex); + ServerResponse errorResponse = errorResponse(ex, request); + if (errorResponse != null) { + result.setResult(errorResponse); + } + else { + result.setErrorResult(ex); + } } else { result.setResult(value); diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultEntityResponseBuilder.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultEntityResponseBuilder.java index 44b721e72a2..fedfe2d4a40 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultEntityResponseBuilder.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultEntityResponseBuilder.java @@ -361,21 +361,27 @@ final class DefaultEntityResponseBuilder implements EntityResponse.Builder protected ModelAndView writeToInternal(HttpServletRequest servletRequest, HttpServletResponse servletResponse, Context context) throws ServletException, IOException { - DeferredResult deferredResult = createDeferredResult(servletRequest, servletResponse, context); + DeferredResult deferredResult = createDeferredResult(servletRequest, servletResponse, context); DefaultAsyncServerResponse.writeAsync(servletRequest, servletResponse, deferredResult); return null; } - private DeferredResult createDeferredResult(HttpServletRequest request, HttpServletResponse response, + private DeferredResult createDeferredResult(HttpServletRequest request, HttpServletResponse response, Context context) { - DeferredResult result = new DeferredResult<>(); + DeferredResult result = new DeferredResult<>(); entity().handle((value, ex) -> { if (ex != null) { if (ex instanceof CompletionException && ex.getCause() != null) { ex = ex.getCause(); } - result.setErrorResult(ex); + ServerResponse errorResponse = errorResponse(ex, request); + if (errorResponse != null) { + result.setResult(errorResponse); + } + else { + result.setErrorResult(ex); + } } else { try { @@ -468,7 +474,12 @@ final class DefaultEntityResponseBuilder implements EntityResponse.Builder @Override public void onError(Throwable t) { - this.deferredResult.setErrorResult(t); + try { + handleError(t, this.servletRequest, this.servletResponse, this.context); + } + catch (ServletException | IOException handlingThrowable) { + this.deferredResult.setErrorResult(handlingThrowable); + } } @Override diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/ErrorHandlingServerResponse.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/ErrorHandlingServerResponse.java index 09785c5cf92..9ae67ec1023 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/ErrorHandlingServerResponse.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/ErrorHandlingServerResponse.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,7 +35,6 @@ import org.springframework.web.servlet.ModelAndView; /** * Base class for {@link ServerResponse} implementations with error handling. - * * @author Arjen Poutsma * @since 5.3 */ @@ -55,21 +54,36 @@ abstract class ErrorHandlingServerResponse implements ServerResponse { } @Nullable - protected ModelAndView handleError(Throwable t, HttpServletRequest servletRequest, + protected final ModelAndView handleError(Throwable t, HttpServletRequest servletRequest, HttpServletResponse servletResponse, Context context) throws ServletException, IOException { + ServerResponse serverResponse = errorResponse(t, servletRequest); + if (serverResponse != null) { + return serverResponse.writeTo(servletRequest, servletResponse, context); + } + else if (t instanceof ServletException) { + throw (ServletException) t; + } + else if (t instanceof IOException) { + throw (IOException) t; + } + else { + throw new ServletException(t); + } + } + + @Nullable + protected final ServerResponse errorResponse(Throwable t, HttpServletRequest servletRequest) { for (ErrorHandler errorHandler : this.errorHandlers) { if (errorHandler.test(t)) { ServerRequest serverRequest = (ServerRequest) servletRequest.getAttribute(RouterFunctions.REQUEST_ATTRIBUTE); - ServerResponse serverResponse = errorHandler.handle(t, serverRequest); - return serverResponse.writeTo(servletRequest, servletResponse, context); + return errorHandler.handle(t, serverRequest); } } - throw new ServletException(t); + return null; } - private static class ErrorHandler { private final Predicate predicate;