From 74de35df1ecccd13e982d6a684bddfcd69183b09 Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Wed, 5 Mar 2014 12:48:33 -0500 Subject: [PATCH] Refactor async result handling in Spring MVC Test This change removes the use of a CountDownLatch to wait for the asynchronously computed controller method return value. Instead we check in a loop every 200 milliseconds if the result has been set. If the result is not set within the specified amount of time to wait an IllegalStateException is raised. Additional changes: - Use AtomicReference to hold the async result - Remove @Ignore annotations on AsyncTests methods - Remove checks for the presence of Servlet 3 Issue: SPR-11516 --- .../test/web/servlet/DefaultMvcResult.java | 50 ++++++------- .../test/web/servlet/MvcResult.java | 26 +++---- .../web/servlet/TestDispatcherServlet.java | 23 +----- .../servlet/result/PrintingResultHandler.java | 20 ++---- .../web/servlet/DefaultMvcResultTests.java | 72 +++---------------- .../test/web/servlet/StubMvcResult.java | 2 +- .../samples/standalone/AsyncTests.java | 5 -- 7 files changed, 53 insertions(+), 145 deletions(-) diff --git a/spring-test/src/main/java/org/springframework/test/web/servlet/DefaultMvcResult.java b/spring-test/src/main/java/org/springframework/test/web/servlet/DefaultMvcResult.java index 0cb7119bee6..8c138078f64 100644 --- a/spring-test/src/main/java/org/springframework/test/web/servlet/DefaultMvcResult.java +++ b/spring-test/src/main/java/org/springframework/test/web/servlet/DefaultMvcResult.java @@ -16,12 +16,11 @@ package org.springframework.test.web.servlet; import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; - -import javax.servlet.http.HttpServletRequest; +import java.util.concurrent.atomic.AtomicReference; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.util.Assert; import org.springframework.web.servlet.FlashMap; import org.springframework.web.servlet.HandlerInterceptor; import org.springframework.web.servlet.ModelAndView; @@ -51,7 +50,7 @@ class DefaultMvcResult implements MvcResult { private Exception resolvedException; - private Object asyncResult = RESULT_NONE; + private final AtomicReference asyncResult = new AtomicReference(RESULT_NONE); private CountDownLatch asyncResultLatch; @@ -116,7 +115,7 @@ class DefaultMvcResult implements MvcResult { } public void setAsyncResult(Object asyncResult) { - this.asyncResult = asyncResult; + this.asyncResult.set(asyncResult); } @Override @@ -125,35 +124,30 @@ class DefaultMvcResult implements MvcResult { } @Override - public Object getAsyncResult(long timeout) { - if (this.asyncResult == RESULT_NONE) { - if ((timeout != 0) && this.mockRequest.isAsyncStarted()) { - if (timeout == -1) { - timeout = this.mockRequest.getAsyncContext().getTimeout(); + public Object getAsyncResult(long timeToWait) { + + if (this.mockRequest.getAsyncContext() != null) { + timeToWait = (timeToWait == -1 ? this.mockRequest.getAsyncContext().getTimeout() : timeToWait); + } + + if (timeToWait > 0) { + long endTime = System.currentTimeMillis() + timeToWait; + while (System.currentTimeMillis() < endTime && this.asyncResult.get() == RESULT_NONE) { + try { + Thread.sleep(200); } - if (!awaitAsyncResult(timeout) && this.asyncResult == RESULT_NONE) { - throw new IllegalStateException( - "Gave up waiting on async result from handler [" + this.handler + "] to complete"); + catch (InterruptedException ex) { + throw new IllegalStateException("Interrupted while waiting for " + + "async result to be set for handler [" + this.handler + "]", ex); } } } - return (this.asyncResult == RESULT_NONE ? null : this.asyncResult); - } - private boolean awaitAsyncResult(long timeout) { - if (this.asyncResultLatch != null) { - try { - return this.asyncResultLatch.await(timeout, TimeUnit.MILLISECONDS); - } - catch (InterruptedException e) { - return false; - } - } - return true; - } + Assert.state(this.asyncResult.get() != RESULT_NONE, + "Async result for handler [" + this.handler + "] " + + "was not set during the specified timeToWait=" + timeToWait); - public void setAsyncResultLatch(CountDownLatch asyncResultLatch) { - this.asyncResultLatch = asyncResultLatch; + return this.asyncResult.get(); } } diff --git a/spring-test/src/main/java/org/springframework/test/web/servlet/MvcResult.java b/spring-test/src/main/java/org/springframework/test/web/servlet/MvcResult.java index f886e7bc1f3..0aabe54563e 100644 --- a/spring-test/src/main/java/org/springframework/test/web/servlet/MvcResult.java +++ b/spring-test/src/main/java/org/springframework/test/web/servlet/MvcResult.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2012 the original author or authors. + * Copyright 2002-2014 the original author or authors. * * Licensed under the Apache License; Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -76,24 +76,24 @@ public interface MvcResult { FlashMap getFlashMap(); /** - * Get the result of asynchronous execution or {@code null} if concurrent - * handling did not start. This method will hold and await the completion - * of concurrent handling. + * Get the result of async execution. This method will wait for the async result + * to be set for up to the amount of time configured on the async request, + * i.e. {@link org.springframework.mock.web.MockAsyncContext#getTimeout()}. * - * @throws IllegalStateException if concurrent handling does not complete - * within the allocated async timeout value. + * @throws IllegalStateException if the async result was not set. */ Object getAsyncResult(); /** - * Get the result of asynchronous execution or {@code null} if concurrent - * handling did not start. This method will wait for up to the given timeout - * for the completion of concurrent handling. + * Get the result of async execution. This method will wait for the async result + * to be set for up to the specified amount of time. * - * @param timeout how long to wait for the async result to be set in - * milliseconds; if -1, the wait will be as long as the async timeout set - * on the Servlet request + * @param timeToWait how long to wait for the async result to be set, in + * milliseconds; if -1, then the async request timeout value is used, + * i.e.{@link org.springframework.mock.web.MockAsyncContext#getTimeout()}. + * + * @throws IllegalStateException if the async result was not set. */ - Object getAsyncResult(long timeout); + Object getAsyncResult(long timeToWait); } diff --git a/spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java b/spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java index 31de6bd6ee6..4ff099059f0 100644 --- a/spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java +++ b/spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2012 the original author or authors. + * Copyright 2002-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,14 +18,12 @@ package org.springframework.test.web.servlet; import java.io.IOException; import java.util.concurrent.Callable; -import java.util.concurrent.CountDownLatch; import javax.servlet.ServletException; import javax.servlet.ServletRequest; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; -import org.springframework.mock.web.MockAsyncContext; import org.springframework.web.context.WebApplicationContext; import org.springframework.web.context.request.NativeWebRequest; import org.springframework.web.context.request.async.*; @@ -57,15 +55,11 @@ final class TestDispatcherServlet extends DispatcherServlet { @Override - protected void service(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException { + protected void service(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { registerAsyncResultInterceptors(request); - super.service(request, response); - - if (request.isAsyncStarted()) { - addAsyncResultLatch(request); - } } private void registerAsyncResultInterceptors(final HttpServletRequest request) { @@ -84,17 +78,6 @@ final class TestDispatcherServlet extends DispatcherServlet { }); } - private void addAsyncResultLatch(HttpServletRequest request) { - final CountDownLatch latch = new CountDownLatch(1); - ((MockAsyncContext) request.getAsyncContext()).addDispatchHandler(new Runnable() { - @Override - public void run() { - latch.countDown(); - } - }); - getMvcResult(request).setAsyncResultLatch(latch); - } - protected DefaultMvcResult getMvcResult(ServletRequest request) { return (DefaultMvcResult) request.getAttribute(MockMvc.MVC_RESULT_ATTRIBUTE); } diff --git a/spring-test/src/main/java/org/springframework/test/web/servlet/result/PrintingResultHandler.java b/spring-test/src/main/java/org/springframework/test/web/servlet/result/PrintingResultHandler.java index 3796841d6ab..9faacfdd6ff 100644 --- a/spring-test/src/main/java/org/springframework/test/web/servlet/result/PrintingResultHandler.java +++ b/spring-test/src/main/java/org/springframework/test/web/servlet/result/PrintingResultHandler.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2012 the original author or authors. + * Copyright 2002-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,7 +19,6 @@ package org.springframework.test.web.servlet.result; import java.util.Enumeration; import java.util.Map; -import javax.servlet.ServletRequest; import javax.servlet.http.HttpServletRequest; import org.springframework.http.HttpHeaders; @@ -27,7 +26,6 @@ import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.test.web.servlet.MvcResult; import org.springframework.test.web.servlet.ResultHandler; -import org.springframework.util.ClassUtils; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.validation.BindingResult; @@ -48,8 +46,6 @@ import org.springframework.web.servlet.support.RequestContextUtils; */ public class PrintingResultHandler implements ResultHandler { - private static final boolean servlet3Present = ClassUtils.hasMethod(ServletRequest.class, "startAsync"); - private final ResultValuePrinter printer; @@ -80,10 +76,8 @@ public class PrintingResultHandler implements ResultHandler { this.printer.printHeading("Handler"); printHandler(result.getHandler(), result.getInterceptors()); - if (servlet3Present) { - this.printer.printHeading("Async"); - printAsyncResult(result); - } + this.printer.printHeading("Async"); + printAsyncResult(result); this.printer.printHeading("Resolved Exception"); printResolvedException(result.getResolvedException()); @@ -133,11 +127,9 @@ public class PrintingResultHandler implements ResultHandler { } protected void printAsyncResult(MvcResult result) throws Exception { - if (servlet3Present) { - HttpServletRequest request = result.getRequest(); - this.printer.printValue("Was async started", request.isAsyncStarted()); - this.printer.printValue("Async result", result.getAsyncResult(0)); - } + HttpServletRequest request = result.getRequest(); + this.printer.printValue("Was async started", request.isAsyncStarted()); + this.printer.printValue("Async result", (request.isAsyncStarted() ? result.getAsyncResult(0) : null)); } /** Print the handler */ diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/DefaultMvcResultTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/DefaultMvcResultTests.java index c06329b35ab..575483a61ff 100644 --- a/spring-test/src/test/java/org/springframework/test/web/servlet/DefaultMvcResultTests.java +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/DefaultMvcResultTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2013 the original author or authors. + * Copyright 2002-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,6 +24,7 @@ import org.junit.Before; import org.junit.Test; import org.springframework.mock.web.MockHttpServletRequest; +import static org.junit.Assert.assertEquals; import static org.mockito.BDDMockito.*; /** @@ -37,80 +38,23 @@ public class DefaultMvcResultTests { private DefaultMvcResult mvcResult; - private CountDownLatch countDownLatch; - @Before public void setup() { - ExtendedMockHttpServletRequest request = new ExtendedMockHttpServletRequest(); + MockHttpServletRequest request = new MockHttpServletRequest(); request.setAsyncStarted(true); - - this.countDownLatch = mock(CountDownLatch.class); - this.mvcResult = new DefaultMvcResult(request, null); - this.mvcResult.setAsyncResultLatch(this.countDownLatch); } @Test - public void getAsyncResultWithTimeout() throws Exception { - long timeout = 1234L; - given(this.countDownLatch.await(timeout, TimeUnit.MILLISECONDS)).willReturn(true); - this.mvcResult.getAsyncResult(timeout); - verify(this.countDownLatch).await(timeout, TimeUnit.MILLISECONDS); + public void getAsyncResultSuccess() throws Exception { + this.mvcResult.setAsyncResult("Foo"); + assertEquals("Foo", this.mvcResult.getAsyncResult(10 * 1000)); } - @Test - public void getAsyncResultWithTimeoutNegativeOne() throws Exception { - given(this.countDownLatch.await(DEFAULT_TIMEOUT, TimeUnit.MILLISECONDS)).willReturn(true); - this.mvcResult.getAsyncResult(-1); - verify(this.countDownLatch).await(DEFAULT_TIMEOUT, TimeUnit.MILLISECONDS); - } - - @Test - public void getAsyncResultWithoutTimeout() throws Exception { - given(this.countDownLatch.await(DEFAULT_TIMEOUT, TimeUnit.MILLISECONDS)).willReturn(true); - this.mvcResult.getAsyncResult(); - verify(this.countDownLatch).await(DEFAULT_TIMEOUT, TimeUnit.MILLISECONDS); - } - - @Test - public void getAsyncResultWithTimeoutZero() throws Exception { + @Test(expected = IllegalStateException.class) + public void getAsyncResultFailure() throws Exception { this.mvcResult.getAsyncResult(0); - verifyZeroInteractions(this.countDownLatch); - } - - @Test(expected=IllegalStateException.class) - public void getAsyncResultAndTimeOut() throws Exception { - this.mvcResult.getAsyncResult(-1); - verify(this.countDownLatch).await(DEFAULT_TIMEOUT, TimeUnit.MILLISECONDS); - } - - - private static class ExtendedMockHttpServletRequest extends MockHttpServletRequest { - - private boolean asyncStarted; - private AsyncContext asyncContext; - - public ExtendedMockHttpServletRequest() { - super(); - this.asyncContext = mock(AsyncContext.class); - given(this.asyncContext.getTimeout()).willReturn(new Long(DEFAULT_TIMEOUT)); - } - - @Override - public void setAsyncStarted(boolean asyncStarted) { - this.asyncStarted = asyncStarted; - } - - @Override - public boolean isAsyncStarted() { - return this.asyncStarted; - } - - @Override - public AsyncContext getAsyncContext() { - return asyncContext; - } } } diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/StubMvcResult.java b/spring-test/src/test/java/org/springframework/test/web/servlet/StubMvcResult.java index 60426d8075c..a05d89a9157 100644 --- a/spring-test/src/test/java/org/springframework/test/web/servlet/StubMvcResult.java +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/StubMvcResult.java @@ -133,7 +133,7 @@ public class StubMvcResult implements MvcResult { } @Override - public Object getAsyncResult(long timeout) { + public Object getAsyncResult(long timeToWait) { return null; } diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/AsyncTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/AsyncTests.java index 2c8f1f9fc41..ca39033ded7 100644 --- a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/AsyncTests.java +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/AsyncTests.java @@ -59,23 +59,19 @@ public class AsyncTests { } @Test - @Ignore public void testCallable() throws Exception { MvcResult mvcResult = this.mockMvc.perform(get("/1").param("callable", "true")) - .andDo(print()) .andExpect(request().asyncStarted()) .andExpect(request().asyncResult(new Person("Joe"))) .andReturn(); this.mockMvc.perform(asyncDispatch(mvcResult)) - .andDo(print()) .andExpect(status().isOk()) .andExpect(content().contentType(MediaType.APPLICATION_JSON)) .andExpect(content().string("{\"name\":\"Joe\",\"someDouble\":0.0,\"someBoolean\":false}")); } @Test - @Ignore public void testDeferredResult() throws Exception { MvcResult mvcResult = this.mockMvc.perform(get("/1").param("deferredResult", "true")) .andExpect(request().asyncStarted()) @@ -90,7 +86,6 @@ public class AsyncTests { } @Test - @Ignore public void testDeferredResultWithSetValue() throws Exception { MvcResult mvcResult = this.mockMvc.perform(get("/1").param("deferredResultWithSetValue", "true")) .andExpect(request().asyncStarted())