diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitter.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitter.java index b801457c0ac..e78b416d3df 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitter.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,7 +17,9 @@ package org.springframework.web.servlet.mvc.method.annotation; import java.io.IOException; +import java.util.ArrayList; import java.util.LinkedHashSet; +import java.util.List; import java.util.Set; import java.util.function.Consumer; @@ -59,6 +61,7 @@ import org.springframework.util.ObjectUtils; * * @author Rossen Stoyanchev * @author Juergen Hoeller + * @author Brian Clozel * @since 4.2 */ public class ResponseBodyEmitter { @@ -271,19 +274,21 @@ public class ResponseBodyEmitter { /** * Register code to invoke when the async request times out. This method is * called from a container thread when an async request times out. + *

As of 6.2, one can register multiple callbacks for this event. */ public synchronized void onTimeout(Runnable callback) { - this.timeoutCallback.setDelegate(callback); + this.timeoutCallback.addDelegate(callback); } /** * Register code to invoke for an error during async request processing. * This method is called from a container thread when an error occurred * while processing an async request. + *

As of 6.2, one can register multiple callbacks for this event. * @since 5.0 */ public synchronized void onError(Consumer callback) { - this.errorCallback.setDelegate(callback); + this.errorCallback.addDelegate(callback); } /** @@ -291,9 +296,10 @@ public class ResponseBodyEmitter { * called from a container thread when an async request completed for any * reason including timeout and network error. This method is useful for * detecting that a {@code ResponseBodyEmitter} instance is no longer usable. + *

As of 6.2, one can register multiple callbacks for this event. */ public synchronized void onCompletion(Runnable callback) { - this.completionCallback.setDelegate(callback); + this.completionCallback.addDelegate(callback); } @@ -363,18 +369,17 @@ public class ResponseBodyEmitter { private class DefaultCallback implements Runnable { - @Nullable - private Runnable delegate; + private List delegates = new ArrayList<>(1); - public void setDelegate(Runnable delegate) { - this.delegate = delegate; + public void addDelegate(Runnable delegate) { + this.delegates.add(delegate); } @Override public void run() { ResponseBodyEmitter.this.complete = true; - if (this.delegate != null) { - this.delegate.run(); + for (Runnable delegate : this.delegates) { + delegate.run(); } } } @@ -382,18 +387,17 @@ public class ResponseBodyEmitter { private class ErrorCallback implements Consumer { - @Nullable - private Consumer delegate; + private List> delegates = new ArrayList<>(1); - public void setDelegate(Consumer callback) { - this.delegate = callback; + public void addDelegate(Consumer callback) { + this.delegates.add(callback); } @Override public void accept(Throwable t) { ResponseBodyEmitter.this.complete = true; - if (this.delegate != null) { - this.delegate.accept(t); + for(Consumer delegate : this.delegates) { + delegate.accept(t); } } } diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterTests.java index 10f7bc639af..83d49befb5a 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterTests.java @@ -17,6 +17,7 @@ package org.springframework.web.servlet.mvc.method.annotation; import java.io.IOException; +import java.util.function.Consumer; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -31,6 +32,7 @@ import static org.assertj.core.api.Assertions.assertThatIOException; import static org.assertj.core.api.Assertions.assertThatIllegalStateException; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anySet; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.willThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -41,6 +43,7 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; * * @author Rossen Stoyanchev * @author Tomasz Nurkiewicz + * @author Brian Clozel */ @ExtendWith(MockitoExtension.class) public class ResponseBodyEmitterTests { @@ -197,6 +200,25 @@ public class ResponseBodyEmitterTests { verify(runnable).run(); } + @Test + void multipleOnTimeoutCallbacks() throws Exception { + this.emitter.initialize(this.handler); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Runnable.class); + verify(this.handler).onTimeout(captor.capture()); + verify(this.handler).onCompletion(any()); + + Runnable first = mock(); + Runnable second = mock(); + this.emitter.onTimeout(first); + this.emitter.onTimeout(second); + + assertThat(captor.getValue()).isNotNull(); + captor.getValue().run(); + verify(first).run(); + verify(second).run(); + } + @Test void onCompletionBeforeHandlerInitialized() throws Exception { Runnable runnable = mock(); @@ -228,4 +250,42 @@ public class ResponseBodyEmitterTests { verify(runnable).run(); } + @Test + void multipleOnCompletionCallbacks() throws Exception { + this.emitter.initialize(this.handler); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Runnable.class); + verify(this.handler).onTimeout(any()); + verify(this.handler).onCompletion(captor.capture()); + + Runnable first = mock(); + Runnable second = mock(); + this.emitter.onCompletion(first); + this.emitter.onCompletion(second); + + assertThat(captor.getValue()).isNotNull(); + captor.getValue().run(); + verify(first).run(); + verify(second).run(); + } + + @Test + void multipleOnErrorCallbacks() throws Exception { + this.emitter.initialize(this.handler); + + ArgumentCaptor> captor = ArgumentCaptor., Consumer>forClass(Consumer.class); + verify(this.handler).onError(captor.capture()); + + Consumer first = mock(); + Consumer second = mock(); + this.emitter.onError(first); + this.emitter.onError(second); + + assertThat(captor.getValue()).isNotNull(); + IllegalStateException illegalStateException = new IllegalStateException(); + captor.getValue().accept(illegalStateException); + verify(first).accept(eq(illegalStateException)); + verify(second).accept(eq(illegalStateException)); + } + }