Refactor async result handling in Spring MVC Test

This change removes the use of a CountDownLatch to wait for the
asynchronously computed controller method return value. Instead we
check in a loop every 200 milliseconds if the result has been set.
If the result is not set within the specified amount of time to wait
an IllegalStateException is raised.

Additional changes:
 - Use AtomicReference to hold the async result
 - Remove @Ignore annotations on AsyncTests methods
 - Remove checks for the presence of Servlet 3

Issue: SPR-11516
This commit is contained in:
Rossen Stoyanchev 2014-03-05 12:48:33 -05:00
parent e50cff47c1
commit 74de35df1e
7 changed files with 53 additions and 145 deletions

View File

@ -16,12 +16,11 @@
package org.springframework.test.web.servlet;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import javax.servlet.http.HttpServletRequest;
import java.util.concurrent.atomic.AtomicReference;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.util.Assert;
import org.springframework.web.servlet.FlashMap;
import org.springframework.web.servlet.HandlerInterceptor;
import org.springframework.web.servlet.ModelAndView;
@ -51,7 +50,7 @@ class DefaultMvcResult implements MvcResult {
private Exception resolvedException;
private Object asyncResult = RESULT_NONE;
private final AtomicReference<Object> asyncResult = new AtomicReference<Object>(RESULT_NONE);
private CountDownLatch asyncResultLatch;
@ -116,7 +115,7 @@ class DefaultMvcResult implements MvcResult {
}
public void setAsyncResult(Object asyncResult) {
this.asyncResult = asyncResult;
this.asyncResult.set(asyncResult);
}
@Override
@ -125,35 +124,30 @@ class DefaultMvcResult implements MvcResult {
}
@Override
public Object getAsyncResult(long timeout) {
if (this.asyncResult == RESULT_NONE) {
if ((timeout != 0) && this.mockRequest.isAsyncStarted()) {
if (timeout == -1) {
timeout = this.mockRequest.getAsyncContext().getTimeout();
public Object getAsyncResult(long timeToWait) {
if (this.mockRequest.getAsyncContext() != null) {
timeToWait = (timeToWait == -1 ? this.mockRequest.getAsyncContext().getTimeout() : timeToWait);
}
if (timeToWait > 0) {
long endTime = System.currentTimeMillis() + timeToWait;
while (System.currentTimeMillis() < endTime && this.asyncResult.get() == RESULT_NONE) {
try {
Thread.sleep(200);
}
if (!awaitAsyncResult(timeout) && this.asyncResult == RESULT_NONE) {
throw new IllegalStateException(
"Gave up waiting on async result from handler [" + this.handler + "] to complete");
catch (InterruptedException ex) {
throw new IllegalStateException("Interrupted while waiting for " +
"async result to be set for handler [" + this.handler + "]", ex);
}
}
}
return (this.asyncResult == RESULT_NONE ? null : this.asyncResult);
}
private boolean awaitAsyncResult(long timeout) {
if (this.asyncResultLatch != null) {
try {
return this.asyncResultLatch.await(timeout, TimeUnit.MILLISECONDS);
}
catch (InterruptedException e) {
return false;
}
}
return true;
}
Assert.state(this.asyncResult.get() != RESULT_NONE,
"Async result for handler [" + this.handler + "] " +
"was not set during the specified timeToWait=" + timeToWait);
public void setAsyncResultLatch(CountDownLatch asyncResultLatch) {
this.asyncResultLatch = asyncResultLatch;
return this.asyncResult.get();
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2012 the original author or authors.
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License; Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -76,24 +76,24 @@ public interface MvcResult {
FlashMap getFlashMap();
/**
* Get the result of asynchronous execution or {@code null} if concurrent
* handling did not start. This method will hold and await the completion
* of concurrent handling.
* Get the result of async execution. This method will wait for the async result
* to be set for up to the amount of time configured on the async request,
* i.e. {@link org.springframework.mock.web.MockAsyncContext#getTimeout()}.
*
* @throws IllegalStateException if concurrent handling does not complete
* within the allocated async timeout value.
* @throws IllegalStateException if the async result was not set.
*/
Object getAsyncResult();
/**
* Get the result of asynchronous execution or {@code null} if concurrent
* handling did not start. This method will wait for up to the given timeout
* for the completion of concurrent handling.
* Get the result of async execution. This method will wait for the async result
* to be set for up to the specified amount of time.
*
* @param timeout how long to wait for the async result to be set in
* milliseconds; if -1, the wait will be as long as the async timeout set
* on the Servlet request
* @param timeToWait how long to wait for the async result to be set, in
* milliseconds; if -1, then the async request timeout value is used,
* i.e.{@link org.springframework.mock.web.MockAsyncContext#getTimeout()}.
*
* @throws IllegalStateException if the async result was not set.
*/
Object getAsyncResult(long timeout);
Object getAsyncResult(long timeToWait);
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2012 the original author or authors.
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -18,14 +18,12 @@ package org.springframework.test.web.servlet;
import java.io.IOException;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.mock.web.MockAsyncContext;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.context.request.NativeWebRequest;
import org.springframework.web.context.request.async.*;
@ -57,15 +55,11 @@ final class TestDispatcherServlet extends DispatcherServlet {
@Override
protected void service(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
protected void service(HttpServletRequest request, HttpServletResponse response)
throws ServletException, IOException {
registerAsyncResultInterceptors(request);
super.service(request, response);
if (request.isAsyncStarted()) {
addAsyncResultLatch(request);
}
}
private void registerAsyncResultInterceptors(final HttpServletRequest request) {
@ -84,17 +78,6 @@ final class TestDispatcherServlet extends DispatcherServlet {
});
}
private void addAsyncResultLatch(HttpServletRequest request) {
final CountDownLatch latch = new CountDownLatch(1);
((MockAsyncContext) request.getAsyncContext()).addDispatchHandler(new Runnable() {
@Override
public void run() {
latch.countDown();
}
});
getMvcResult(request).setAsyncResultLatch(latch);
}
protected DefaultMvcResult getMvcResult(ServletRequest request) {
return (DefaultMvcResult) request.getAttribute(MockMvc.MVC_RESULT_ATTRIBUTE);
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2012 the original author or authors.
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -19,7 +19,6 @@ package org.springframework.test.web.servlet.result;
import java.util.Enumeration;
import java.util.Map;
import javax.servlet.ServletRequest;
import javax.servlet.http.HttpServletRequest;
import org.springframework.http.HttpHeaders;
@ -27,7 +26,6 @@ import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.test.web.servlet.MvcResult;
import org.springframework.test.web.servlet.ResultHandler;
import org.springframework.util.ClassUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.validation.BindingResult;
@ -48,8 +46,6 @@ import org.springframework.web.servlet.support.RequestContextUtils;
*/
public class PrintingResultHandler implements ResultHandler {
private static final boolean servlet3Present = ClassUtils.hasMethod(ServletRequest.class, "startAsync");
private final ResultValuePrinter printer;
@ -80,10 +76,8 @@ public class PrintingResultHandler implements ResultHandler {
this.printer.printHeading("Handler");
printHandler(result.getHandler(), result.getInterceptors());
if (servlet3Present) {
this.printer.printHeading("Async");
printAsyncResult(result);
}
this.printer.printHeading("Async");
printAsyncResult(result);
this.printer.printHeading("Resolved Exception");
printResolvedException(result.getResolvedException());
@ -133,11 +127,9 @@ public class PrintingResultHandler implements ResultHandler {
}
protected void printAsyncResult(MvcResult result) throws Exception {
if (servlet3Present) {
HttpServletRequest request = result.getRequest();
this.printer.printValue("Was async started", request.isAsyncStarted());
this.printer.printValue("Async result", result.getAsyncResult(0));
}
HttpServletRequest request = result.getRequest();
this.printer.printValue("Was async started", request.isAsyncStarted());
this.printer.printValue("Async result", (request.isAsyncStarted() ? result.getAsyncResult(0) : null));
}
/** Print the handler */

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2013 the original author or authors.
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -24,6 +24,7 @@ import org.junit.Before;
import org.junit.Test;
import org.springframework.mock.web.MockHttpServletRequest;
import static org.junit.Assert.assertEquals;
import static org.mockito.BDDMockito.*;
/**
@ -37,80 +38,23 @@ public class DefaultMvcResultTests {
private DefaultMvcResult mvcResult;
private CountDownLatch countDownLatch;
@Before
public void setup() {
ExtendedMockHttpServletRequest request = new ExtendedMockHttpServletRequest();
MockHttpServletRequest request = new MockHttpServletRequest();
request.setAsyncStarted(true);
this.countDownLatch = mock(CountDownLatch.class);
this.mvcResult = new DefaultMvcResult(request, null);
this.mvcResult.setAsyncResultLatch(this.countDownLatch);
}
@Test
public void getAsyncResultWithTimeout() throws Exception {
long timeout = 1234L;
given(this.countDownLatch.await(timeout, TimeUnit.MILLISECONDS)).willReturn(true);
this.mvcResult.getAsyncResult(timeout);
verify(this.countDownLatch).await(timeout, TimeUnit.MILLISECONDS);
public void getAsyncResultSuccess() throws Exception {
this.mvcResult.setAsyncResult("Foo");
assertEquals("Foo", this.mvcResult.getAsyncResult(10 * 1000));
}
@Test
public void getAsyncResultWithTimeoutNegativeOne() throws Exception {
given(this.countDownLatch.await(DEFAULT_TIMEOUT, TimeUnit.MILLISECONDS)).willReturn(true);
this.mvcResult.getAsyncResult(-1);
verify(this.countDownLatch).await(DEFAULT_TIMEOUT, TimeUnit.MILLISECONDS);
}
@Test
public void getAsyncResultWithoutTimeout() throws Exception {
given(this.countDownLatch.await(DEFAULT_TIMEOUT, TimeUnit.MILLISECONDS)).willReturn(true);
this.mvcResult.getAsyncResult();
verify(this.countDownLatch).await(DEFAULT_TIMEOUT, TimeUnit.MILLISECONDS);
}
@Test
public void getAsyncResultWithTimeoutZero() throws Exception {
@Test(expected = IllegalStateException.class)
public void getAsyncResultFailure() throws Exception {
this.mvcResult.getAsyncResult(0);
verifyZeroInteractions(this.countDownLatch);
}
@Test(expected=IllegalStateException.class)
public void getAsyncResultAndTimeOut() throws Exception {
this.mvcResult.getAsyncResult(-1);
verify(this.countDownLatch).await(DEFAULT_TIMEOUT, TimeUnit.MILLISECONDS);
}
private static class ExtendedMockHttpServletRequest extends MockHttpServletRequest {
private boolean asyncStarted;
private AsyncContext asyncContext;
public ExtendedMockHttpServletRequest() {
super();
this.asyncContext = mock(AsyncContext.class);
given(this.asyncContext.getTimeout()).willReturn(new Long(DEFAULT_TIMEOUT));
}
@Override
public void setAsyncStarted(boolean asyncStarted) {
this.asyncStarted = asyncStarted;
}
@Override
public boolean isAsyncStarted() {
return this.asyncStarted;
}
@Override
public AsyncContext getAsyncContext() {
return asyncContext;
}
}
}

View File

@ -133,7 +133,7 @@ public class StubMvcResult implements MvcResult {
}
@Override
public Object getAsyncResult(long timeout) {
public Object getAsyncResult(long timeToWait) {
return null;
}

View File

@ -59,23 +59,19 @@ public class AsyncTests {
}
@Test
@Ignore
public void testCallable() throws Exception {
MvcResult mvcResult = this.mockMvc.perform(get("/1").param("callable", "true"))
.andDo(print())
.andExpect(request().asyncStarted())
.andExpect(request().asyncResult(new Person("Joe")))
.andReturn();
this.mockMvc.perform(asyncDispatch(mvcResult))
.andDo(print())
.andExpect(status().isOk())
.andExpect(content().contentType(MediaType.APPLICATION_JSON))
.andExpect(content().string("{\"name\":\"Joe\",\"someDouble\":0.0,\"someBoolean\":false}"));
}
@Test
@Ignore
public void testDeferredResult() throws Exception {
MvcResult mvcResult = this.mockMvc.perform(get("/1").param("deferredResult", "true"))
.andExpect(request().asyncStarted())
@ -90,7 +86,6 @@ public class AsyncTests {
}
@Test
@Ignore
public void testDeferredResultWithSetValue() throws Exception {
MvcResult mvcResult = this.mockMvc.perform(get("/1").param("deferredResultWithSetValue", "true"))
.andExpect(request().asyncStarted())