Add state and response wrapping to StandardServletAsyncWebRequest

The wrapped response prevents use after AsyncListener onError or completion
to ensure compliance with Servlet Spec 2.3.3.4.

The wrapped response is applied in RequestMappingHandlerAdapter.

The wrapped response raises AsyncRequestNotUsableException that is now
handled in DefaultHandlerExceptionResolver.

See gh-32340
This commit is contained in:
rstoyanchev 2024-03-01 22:31:09 +00:00
parent 380184e85a
commit 379ffac508
8 changed files with 622 additions and 31 deletions

View File

@ -0,0 +1,44 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.context.request.async;
import java.io.IOException;
/**
* Raised when the response for an asynchronous request becomes unusable as
* indicated by a write failure, or a Servlet container error notification, or
* after the async request has completed.
*
* <p>The exception relies on response wrapping, and on {@code AsyncListener}
* notifications, managed by {@link StandardServletAsyncWebRequest}.
*
* @author Rossen Stoyanchev
* @since 5.3.33
*/
@SuppressWarnings("serial")
public class AsyncRequestNotUsableException extends IOException {
public AsyncRequestNotUsableException(String message) {
super(message);
}
public AsyncRequestNotUsableException(String message, Throwable cause) {
super(message, cause);
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -17,16 +17,22 @@
package org.springframework.web.context.request.async;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.Locale;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer;
import jakarta.servlet.AsyncContext;
import jakarta.servlet.AsyncEvent;
import jakarta.servlet.AsyncListener;
import jakarta.servlet.ServletOutputStream;
import jakarta.servlet.WriteListener;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.servlet.http.HttpServletResponseWrapper;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
@ -45,8 +51,6 @@ import org.springframework.web.context.request.ServletWebRequest;
*/
public class StandardServletAsyncWebRequest extends ServletWebRequest implements AsyncWebRequest, AsyncListener {
private final AtomicBoolean asyncCompleted = new AtomicBoolean();
private final List<Runnable> timeoutHandlers = new ArrayList<>();
private final List<Consumer<Throwable>> exceptionHandlers = new ArrayList<>();
@ -59,6 +63,10 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
@Nullable
private AsyncContext asyncContext;
private State state;
private final ReentrantLock stateLock = new ReentrantLock();
/**
* Create a new instance for the given request/response pair.
@ -66,7 +74,26 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
* @param response current HTTP response
*/
public StandardServletAsyncWebRequest(HttpServletRequest request, HttpServletResponse response) {
super(request, response);
this(request, response, null);
}
/**
* Constructor to wrap the request and response for the current dispatch that
* also picks up the state of the last (probably the REQUEST) dispatch.
* @param request current HTTP request
* @param response current HTTP response
* @param previousRequest the existing request from the last dispatch
* @since 5.3.33
*/
StandardServletAsyncWebRequest(HttpServletRequest request, HttpServletResponse response,
@Nullable StandardServletAsyncWebRequest previousRequest) {
super(request, new LifecycleHttpServletResponse(response));
this.state = (previousRequest != null ? previousRequest.state : State.NEW);
//noinspection DataFlowIssue
((LifecycleHttpServletResponse) getResponse()).setAsyncWebRequest(this);
}
@ -107,7 +134,7 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
*/
@Override
public boolean isAsyncComplete() {
return this.asyncCompleted.get();
return (this.state == State.COMPLETED);
}
@Override
@ -117,11 +144,18 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
"in async request processing. This is done in Java code using the Servlet API " +
"or by adding \"<async-supported>true</async-supported>\" to servlet and " +
"filter declarations in web.xml.");
Assert.state(!isAsyncComplete(), "Async processing has already completed");
if (isAsyncStarted()) {
return;
}
if (this.state == State.NEW) {
this.state = State.ASYNC;
}
else {
Assert.state(this.state == State.ASYNC, "Cannot start async: [" + this.state + "]");
}
this.asyncContext = getRequest().startAsync(getRequest(), getResponse());
this.asyncContext.addListener(this);
if (this.timeout != null) {
@ -131,8 +165,10 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
@Override
public void dispatch() {
Assert.state(this.asyncContext != null, "Cannot dispatch without an AsyncContext");
this.asyncContext.dispatch();
Assert.state(this.asyncContext != null, "AsyncContext not yet initialized");
if (!this.isAsyncComplete()) {
this.asyncContext.dispatch();
}
}
@ -151,14 +187,478 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements
@Override
public void onError(AsyncEvent event) throws IOException {
this.exceptionHandlers.forEach(consumer -> consumer.accept(event.getThrowable()));
this.stateLock.lock();
try {
transitionToErrorState();
Throwable ex = event.getThrowable();
this.exceptionHandlers.forEach(consumer -> consumer.accept(ex));
}
finally {
this.stateLock.unlock();
}
}
private void transitionToErrorState() {
if (!isAsyncComplete()) {
this.state = State.ERROR;
}
}
@Override
public void onComplete(AsyncEvent event) throws IOException {
this.completionHandlers.forEach(Runnable::run);
this.asyncContext = null;
this.asyncCompleted.set(true);
this.stateLock.lock();
try {
this.completionHandlers.forEach(Runnable::run);
this.asyncContext = null;
this.state = State.COMPLETED;
}
finally {
this.stateLock.unlock();
}
}
/**
* Response wrapper to wrap the output stream with {@link LifecycleServletOutputStream}.
*/
private static final class LifecycleHttpServletResponse extends HttpServletResponseWrapper {
@Nullable
private StandardServletAsyncWebRequest asyncWebRequest;
@Nullable
private ServletOutputStream outputStream;
@Nullable
private PrintWriter writer;
public LifecycleHttpServletResponse(HttpServletResponse response) {
super(response);
}
public void setAsyncWebRequest(StandardServletAsyncWebRequest asyncWebRequest) {
this.asyncWebRequest = asyncWebRequest;
}
@Override
public ServletOutputStream getOutputStream() {
if (this.outputStream == null) {
Assert.notNull(this.asyncWebRequest, "Not initialized");
this.outputStream = new LifecycleServletOutputStream(
(HttpServletResponse) getResponse(), this.asyncWebRequest);
}
return this.outputStream;
}
@Override
public PrintWriter getWriter() throws IOException {
if (this.writer == null) {
Assert.notNull(this.asyncWebRequest, "Not initialized");
this.writer = new LifecyclePrintWriter(getResponse().getWriter(), this.asyncWebRequest);
}
return this.writer;
}
}
/**
* Wraps a ServletOutputStream to prevent use after Servlet container onError
* notifications, and after async request completion.
*/
private static final class LifecycleServletOutputStream extends ServletOutputStream {
private final HttpServletResponse delegate;
private final StandardServletAsyncWebRequest asyncWebRequest;
private LifecycleServletOutputStream(
HttpServletResponse delegate, StandardServletAsyncWebRequest asyncWebRequest) {
this.delegate = delegate;
this.asyncWebRequest = asyncWebRequest;
}
@Override
public boolean isReady() {
return false;
}
@Override
public void setWriteListener(WriteListener writeListener) {
throw new UnsupportedOperationException();
}
@Override
public void write(int b) throws IOException {
obtainLockAndCheckState();
try {
this.delegate.getOutputStream().write(b);
}
catch (IOException ex) {
handleIOException(ex, "ServletOutputStream failed to write");
}
finally {
releaseLock();
}
}
public void write(byte[] buf, int offset, int len) throws IOException {
obtainLockAndCheckState();
try {
this.delegate.getOutputStream().write(buf, offset, len);
}
catch (IOException ex) {
handleIOException(ex, "ServletOutputStream failed to write");
}
finally {
releaseLock();
}
}
@Override
public void flush() throws IOException {
obtainLockAndCheckState();
try {
this.delegate.getOutputStream().flush();
}
catch (IOException ex) {
handleIOException(ex, "ServletOutputStream failed to flush");
}
finally {
releaseLock();
}
}
@Override
public void close() throws IOException {
obtainLockAndCheckState();
try {
this.delegate.getOutputStream().close();
}
catch (IOException ex) {
handleIOException(ex, "ServletOutputStream failed to close");
}
finally {
releaseLock();
}
}
private void obtainLockAndCheckState() throws AsyncRequestNotUsableException {
if (state() != State.NEW) {
stateLock().lock();
if (state() != State.ASYNC) {
stateLock().unlock();
throw new AsyncRequestNotUsableException("Response not usable after " +
(state() == State.COMPLETED ?
"async request completion" : "onError notification") + ".");
}
}
}
private void handleIOException(IOException ex, String msg) throws AsyncRequestNotUsableException {
this.asyncWebRequest.transitionToErrorState();
throw new AsyncRequestNotUsableException(msg, ex);
}
private void releaseLock() {
if (state() != State.NEW) {
stateLock().unlock();
}
}
private State state() {
return this.asyncWebRequest.state;
}
private Lock stateLock() {
return this.asyncWebRequest.stateLock;
}
}
/**
* Wraps a PrintWriter to prevent use after Servlet container onError
* notifications, and after async request completion.
*/
private static final class LifecyclePrintWriter extends PrintWriter {
private final PrintWriter delegate;
private final StandardServletAsyncWebRequest asyncWebRequest;
private LifecyclePrintWriter(PrintWriter delegate, StandardServletAsyncWebRequest asyncWebRequest) {
super(delegate);
this.delegate = delegate;
this.asyncWebRequest = asyncWebRequest;
}
@Override
public void flush() {
if (tryObtainLockAndCheckState()) {
try {
this.delegate.flush();
}
finally {
releaseLock();
}
}
}
@Override
public void close() {
if (tryObtainLockAndCheckState()) {
try {
this.delegate.close();
}
finally {
releaseLock();
}
}
}
@Override
public boolean checkError() {
return this.delegate.checkError();
}
@Override
public void write(int c) {
if (tryObtainLockAndCheckState()) {
try {
this.delegate.write(c);
}
finally {
releaseLock();
}
}
}
@Override
public void write(char[] buf, int off, int len) {
if (tryObtainLockAndCheckState()) {
try {
this.delegate.write(buf, off, len);
}
finally {
releaseLock();
}
}
}
@Override
public void write(char[] buf) {
this.delegate.write(buf);
}
@Override
public void write(String s, int off, int len) {
if (tryObtainLockAndCheckState()) {
try {
this.delegate.write(s, off, len);
}
finally {
releaseLock();
}
}
}
@Override
public void write(String s) {
this.delegate.write(s);
}
private boolean tryObtainLockAndCheckState() {
if (state() == State.NEW) {
return true;
}
if (stateLock().tryLock()) {
if (state() == State.ASYNC) {
return true;
}
stateLock().unlock();
}
return false;
}
private void releaseLock() {
if (state() != State.NEW) {
stateLock().unlock();
}
}
private State state() {
return this.asyncWebRequest.state;
}
private Lock stateLock() {
return this.asyncWebRequest.stateLock;
}
// Plain delegates
@Override
public void print(boolean b) {
this.delegate.print(b);
}
@Override
public void print(char c) {
this.delegate.print(c);
}
@Override
public void print(int i) {
this.delegate.print(i);
}
@Override
public void print(long l) {
this.delegate.print(l);
}
@Override
public void print(float f) {
this.delegate.print(f);
}
@Override
public void print(double d) {
this.delegate.print(d);
}
@Override
public void print(char[] s) {
this.delegate.print(s);
}
@Override
public void print(String s) {
this.delegate.print(s);
}
@Override
public void print(Object obj) {
this.delegate.print(obj);
}
@Override
public void println() {
this.delegate.println();
}
@Override
public void println(boolean x) {
this.delegate.println(x);
}
@Override
public void println(char x) {
this.delegate.println(x);
}
@Override
public void println(int x) {
this.delegate.println(x);
}
@Override
public void println(long x) {
this.delegate.println(x);
}
@Override
public void println(float x) {
this.delegate.println(x);
}
@Override
public void println(double x) {
this.delegate.println(x);
}
@Override
public void println(char[] x) {
this.delegate.println(x);
}
@Override
public void println(String x) {
this.delegate.println(x);
}
@Override
public void println(Object x) {
this.delegate.println(x);
}
@Override
public PrintWriter printf(String format, Object... args) {
return this.delegate.printf(format, args);
}
@Override
public PrintWriter printf(Locale l, String format, Object... args) {
return this.delegate.printf(l, format, args);
}
@Override
public PrintWriter format(String format, Object... args) {
return this.delegate.format(format, args);
}
@Override
public PrintWriter format(Locale l, String format, Object... args) {
return this.delegate.format(l, format, args);
}
@Override
public PrintWriter append(CharSequence csq) {
return this.delegate.append(csq);
}
@Override
public PrintWriter append(CharSequence csq, int start, int end) {
return this.delegate.append(csq, start, end);
}
@Override
public PrintWriter append(char c) {
return this.delegate.append(c);
}
}
/**
* Represents a state for {@link StandardServletAsyncWebRequest} to be in.
* <p><pre>
* NEW
* |
* v
* ASYNC----> +
* | |
* v |
* ERROR |
* | |
* v |
* COMPLETED <--+
* </pre>
* @since 5.3.33
*/
private enum State {
/** New request (thas may not do async handling). */
NEW,
/** Async handling has started. */
ASYNC,
/** onError notification received, or ServletOutputStream failed. */
ERROR,
/** onComplete notification received. */
COMPLETED
}
}

View File

@ -132,6 +132,15 @@ public final class WebAsyncManager {
WebAsyncUtils.WEB_ASYNC_MANAGER_ATTRIBUTE, RequestAttributes.SCOPE_REQUEST));
}
/**
* Return the current {@link AsyncWebRequest}.
* @since 5.3.33
*/
@Nullable
public AsyncWebRequest getAsyncWebRequest() {
return this.asyncWebRequest;
}
/**
* Configure an AsyncTaskExecutor for use with concurrent processing via
* {@link #startCallableProcessing(Callable, Object...)}.

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -82,7 +82,10 @@ public abstract class WebAsyncUtils {
* @return an AsyncWebRequest instance (never {@code null})
*/
public static AsyncWebRequest createAsyncWebRequest(HttpServletRequest request, HttpServletResponse response) {
return new StandardServletAsyncWebRequest(request, response);
AsyncWebRequest prev = getAsyncManager(request).getAsyncWebRequest();
return (prev instanceof StandardServletAsyncWebRequest standardRequest ?
new StandardServletAsyncWebRequest(request, response, standardRequest) :
new StandardServletAsyncWebRequest(request, response));
}
}

