diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitter.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitter.java index 9b12fb0955..2973ef9695 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitter.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitter.java @@ -69,9 +69,9 @@ public class ResponseBodyEmitter { private Throwable failure; - private Runnable timeoutCallback; + private final DefaultCallback timeoutCallback = new DefaultCallback(); - private Runnable completionCallback; + private final DefaultCallback completionCallback = new DefaultCallback(); /** @@ -126,11 +126,8 @@ public class ResponseBodyEmitter { this.handler.complete(); } } - - if (this.timeoutCallback != null) { + else { this.handler.onTimeout(this.timeoutCallback); - } - if (this.completionCallback != null) { this.handler.onCompletion(this.completionCallback); } } @@ -168,11 +165,11 @@ public class ResponseBodyEmitter { this.handler.send(object, mediaType); } catch (IOException ex) { - this.handler.completeWithError(ex); + completeWithError(ex); throw ex; } catch (Throwable ex) { - this.handler.completeWithError(ex); + completeWithError(ex); throw new IllegalStateException("Failed to send " + object, ex); } } @@ -212,10 +209,7 @@ public class ResponseBodyEmitter { * called from a container thread when an async request times out. */ public synchronized void onTimeout(Runnable callback) { - this.timeoutCallback = callback; - if (this.handler != null) { - this.handler.onTimeout(callback); - } + this.timeoutCallback.setDelegate(callback); } /** @@ -225,10 +219,7 @@ public class ResponseBodyEmitter { * detecting that a {@code ResponseBodyEmitter} instance is no longer usable. */ public synchronized void onCompletion(Runnable callback) { - this.completionCallback = callback; - if (this.handler != null) { - this.handler.onCompletion(callback); - } + this.completionCallback.setDelegate(callback); } @@ -272,4 +263,22 @@ public class ResponseBodyEmitter { } } + private class DefaultCallback implements Runnable { + + private Runnable delegate; + + + public void setDelegate(Runnable delegate) { + this.delegate = delegate; + } + + @Override + public void run() { + ResponseBodyEmitter.this.complete = true; + if (this.delegate != null) { + this.delegate.run(); + } + } + } + } diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterTests.java index f7112d3af0..44a22bb6ca 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterTests.java @@ -20,12 +20,14 @@ import java.io.IOException; import org.junit.Before; import org.junit.Test; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import static org.mockito.Mockito.*; import org.mockito.MockitoAnnotations; import org.springframework.http.MediaType; +import org.springframework.util.Assert; import static org.junit.Assert.fail; @@ -101,6 +103,8 @@ public class ResponseBodyEmitterTests { @Test public void sendAfterHandlerInitialized() throws Exception { this.emitter.initialize(this.handler); + verify(this.handler).onTimeout(any()); + verify(this.handler).onCompletion(any()); verifyNoMoreInteractions(this.handler); this.emitter.send("foo", MediaType.TEXT_PLAIN); @@ -116,6 +120,8 @@ public class ResponseBodyEmitterTests { @Test public void sendAfterHandlerInitializedWithError() throws Exception { this.emitter.initialize(this.handler); + verify(this.handler).onTimeout(any()); + verify(this.handler).onCompletion(any()); verifyNoMoreInteractions(this.handler); IllegalStateException ex = new IllegalStateException(); @@ -132,6 +138,8 @@ public class ResponseBodyEmitterTests { @Test public void sendWithError() throws Exception { this.emitter.initialize(this.handler); + verify(this.handler).onTimeout(any()); + verify(this.handler).onCompletion(any()); verifyNoMoreInteractions(this.handler); IOException failure = new IOException(); @@ -153,15 +161,30 @@ public class ResponseBodyEmitterTests { Runnable runnable = mock(Runnable.class); this.emitter.onTimeout(runnable); this.emitter.initialize(this.handler); - verify(this.handler).onTimeout(runnable); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Runnable.class); + verify(this.handler).onTimeout(captor.capture()); + verify(this.handler).onCompletion(any()); + + Assert.notNull(captor.getValue()); + captor.getValue().run(); + verify(runnable).run(); } @Test public void onTimeoutAfterHandlerInitialized() throws Exception { - Runnable runnable = mock(Runnable.class); this.emitter.initialize(this.handler); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Runnable.class); + verify(this.handler).onTimeout(captor.capture()); + verify(this.handler).onCompletion(any()); + + Runnable runnable = mock(Runnable.class); this.emitter.onTimeout(runnable); - verify(this.handler).onTimeout(runnable); + + Assert.notNull(captor.getValue()); + captor.getValue().run(); + verify(runnable).run(); } @Test @@ -169,15 +192,30 @@ public class ResponseBodyEmitterTests { Runnable runnable = mock(Runnable.class); this.emitter.onCompletion(runnable); this.emitter.initialize(this.handler); - verify(this.handler).onCompletion(runnable); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Runnable.class); + verify(this.handler).onTimeout(any()); + verify(this.handler).onCompletion(captor.capture()); + + Assert.notNull(captor.getValue()); + captor.getValue().run(); + verify(runnable).run(); } @Test public void onCompletionAfterHandlerInitialized() throws Exception { - Runnable runnable = mock(Runnable.class); this.emitter.initialize(this.handler); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Runnable.class); + verify(this.handler).onTimeout(any()); + verify(this.handler).onCompletion(captor.capture()); + + Runnable runnable = mock(Runnable.class); this.emitter.onCompletion(runnable); - verify(this.handler).onCompletion(runnable); + + Assert.notNull(captor.getValue()); + captor.getValue().run(); + verify(runnable).run(); } }