diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitter.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitter.java
index b801457c0ac..e78b416d3df 100644
--- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitter.java
+++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitter.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2002-2023 the original author or authors.
+ * Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -17,7 +17,9 @@
package org.springframework.web.servlet.mvc.method.annotation;
import java.io.IOException;
+import java.util.ArrayList;
import java.util.LinkedHashSet;
+import java.util.List;
import java.util.Set;
import java.util.function.Consumer;
@@ -59,6 +61,7 @@ import org.springframework.util.ObjectUtils;
*
* @author Rossen Stoyanchev
* @author Juergen Hoeller
+ * @author Brian Clozel
* @since 4.2
*/
public class ResponseBodyEmitter {
@@ -271,19 +274,21 @@ public class ResponseBodyEmitter {
/**
* Register code to invoke when the async request times out. This method is
* called from a container thread when an async request times out.
+ *
As of 6.2, one can register multiple callbacks for this event.
*/
public synchronized void onTimeout(Runnable callback) {
- this.timeoutCallback.setDelegate(callback);
+ this.timeoutCallback.addDelegate(callback);
}
/**
* Register code to invoke for an error during async request processing.
* This method is called from a container thread when an error occurred
* while processing an async request.
+ *
As of 6.2, one can register multiple callbacks for this event.
* @since 5.0
*/
public synchronized void onError(Consumer callback) {
- this.errorCallback.setDelegate(callback);
+ this.errorCallback.addDelegate(callback);
}
/**
@@ -291,9 +296,10 @@ public class ResponseBodyEmitter {
* called from a container thread when an async request completed for any
* reason including timeout and network error. This method is useful for
* detecting that a {@code ResponseBodyEmitter} instance is no longer usable.
+ * As of 6.2, one can register multiple callbacks for this event.
*/
public synchronized void onCompletion(Runnable callback) {
- this.completionCallback.setDelegate(callback);
+ this.completionCallback.addDelegate(callback);
}
@@ -363,18 +369,17 @@ public class ResponseBodyEmitter {
private class DefaultCallback implements Runnable {
- @Nullable
- private Runnable delegate;
+ private List delegates = new ArrayList<>(1);
- public void setDelegate(Runnable delegate) {
- this.delegate = delegate;
+ public void addDelegate(Runnable delegate) {
+ this.delegates.add(delegate);
}
@Override
public void run() {
ResponseBodyEmitter.this.complete = true;
- if (this.delegate != null) {
- this.delegate.run();
+ for (Runnable delegate : this.delegates) {
+ delegate.run();
}
}
}
@@ -382,18 +387,17 @@ public class ResponseBodyEmitter {
private class ErrorCallback implements Consumer {
- @Nullable
- private Consumer delegate;
+ private List> delegates = new ArrayList<>(1);
- public void setDelegate(Consumer callback) {
- this.delegate = callback;
+ public void addDelegate(Consumer callback) {
+ this.delegates.add(callback);
}
@Override
public void accept(Throwable t) {
ResponseBodyEmitter.this.complete = true;
- if (this.delegate != null) {
- this.delegate.accept(t);
+ for(Consumer delegate : this.delegates) {
+ delegate.accept(t);
}
}
}
diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterTests.java
index 10f7bc639af..83d49befb5a 100644
--- a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterTests.java
+++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterTests.java
@@ -17,6 +17,7 @@
package org.springframework.web.servlet.mvc.method.annotation;
import java.io.IOException;
+import java.util.function.Consumer;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
@@ -31,6 +32,7 @@ import static org.assertj.core.api.Assertions.assertThatIOException;
import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anySet;
+import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.BDDMockito.willThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
@@ -41,6 +43,7 @@ import static org.mockito.Mockito.verifyNoMoreInteractions;
*
* @author Rossen Stoyanchev
* @author Tomasz Nurkiewicz
+ * @author Brian Clozel
*/
@ExtendWith(MockitoExtension.class)
public class ResponseBodyEmitterTests {
@@ -197,6 +200,25 @@ public class ResponseBodyEmitterTests {
verify(runnable).run();
}
+ @Test
+ void multipleOnTimeoutCallbacks() throws Exception {
+ this.emitter.initialize(this.handler);
+
+ ArgumentCaptor captor = ArgumentCaptor.forClass(Runnable.class);
+ verify(this.handler).onTimeout(captor.capture());
+ verify(this.handler).onCompletion(any());
+
+ Runnable first = mock();
+ Runnable second = mock();
+ this.emitter.onTimeout(first);
+ this.emitter.onTimeout(second);
+
+ assertThat(captor.getValue()).isNotNull();
+ captor.getValue().run();
+ verify(first).run();
+ verify(second).run();
+ }
+
@Test
void onCompletionBeforeHandlerInitialized() throws Exception {
Runnable runnable = mock();
@@ -228,4 +250,42 @@ public class ResponseBodyEmitterTests {
verify(runnable).run();
}
+ @Test
+ void multipleOnCompletionCallbacks() throws Exception {
+ this.emitter.initialize(this.handler);
+
+ ArgumentCaptor captor = ArgumentCaptor.forClass(Runnable.class);
+ verify(this.handler).onTimeout(any());
+ verify(this.handler).onCompletion(captor.capture());
+
+ Runnable first = mock();
+ Runnable second = mock();
+ this.emitter.onCompletion(first);
+ this.emitter.onCompletion(second);
+
+ assertThat(captor.getValue()).isNotNull();
+ captor.getValue().run();
+ verify(first).run();
+ verify(second).run();
+ }
+
+ @Test
+ void multipleOnErrorCallbacks() throws Exception {
+ this.emitter.initialize(this.handler);
+
+ ArgumentCaptor> captor = ArgumentCaptor., Consumer>forClass(Consumer.class);
+ verify(this.handler).onError(captor.capture());
+
+ Consumer first = mock();
+ Consumer second = mock();
+ this.emitter.onError(first);
+ this.emitter.onError(second);
+
+ assertThat(captor.getValue()).isNotNull();
+ IllegalStateException illegalStateException = new IllegalStateException();
+ captor.getValue().accept(illegalStateException);
+ verify(first).accept(eq(illegalStateException));
+ verify(second).accept(eq(illegalStateException));
+ }
+
}