From 6cc512b51cfd0652fec982c5352905d5cd060e83 Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Fri, 20 Jul 2012 10:54:58 -0400 Subject: [PATCH] Ensure async Callables are in sync with the call stack After this change each call stack level pushes and pops an async Callable to ensure the AsyncExecutionChain is in sync with the call stack. Before this change, a controller returning a "forward:" prefixed string caused the AsyncExecutionChain to contain a extra Callables that did not match the actual call stack. Issue: SPR-9611 --- .../support/OpenSessionInViewFilter.java | 9 +- .../support/OpenSessionInViewInterceptor.java | 2 +- .../support/OpenSessionInViewFilter.java | 6 +- .../support/OpenSessionInViewInterceptor.java | 2 +- .../support/OpenSessionInViewTests.java | 6 +- .../async/AbstractDelegatingCallable.java | 29 ++-- .../request/async/AsyncExecutionChain.java | 153 +++++++++--------- .../async/AsyncExecutionChainRunnable.java | 2 +- .../StaleAsyncRequestCheckingCallable.java | 2 +- .../async/StandardServletAsyncWebRequest.java | 6 +- .../filter/AbstractRequestLoggingFilter.java | 6 +- .../web/filter/OncePerRequestFilter.java | 6 +- .../web/filter/RequestContextFilter.java | 6 +- .../web/filter/ShallowEtagHeaderFilter.java | 6 +- .../async/AsyncExecutionChainTests.java | 35 ++-- ...taleAsyncRequestCheckingCallableTests.java | 2 +- .../StandardServletAsyncWebRequestTests.java | 2 +- .../web/servlet/AsyncHandlerInterceptor.java | 16 +- .../web/servlet/DispatcherServlet.java | 40 ++--- .../web/servlet/FrameworkServlet.java | 6 +- .../web/servlet/HandlerExecutionChain.java | 85 +++++----- .../WebRequestHandlerInterceptorAdapter.java | 2 +- .../AsyncMethodReturnValueHandler.java | 4 +- .../RequestMappingHandlerAdapter.java | 15 +- .../ServletInvocableHandlerMethod.java | 15 +- .../servlet/HandlerExecutionChainTests.java | 15 +- src/dist/changelog.txt | 1 + 27 files changed, 240 insertions(+), 239 deletions(-) diff --git a/spring-orm/src/main/java/org/springframework/orm/hibernate3/support/OpenSessionInViewFilter.java b/spring-orm/src/main/java/org/springframework/orm/hibernate3/support/OpenSessionInViewFilter.java index 30cdc55f53..6c43e78f95 100644 --- a/spring-orm/src/main/java/org/springframework/orm/hibernate3/support/OpenSessionInViewFilter.java +++ b/spring-orm/src/main/java/org/springframework/orm/hibernate3/support/OpenSessionInViewFilter.java @@ -187,7 +187,7 @@ public class OpenSessionInViewFilter extends OncePerRequestFilter { SessionHolder sessionHolder = new SessionHolder(session); TransactionSynchronizationManager.bindResource(sessionFactory, sessionHolder); - chain.addDelegatingCallable(getAsyncCallable(request, sessionFactory, sessionHolder)); + chain.push(getAsyncCallable(request, sessionFactory, sessionHolder)); } } else { @@ -204,21 +204,20 @@ public class OpenSessionInViewFilter extends OncePerRequestFilter { try { filterChain.doFilter(request, response); } - finally { if (!participate) { if (isSingleSession()) { // single session mode SessionHolder sessionHolder = (SessionHolder) TransactionSynchronizationManager.unbindResource(sessionFactory); - if (chain.isAsyncStarted()) { + if (!chain.pop()) { return; } logger.debug("Closing single Hibernate Session in OpenSessionInViewFilter"); closeSession(sessionHolder.getSession(), sessionFactory); } else { - if (chain.isAsyncStarted()) { + if (!chain.pop()) { throw new IllegalStateException("Deferred close is not supported with async requests."); } // deferred close mode @@ -303,7 +302,7 @@ public class OpenSessionInViewFilter extends OncePerRequestFilter { public Object call() throws Exception { TransactionSynchronizationManager.bindResource(sessionFactory, sessionHolder); try { - getNextCallable().call(); + getNext().call(); } finally { SessionHolder sessionHolder = diff --git a/spring-orm/src/main/java/org/springframework/orm/hibernate3/support/OpenSessionInViewInterceptor.java b/spring-orm/src/main/java/org/springframework/orm/hibernate3/support/OpenSessionInViewInterceptor.java index 5eccff677b..253442389e 100644 --- a/spring-orm/src/main/java/org/springframework/orm/hibernate3/support/OpenSessionInViewInterceptor.java +++ b/spring-orm/src/main/java/org/springframework/orm/hibernate3/support/OpenSessionInViewInterceptor.java @@ -181,7 +181,7 @@ public class OpenSessionInViewInterceptor extends HibernateAccessor implements A return new AbstractDelegatingCallable() { public Object call() throws Exception { TransactionSynchronizationManager.bindResource(getSessionFactory(), sessionHolder); - getNextCallable().call(); + getNext().call(); return null; } }; diff --git a/spring-orm/src/main/java/org/springframework/orm/hibernate4/support/OpenSessionInViewFilter.java b/spring-orm/src/main/java/org/springframework/orm/hibernate4/support/OpenSessionInViewFilter.java index c8dd223406..9f1f0b10de 100644 --- a/spring-orm/src/main/java/org/springframework/orm/hibernate4/support/OpenSessionInViewFilter.java +++ b/spring-orm/src/main/java/org/springframework/orm/hibernate4/support/OpenSessionInViewFilter.java @@ -119,7 +119,7 @@ public class OpenSessionInViewFilter extends OncePerRequestFilter { SessionHolder sessionHolder = new SessionHolder(session); TransactionSynchronizationManager.bindResource(sessionFactory, sessionHolder); - chain.addDelegatingCallable(getAsyncCallable(request, sessionFactory, sessionHolder)); + chain.push(getAsyncCallable(request, sessionFactory, sessionHolder)); } try { @@ -130,7 +130,7 @@ public class OpenSessionInViewFilter extends OncePerRequestFilter { if (!participate) { SessionHolder sessionHolder = (SessionHolder) TransactionSynchronizationManager.unbindResource(sessionFactory); - if (chain.isAsyncStarted()) { + if (!chain.pop()) { return; } logger.debug("Closing Hibernate Session in OpenSessionInViewFilter"); @@ -198,7 +198,7 @@ public class OpenSessionInViewFilter extends OncePerRequestFilter { public Object call() throws Exception { TransactionSynchronizationManager.bindResource(sessionFactory, sessionHolder); try { - getNextCallable().call(); + getNext().call(); } finally { SessionHolder sessionHolder = diff --git a/spring-orm/src/main/java/org/springframework/orm/hibernate4/support/OpenSessionInViewInterceptor.java b/spring-orm/src/main/java/org/springframework/orm/hibernate4/support/OpenSessionInViewInterceptor.java index 66ceffaaf7..7bd43ac939 100644 --- a/spring-orm/src/main/java/org/springframework/orm/hibernate4/support/OpenSessionInViewInterceptor.java +++ b/spring-orm/src/main/java/org/springframework/orm/hibernate4/support/OpenSessionInViewInterceptor.java @@ -137,7 +137,7 @@ public class OpenSessionInViewInterceptor implements AsyncWebRequestInterceptor return new AbstractDelegatingCallable() { public Object call() throws Exception { TransactionSynchronizationManager.bindResource(getSessionFactory(), sessionHolder); - getNextCallable().call(); + getNext().call(); return null; } }; diff --git a/spring-orm/src/test/java/org/springframework/orm/hibernate3/support/OpenSessionInViewTests.java b/spring-orm/src/test/java/org/springframework/orm/hibernate3/support/OpenSessionInViewTests.java index 82779fc09b..b31747558d 100644 --- a/spring-orm/src/test/java/org/springframework/orm/hibernate3/support/OpenSessionInViewTests.java +++ b/spring-orm/src/test/java/org/springframework/orm/hibernate3/support/OpenSessionInViewTests.java @@ -176,7 +176,7 @@ public class OpenSessionInViewTests { verify(sf); verify(session); - asyncCallable.setNextCallable(new Callable() { + asyncCallable.setNext(new Callable() { public Object call() { return null; } @@ -484,7 +484,7 @@ public class OpenSessionInViewTests { verify(asyncWebRequest); chain.setTaskExecutor(new SyncTaskExecutor()); - chain.setCallable(new Callable() { + chain.setLastCallable(new Callable() { public Object call() { assertTrue(TransactionSynchronizationManager.hasResource(sf)); return null; @@ -503,7 +503,7 @@ public class OpenSessionInViewTests { replay(sf); replay(session); - chain.startCallableChainProcessing(); + chain.startCallableProcessing(); assertFalse(TransactionSynchronizationManager.hasResource(sf)); verify(sf); diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/AbstractDelegatingCallable.java b/spring-web/src/main/java/org/springframework/web/context/request/async/AbstractDelegatingCallable.java index 352616c722..265f722c0f 100644 --- a/spring-web/src/main/java/org/springframework/web/context/request/async/AbstractDelegatingCallable.java +++ b/spring-web/src/main/java/org/springframework/web/context/request/async/AbstractDelegatingCallable.java @@ -19,20 +19,11 @@ package org.springframework.web.context.request.async; import java.util.concurrent.Callable; /** - * A base class for a Callable that can be used in a chain of Callable instances. - * - *

Typical use for async request processing scenarios involves: - *

    - *
  • Create an instance of this type and register it via - * {@link AsyncExecutionChain#addDelegatingCallable(AbstractDelegatingCallable)} - * (internally the nodes of the chain will be linked so no need to set up "next"). - *
  • Provide an implementation of {@link Callable#call()} that contains the - * logic needed to complete request processing outside the main processing thread. - *
  • In the implementation, delegate to the next Callable to obtain - * its result, e.g. ModelAndView, and then do some post-processing, e.g. view - * resolution. In some cases both pre- and post-processing might be - * appropriate, e.g. setting up {@link ThreadLocal} storage. - *
+ * A base class for a Callable used to form a chain of Callable instances. + * Instances of this class are typically registered via + * {@link AsyncExecutionChain#push(AbstractDelegatingCallable)} in which case + * there is no need to set the next Callable. Implementations can simply use + * {@link #getNext()} to delegate to the next Callable and assume it will be set. * * @author Rossen Stoyanchev * @since 3.2 @@ -43,12 +34,12 @@ public abstract class AbstractDelegatingCallable implements Callable { private Callable next; - public void setNextCallable(Callable nextCallable) { - this.next = nextCallable; - } - - protected Callable getNextCallable() { + protected Callable getNext() { return this.next; } + public void setNext(Callable callable) { + this.next = callable; + } + } diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/AsyncExecutionChain.java b/spring-web/src/main/java/org/springframework/web/context/request/async/AsyncExecutionChain.java index 887a40e7dc..48d0a165cc 100644 --- a/spring-web/src/main/java/org/springframework/web/context/request/async/AsyncExecutionChain.java +++ b/spring-web/src/main/java/org/springframework/web/context/request/async/AsyncExecutionChain.java @@ -16,8 +16,8 @@ package org.springframework.web.context.request.async; -import java.util.ArrayList; -import java.util.List; +import java.util.ArrayDeque; +import java.util.Deque; import java.util.concurrent.Callable; import javax.servlet.ServletRequest; @@ -31,18 +31,20 @@ import org.springframework.web.context.request.async.DeferredResult.DeferredResu /** * The central class for managing async request processing, mainly intended as - * an SPI and typically not by non-framework classes. + * an SPI and not typically used directly by application classes. * - *

An async execution chain consists of a sequence of Callable instances and - * represents the work required to complete request processing in a separate - * thread. To construct the chain, each layer in the call stack of a normal - * request (e.g. filter, servlet) may contribute an - * {@link AbstractDelegatingCallable} when a request is being processed. - * For example the DispatcherServlet might contribute a Callable that - * performs view resolution while a HandlerAdapter might contribute a Callable - * that returns the ModelAndView, etc. The last Callable is the one that - * actually produces an application-specific value, for example the Callable - * returned by an {@code @RequestMapping} method. + *

An async execution chain consists of a sequence of Callable instances that + * represent the work required to complete request processing in a separate thread. + * To construct the chain, each level of the call stack pushes an + * {@link AbstractDelegatingCallable} during the course of a normal request and + * pops (removes) it on the way out. If async processing has not started, the pop + * operation succeeds and the processing continues as normal, or otherwise if async + * processing has begun, the main processing thread must be exited. + * + *

For example the DispatcherServlet might contribute a Callable that completes + * view resolution or the HandlerAdapter might contribute a Callable that prepares a + * ModelAndView while the last Callable in the chain is usually associated with the + * application, e.g. the return value of an {@code @RequestMapping} method. * * @author Rossen Stoyanchev * @since 3.2 @@ -51,13 +53,13 @@ public final class AsyncExecutionChain { public static final String CALLABLE_CHAIN_ATTRIBUTE = AsyncExecutionChain.class.getName() + ".CALLABLE_CHAIN"; - private final List delegatingCallables = new ArrayList(); + private final Deque callables = new ArrayDeque(); - private Callable callable; + private Callable lastCallable; private AsyncWebRequest asyncWebRequest; - private AsyncTaskExecutor taskExecutor = new SimpleAsyncTaskExecutor("AsyncExecutionChain"); + private AsyncTaskExecutor taskExecutor = new SimpleAsyncTaskExecutor("MvcAsync"); /** * Private constructor @@ -68,7 +70,7 @@ public final class AsyncExecutionChain { /** * Obtain the AsyncExecutionChain for the current request. - * Or if not found, create an instance and associate it with the request. + * Or if not found, create it and associate it with the request. */ public static AsyncExecutionChain getForCurrentRequest(ServletRequest request) { AsyncExecutionChain chain = (AsyncExecutionChain) request.getAttribute(CALLABLE_CHAIN_ATTRIBUTE); @@ -81,7 +83,7 @@ public final class AsyncExecutionChain { /** * Obtain the AsyncExecutionChain for the current request. - * Or if not found, create an instance and associate it with the request. + * Or if not found, create it and associate it with the request. */ public static AsyncExecutionChain getForCurrentRequest(WebRequest request) { int scope = RequestAttributes.SCOPE_REQUEST; @@ -94,105 +96,106 @@ public final class AsyncExecutionChain { } /** - * Provide an instance of an AsyncWebRequest. - * This property must be set before async request processing can begin. + * Provide an instance of an AsyncWebRequest -- required for async processing. */ public void setAsyncWebRequest(AsyncWebRequest asyncRequest) { + Assert.state(!isAsyncStarted(), "Cannot set AsyncWebRequest after the start of async processing."); this.asyncWebRequest = asyncRequest; } /** - * Provide an AsyncTaskExecutor to use when - * {@link #startCallableChainProcessing()} is invoked, for example when a - * controller method returns a Callable. - *

By default a {@link SimpleAsyncTaskExecutor} instance is used. + * Provide an AsyncTaskExecutor for use with {@link #startCallableProcessing()}. + *

By default a {@link SimpleAsyncTaskExecutor} instance is used. Applications are + * advised to provide a TaskExecutor configured for production use. + * @see org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter#setAsyncTaskExecutor */ public void setTaskExecutor(AsyncTaskExecutor taskExecutor) { this.taskExecutor = taskExecutor; } /** - * Whether async request processing has started through one of: - *

    - *
  • {@link #startCallableChainProcessing()} - *
  • {@link #startDeferredResultProcessing(DeferredResult)} - *
+ * Push an async Callable for the current stack level. This method should be + * invoked before delegating to the next level of the stack where async + * processing may start. + */ + public void push(AbstractDelegatingCallable callable) { + Assert.notNull(callable, "Async Callable is required"); + this.callables.addFirst(callable); + } + + /** + * Pop the Callable of the current stack level. Ensure this method is invoked + * after delegation to the next level of the stack where async processing may + * start. The pop operation succeeds if async processing did not start. + * @return {@code true} if the Callable was removed, or {@code false} + * otherwise (i.e. async started). + */ + public boolean pop() { + if (isAsyncStarted()) { + return false; + } + else { + this.callables.removeFirst(); + return true; + } + } + + /** + * Whether async request processing has started. */ public boolean isAsyncStarted() { return ((this.asyncWebRequest != null) && this.asyncWebRequest.isAsyncStarted()); } /** - * Add a Callable with logic required to complete request processing in a - * separate thread. See {@link AbstractDelegatingCallable} for details. + * Set the last Callable, e.g. the one returned by the controller. */ - public void addDelegatingCallable(AbstractDelegatingCallable callable) { + public AsyncExecutionChain setLastCallable(Callable callable) { Assert.notNull(callable, "Callable required"); - this.delegatingCallables.add(callable); - } - - /** - * Add the last Callable, for example the one returned by the controller. - * This property must be set prior to invoking - * {@link #startCallableChainProcessing()}. - */ - public AsyncExecutionChain setCallable(Callable callable) { - Assert.notNull(callable, "Callable required"); - this.callable = callable; + this.lastCallable = callable; return this; } /** - * Start the async execution chain by submitting an - * {@link AsyncExecutionChainRunnable} instance to the TaskExecutor provided via - * {@link #setTaskExecutor(AsyncTaskExecutor)} and returning immediately. - * @see AsyncExecutionChainRunnable + * Start async processing and execute the async chain with an AsyncTaskExecutor. + * This method returns immediately. */ - public void startCallableChainProcessing() { - startAsync(); + public void startCallableProcessing() { + Assert.state(this.asyncWebRequest != null, "AsyncWebRequest was not set"); + this.asyncWebRequest.startAsync(); this.taskExecutor.execute(new AsyncExecutionChainRunnable(this.asyncWebRequest, buildChain())); } - private void startAsync() { - Assert.state(this.asyncWebRequest != null, "An AsyncWebRequest is required to start async processing"); - this.asyncWebRequest.startAsync(); - } - private Callable buildChain() { - Assert.state(this.callable != null, "The last callable is required to build the async chain"); - this.delegatingCallables.add(new StaleAsyncRequestCheckingCallable(asyncWebRequest)); - Callable result = this.callable; - for (int i = this.delegatingCallables.size() - 1; i >= 0; i--) { - AbstractDelegatingCallable callable = this.delegatingCallables.get(i); - callable.setNextCallable(result); - result = callable; + Assert.state(this.lastCallable != null, "The last Callable was not set"); + AbstractDelegatingCallable head = new StaleAsyncRequestCheckingCallable(this.asyncWebRequest); + head.setNext(this.lastCallable); + for (AbstractDelegatingCallable callable : this.callables) { + callable.setNext(head); + head = callable; } - return result; + return head; } /** - * Mark the start of async request processing accepting the provided - * DeferredResult and initializing it such that if - * {@link DeferredResult#set(Object)} is called (from another thread), - * the set Object value will be processed with the execution chain by - * invoking {@link AsyncExecutionChainRunnable}. - *

The resulting processing from this method is identical to - * {@link #startCallableChainProcessing()}. The main difference is in - * the threading model, i.e. whether a TaskExecutor is used. - * @see DeferredResult + * Start async processing and initialize the given DeferredResult so when + * its value is set, the async chain is executed with an AsyncTaskExecutor. */ public void startDeferredResultProcessing(final DeferredResult deferredResult) { Assert.notNull(deferredResult, "DeferredResult is required"); - startAsync(); + Assert.state(this.asyncWebRequest != null, "AsyncWebRequest was not set"); + this.asyncWebRequest.startAsync(); + deferredResult.init(new DeferredResultHandler() { public void handle(Object result) { if (asyncWebRequest.isAsyncCompleted()) { - throw new StaleAsyncWebRequestException("Async request processing already completed"); + throw new StaleAsyncWebRequestException("Too late to set DeferredResult: " + result); } - setCallable(new PassThroughCallable(result)); - new AsyncExecutionChainRunnable(asyncWebRequest, buildChain()).run(); + setLastCallable(new PassThroughCallable(result)); + taskExecutor.execute(new AsyncExecutionChainRunnable(asyncWebRequest, buildChain())); } }); + this.asyncWebRequest.setTimeoutHandler(deferredResult.getTimeoutHandler()); } diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/AsyncExecutionChainRunnable.java b/spring-web/src/main/java/org/springframework/web/context/request/async/AsyncExecutionChainRunnable.java index 03e86b2426..14d4d4b1b3 100644 --- a/spring-web/src/main/java/org/springframework/web/context/request/async/AsyncExecutionChainRunnable.java +++ b/spring-web/src/main/java/org/springframework/web/context/request/async/AsyncExecutionChainRunnable.java @@ -30,7 +30,7 @@ import org.springframework.util.Assert; * @author Rossen Stoyanchev * @since 3.2 * - * @see AsyncExecutionChain#startCallableChainProcessing() + * @see AsyncExecutionChain#startCallableProcessing() * @see AsyncExecutionChain#startDeferredResultProcessing(DeferredResult) */ public class AsyncExecutionChainRunnable implements Runnable { diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/StaleAsyncRequestCheckingCallable.java b/spring-web/src/main/java/org/springframework/web/context/request/async/StaleAsyncRequestCheckingCallable.java index ee229349bd..42a6a4d5ec 100644 --- a/spring-web/src/main/java/org/springframework/web/context/request/async/StaleAsyncRequestCheckingCallable.java +++ b/spring-web/src/main/java/org/springframework/web/context/request/async/StaleAsyncRequestCheckingCallable.java @@ -39,7 +39,7 @@ public class StaleAsyncRequestCheckingCallable extends AbstractDelegatingCallabl } public Object call() throws Exception { - Object result = getNextCallable().call(); + Object result = getNext().call(); if (this.asyncWebRequest.isAsyncCompleted()) { throw new StaleAsyncWebRequestException( "Async request no longer available due to a timeout or a (client) error"); diff --git a/spring-web/src/main/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequest.java b/spring-web/src/main/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequest.java index 3558ae6044..3b2c9de7c2 100644 --- a/spring-web/src/main/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequest.java +++ b/spring-web/src/main/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequest.java @@ -76,8 +76,8 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements "in async request processing. This is done in Java code using the Servlet API " + "or by adding \"true\" to servlet and " + "filter declarations in web.xml."); - assertNotStale(); Assert.state(!isAsyncStarted(), "Async processing already started"); + Assert.state(!isAsyncCompleted(), "Cannot use async request that has completed"); this.asyncContext = getRequest().startAsync(getRequest(), getResponse()); this.asyncContext.addListener(this); if (this.timeout != null) { @@ -108,10 +108,6 @@ public class StandardServletAsyncWebRequest extends ServletWebRequest implements } } - private void assertNotStale() { - Assert.state(!isAsyncCompleted(), "Cannot use async request after completion"); - } - // --------------------------------------------------------------------- // Implementation of AsyncListener methods // --------------------------------------------------------------------- diff --git a/spring-web/src/main/java/org/springframework/web/filter/AbstractRequestLoggingFilter.java b/spring-web/src/main/java/org/springframework/web/filter/AbstractRequestLoggingFilter.java index 0f6ef8c341..b1ad7a9e09 100644 --- a/spring-web/src/main/java/org/springframework/web/filter/AbstractRequestLoggingFilter.java +++ b/spring-web/src/main/java/org/springframework/web/filter/AbstractRequestLoggingFilter.java @@ -196,13 +196,13 @@ public abstract class AbstractRequestLoggingFilter extends OncePerRequestFilter beforeRequest(request, getBeforeMessage(request)); AsyncExecutionChain chain = AsyncExecutionChain.getForCurrentRequest(request); - chain.addDelegatingCallable(getAsyncCallable(request)); + chain.push(getAsyncCallable(request)); try { filterChain.doFilter(request, response); } finally { - if (chain.isAsyncStarted()) { + if (!chain.pop()) { return; } afterRequest(request, getAfterMessage(request)); @@ -296,7 +296,7 @@ public abstract class AbstractRequestLoggingFilter extends OncePerRequestFilter private AbstractDelegatingCallable getAsyncCallable(final HttpServletRequest request) { return new AbstractDelegatingCallable() { public Object call() throws Exception { - getNextCallable().call(); + getNext().call(); afterRequest(request, getAfterMessage(request)); return null; } diff --git a/spring-web/src/main/java/org/springframework/web/filter/OncePerRequestFilter.java b/spring-web/src/main/java/org/springframework/web/filter/OncePerRequestFilter.java index 3b8378d771..29cc896f13 100644 --- a/spring-web/src/main/java/org/springframework/web/filter/OncePerRequestFilter.java +++ b/spring-web/src/main/java/org/springframework/web/filter/OncePerRequestFilter.java @@ -75,7 +75,7 @@ public abstract class OncePerRequestFilter extends GenericFilterBean { } else { AsyncExecutionChain chain = AsyncExecutionChain.getForCurrentRequest(request); - chain.addDelegatingCallable(getAsyncCallable(request, alreadyFilteredAttributeName)); + chain.push(getAsyncCallable(request, alreadyFilteredAttributeName)); // Do invoke this filter... request.setAttribute(alreadyFilteredAttributeName, Boolean.TRUE); @@ -83,7 +83,7 @@ public abstract class OncePerRequestFilter extends GenericFilterBean { doFilterInternal(httpRequest, httpResponse, filterChain); } finally { - if (chain.isAsyncStarted()) { + if (!chain.pop()) { return; } // Remove the "already filtered" request attribute for this request. @@ -129,7 +129,7 @@ public abstract class OncePerRequestFilter extends GenericFilterBean { return new AbstractDelegatingCallable() { public Object call() throws Exception { - getNextCallable().call(); + getNext().call(); request.removeAttribute(alreadyFilteredAttributeName); return null; } diff --git a/spring-web/src/main/java/org/springframework/web/filter/RequestContextFilter.java b/spring-web/src/main/java/org/springframework/web/filter/RequestContextFilter.java index 3ce5e4e18a..3c640fa4e1 100644 --- a/spring-web/src/main/java/org/springframework/web/filter/RequestContextFilter.java +++ b/spring-web/src/main/java/org/springframework/web/filter/RequestContextFilter.java @@ -81,14 +81,14 @@ public class RequestContextFilter extends OncePerRequestFilter { initContextHolders(request, attributes); AsyncExecutionChain chain = AsyncExecutionChain.getForCurrentRequest(request); - chain.addDelegatingCallable(getChainedCallable(request, attributes)); + chain.push(getChainedCallable(request, attributes)); try { filterChain.doFilter(request, response); } finally { resetContextHolders(); - if (chain.isAsyncStarted()) { + if (!chain.pop()) { return; } attributes.requestCompleted(); @@ -121,7 +121,7 @@ public class RequestContextFilter extends OncePerRequestFilter { public Object call() throws Exception { initContextHolders(request, requestAttributes); try { - getNextCallable().call(); + getNext().call(); } finally { resetContextHolders(); diff --git a/spring-web/src/main/java/org/springframework/web/filter/ShallowEtagHeaderFilter.java b/spring-web/src/main/java/org/springframework/web/filter/ShallowEtagHeaderFilter.java index e88948cce3..79d704b7d2 100644 --- a/spring-web/src/main/java/org/springframework/web/filter/ShallowEtagHeaderFilter.java +++ b/spring-web/src/main/java/org/springframework/web/filter/ShallowEtagHeaderFilter.java @@ -61,11 +61,11 @@ public class ShallowEtagHeaderFilter extends OncePerRequestFilter { ShallowEtagResponseWrapper responseWrapper = new ShallowEtagResponseWrapper(response); AsyncExecutionChain chain = AsyncExecutionChain.getForCurrentRequest(request); - chain.addDelegatingCallable(getAsyncCallable(request, response, responseWrapper)); + chain.push(getAsyncCallable(request, response, responseWrapper)); filterChain.doFilter(request, responseWrapper); - if (chain.isAsyncStarted()) { + if (!chain.pop()) { return; } @@ -80,7 +80,7 @@ public class ShallowEtagHeaderFilter extends OncePerRequestFilter { return new AbstractDelegatingCallable() { public Object call() throws Exception { - getNextCallable().call(); + getNext().call(); updateResponse(request, response, responseWrapper); return null; } diff --git a/spring-web/src/test/java/org/springframework/web/context/request/async/AsyncExecutionChainTests.java b/spring-web/src/test/java/org/springframework/web/context/request/async/AsyncExecutionChainTests.java index 2e7d83877d..ebbb404810 100644 --- a/spring-web/src/test/java/org/springframework/web/context/request/async/AsyncExecutionChainTests.java +++ b/spring-web/src/test/java/org/springframework/web/context/request/async/AsyncExecutionChainTests.java @@ -63,7 +63,7 @@ public class AsyncExecutionChainTests { this.chain = AsyncExecutionChain.getForCurrentRequest(this.request); this.chain.setTaskExecutor(new SyncTaskExecutor()); this.chain.setAsyncWebRequest(this.asyncWebRequest); - this.chain.addDelegatingCallable(this.resultSavingCallable); + this.chain.push(this.resultSavingCallable); } @Test @@ -79,29 +79,32 @@ public class AsyncExecutionChainTests { this.asyncWebRequest.startAsync(); assertTrue(this.chain.isAsyncStarted()); + } + @Test(expected=IllegalStateException.class) + public void setAsyncWebRequestAfterAsyncStarted() { + this.asyncWebRequest.startAsync(); this.chain.setAsyncWebRequest(null); - assertFalse(this.chain.isAsyncStarted()); } @Test public void startCallableChainProcessing() throws Exception { - this.chain.addDelegatingCallable(new IntegerIncrementingCallable()); - this.chain.addDelegatingCallable(new IntegerIncrementingCallable()); - this.chain.setCallable(new Callable() { + this.chain.push(new IntegerIncrementingCallable()); + this.chain.push(new IntegerIncrementingCallable()); + this.chain.setLastCallable(new Callable() { public Object call() throws Exception { return 1; } }); - this.chain.startCallableChainProcessing(); + this.chain.startCallableProcessing(); assertEquals(3, this.resultSavingCallable.result); } @Test public void startCallableChainProcessing_staleRequest() { - this.chain.setCallable(new Callable() { + this.chain.setLastCallable(new Callable() { public Object call() throws Exception { return 1; } @@ -109,7 +112,7 @@ public class AsyncExecutionChainTests { this.asyncWebRequest.startAsync(); this.asyncWebRequest.complete(); - this.chain.startCallableChainProcessing(); + this.chain.startCallableProcessing(); Exception ex = this.resultSavingCallable.exception; assertNotNull(ex); @@ -119,11 +122,11 @@ public class AsyncExecutionChainTests { @Test public void startCallableChainProcessing_requiredCallable() { try { - this.chain.startCallableChainProcessing(); + this.chain.startCallableProcessing(); fail("Expected exception"); } catch (IllegalStateException ex) { - assertThat(ex.getMessage(), containsString("last callable is required")); + assertEquals(ex.getMessage(), "The last Callable was not set"); } } @@ -131,18 +134,18 @@ public class AsyncExecutionChainTests { public void startCallableChainProcessing_requiredAsyncWebRequest() { this.chain.setAsyncWebRequest(null); try { - this.chain.startCallableChainProcessing(); + this.chain.startCallableProcessing(); fail("Expected exception"); } catch (IllegalStateException ex) { - assertThat(ex.getMessage(), containsString("AsyncWebRequest is required")); + assertEquals(ex.getMessage(), "AsyncWebRequest was not set"); } } @Test public void startDeferredResultProcessing() throws Exception { - this.chain.addDelegatingCallable(new IntegerIncrementingCallable()); - this.chain.addDelegatingCallable(new IntegerIncrementingCallable()); + this.chain.push(new IntegerIncrementingCallable()); + this.chain.push(new IntegerIncrementingCallable()); DeferredResult deferredResult = new DeferredResult(); this.chain.startDeferredResultProcessing(deferredResult); @@ -228,7 +231,7 @@ public class AsyncExecutionChainTests { public Object call() throws Exception { try { - this.result = getNextCallable().call(); + this.result = getNext().call(); } catch (Exception ex) { this.exception = ex; @@ -241,7 +244,7 @@ public class AsyncExecutionChainTests { private static class IntegerIncrementingCallable extends AbstractDelegatingCallable { public Object call() throws Exception { - return ((Integer) getNextCallable().call() + 1); + return ((Integer) getNext().call() + 1); } } diff --git a/spring-web/src/test/java/org/springframework/web/context/request/async/StaleAsyncRequestCheckingCallableTests.java b/spring-web/src/test/java/org/springframework/web/context/request/async/StaleAsyncRequestCheckingCallableTests.java index c676e5a47d..abd5c68bd2 100644 --- a/spring-web/src/test/java/org/springframework/web/context/request/async/StaleAsyncRequestCheckingCallableTests.java +++ b/spring-web/src/test/java/org/springframework/web/context/request/async/StaleAsyncRequestCheckingCallableTests.java @@ -39,7 +39,7 @@ public class StaleAsyncRequestCheckingCallableTests { public void setUp() { this.asyncWebRequest = EasyMock.createMock(AsyncWebRequest.class); this.callable = new StaleAsyncRequestCheckingCallable(asyncWebRequest); - this.callable.setNextCallable(new Callable() { + this.callable.setNext(new Callable() { public Object call() throws Exception { return 1; } diff --git a/spring-web/src/test/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequestTests.java b/spring-web/src/test/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequestTests.java index f7e7594218..75263147e9 100644 --- a/spring-web/src/test/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequestTests.java +++ b/spring-web/src/test/java/org/springframework/web/context/request/async/StandardServletAsyncWebRequestTests.java @@ -141,7 +141,7 @@ public class StandardServletAsyncWebRequestTests { fail("expected exception"); } catch (IllegalStateException ex) { - assertEquals("Cannot use async request after completion", ex.getMessage()); + assertEquals("Cannot use async request that has completed", ex.getMessage()); } } diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/AsyncHandlerInterceptor.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/AsyncHandlerInterceptor.java index df341c6489..167059e500 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/AsyncHandlerInterceptor.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/AsyncHandlerInterceptor.java @@ -64,19 +64,17 @@ public interface AsyncHandlerInterceptor extends HandlerInterceptor { AbstractDelegatingCallable getAsyncCallable(HttpServletRequest request, HttpServletResponse response, Object handler); /** - * Invoked after the execution of a handler if the handler started + * Invoked after the execution of a handler but only if the handler started * async processing instead of handling the request. Effectively this method - * is invoked on the way out of the main processing thread instead of - * {@link #postHandle(WebRequest, org.springframework.ui.ModelMap)}. The - * postHandle method is invoked after the request is handled - * in the async thread. - *

Implementations of this method can ensure ThreadLocal attributes bound - * to the main thread are cleared and also prepare for binding them to the - * async thread. + * is invoked instead of {@link #postHandle(WebRequest, org.springframework.ui.ModelMap)} + * on the way out of the main processing thread allowing implementations + * to ensure ThreadLocal attributes are cleared. The postHandle + * invocation is effectively delayed until after async processing when the + * request has actually been handled. * @param request current HTTP request * @param response current HTTP response * @param handler chosen handler to execute, for type and/or instance examination */ - void postHandleAsyncStarted(HttpServletRequest request, HttpServletResponse response, Object handler); + void postHandleAfterAsyncStarted(HttpServletRequest request, HttpServletResponse response, Object handler); } diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/DispatcherServlet.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/DispatcherServlet.java index 7cc2edc5ba..7119add643 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/DispatcherServlet.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/DispatcherServlet.java @@ -817,8 +817,6 @@ public class DispatcherServlet extends FrameworkServlet { @Override protected void doService(HttpServletRequest request, HttpServletResponse response) throws Exception { - AsyncExecutionChain asyncChain = AsyncExecutionChain.getForCurrentRequest(request); - if (logger.isDebugEnabled()) { String requestUri = urlPathHelper.getRequestUri(request); logger.debug("DispatcherServlet with name '" + getServletName() + "' processing " + request.getMethod() + @@ -853,13 +851,14 @@ public class DispatcherServlet extends FrameworkServlet { request.setAttribute(OUTPUT_FLASH_MAP_ATTRIBUTE, new FlashMap()); request.setAttribute(FLASH_MAP_MANAGER_ATTRIBUTE, this.flashMapManager); - asyncChain.addDelegatingCallable(getServiceAsyncCallable(request, attributesSnapshot)); + AsyncExecutionChain asyncChain = AsyncExecutionChain.getForCurrentRequest(request); + asyncChain.push(getServiceAsyncCallable(request, attributesSnapshot)); try { doDispatch(request, response); } finally { - if (asyncChain.isAsyncStarted()) { + if (!asyncChain.pop()) { return; } // Restore the original attribute snapshot, in case of an include. @@ -881,7 +880,7 @@ public class DispatcherServlet extends FrameworkServlet { logger.debug("Resuming asynchronous processing of " + request.getMethod() + " request for [" + urlPathHelper.getRequestUri(request) + "]"); } - getNextCallable().call(); + getNext().call(); if (attributesSnapshot != null) { restoreAttributesAfterInclude(request, attributesSnapshot); } @@ -904,7 +903,9 @@ public class DispatcherServlet extends FrameworkServlet { protected void doDispatch(HttpServletRequest request, HttpServletResponse response) throws Exception { HttpServletRequest processedRequest = request; HandlerExecutionChain mappedHandler = null; + AsyncExecutionChain asyncChain = AsyncExecutionChain.getForCurrentRequest(request); + boolean asyncStarted = false; try { ModelAndView mv = null; @@ -941,22 +942,23 @@ public class DispatcherServlet extends FrameworkServlet { return; } - mappedHandler.addDelegatingCallables(processedRequest, response); + mappedHandler.pushInterceptorCallables(processedRequest, response); + asyncChain.push(getDispatchAsyncCallable(mappedHandler, request, response, processedRequest)); - asyncChain.addDelegatingCallable( - getDispatchAsyncCallable(mappedHandler, request, response, processedRequest)); - - // Actually invoke the handler. - mv = ha.handle(processedRequest, response, mappedHandler.getHandler()); - - if (asyncChain.isAsyncStarted()) { - mappedHandler.applyPostHandleAsyncStarted(processedRequest, response); - logger.debug("Exiting request thread and leaving the response open"); - return; + try { + // Actually invoke the handler. + mv = ha.handle(processedRequest, response, mappedHandler.getHandler()); + } + finally { + asyncStarted = !asyncChain.pop(); + mappedHandler.popInterceptorCallables(processedRequest, response, asyncStarted); + if (asyncStarted) { + logger.debug("Exiting request thread and leaving the response open"); + return; + } } applyDefaultViewName(request, mv); - mappedHandler.applyPostHandle(processedRequest, response, mv); } catch (Exception ex) { @@ -971,7 +973,7 @@ public class DispatcherServlet extends FrameworkServlet { triggerAfterCompletionWithError(processedRequest, response, mappedHandler, err); } finally { - if (asyncChain.isAsyncStarted()) { + if (asyncStarted) { return; } // Clean up any resources used by a multipart request. @@ -1044,7 +1046,7 @@ public class DispatcherServlet extends FrameworkServlet { ModelAndView mv = null; Exception dispatchException = null; try { - mv = (ModelAndView) getNextCallable().call(); + mv = (ModelAndView) getNext().call(); applyDefaultViewName(processedRequest, mv); mappedHandler.applyPostHandle(request, response, mv); } diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/FrameworkServlet.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/FrameworkServlet.java index c50b827608..b9258a056d 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/FrameworkServlet.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/FrameworkServlet.java @@ -906,7 +906,7 @@ public abstract class FrameworkServlet extends HttpServletBean { initContextHolders(request, localeContext, requestAttributes); AsyncExecutionChain chain = AsyncExecutionChain.getForCurrentRequest(request); - chain.addDelegatingCallable(getAsyncCallable(startTime, request, response, + chain.push(getAsyncCallable(startTime, request, response, previousLocaleContext, previousAttributes, localeContext, requestAttributes)); try { @@ -917,7 +917,7 @@ public abstract class FrameworkServlet extends HttpServletBean { } finally { resetContextHolders(request, previousLocaleContext, previousAttributes); - if (chain.isAsyncStarted()) { + if (!chain.pop()) { return; } finalizeProcessing(startTime, request, response, requestAttributes, failureCause); @@ -1018,7 +1018,7 @@ public abstract class FrameworkServlet extends HttpServletBean { initContextHolders(request, localeContext, requestAttributes); Throwable unhandledFailure = null; try { - getNextCallable().call(); + getNext().call(); } catch (Throwable t) { unhandledFailure = t; diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/HandlerExecutionChain.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/HandlerExecutionChain.java index 68c8ff96c6..6509a477a3 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/HandlerExecutionChain.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/HandlerExecutionChain.java @@ -49,6 +49,8 @@ public class HandlerExecutionChain { private int interceptorIndex = -1; + private int pushedCallableCount; + /** * Create a new HandlerExecutionChain. * @param handler the handler object to execute @@ -124,9 +126,7 @@ public class HandlerExecutionChain { * next interceptor or the handler itself. Else, DispatcherServlet assumes * that this interceptor has already dealt with the response itself. */ - boolean applyPreHandle(HttpServletRequest request, HttpServletResponse response) - throws Exception { - + boolean applyPreHandle(HttpServletRequest request, HttpServletResponse response) throws Exception { if (getInterceptors() != null) { for (int i = 0; i < getInterceptors().length; i++) { HandlerInterceptor interceptor = getInterceptors()[i]; @@ -140,12 +140,31 @@ public class HandlerExecutionChain { return true; } + void pushInterceptorCallables(HttpServletRequest request, HttpServletResponse response) { + if (getInterceptors() == null) { + return; + } + for (HandlerInterceptor interceptor : getInterceptors()) { + if (interceptor instanceof AsyncHandlerInterceptor) { + try { + AsyncHandlerInterceptor asyncInterceptor = (AsyncHandlerInterceptor) interceptor; + AbstractDelegatingCallable callable = asyncInterceptor.getAsyncCallable(request, response, this.handler); + if (callable != null) { + AsyncExecutionChain.getForCurrentRequest(request).push(callable); + this.pushedCallableCount++; + } + } + catch (Throwable ex) { + logger.error("HandlerInterceptor failed to return an async Callable", ex); + } + } + } + } + /** * Apply postHandle methods of registered interceptors. */ - void applyPostHandle(HttpServletRequest request, HttpServletResponse response, ModelAndView mv) - throws Exception { - + void applyPostHandle(HttpServletRequest request, HttpServletResponse response, ModelAndView mv) throws Exception { if (getInterceptors() == null) { return; } @@ -156,50 +175,28 @@ public class HandlerExecutionChain { } /** - * Add delegating, async Callable instances to the {@link AsyncExecutionChain} - * for use in case of asynchronous request processing. + * Remove pushed callables and apply postHandleAsyncStarted callbacks. */ - void addDelegatingCallables(HttpServletRequest request, HttpServletResponse response) - throws Exception { + void popInterceptorCallables(HttpServletRequest request, HttpServletResponse response, + boolean asyncStarted) throws Exception { if (getInterceptors() == null) { return; } - for (int i = getInterceptors().length - 1; i >= 0; i--) { - HandlerInterceptor interceptor = getInterceptors()[i]; - if (interceptor instanceof AsyncHandlerInterceptor) { - try { - AsyncHandlerInterceptor asyncInterceptor = (AsyncHandlerInterceptor) interceptor; - AsyncExecutionChain chain = AsyncExecutionChain.getForCurrentRequest(request); - AbstractDelegatingCallable callable = asyncInterceptor.getAsyncCallable(request, response, this.handler); - if (callable != null) { - chain.addDelegatingCallable(callable); + AsyncExecutionChain chain = AsyncExecutionChain.getForCurrentRequest(request); + for ( ; this.pushedCallableCount > 0; this.pushedCallableCount--) { + chain.pop(); + } + if (asyncStarted) { + for (int i = getInterceptors().length - 1; i >= 0; i--) { + HandlerInterceptor interceptor = getInterceptors()[i]; + if (interceptor instanceof AsyncHandlerInterceptor) { + try { + ((AsyncHandlerInterceptor) interceptor).postHandleAfterAsyncStarted(request, response, this.handler); + } + catch (Throwable ex) { + logger.error("HandlerInterceptor.postHandleAsyncStarted(..) failed", ex); } - } - catch (Throwable ex) { - logger.error("HandlerInterceptor.addAsyncCallables threw exception", ex); - } - } - } - } - - /** - * Trigger postHandleAsyncStarted callbacks on the mapped HandlerInterceptors. - */ - void applyPostHandleAsyncStarted(HttpServletRequest request, HttpServletResponse response) - throws Exception { - - if (getInterceptors() == null) { - return; - } - for (int i = getInterceptors().length - 1; i >= 0; i--) { - HandlerInterceptor interceptor = getInterceptors()[i]; - if (interceptor instanceof AsyncHandlerInterceptor) { - try { - ((AsyncHandlerInterceptor) interceptor).postHandleAsyncStarted(request, response, this.handler); - } - catch (Throwable ex) { - logger.error("HandlerInterceptor.postHandleAsyncStarted threw exception", ex); } } } diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/WebRequestHandlerInterceptorAdapter.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/WebRequestHandlerInterceptorAdapter.java index 9585c63280..250baedcf0 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/WebRequestHandlerInterceptorAdapter.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/WebRequestHandlerInterceptorAdapter.java @@ -68,7 +68,7 @@ public class WebRequestHandlerInterceptorAdapter implements AsyncHandlerIntercep return null; } - public void postHandleAsyncStarted(HttpServletRequest request, HttpServletResponse response, Object handler) { + public void postHandleAfterAsyncStarted(HttpServletRequest request, HttpServletResponse response, Object handler) { if (this.requestInterceptor instanceof AsyncWebRequestInterceptor) { AsyncWebRequestInterceptor asyncInterceptor = (AsyncWebRequestInterceptor) this.requestInterceptor; DispatcherServletWebRequest webRequest = new DispatcherServletWebRequest(request, response); diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/AsyncMethodReturnValueHandler.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/AsyncMethodReturnValueHandler.java index d7cd9d95d8..d07ff2e3f6 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/AsyncMethodReturnValueHandler.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/AsyncMethodReturnValueHandler.java @@ -59,8 +59,8 @@ public class AsyncMethodReturnValueHandler implements HandlerMethodReturnValueHa AsyncExecutionChain chain = AsyncExecutionChain.getForCurrentRequest(servletRequest); if (Callable.class.isAssignableFrom(paramType)) { - chain.setCallable((Callable) returnValue); - chain.startCallableChainProcessing(); + chain.setLastCallable((Callable) returnValue); + chain.startCallableProcessing(); } else if (DeferredResult.class.isAssignableFrom(paramType)) { chain.startDeferredResultProcessing((DeferredResult) returnValue); diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerAdapter.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerAdapter.java index 4f2d7e1fb3..e113f147e0 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerAdapter.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/RequestMappingHandlerAdapter.java @@ -653,14 +653,17 @@ public class RequestMappingHandlerAdapter extends AbstractHandlerMethodAdapter i mavContainer.setIgnoreDefaultModelOnRedirect(this.ignoreDefaultModelOnRedirect); AsyncExecutionChain chain = AsyncExecutionChain.getForCurrentRequest(request); - chain.addDelegatingCallable(getAsyncCallable(mavContainer, modelFactory, webRequest)); chain.setAsyncWebRequest(createAsyncWebRequest(request, response)); chain.setTaskExecutor(this.taskExecutor); + chain.push(getAsyncCallable(mavContainer, modelFactory, webRequest)); - requestMappingMethod.invokeAndHandle(webRequest, mavContainer); - - if (chain.isAsyncStarted()) { - return null; + try { + requestMappingMethod.invokeAndHandle(webRequest, mavContainer); + } + finally { + if (!chain.pop()) { + return null; + } } return getModelAndView(mavContainer, modelFactory, webRequest); @@ -758,7 +761,7 @@ public class RequestMappingHandlerAdapter extends AbstractHandlerMethodAdapter i return new AbstractDelegatingCallable() { public Object call() throws Exception { - getNextCallable().call(); + getNext().call(); return getModelAndView(mavContainer, modelFactory, webRequest); } }; diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ServletInvocableHandlerMethod.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ServletInvocableHandlerMethod.java index 32d4920750..def48a6ea1 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ServletInvocableHandlerMethod.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ServletInvocableHandlerMethod.java @@ -91,9 +91,6 @@ public class ServletInvocableHandlerMethod extends InvocableHandlerMethod { public final void invokeAndHandle(ServletWebRequest webRequest, ModelAndViewContainer mavContainer, Object... providedArgs) throws Exception { - AsyncExecutionChain chain = AsyncExecutionChain.getForCurrentRequest(webRequest.getRequest()); - chain.addDelegatingCallable(geAsyncCallable(webRequest, mavContainer, providedArgs)); - Object returnValue = invokeForRequest(webRequest, mavContainer, providedArgs); setResponseStatus(webRequest); @@ -111,15 +108,21 @@ public class ServletInvocableHandlerMethod extends InvocableHandlerMethod { mavContainer.setRequestHandled(false); + AsyncExecutionChain chain = AsyncExecutionChain.getForCurrentRequest(webRequest.getRequest()); + chain.push(geAsyncCallable(webRequest, mavContainer, providedArgs)); + try { this.returnValueHandlers.handleReturnValue(returnValue, getReturnValueType(returnValue), mavContainer, webRequest); - - } catch (Exception ex) { + } + catch (Exception ex) { if (logger.isTraceEnabled()) { logger.trace(getReturnValueHandlingErrorMessage("Error handling return value", returnValue), ex); } throw ex; } + finally { + chain.pop(); + } } /** @@ -131,7 +134,7 @@ public class ServletInvocableHandlerMethod extends InvocableHandlerMethod { return new AbstractDelegatingCallable() { public Object call() throws Exception { mavContainer.setRequestHandled(false); - new CallableHandlerMethod(getNextCallable()).invokeAndHandle(webRequest, mavContainer, providedArgs); + new CallableHandlerMethod(getNext()).invokeAndHandle(webRequest, mavContainer, providedArgs); return null; } }; diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/HandlerExecutionChainTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/HandlerExecutionChainTests.java index da53c2b22d..33463473bc 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/HandlerExecutionChainTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/HandlerExecutionChainTests.java @@ -74,6 +74,10 @@ public class HandlerExecutionChainTests { expect(this.interceptor2.preHandle(this.request, this.response, this.handler)).andReturn(true); expect(this.interceptor3.preHandle(this.request, this.response, this.handler)).andReturn(true); + expect(this.interceptor1.getAsyncCallable(request, response, this.handler)).andReturn(new TestAsyncCallable()); + expect(this.interceptor2.getAsyncCallable(request, response, this.handler)).andReturn(new TestAsyncCallable()); + expect(this.interceptor3.getAsyncCallable(request, response, this.handler)).andReturn(new TestAsyncCallable()); + this.interceptor1.postHandle(this.request, this.response, this.handler, mav); this.interceptor2.postHandle(this.request, this.response, this.handler, mav); this.interceptor3.postHandle(this.request, this.response, this.handler, mav); @@ -85,6 +89,7 @@ public class HandlerExecutionChainTests { replay(this.interceptor1, this.interceptor2, this.interceptor3); this.chain.applyPreHandle(request, response); + this.chain.pushInterceptorCallables(request, response); this.chain.applyPostHandle(request, response, mav); this.chain.triggerAfterCompletion(this.request, this.response, null); @@ -103,9 +108,9 @@ public class HandlerExecutionChainTests { expect(this.interceptor2.getAsyncCallable(request, response, this.handler)).andReturn(new TestAsyncCallable()); expect(this.interceptor3.getAsyncCallable(request, response, this.handler)).andReturn(new TestAsyncCallable()); - this.interceptor1.postHandleAsyncStarted(request, response, this.handler); - this.interceptor2.postHandleAsyncStarted(request, response, this.handler); - this.interceptor3.postHandleAsyncStarted(request, response, this.handler); + this.interceptor1.postHandleAfterAsyncStarted(request, response, this.handler); + this.interceptor2.postHandleAfterAsyncStarted(request, response, this.handler); + this.interceptor3.postHandleAfterAsyncStarted(request, response, this.handler); this.interceptor1.postHandle(this.request, this.response, this.handler, mav); this.interceptor2.postHandle(this.request, this.response, this.handler, mav); @@ -118,8 +123,8 @@ public class HandlerExecutionChainTests { replay(this.interceptor1, this.interceptor2, this.interceptor3); this.chain.applyPreHandle(request, response); - this.chain.addDelegatingCallables(request, response); - this.chain.applyPostHandleAsyncStarted(request, response); + this.chain.pushInterceptorCallables(request, response); + this.chain.popInterceptorCallables(request, response, true); this.chain.applyPostHandle(request, response, mav); this.chain.triggerAfterCompletion(this.request, this.response, null); diff --git a/src/dist/changelog.txt b/src/dist/changelog.txt index e8d1a753e4..67b7c30e66 100644 --- a/src/dist/changelog.txt +++ b/src/dist/changelog.txt @@ -24,6 +24,7 @@ Changes in version 3.2 M2 (2012-08-xx) * handle BindException in DefaultHandlerExceptionResolver * parameterize DeferredResult type * use reflection to instantiate StandardServletAsyncWebRequest +* fix issue with forward before async request processing Changes in version 3.2 M1 (2012-05-28)