View File

@ -94,9 +94,8 @@ class StandardServletAsyncWebRequestTests {
@Test
void startAsyncAfterCompleted() throws Exception {
this.asyncRequest.onComplete(new AsyncEvent(new MockAsyncContext(this.request, this.response)));
assertThatIllegalStateException().isThrownBy(
this.asyncRequest::startAsync)
.withMessage("Async processing has already completed");
assertThatIllegalStateException().isThrownBy(this.asyncRequest::startAsync)
.withMessage("Cannot start async: [COMPLETED]");
}
@Test

View File

@ -875,7 +875,21 @@ public class RequestMappingHandlerAdapter extends AbstractHandlerMethodAdapter
protected ModelAndView invokeHandlerMethod(HttpServletRequest request,
HttpServletResponse response, HandlerMethod handlerMethod) throws Exception {
ServletWebRequest webRequest = new ServletWebRequest(request, response);
WebAsyncManager asyncManager = WebAsyncUtils.getAsyncManager(request);
AsyncWebRequest asyncWebRequest = WebAsyncUtils.createAsyncWebRequest(request, response);
asyncWebRequest.setTimeout(this.asyncRequestTimeout);
asyncManager.setTaskExecutor(this.taskExecutor);
asyncManager.setAsyncWebRequest(asyncWebRequest);
asyncManager.registerCallableInterceptors(this.callableInterceptors);
asyncManager.registerDeferredResultInterceptors(this.deferredResultInterceptors);
// Obtain wrapped response to enforce lifecycle rule from Servlet spec, section 2.3.3.4
response = asyncWebRequest.getNativeResponse(HttpServletResponse.class);
ServletWebRequest webRequest = (asyncWebRequest instanceof ServletWebRequest ?
(ServletWebRequest) asyncWebRequest : new ServletWebRequest(request, response));
WebDataBinderFactory binderFactory = getDataBinderFactory(handlerMethod);
ModelFactory modelFactory = getModelFactory(handlerMethod, binderFactory);
@ -895,15 +909,6 @@ public class RequestMappingHandlerAdapter extends AbstractHandlerMethodAdapter
modelFactory.initModel(webRequest, mavContainer, invocableMethod);
mavContainer.setIgnoreDefaultModelOnRedirect(this.ignoreDefaultModelOnRedirect);
AsyncWebRequest asyncWebRequest = WebAsyncUtils.createAsyncWebRequest(request, response);
asyncWebRequest.setTimeout(this.asyncRequestTimeout);
WebAsyncManager asyncManager = WebAsyncUtils.getAsyncManager(request);
asyncManager.setTaskExecutor(this.taskExecutor);
asyncManager.setAsyncWebRequest(asyncWebRequest);
asyncManager.registerCallableInterceptors(this.callableInterceptors);
asyncManager.registerDeferredResultInterceptors(this.deferredResultInterceptors);
if (asyncManager.hasConcurrentResult()) {
Object result = asyncManager.getConcurrentResult();
Object[] resultContext = asyncManager.getConcurrentResultContext();

View File

@ -45,6 +45,7 @@ import org.springframework.web.bind.ServletRequestBindingException;
import org.springframework.web.bind.annotation.ModelAttribute;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestPart;
import org.springframework.web.context.request.async.AsyncRequestNotUsableException;
import org.springframework.web.context.request.async.AsyncRequestTimeoutException;
import org.springframework.web.method.annotation.HandlerMethodValidationException;
import org.springframework.web.multipart.MultipartFile;
@ -121,6 +122,10 @@ import org.springframework.web.util.WebUtils;
* <td><div class="block">400 (SC_BAD_REQUEST)</div></td>
* </tr>
* <tr class="odd-row-color">
* <td><div class="block">{@link MethodValidationException}</div></td>
* <td><div class="block">500 (SC_INTERNAL_SERVER_ERROR)</div></td>
* </tr>
* <tr class="odd-row-color">
* <td><div class="block">{@link HandlerMethodValidationException}</div></td>
* <td><div class="block">400 (SC_BAD_REQUEST)</div></td>
* </tr>
@ -136,9 +141,9 @@ import org.springframework.web.util.WebUtils;
* <td><div class="block">AsyncRequestTimeoutException</div></td>
* <td><div class="block">503 (SC_SERVICE_UNAVAILABLE)</div></td>
* </tr>
* <tr class="odd-row-color">
* <td><div class="block">{@link MethodValidationException}</div></td>
* <td><div class="block">500 (SC_INTERNAL_SERVER_ERROR)</div></td>
* <tr class="even-row-color">
* <td><div class="block">AsyncRequestNotUsableException</div></td>
* <td><div class="block">Not applicable</div></td>
* </tr>
* </tbody>
* </table>
@ -243,6 +248,10 @@ public class DefaultHandlerExceptionResolver extends AbstractHandlerExceptionRes
else if (ex instanceof BindException theEx) {
return handleBindException(theEx, request, response, handler);
}
else if (ex instanceof AsyncRequestNotUsableException) {
return handleAsyncRequestNotUsableException(
(AsyncRequestNotUsableException) ex, request, response, handler);
}
}
catch (Exception handlerEx) {
if (logger.isWarnEnabled()) {
@ -494,6 +503,24 @@ public class DefaultHandlerExceptionResolver extends AbstractHandlerExceptionRes
return null;
}
/**
* Handle the case of an I/O failure from the ServletOutputStream.
* <p>By default, do nothing since the response is not usable.
* @param ex the {@link AsyncRequestTimeoutException} to be handled
* @param request current HTTP request
* @param response current HTTP response
* @param handler the executed handler, or {@code null} if none chosen
* at the time of the exception (for example, if multipart resolution failed)
* @return an empty ModelAndView indicating the exception was handled
* @throws IOException potentially thrown from {@link HttpServletResponse#sendError}
* @since 5.3.33
*/
protected ModelAndView handleAsyncRequestNotUsableException(AsyncRequestNotUsableException ex,
HttpServletRequest request, HttpServletResponse response, @Nullable Object handler) {
return new ModelAndView();
}
/**
* Handle an {@link ErrorResponse} exception.
* <p>The default implementation sets status and the headers of the response

View File

@ -108,6 +108,10 @@ class ResponseEntityExceptionHandlerTests {
.filter(method -> method.getName().startsWith("handle") && (method.getParameterCount() == 4))
.filter(method -> !method.getName().equals("handleErrorResponse"))
.map(method -> method.getParameterTypes()[0])
.filter(exceptionType -> {
String name = exceptionType.getSimpleName();
return !name.equals("AsyncRequestNotUsableException");
})
.forEach(exceptionType -> assertThat(annotation.value())
.as("@ExceptionHandler is missing declaration for " + exceptionType.getName())
.contains((Class<Exception>) exceptionType));