diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultAsyncServerResponse.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultAsyncServerResponse.java index 0fd28344543..e720174b37e 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultAsyncServerResponse.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultAsyncServerResponse.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -118,7 +118,7 @@ final class DefaultAsyncServerResponse extends ErrorHandlingServerResponse imple public ModelAndView writeTo(HttpServletRequest request, HttpServletResponse response, Context context) throws ServletException, IOException { - writeAsync(request, response, createDeferredResult()); + writeAsync(request, response, createDeferredResult(request)); return null; } @@ -140,7 +140,7 @@ final class DefaultAsyncServerResponse extends ErrorHandlingServerResponse imple } - private DeferredResult createDeferredResult() { + private DeferredResult createDeferredResult(HttpServletRequest request) { DeferredResult result; if (this.timeout != null) { result = new DeferredResult<>(this.timeout.toMillis()); @@ -153,7 +153,13 @@ final class DefaultAsyncServerResponse extends ErrorHandlingServerResponse imple if (ex instanceof CompletionException && ex.getCause() != null) { ex = ex.getCause(); } - result.setErrorResult(ex); + ServerResponse errorResponse = errorResponse(ex, request); + if (errorResponse != null) { + result.setResult(errorResponse); + } + else { + result.setErrorResult(ex); + } } else { result.setResult(value); diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultEntityResponseBuilder.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultEntityResponseBuilder.java index 44b721e72a2..fedfe2d4a40 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultEntityResponseBuilder.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/DefaultEntityResponseBuilder.java @@ -361,21 +361,27 @@ final class DefaultEntityResponseBuilder implements EntityResponse.Builder protected ModelAndView writeToInternal(HttpServletRequest servletRequest, HttpServletResponse servletResponse, Context context) throws ServletException, IOException { - DeferredResult deferredResult = createDeferredResult(servletRequest, servletResponse, context); + DeferredResult deferredResult = createDeferredResult(servletRequest, servletResponse, context); DefaultAsyncServerResponse.writeAsync(servletRequest, servletResponse, deferredResult); return null; } - private DeferredResult createDeferredResult(HttpServletRequest request, HttpServletResponse response, + private DeferredResult createDeferredResult(HttpServletRequest request, HttpServletResponse response, Context context) { - DeferredResult result = new DeferredResult<>(); + DeferredResult result = new DeferredResult<>(); entity().handle((value, ex) -> { if (ex != null) { if (ex instanceof CompletionException && ex.getCause() != null) { ex = ex.getCause(); } - result.setErrorResult(ex); + ServerResponse errorResponse = errorResponse(ex, request); + if (errorResponse != null) { + result.setResult(errorResponse); + } + else { + result.setErrorResult(ex); + } } else { try { @@ -468,7 +474,12 @@ final class DefaultEntityResponseBuilder implements EntityResponse.Builder @Override public void onError(Throwable t) { - this.deferredResult.setErrorResult(t); + try { + handleError(t, this.servletRequest, this.servletResponse, this.context); + } + catch (ServletException | IOException handlingThrowable) { + this.deferredResult.setErrorResult(handlingThrowable); + } } @Override diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/ErrorHandlingServerResponse.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/ErrorHandlingServerResponse.java index 09785c5cf92..9ae67ec1023 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/ErrorHandlingServerResponse.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/ErrorHandlingServerResponse.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,7 +35,6 @@ import org.springframework.web.servlet.ModelAndView; /** * Base class for {@link ServerResponse} implementations with error handling. - * * @author Arjen Poutsma * @since 5.3 */ @@ -55,21 +54,36 @@ abstract class ErrorHandlingServerResponse implements ServerResponse { } @Nullable - protected ModelAndView handleError(Throwable t, HttpServletRequest servletRequest, + protected final ModelAndView handleError(Throwable t, HttpServletRequest servletRequest, HttpServletResponse servletResponse, Context context) throws ServletException, IOException { + ServerResponse serverResponse = errorResponse(t, servletRequest); + if (serverResponse != null) { + return serverResponse.writeTo(servletRequest, servletResponse, context); + } + else if (t instanceof ServletException) { + throw (ServletException) t; + } + else if (t instanceof IOException) { + throw (IOException) t; + } + else { + throw new ServletException(t); + } + } + + @Nullable + protected final ServerResponse errorResponse(Throwable t, HttpServletRequest servletRequest) { for (ErrorHandler errorHandler : this.errorHandlers) { if (errorHandler.test(t)) { ServerRequest serverRequest = (ServerRequest) servletRequest.getAttribute(RouterFunctions.REQUEST_ATTRIBUTE); - ServerResponse serverResponse = errorHandler.handle(t, serverRequest); - return serverResponse.writeTo(servletRequest, servletResponse, context); + return errorHandler.handle(t, serverRequest); } } - throw new ServletException(t); + return null; } - private static class ErrorHandler { private final Predicate predicate;