Allow multiple listeners on ResponseBodyEmitter

Prior to this commit, `ResponseBodyEmitter` woud accept a single
`Runnable` callback on each of its `onTimeout`, `onError` or
`onCompletion` methods. This would limit the developers' ability to
register multiple sets of callbacks: one for managing the publication of
streaming values, another one for managing other concerns like
keep-alive signals to maintain the connection.

This commit now allows multiple calls to `onTimeout`, `onError` and
`onCompletion` and will register all callbacks accordingly.

Closes gh-33356
This commit is contained in:
Brian Clozel 2024-09-06 15:30:18 +02:00
parent 2b6639e587
commit 761fb8f6c9
2 changed files with 80 additions and 16 deletions

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2023 the original author or authors. * Copyright 2002-2024 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -17,7 +17,9 @@
package org.springframework.web.servlet.mvc.method.annotation; package org.springframework.web.servlet.mvc.method.annotation;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList;
import java.util.LinkedHashSet; import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set; import java.util.Set;
import java.util.function.Consumer; import java.util.function.Consumer;
@ -59,6 +61,7 @@ import org.springframework.util.ObjectUtils;
* *
* @author Rossen Stoyanchev * @author Rossen Stoyanchev
* @author Juergen Hoeller * @author Juergen Hoeller
* @author Brian Clozel
* @since 4.2 * @since 4.2
*/ */
public class ResponseBodyEmitter { public class ResponseBodyEmitter {
@ -271,19 +274,21 @@ public class ResponseBodyEmitter {
/** /**
* Register code to invoke when the async request times out. This method is * Register code to invoke when the async request times out. This method is
* called from a container thread when an async request times out. * called from a container thread when an async request times out.
* <p>As of 6.2, one can register multiple callbacks for this event.
*/ */
public synchronized void onTimeout(Runnable callback) { public synchronized void onTimeout(Runnable callback) {
this.timeoutCallback.setDelegate(callback); this.timeoutCallback.addDelegate(callback);
} }
/** /**
* Register code to invoke for an error during async request processing. * Register code to invoke for an error during async request processing.
* This method is called from a container thread when an error occurred * This method is called from a container thread when an error occurred
* while processing an async request. * while processing an async request.
* <p>As of 6.2, one can register multiple callbacks for this event.
* @since 5.0 * @since 5.0
*/ */
public synchronized void onError(Consumer<Throwable> callback) { public synchronized void onError(Consumer<Throwable> callback) {
this.errorCallback.setDelegate(callback); this.errorCallback.addDelegate(callback);
} }
/** /**
@ -291,9 +296,10 @@ public class ResponseBodyEmitter {
* called from a container thread when an async request completed for any * called from a container thread when an async request completed for any
* reason including timeout and network error. This method is useful for * reason including timeout and network error. This method is useful for
* detecting that a {@code ResponseBodyEmitter} instance is no longer usable. * detecting that a {@code ResponseBodyEmitter} instance is no longer usable.
* <p>As of 6.2, one can register multiple callbacks for this event.
*/ */
public synchronized void onCompletion(Runnable callback) { public synchronized void onCompletion(Runnable callback) {
this.completionCallback.setDelegate(callback); this.completionCallback.addDelegate(callback);
} }
@ -363,18 +369,17 @@ public class ResponseBodyEmitter {
private class DefaultCallback implements Runnable { private class DefaultCallback implements Runnable {
@Nullable private List<Runnable> delegates = new ArrayList<>(1);
private Runnable delegate;
public void setDelegate(Runnable delegate) { public void addDelegate(Runnable delegate) {
this.delegate = delegate; this.delegates.add(delegate);
} }
@Override @Override
public void run() { public void run() {
ResponseBodyEmitter.this.complete = true; ResponseBodyEmitter.this.complete = true;
if (this.delegate != null) { for (Runnable delegate : this.delegates) {
this.delegate.run(); delegate.run();
} }
} }
} }
@ -382,18 +387,17 @@ public class ResponseBodyEmitter {
private class ErrorCallback implements Consumer<Throwable> { private class ErrorCallback implements Consumer<Throwable> {
@Nullable private List<Consumer<Throwable>> delegates = new ArrayList<>(1);
private Consumer<Throwable> delegate;
public void setDelegate(Consumer<Throwable> callback) { public void addDelegate(Consumer<Throwable> callback) {
this.delegate = callback; this.delegates.add(callback);
} }
@Override @Override
public void accept(Throwable t) { public void accept(Throwable t) {
ResponseBodyEmitter.this.complete = true; ResponseBodyEmitter.this.complete = true;
if (this.delegate != null) { for(Consumer<Throwable> delegate : this.delegates) {
this.delegate.accept(t); delegate.accept(t);
} }
} }
} }

View File

@ -17,6 +17,7 @@
package org.springframework.web.servlet.mvc.method.annotation; package org.springframework.web.servlet.mvc.method.annotation;
import java.io.IOException; import java.io.IOException;
import java.util.function.Consumer;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
@ -31,6 +32,7 @@ import static org.assertj.core.api.Assertions.assertThatIOException;
import static org.assertj.core.api.Assertions.assertThatIllegalStateException; import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anySet; import static org.mockito.ArgumentMatchers.anySet;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.BDDMockito.willThrow; import static org.mockito.BDDMockito.willThrow;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -41,6 +43,7 @@ import static org.mockito.Mockito.verifyNoMoreInteractions;
* *
* @author Rossen Stoyanchev * @author Rossen Stoyanchev
* @author Tomasz Nurkiewicz * @author Tomasz Nurkiewicz
* @author Brian Clozel
*/ */
@ExtendWith(MockitoExtension.class) @ExtendWith(MockitoExtension.class)
public class ResponseBodyEmitterTests { public class ResponseBodyEmitterTests {
@ -197,6 +200,25 @@ public class ResponseBodyEmitterTests {
verify(runnable).run(); verify(runnable).run();
} }
@Test
void multipleOnTimeoutCallbacks() throws Exception {
this.emitter.initialize(this.handler);
ArgumentCaptor<Runnable> captor = ArgumentCaptor.forClass(Runnable.class);
verify(this.handler).onTimeout(captor.capture());
verify(this.handler).onCompletion(any());
Runnable first = mock();
Runnable second = mock();
this.emitter.onTimeout(first);
this.emitter.onTimeout(second);
assertThat(captor.getValue()).isNotNull();
captor.getValue().run();
verify(first).run();
verify(second).run();
}
@Test @Test
void onCompletionBeforeHandlerInitialized() throws Exception { void onCompletionBeforeHandlerInitialized() throws Exception {
Runnable runnable = mock(); Runnable runnable = mock();
@ -228,4 +250,42 @@ public class ResponseBodyEmitterTests {
verify(runnable).run(); verify(runnable).run();
} }
@Test
void multipleOnCompletionCallbacks() throws Exception {
this.emitter.initialize(this.handler);
ArgumentCaptor<Runnable> captor = ArgumentCaptor.forClass(Runnable.class);
verify(this.handler).onTimeout(any());
verify(this.handler).onCompletion(captor.capture());
Runnable first = mock();
Runnable second = mock();
this.emitter.onCompletion(first);
this.emitter.onCompletion(second);
assertThat(captor.getValue()).isNotNull();
captor.getValue().run();
verify(first).run();
verify(second).run();
}
@Test
void multipleOnErrorCallbacks() throws Exception {
this.emitter.initialize(this.handler);
ArgumentCaptor<Consumer<Throwable>> captor = ArgumentCaptor.<Consumer<Throwable>, Consumer>forClass(Consumer.class);
verify(this.handler).onError(captor.capture());
Consumer<Throwable> first = mock();
Consumer<Throwable> second = mock();
this.emitter.onError(first);
this.emitter.onError(second);
assertThat(captor.getValue()).isNotNull();
IllegalStateException illegalStateException = new IllegalStateException();
captor.getValue().accept(illegalStateException);
verify(first).accept(eq(illegalStateException));
verify(second).accept(eq(illegalStateException));
}
} }