diff --git a/spring-test/src/main/java/org/springframework/mock/web/MockAsyncContext.java b/spring-test/src/main/java/org/springframework/mock/web/MockAsyncContext.java index d53e97505d..ba88785f25 100644 --- a/spring-test/src/main/java/org/springframework/mock/web/MockAsyncContext.java +++ b/spring-test/src/main/java/org/springframework/mock/web/MockAsyncContext.java @@ -28,7 +28,9 @@ import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import org.apache.commons.logging.Log; import org.springframework.beans.BeanUtils; +import org.springframework.util.Assert; import org.springframework.web.util.WebUtils; /** @@ -49,6 +51,8 @@ public class MockAsyncContext implements AsyncContext { private long timeout = 10 * 1000L; // 10 seconds is Tomcat's default + private final List dispatchHandlers = new ArrayList(); + public MockAsyncContext(ServletRequest request, ServletResponse response) { this.request = (HttpServletRequest) request; @@ -56,6 +60,11 @@ public class MockAsyncContext implements AsyncContext { } + public void addDispatchHandler(Runnable handler) { + Assert.notNull(handler); + this.dispatchHandlers.add(handler); + } + @Override public ServletRequest getRequest() { return this.request; @@ -84,6 +93,9 @@ public class MockAsyncContext implements AsyncContext { @Override public void dispatch(ServletContext context, String path) { this.dispatchedPath = path; + for (Runnable r : this.dispatchHandlers) { + r.run(); + } } public String getDispatchedPath() { diff --git a/spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java b/spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java index 96389241c0..31de6bd6ee 100644 --- a/spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java +++ b/spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java @@ -25,13 +25,10 @@ import javax.servlet.ServletRequest; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import org.springframework.mock.web.MockAsyncContext; import org.springframework.web.context.WebApplicationContext; import org.springframework.web.context.request.NativeWebRequest; -import org.springframework.web.context.request.async.CallableProcessingInterceptorAdapter; -import org.springframework.web.context.request.async.DeferredResult; -import org.springframework.web.context.request.async.DeferredResultProcessingInterceptorAdapter; -import org.springframework.web.context.request.async.WebAsyncManager; -import org.springframework.web.context.request.async.WebAsyncUtils; +import org.springframework.web.context.request.async.*; import org.springframework.web.servlet.DispatcherServlet; import org.springframework.web.servlet.HandlerExecutionChain; import org.springframework.web.servlet.ModelAndView; @@ -50,6 +47,7 @@ final class TestDispatcherServlet extends DispatcherServlet { private static final String KEY = TestDispatcherServlet.class.getName() + ".interceptor"; + /** * Create a new instance with the given web application context. */ @@ -57,37 +55,44 @@ final class TestDispatcherServlet extends DispatcherServlet { super(webApplicationContext); } + @Override protected void service(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException { - CountDownLatch latch = registerAsyncInterceptors(request); - getMvcResult(request).setAsyncResultLatch(latch); + registerAsyncResultInterceptors(request); super.service(request, response); + + if (request.isAsyncStarted()) { + addAsyncResultLatch(request); + } } - private CountDownLatch registerAsyncInterceptors(final HttpServletRequest servletRequest) { - - final CountDownLatch asyncResultLatch = new CountDownLatch(1); - - WebAsyncManager asyncManager = WebAsyncUtils.getAsyncManager(servletRequest); - + private void registerAsyncResultInterceptors(final HttpServletRequest request) { + WebAsyncManager asyncManager = WebAsyncUtils.getAsyncManager(request); asyncManager.registerCallableInterceptor(KEY, new CallableProcessingInterceptorAdapter() { @Override - public void postProcess(NativeWebRequest request, Callable task, Object value) throws Exception { - getMvcResult(servletRequest).setAsyncResult(value); - asyncResultLatch.countDown(); + public void postProcess(NativeWebRequest r, Callable task, Object value) throws Exception { + getMvcResult(request).setAsyncResult(value); } }); asyncManager.registerDeferredResultInterceptor(KEY, new DeferredResultProcessingInterceptorAdapter() { @Override - public void postProcess(NativeWebRequest request, DeferredResult result, Object value) throws Exception { - getMvcResult(servletRequest).setAsyncResult(value); - asyncResultLatch.countDown(); + public void postProcess(NativeWebRequest r, DeferredResult result, Object value) throws Exception { + getMvcResult(request).setAsyncResult(value); } }); + } - return asyncResultLatch; + private void addAsyncResultLatch(HttpServletRequest request) { + final CountDownLatch latch = new CountDownLatch(1); + ((MockAsyncContext) request.getAsyncContext()).addDispatchHandler(new Runnable() { + @Override + public void run() { + latch.countDown(); + } + }); + getMvcResult(request).setAsyncResultLatch(latch); } protected DefaultMvcResult getMvcResult(ServletRequest request) {