Batch SSE events writes when possible

Prior to this commit, the `SseEventBuilder` would be used to create SSE
events and write them to the connection using the `ResponseBodyEmitter`.
This would send each data item one by one, effectively writing and
flushing to the network for each. Since multiple data lines are prepared
by the `SseEventBuilder`, a typical write of an SSE event performs
multiple flushes operations.

This commit adds a method on `ResponseBodyEmitter` to perform batch
writes (given a `Set<DataWithMediaType>`) and only flush once all
elements of the set have been written.
This also applies in case of early writes, where now all buffered
elements are written then flushed altogether.

Fixes gh-30912
This commit is contained in:
Brian Clozel 2023-08-04 10:08:50 +02:00
parent 18966d048c
commit e83793ba7f
6 changed files with 108 additions and 33 deletions

View File

@ -128,9 +128,7 @@ public class ResponseBodyEmitter {
this.handler = handler; this.handler = handler;
try { try {
for (DataWithMediaType sendAttempt : this.earlySendAttempts) { sendInternal(this.earlySendAttempts);
sendInternal(sendAttempt.getData(), sendAttempt.getMediaType());
}
} }
finally { finally {
this.earlySendAttempts.clear(); this.earlySendAttempts.clear();
@ -194,11 +192,7 @@ public class ResponseBodyEmitter {
*/ */
public synchronized void send(Object object, @Nullable MediaType mediaType) throws IOException { public synchronized void send(Object object, @Nullable MediaType mediaType) throws IOException {
Assert.state(!this.complete, () -> "ResponseBodyEmitter has already completed" + Assert.state(!this.complete, () -> "ResponseBodyEmitter has already completed" +
(this.failure != null ? " with error: " + this.failure : "")); (this.failure != null ? " with error: " + this.failure : ""));
sendInternal(object, mediaType);
}
private void sendInternal(Object object, @Nullable MediaType mediaType) throws IOException {
if (this.handler != null) { if (this.handler != null) {
try { try {
this.handler.send(object, mediaType); this.handler.send(object, mediaType);
@ -217,6 +211,43 @@ public class ResponseBodyEmitter {
} }
} }
/**
* Write a set of data and MediaType pairs in a batch.
* <p>Compared to {@link #send(Object, MediaType)}, this batches the write operations
* and flushes to the network at the end.
* @param items the object and media type pairs to write
* @throws IOException raised when an I/O error occurs
* @throws java.lang.IllegalStateException wraps any other errors
* @since 6.0.12
*/
public synchronized void send(Set<DataWithMediaType> items) throws IOException {
Assert.state(!this.complete, () -> "ResponseBodyEmitter has already completed" +
(this.failure != null ? " with error: " + this.failure : ""));
sendInternal(items);
}
private void sendInternal(Set<DataWithMediaType> items) throws IOException {
if (items.isEmpty()) {
return;
}
if (this.handler != null) {
try {
this.handler.send(items);
}
catch (IOException ex) {
this.sendFailed = true;
throw ex;
}
catch (Throwable ex) {
this.sendFailed = true;
throw new IllegalStateException("Failed to send " + items, ex);
}
}
else {
this.earlySendAttempts.addAll(items);
}
}
/** /**
* Complete request processing by performing a dispatch into the servlet * Complete request processing by performing a dispatch into the servlet
* container, where Spring MVC is invoked once more, and completes the * container, where Spring MVC is invoked once more, and completes the
@ -302,8 +333,17 @@ public class ResponseBodyEmitter {
*/ */
interface Handler { interface Handler {
/**
* Immediately write and flush the given data to the network.
*/
void send(Object data, @Nullable MediaType mediaType) throws IOException; void send(Object data, @Nullable MediaType mediaType) throws IOException;
/**
* Immediately write all data items then flush to the network.
* @since 6.0.12
*/
void send(Set<DataWithMediaType> items) throws IOException;
void complete(); void complete();
void completeWithError(Throwable failure); void completeWithError(Throwable failure);

View File

@ -20,6 +20,7 @@ import java.io.IOException;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Set;
import java.util.function.Consumer; import java.util.function.Consumer;
import jakarta.servlet.ServletRequest; import jakarta.servlet.ServletRequest;
@ -202,6 +203,15 @@ public class ResponseBodyEmitterReturnValueHandler implements HandlerMethodRetur
@Override @Override
public void send(Object data, @Nullable MediaType mediaType) throws IOException { public void send(Object data, @Nullable MediaType mediaType) throws IOException {
sendInternal(data, mediaType); sendInternal(data, mediaType);
this.outputMessage.flush();
}
@Override
public void send(Set<ResponseBodyEmitter.DataWithMediaType> items) throws IOException {
for (ResponseBodyEmitter.DataWithMediaType item : items) {
sendInternal(item.getData(), item.getMediaType());
}
this.outputMessage.flush();
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@ -209,7 +219,6 @@ public class ResponseBodyEmitterReturnValueHandler implements HandlerMethodRetur
for (HttpMessageConverter<?> converter : ResponseBodyEmitterReturnValueHandler.this.sseMessageConverters) { for (HttpMessageConverter<?> converter : ResponseBodyEmitterReturnValueHandler.this.sseMessageConverters) {
if (converter.canWrite(data.getClass(), mediaType)) { if (converter.canWrite(data.getClass(), mediaType)) {
((HttpMessageConverter<T>) converter).write(data, mediaType, this.outputMessage); ((HttpMessageConverter<T>) converter).write(data, mediaType, this.outputMessage);
this.outputMessage.flush();
return; return;
} }
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2002-2021 the original author or authors. * Copyright 2002-2023 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -123,9 +123,7 @@ public class SseEmitter extends ResponseBodyEmitter {
public void send(SseEventBuilder builder) throws IOException { public void send(SseEventBuilder builder) throws IOException {
Set<DataWithMediaType> dataToSend = builder.build(); Set<DataWithMediaType> dataToSend = builder.build();
synchronized (this) { synchronized (this) {
for (DataWithMediaType entry : dataToSend) { super.send(dataToSend);
super.send(entry.getData(), entry.getMediaType());
}
} }
} }

View File

@ -365,6 +365,11 @@ public class ReactiveTypeHandlerTests {
this.values.add(data); this.values.add(data);
} }
@Override
public void send(Set<ResponseBodyEmitter.DataWithMediaType> items) throws IOException {
items.forEach(item -> this.values.add(item.getData()));
}
@Override @Override
public void complete() { public void complete() {
} }

View File

@ -30,9 +30,9 @@ import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIOException; import static org.assertj.core.api.Assertions.assertThatIOException;
import static org.assertj.core.api.Assertions.assertThatIllegalStateException; import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anySet;
import static org.mockito.BDDMockito.willThrow; import static org.mockito.BDDMockito.willThrow;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions;
@ -52,34 +52,33 @@ public class ResponseBodyEmitterTests {
@Test @Test
public void sendBeforeHandlerInitialized() throws Exception { void sendBeforeHandlerInitialized() throws Exception {
this.emitter.send("foo", MediaType.TEXT_PLAIN); this.emitter.send("foo", MediaType.TEXT_PLAIN);
this.emitter.send("bar", MediaType.TEXT_PLAIN); this.emitter.send("bar", MediaType.TEXT_PLAIN);
this.emitter.complete(); this.emitter.complete();
verifyNoMoreInteractions(this.handler); verifyNoMoreInteractions(this.handler);
this.emitter.initialize(this.handler); this.emitter.initialize(this.handler);
verify(this.handler).send("foo", MediaType.TEXT_PLAIN); verify(this.handler).send(anySet());
verify(this.handler).send("bar", MediaType.TEXT_PLAIN);
verify(this.handler).complete(); verify(this.handler).complete();
verifyNoMoreInteractions(this.handler); verifyNoMoreInteractions(this.handler);
} }
@Test @Test
public void sendDuplicateBeforeHandlerInitialized() throws Exception { void sendDuplicateBeforeHandlerInitialized() throws Exception {
this.emitter.send("foo", MediaType.TEXT_PLAIN); this.emitter.send("foo", MediaType.TEXT_PLAIN);
this.emitter.send("foo", MediaType.TEXT_PLAIN); this.emitter.send("foo", MediaType.TEXT_PLAIN);
this.emitter.complete(); this.emitter.complete();
verifyNoMoreInteractions(this.handler); verifyNoMoreInteractions(this.handler);
this.emitter.initialize(this.handler); this.emitter.initialize(this.handler);
verify(this.handler, times(2)).send("foo", MediaType.TEXT_PLAIN); verify(this.handler).send(anySet());
verify(this.handler).complete(); verify(this.handler).complete();
verifyNoMoreInteractions(this.handler); verifyNoMoreInteractions(this.handler);
} }
@Test @Test
public void sendBeforeHandlerInitializedWithError() throws Exception { void sendBeforeHandlerInitializedWithError() throws Exception {
IllegalStateException ex = new IllegalStateException(); IllegalStateException ex = new IllegalStateException();
this.emitter.send("foo", MediaType.TEXT_PLAIN); this.emitter.send("foo", MediaType.TEXT_PLAIN);
this.emitter.send("bar", MediaType.TEXT_PLAIN); this.emitter.send("bar", MediaType.TEXT_PLAIN);
@ -87,21 +86,20 @@ public class ResponseBodyEmitterTests {
verifyNoMoreInteractions(this.handler); verifyNoMoreInteractions(this.handler);
this.emitter.initialize(this.handler); this.emitter.initialize(this.handler);
verify(this.handler).send("foo", MediaType.TEXT_PLAIN); verify(this.handler).send(anySet());
verify(this.handler).send("bar", MediaType.TEXT_PLAIN);
verify(this.handler).completeWithError(ex); verify(this.handler).completeWithError(ex);
verifyNoMoreInteractions(this.handler); verifyNoMoreInteractions(this.handler);
} }
@Test @Test
public void sendFailsAfterComplete() throws Exception { void sendFailsAfterComplete() throws Exception {
this.emitter.complete(); this.emitter.complete();
assertThatIllegalStateException().isThrownBy(() -> assertThatIllegalStateException().isThrownBy(() ->
this.emitter.send("foo")); this.emitter.send("foo"));
} }
@Test @Test
public void sendAfterHandlerInitialized() throws Exception { void sendAfterHandlerInitialized() throws Exception {
this.emitter.initialize(this.handler); this.emitter.initialize(this.handler);
verify(this.handler).onTimeout(any()); verify(this.handler).onTimeout(any());
verify(this.handler).onError(any()); verify(this.handler).onError(any());
@ -119,7 +117,7 @@ public class ResponseBodyEmitterTests {
} }
@Test @Test
public void sendAfterHandlerInitializedWithError() throws Exception { void sendAfterHandlerInitializedWithError() throws Exception {
this.emitter.initialize(this.handler); this.emitter.initialize(this.handler);
verify(this.handler).onTimeout(any()); verify(this.handler).onTimeout(any());
verify(this.handler).onError(any()); verify(this.handler).onError(any());
@ -138,7 +136,7 @@ public class ResponseBodyEmitterTests {
} }
@Test @Test
public void sendWithError() throws Exception { void sendWithError() throws Exception {
this.emitter.initialize(this.handler); this.emitter.initialize(this.handler);
verify(this.handler).onTimeout(any()); verify(this.handler).onTimeout(any());
verify(this.handler).onError(any()); verify(this.handler).onError(any());
@ -154,7 +152,7 @@ public class ResponseBodyEmitterTests {
} }
@Test @Test
public void onTimeoutBeforeHandlerInitialized() throws Exception { void onTimeoutBeforeHandlerInitialized() throws Exception {
Runnable runnable = mock(); Runnable runnable = mock();
this.emitter.onTimeout(runnable); this.emitter.onTimeout(runnable);
this.emitter.initialize(this.handler); this.emitter.initialize(this.handler);
@ -169,7 +167,7 @@ public class ResponseBodyEmitterTests {
} }
@Test @Test
public void onTimeoutAfterHandlerInitialized() throws Exception { void onTimeoutAfterHandlerInitialized() throws Exception {
this.emitter.initialize(this.handler); this.emitter.initialize(this.handler);
ArgumentCaptor<Runnable> captor = ArgumentCaptor.forClass(Runnable.class); ArgumentCaptor<Runnable> captor = ArgumentCaptor.forClass(Runnable.class);
@ -185,7 +183,7 @@ public class ResponseBodyEmitterTests {
} }
@Test @Test
public void onCompletionBeforeHandlerInitialized() throws Exception { void onCompletionBeforeHandlerInitialized() throws Exception {
Runnable runnable = mock(); Runnable runnable = mock();
this.emitter.onCompletion(runnable); this.emitter.onCompletion(runnable);
this.emitter.initialize(this.handler); this.emitter.initialize(this.handler);
@ -200,7 +198,7 @@ public class ResponseBodyEmitterTests {
} }
@Test @Test
public void onCompletionAfterHandlerInitialized() throws Exception { void onCompletionAfterHandlerInitialized() throws Exception {
this.emitter.initialize(this.handler); this.emitter.initialize(this.handler);
ArgumentCaptor<Runnable> captor = ArgumentCaptor.forClass(Runnable.class); ArgumentCaptor<Runnable> captor = ArgumentCaptor.forClass(Runnable.class);

View File

@ -20,12 +20,14 @@ import java.io.IOException;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Set;
import java.util.function.Consumer; import java.util.function.Consumer;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
import org.springframework.lang.Nullable;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.web.servlet.mvc.method.annotation.SseEmitter.event; import static org.springframework.web.servlet.mvc.method.annotation.SseEmitter.event;
@ -60,6 +62,7 @@ public class SseEmitterTests {
this.handler.assertObject(0, "data:", TEXT_PLAIN_UTF8); this.handler.assertObject(0, "data:", TEXT_PLAIN_UTF8);
this.handler.assertObject(1, "foo"); this.handler.assertObject(1, "foo");
this.handler.assertObject(2, "\n\n", TEXT_PLAIN_UTF8); this.handler.assertObject(2, "\n\n", TEXT_PLAIN_UTF8);
this.handler.assertWriteCount(1);
} }
@Test @Test
@ -69,12 +72,14 @@ public class SseEmitterTests {
this.handler.assertObject(0, "data:", TEXT_PLAIN_UTF8); this.handler.assertObject(0, "data:", TEXT_PLAIN_UTF8);
this.handler.assertObject(1, "foo", MediaType.TEXT_PLAIN); this.handler.assertObject(1, "foo", MediaType.TEXT_PLAIN);
this.handler.assertObject(2, "\n\n", TEXT_PLAIN_UTF8); this.handler.assertObject(2, "\n\n", TEXT_PLAIN_UTF8);
this.handler.assertWriteCount(1);
} }
@Test @Test
public void sendEventEmpty() throws Exception { public void sendEventEmpty() throws Exception {
this.emitter.send(event()); this.emitter.send(event());
this.handler.assertSentObjectCount(0); this.handler.assertSentObjectCount(0);
this.handler.assertWriteCount(0);
} }
@Test @Test
@ -84,6 +89,7 @@ public class SseEmitterTests {
this.handler.assertObject(0, "data:", TEXT_PLAIN_UTF8); this.handler.assertObject(0, "data:", TEXT_PLAIN_UTF8);
this.handler.assertObject(1, "foo"); this.handler.assertObject(1, "foo");
this.handler.assertObject(2, "\n\n", TEXT_PLAIN_UTF8); this.handler.assertObject(2, "\n\n", TEXT_PLAIN_UTF8);
this.handler.assertWriteCount(1);
} }
@Test @Test
@ -95,6 +101,7 @@ public class SseEmitterTests {
this.handler.assertObject(2, "\ndata:", TEXT_PLAIN_UTF8); this.handler.assertObject(2, "\ndata:", TEXT_PLAIN_UTF8);
this.handler.assertObject(3, "bar"); this.handler.assertObject(3, "bar");
this.handler.assertObject(4, "\n\n", TEXT_PLAIN_UTF8); this.handler.assertObject(4, "\n\n", TEXT_PLAIN_UTF8);
this.handler.assertWriteCount(1);
} }
@Test @Test
@ -104,6 +111,7 @@ public class SseEmitterTests {
this.handler.assertObject(0, ":blah\nevent:test\nretry:5000\nid:1\ndata:", TEXT_PLAIN_UTF8); this.handler.assertObject(0, ":blah\nevent:test\nretry:5000\nid:1\ndata:", TEXT_PLAIN_UTF8);
this.handler.assertObject(1, "foo"); this.handler.assertObject(1, "foo");
this.handler.assertObject(2, "\n\n", TEXT_PLAIN_UTF8); this.handler.assertObject(2, "\n\n", TEXT_PLAIN_UTF8);
this.handler.assertWriteCount(1);
} }
@Test @Test
@ -115,14 +123,17 @@ public class SseEmitterTests {
this.handler.assertObject(2, "\ndata:", TEXT_PLAIN_UTF8); this.handler.assertObject(2, "\ndata:", TEXT_PLAIN_UTF8);
this.handler.assertObject(3, "bar"); this.handler.assertObject(3, "bar");
this.handler.assertObject(4, "\nevent:test\nretry:5000\nid:1\n\n", TEXT_PLAIN_UTF8); this.handler.assertObject(4, "\nevent:test\nretry:5000\nid:1\n\n", TEXT_PLAIN_UTF8);
this.handler.assertWriteCount(1);
} }
private static class TestHandler implements ResponseBodyEmitter.Handler { private static class TestHandler implements ResponseBodyEmitter.Handler {
private List<Object> objects = new ArrayList<>(); private final List<Object> objects = new ArrayList<>();
private List<MediaType> mediaTypes = new ArrayList<>(); private final List<MediaType> mediaTypes = new ArrayList<>();
private int writeCount;
public void assertSentObjectCount(int size) { public void assertSentObjectCount(int size) {
@ -139,10 +150,24 @@ public class SseEmitterTests {
assertThat(this.mediaTypes.get(index)).isEqualTo(mediaType); assertThat(this.mediaTypes.get(index)).isEqualTo(mediaType);
} }
public void assertWriteCount(int writeCount) {
assertThat(this.writeCount).isEqualTo(writeCount);
}
@Override @Override
public void send(Object data, MediaType mediaType) throws IOException { public void send(Object data, @Nullable MediaType mediaType) throws IOException {
this.objects.add(data); this.objects.add(data);
this.mediaTypes.add(mediaType); this.mediaTypes.add(mediaType);
this.writeCount++;
}
@Override
public void send(Set<ResponseBodyEmitter.DataWithMediaType> items) throws IOException {
for (ResponseBodyEmitter.DataWithMediaType item : items) {
this.objects.add(item.getData());
this.mediaTypes.add(item.getMediaType());
}
this.writeCount++;
} }
@Override @Override