Invoke WebMvc.fn error handlers for async errors

This commit makes sure that any error handlers registered on the route
are also applied when an error occurs asynchronously. This commit
applies to asynchronous bodies with both CompletableFuture and Reactive
Streams, as well as completely asynchronous responses.

Closes gh-26831
This commit is contained in:
Arjen Poutsma 2021-04-30 11:38:37 +02:00
parent 4c7cc705de
commit 6f4fb08bf8
3 changed files with 47 additions and 16 deletions

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -118,7 +118,7 @@ final class DefaultAsyncServerResponse extends ErrorHandlingServerResponse imple
public ModelAndView writeTo(HttpServletRequest request, HttpServletResponse response, Context context) public ModelAndView writeTo(HttpServletRequest request, HttpServletResponse response, Context context)
throws ServletException, IOException { throws ServletException, IOException {
writeAsync(request, response, createDeferredResult()); writeAsync(request, response, createDeferredResult(request));
return null; return null;
} }
@ -140,7 +140,7 @@ final class DefaultAsyncServerResponse extends ErrorHandlingServerResponse imple
} }
private DeferredResult<ServerResponse> createDeferredResult() { private DeferredResult<ServerResponse> createDeferredResult(HttpServletRequest request) {
DeferredResult<ServerResponse> result; DeferredResult<ServerResponse> result;
if (this.timeout != null) { if (this.timeout != null) {
result = new DeferredResult<>(this.timeout.toMillis()); result = new DeferredResult<>(this.timeout.toMillis());
@ -153,7 +153,13 @@ final class DefaultAsyncServerResponse extends ErrorHandlingServerResponse imple
if (ex instanceof CompletionException && ex.getCause() != null) { if (ex instanceof CompletionException && ex.getCause() != null) {
ex = ex.getCause(); ex = ex.getCause();
} }
result.setErrorResult(ex); ServerResponse errorResponse = errorResponse(ex, request);
if (errorResponse != null) {
result.setResult(errorResponse);
}
else {
result.setErrorResult(ex);
}
} }
else { else {
result.setResult(value); result.setResult(value);

View File

@ -361,21 +361,27 @@ final class DefaultEntityResponseBuilder<T> implements EntityResponse.Builder<T>
protected ModelAndView writeToInternal(HttpServletRequest servletRequest, HttpServletResponse servletResponse, protected ModelAndView writeToInternal(HttpServletRequest servletRequest, HttpServletResponse servletResponse,
Context context) throws ServletException, IOException { Context context) throws ServletException, IOException {
DeferredResult<?> deferredResult = createDeferredResult(servletRequest, servletResponse, context); DeferredResult<ServerResponse> deferredResult = createDeferredResult(servletRequest, servletResponse, context);
DefaultAsyncServerResponse.writeAsync(servletRequest, servletResponse, deferredResult); DefaultAsyncServerResponse.writeAsync(servletRequest, servletResponse, deferredResult);
return null; return null;
} }
private DeferredResult<?> createDeferredResult(HttpServletRequest request, HttpServletResponse response, private DeferredResult<ServerResponse> createDeferredResult(HttpServletRequest request, HttpServletResponse response,
Context context) { Context context) {
DeferredResult<?> result = new DeferredResult<>(); DeferredResult<ServerResponse> result = new DeferredResult<>();
entity().handle((value, ex) -> { entity().handle((value, ex) -> {
if (ex != null) { if (ex != null) {
if (ex instanceof CompletionException && ex.getCause() != null) { if (ex instanceof CompletionException && ex.getCause() != null) {
ex = ex.getCause(); ex = ex.getCause();
} }
result.setErrorResult(ex); ServerResponse errorResponse = errorResponse(ex, request);
if (errorResponse != null) {
result.setResult(errorResponse);
}
else {
result.setErrorResult(ex);
}
} }
else { else {
try { try {
@ -468,7 +474,12 @@ final class DefaultEntityResponseBuilder<T> implements EntityResponse.Builder<T>
@Override @Override
public void onError(Throwable t) { public void onError(Throwable t) {
this.deferredResult.setErrorResult(t); try {
handleError(t, this.servletRequest, this.servletResponse, this.context);
}
catch (ServletException | IOException handlingThrowable) {
this.deferredResult.setErrorResult(handlingThrowable);
}
} }
@Override @Override

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2020 the original author or authors. * Copyright 2002-2021 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -35,7 +35,6 @@ import org.springframework.web.servlet.ModelAndView;
/** /**
* Base class for {@link ServerResponse} implementations with error handling. * Base class for {@link ServerResponse} implementations with error handling.
*
* @author Arjen Poutsma * @author Arjen Poutsma
* @since 5.3 * @since 5.3
*/ */
@ -55,21 +54,36 @@ abstract class ErrorHandlingServerResponse implements ServerResponse {
} }
@Nullable @Nullable
protected ModelAndView handleError(Throwable t, HttpServletRequest servletRequest, protected final ModelAndView handleError(Throwable t, HttpServletRequest servletRequest,
HttpServletResponse servletResponse, Context context) throws ServletException, IOException { HttpServletResponse servletResponse, Context context) throws ServletException, IOException {
ServerResponse serverResponse = errorResponse(t, servletRequest);
if (serverResponse != null) {
return serverResponse.writeTo(servletRequest, servletResponse, context);
}
else if (t instanceof ServletException) {
throw (ServletException) t;
}
else if (t instanceof IOException) {
throw (IOException) t;
}
else {
throw new ServletException(t);
}
}
@Nullable
protected final ServerResponse errorResponse(Throwable t, HttpServletRequest servletRequest) {
for (ErrorHandler<?> errorHandler : this.errorHandlers) { for (ErrorHandler<?> errorHandler : this.errorHandlers) {
if (errorHandler.test(t)) { if (errorHandler.test(t)) {
ServerRequest serverRequest = (ServerRequest) ServerRequest serverRequest = (ServerRequest)
servletRequest.getAttribute(RouterFunctions.REQUEST_ATTRIBUTE); servletRequest.getAttribute(RouterFunctions.REQUEST_ATTRIBUTE);
ServerResponse serverResponse = errorHandler.handle(t, serverRequest); return errorHandler.handle(t, serverRequest);
return serverResponse.writeTo(servletRequest, servletResponse, context);
} }
} }
throw new ServletException(t); return null;
} }
private static class ErrorHandler<T extends ServerResponse> { private static class ErrorHandler<T extends ServerResponse> {
private final Predicate<Throwable> predicate; private final Predicate<Throwable> predicate;