WebFlux support for SSE with multiline fragments

See gh-33194
This commit is contained in:
rstoyanchev 2024-08-06 11:06:21 +03:00
parent b734156f32
commit 8e2b27e5d8
9 changed files with 138 additions and 39 deletions

View File

@ -41,6 +41,7 @@ import org.springframework.core.ResolvableType;
import org.springframework.core.annotation.AnnotationAwareOrderComparator;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatusCode;
import org.springframework.http.MediaType;
@ -538,21 +539,25 @@ public class ViewResolutionResultHandler extends HandlerResultHandlerSupport imp
@Override
public Flux<DataBuffer> format(
Flux<DataBuffer> fragmentContent, Fragment fragment, ServerWebExchange exchange) {
Flux<DataBuffer> fragmentFlux, Fragment fragment, ServerWebExchange exchange) {
Charset charset = StandardCharsets.UTF_8;
MediaType contentType = exchange.getResponse().getHeaders().getContentType();
if (contentType != null && contentType.getCharset() != null) {
charset = contentType.getCharset();
}
MediaType mediaType = exchange.getResponse().getHeaders().getContentType();
Charset charset = (mediaType != null && mediaType.getCharset() != null ?
mediaType.getCharset() : StandardCharsets.UTF_8);
DataBufferFactory bufferFactory = exchange.getResponse().bufferFactory();
String eventLine = fragment.viewName() != null ? "event:" + fragment.viewName() + "\n" : "";
String eventLine = (fragment.viewName() != null ? "event:" + fragment.viewName() + "\n" : "");
DataBuffer prefix = encodeText(eventLine + "data:", charset, bufferFactory);
DataBuffer suffix = encodeText("\n\n", charset, bufferFactory);
return Flux.concat(Flux.just(prefix), fragmentContent, Flux.just(suffix));
Mono<DataBuffer> content = DataBufferUtils.join(fragmentFlux)
.map(dataBuffer -> {
String s = dataBuffer.toString(charset).replace("\n", "\ndata:");
return bufferFactory.wrap(s.getBytes(charset));
});
return Flux.concat(Flux.just(prefix), content, Flux.just(suffix));
}
private DataBuffer encodeText(String text, Charset charset, DataBufferFactory bufferFactory) {

View File

@ -43,6 +43,7 @@ import org.springframework.web.reactive.accept.RequestedContentTypeResolver;
import org.springframework.web.reactive.result.view.script.ScriptTemplateConfigurer;
import org.springframework.web.reactive.result.view.script.ScriptTemplateViewResolver;
import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRequest;
import org.springframework.web.testfixture.http.server.reactive.MockServerHttpResponse;
import org.springframework.web.testfixture.server.MockServerWebExchange;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
@ -87,7 +88,14 @@ public class FragmentViewResolutionResultHandlerTests {
.then(Mono.defer(() -> exchange.getResponse().getBodyAsString()))
.block(Duration.ofSeconds(60));
assertThat(body).isEqualTo("<p>Hello Foo</p><p>Hello Bar</p>");
assertThat(exchange.getResponse().getHeaders().getContentType()).isEqualTo(MediaType.TEXT_HTML);
assertThat(body).isEqualTo("""
<p>
Hello Foo
</p>\
<p>
Hello Bar
</p>""");
}
@Test
@ -98,6 +106,7 @@ public class FragmentViewResolutionResultHandlerTests {
.build();
MockServerWebExchange exchange = MockServerWebExchange.from(request);
MockServerHttpResponse response = exchange.getResponse();
HandlerResult result = new HandlerResult(
new Handler(),
@ -106,15 +115,20 @@ public class FragmentViewResolutionResultHandlerTests {
new BindingContext());
String body = initHandler().handleResult(exchange, result)
.then(Mono.defer(() -> exchange.getResponse().getBodyAsString()))
.then(Mono.defer(response::getBodyAsString))
.block(Duration.ofSeconds(60));
assertThat(response.getHeaders().getContentType()).isEqualTo(MediaType.TEXT_EVENT_STREAM);
assertThat(body).isEqualTo("""
event:fragment1
data:<p>Hello Foo</p>
data:<p>
data: Hello Foo
data:</p>
event:fragment2
data:<p>Hello Bar</p>
data:<p>
data: Hello Bar
data:</p>
""");
}

View File

@ -1,3 +1,6 @@
import org.springframework.web.reactive.result.view.script.*
"""<p>${i18n("hello")} $foo</p>"""
"""
|<p>
| ${i18n("hello")} $foo
|</p>""".trimMargin()

View File

@ -1,3 +1,6 @@
import org.springframework.web.reactive.result.view.script.*
"""<p>${i18n("hello")} $bar</p>"""
"""
|<p>
| ${i18n("hello")} $bar
|</p>""".trimMargin()

View File

@ -30,6 +30,7 @@ import org.springframework.http.server.ServerHttpResponse;
import org.springframework.lang.Nullable;
import org.springframework.util.ObjectUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.servlet.ModelAndView;
/**
* A specialization of {@link ResponseBodyEmitter} for sending
@ -203,6 +204,8 @@ public class SseEmitter extends ResponseBodyEmitter {
@Nullable
private StringBuilder sb;
private boolean hasName;
@Override
public SseEventBuilder id(String id) {
append("id:").append(id).append('\n');
@ -211,6 +214,7 @@ public class SseEmitter extends ResponseBodyEmitter {
@Override
public SseEventBuilder name(String name) {
this.hasName = true;
append("event:").append(name).append('\n');
return this;
}
@ -234,6 +238,9 @@ public class SseEmitter extends ResponseBodyEmitter {
@Override
public SseEventBuilder data(Object object, @Nullable MediaType mediaType) {
if (object instanceof ModelAndView mav && !this.hasName && mav.getViewName() != null) {
name(mav.getViewName());
}
append("data:");
saveAppendedText();
if (object instanceof String text) {

View File

@ -19,7 +19,9 @@ package org.springframework.web.servlet.mvc.method.annotation;
import java.util.List;
import java.util.Map;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Flux;
import org.springframework.context.annotation.AnnotationConfigApplicationContext;
import org.springframework.context.annotation.Bean;
@ -28,7 +30,8 @@ import org.springframework.context.support.ResourceBundleMessageSource;
import org.springframework.core.MethodParameter;
import org.springframework.core.ReactiveAdapterRegistry;
import org.springframework.core.task.SyncTaskExecutor;
import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.converter.StringHttpMessageConverter;
import org.springframework.web.accept.ContentNegotiationManager;
import org.springframework.web.context.request.NativeWebRequest;
import org.springframework.web.context.request.ServletWebRequest;
@ -51,8 +54,20 @@ import static org.springframework.web.testfixture.method.ResolvableMethod.on;
*/
public class FragmentRenderingStreamTests {
@Test
void streamFragments() throws Exception {
private final MockHttpServletRequest request = new MockHttpServletRequest();
private final MockHttpServletResponse response = new MockHttpServletResponse();
private final NativeWebRequest webRequest = new ServletWebRequest(request, response);
private ResponseBodyEmitterReturnValueHandler handler;
@BeforeEach
void setUp() {
AsyncWebRequest asyncWebRequest = new StandardServletAsyncWebRequest(this.request, this.response);
WebAsyncUtils.getAsyncManager(this.webRequest).setAsyncWebRequest(asyncWebRequest);
this.request.setAsyncSupported(true);
AnnotationConfigApplicationContext context =
new AnnotationConfigApplicationContext(ScriptTemplatingConfiguration.class);
@ -61,44 +76,84 @@ public class FragmentRenderingStreamTests {
ScriptTemplateViewResolver viewResolver = new ScriptTemplateViewResolver(prefix, ".kts");
viewResolver.setApplicationContext(context);
ResponseBodyEmitterReturnValueHandler handler = new ResponseBodyEmitterReturnValueHandler(
List.of(new MappingJackson2HttpMessageConverter()),
this.handler = new ResponseBodyEmitterReturnValueHandler(
List.of(new StringHttpMessageConverter()),
ReactiveAdapterRegistry.getSharedInstance(), new SyncTaskExecutor(),
new ContentNegotiationManager(),
List.of(viewResolver), null);
}
MockHttpServletRequest request = new MockHttpServletRequest();
MockHttpServletResponse response = new MockHttpServletResponse();
NativeWebRequest webRequest = new ServletWebRequest(request, response);
AsyncWebRequest asyncWebRequest = new StandardServletAsyncWebRequest(request, response);
WebAsyncUtils.getAsyncManager(webRequest).setAsyncWebRequest(asyncWebRequest);
request.setAsyncSupported(true);
@Test
void streamWithSseEmitter() throws Exception {
MethodParameter type = on(TestController.class).resolveReturnType(SseEmitter.class);
SseEmitter emitter = new SseEmitter();
handler.handleReturnValue(emitter, type, new ModelAndViewContainer(), webRequest);
assertThat(request.isAsyncStarted()).isTrue();
assertThat(response.getStatus()).isEqualTo(200);
SseEmitter emitter = new SseEmitter();
this.handler.handleReturnValue(emitter, type, new ModelAndViewContainer(), webRequest);
assertThat(this.request.isAsyncStarted()).isTrue();
assertThat(this.response.getStatus()).isEqualTo(200);
ModelAndView mav1 = new ModelAndView("fragment1", Map.of("foo", "Foo"));
ModelAndView mav2 = new ModelAndView("fragment2", Map.of("bar", "Bar"));
emitter.send(SseEmitter.event().data(mav1).data(mav2));
emitter.send(SseEmitter.event().data(mav1));
emitter.send(SseEmitter.event().data(mav2));
assertThat(response.getContentType()).isEqualTo("text/event-stream");
assertThat(response.getContentAsString()).isEqualTo(("""
data:<p>Hello Foo</p>
data:<p>Hello Bar</p>
assertThat(this.response.getContentType()).isEqualTo("text/event-stream");
assertThat(this.response.getContentAsString()).isEqualTo(("""
event:fragment1
data:<p>
data: Hello Foo
data:</p>
event:fragment2
data:<p>
data: Hello Bar
data:</p>
"""));
}
@Test
void streamWithFlux() throws Exception {
MethodParameter type = on(TestController.class).resolveReturnType(Flux.class, ModelAndView.class);
this.request.addHeader(HttpHeaders.ACCEPT, "text/event-stream");
Flux<ModelAndView> flux = Flux.just(
new ModelAndView("fragment1", Map.of("foo", "Foo")),
new ModelAndView("fragment2", Map.of("bar", "Bar")));
this.handler.handleReturnValue(flux, type, new ModelAndViewContainer(), webRequest);
assertThat(this.request.isAsyncStarted()).isTrue();
assertThat(this.response.getStatus()).isEqualTo(200);
assertThat(this.response.getContentType()).isEqualTo("text/event-stream");
assertThat(this.response.getContentAsString()).isEqualTo(("""
event:fragment1
data:<p>
data: Hello Foo
data:</p>
event:fragment2
data:<p>
data: Hello Bar
data:</p>
"""));
}
@SuppressWarnings({"unused", "DataFlowIssue"})
private static class TestController {
SseEmitter handle() {
SseEmitter handleWithSseEmitter() {
return null;
}
Flux<ModelAndView> handleWithFlux() {
return null;
}
}

View File

@ -64,7 +64,13 @@ public class DefaultFragmentsRenderingTests {
view.resolveNestedViews(viewResolver, Locale.ENGLISH);
view.render(Collections.emptyMap(), request, response);
assertThat(response.getContentAsString()).isEqualTo("<p>Hello Foo</p><p>Hello Bar</p>");
assertThat(response.getContentAsString()).isEqualTo("""
<p>
Hello Foo
</p>\
<p>
Hello Bar
</p>""");
}

View File

@ -1,3 +1,6 @@
import org.springframework.web.servlet.view.script.*
"""<p>${i18n("hello")} $foo</p>"""
"""
|<p>
| ${i18n("hello")} $foo
|</p>""".trimMargin()

View File

@ -1,3 +1,6 @@
import org.springframework.web.servlet.view.script.*
"""<p>${i18n("hello")} $bar</p>"""
"""
|<p>
| ${i18n("hello")} $bar
|</p>""".trimMargin()