diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ServletHttpHandlerAdapter.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletHttpHandlerAdapter.java index 169664f233..cf6ce52f4a 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/ServletHttpHandlerAdapter.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletHttpHandlerAdapter.java @@ -18,6 +18,8 @@ package org.springframework.http.server.reactive; import java.io.IOException; import javax.servlet.AsyncContext; +import javax.servlet.AsyncEvent; +import javax.servlet.AsyncListener; import javax.servlet.Servlet; import javax.servlet.ServletConfig; import javax.servlet.ServletRequest; @@ -106,6 +108,8 @@ public class ServletHttpHandlerAdapter implements Servlet { ServerHttpRequest httpRequest = createRequest(((HttpServletRequest) request), asyncContext); ServerHttpResponse httpResponse = createResponse(((HttpServletResponse) response), asyncContext); + asyncContext.addListener(TIMEOUT_LISTENER); + HandlerResultSubscriber subscriber = new HandlerResultSubscriber(asyncContext); this.httpHandler.handle(httpRequest, httpResponse).subscribe(subscriber); } @@ -146,14 +150,40 @@ public class ServletHttpHandlerAdapter implements Servlet { } + private final static AsyncListener TIMEOUT_LISTENER = new AsyncListener() { + + @Override + public void onTimeout(AsyncEvent event) throws IOException { + event.getAsyncContext().complete(); + } + + @Override + public void onError(AsyncEvent event) throws IOException { + event.getAsyncContext().complete(); + } + + @Override + public void onStartAsync(AsyncEvent event) throws IOException { + // no-op + } + + @Override + public void onComplete(AsyncEvent event) throws IOException { + // no-op + } + }; + + private class HandlerResultSubscriber implements Subscriber { private final AsyncContext asyncContext; + public HandlerResultSubscriber(AsyncContext asyncContext) { this.asyncContext = asyncContext; } + @Override public void onSubscribe(Subscription subscription) { subscription.request(Long.MAX_VALUE); @@ -166,32 +196,33 @@ public class ServletHttpHandlerAdapter implements Servlet { @Override public void onError(Throwable ex) { - ServletRequest request = getRequest(); - if (request != null && request.isAsyncStarted()) { + runIfAsyncNotComplete(() -> { logger.error("Could not complete request", ex); HttpServletResponse response = (HttpServletResponse) this.asyncContext.getResponse(); response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR); this.asyncContext.complete(); - } + }); } @Override public void onComplete() { - ServletRequest request = getRequest(); - if (request != null && request.isAsyncStarted()) { + runIfAsyncNotComplete(() -> { logger.debug("Successfully completed request"); this.asyncContext.complete(); - } + }); } - private ServletRequest getRequest() { - ServletRequest request = null; + private void runIfAsyncNotComplete(Runnable task) { try { - request = this.asyncContext.getRequest(); - } catch (IllegalStateException ignore) { - // AsyncContext has been recycled and should not be used + if (this.asyncContext.getRequest().isAsyncStarted()) { + task.run(); + } + } + catch (IllegalStateException ex) { + // Ignore: + // AsyncContext recycled and should not be used + // e.g. TIMEOUT_LISTENER (above) may have completed the AsyncContext } - return request; } } diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java index e248a0eb81..63854ef3dc 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java @@ -180,13 +180,11 @@ public class ServletServerHttpResponse extends AbstractListenerServerHttpRespons Throwable ex = event.getThrowable(); ex = (ex != null ? ex : new IllegalStateException("Async operation timeout.")); handleError(ex); - event.getAsyncContext().complete(); } @Override public void onError(AsyncEvent event) { handleError(event.getThrowable()); - event.getAsyncContext().complete(); } void handleError(Throwable ex) {