WebFlux support for SSE Fragment stream

See gh-33194
This commit is contained in:
rstoyanchev 2024-07-24 15:33:24 +01:00
parent aa6b47bfce
commit 6e55e78b22
2 changed files with 118 additions and 1 deletions

View File

@ -16,6 +16,8 @@
package org.springframework.web.reactive.result.view;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
@ -38,9 +40,11 @@ import org.springframework.core.ReactiveAdapterRegistry;
import org.springframework.core.ResolvableType;
import org.springframework.core.annotation.AnnotationAwareOrderComparator;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatusCode;
import org.springframework.http.MediaType;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.http.server.reactive.ServerHttpResponseDecorator;
import org.springframework.lang.Nullable;
@ -96,6 +100,8 @@ public class ViewResolutionResultHandler extends HandlerResultHandlerSupport imp
private final List<View> defaultViews = new ArrayList<>(4);
private final List<FragmentFormatter> fragmentFormatters = List.of(new SseFragmentFormatter());
/**
* Basic constructor with a default {@link ReactiveAdapterRegistry}.
@ -337,8 +343,22 @@ public class ViewResolutionResultHandler extends HandlerResultHandlerSupport imp
Mono.just(List.of(fragment.view())) :
resolveViews(fragment.viewName() != null ? fragment.viewName() : getDefaultViewName(exchange), locale));
FragmentFormatter fragmentFormatter = getFragmentFormatter(exchange);
return selectedViews.flatMap(views -> render(views, fragment.model(), bindingContext, mutatedExchange))
.then(Mono.fromSupplier(response::getBodyFlux));
.then(Mono.fromSupplier(() -> (fragmentFormatter != null ?
fragmentFormatter.format(response.getBodyFlux(), fragment, exchange) :
response.getBodyFlux())));
}
@Nullable
private FragmentFormatter getFragmentFormatter(ServerWebExchange exchange) {
for (FragmentFormatter formatter : this.fragmentFormatters) {
if (formatter.supports(exchange.getRequest())) {
return formatter;
}
}
return null;
}
private String getNameForReturnValue(MethodParameter returnType) {
@ -436,4 +456,71 @@ public class ViewResolutionResultHandler extends HandlerResultHandlerSupport imp
}
}
/**
* Strategy to render fragment with stream formatting.
*/
private interface FragmentFormatter {
/**
* Whether the formatter supports the given request.
*/
boolean supports(ServerHttpRequest request);
/**
* Format the given fragment.
* @param fragmentBuffers the fragment serialized to data buffers
* @param fragment the fragment being rendered
* @param exchange the current exchange
* @return the formatted fragment
*/
Flux<DataBuffer> format(Flux<DataBuffer> fragmentBuffers, Fragment fragment, ServerWebExchange exchange);
}
/**
* Formatter for Server-Sent Events formatting.
*/
private static class SseFragmentFormatter implements FragmentFormatter {
@Override
public boolean supports(ServerHttpRequest request) {
String header = request.getHeaders().getFirst(HttpHeaders.ACCEPT);
return (header != null && header.contains(MediaType.TEXT_EVENT_STREAM_VALUE));
}
@Override
public Flux<DataBuffer> format(
Flux<DataBuffer> fragmentBuffers, Fragment fragment, ServerWebExchange exchange) {
Charset charset = getCharset(exchange.getRequest());
DataBufferFactory bufferFactory = exchange.getResponse().bufferFactory();
String eventLine = fragment.viewName() != null ? "event:" + fragment.viewName() + "\n" : "";
return Flux.concat(
Flux.just(encodeText(eventLine + "data:", charset, bufferFactory)),
fragmentBuffers,
Flux.just(encodeText("\n\n", charset, bufferFactory)));
}
private Charset getCharset(ServerHttpRequest request) {
for (MediaType mediaType : request.getHeaders().getAccept()) {
if (mediaType.isCompatibleWith(MediaType.TEXT_EVENT_STREAM)) {
if (mediaType.getCharset() != null) {
return mediaType.getCharset();
}
break;
}
}
return StandardCharsets.UTF_8;
}
private DataBuffer encodeText(String text, Charset charset, DataBufferFactory bufferFactory) {
byte[] bytes = text.getBytes(charset);
return bufferFactory.wrap(bytes);
}
}
}

View File

@ -90,6 +90,35 @@ public class FragmentViewResolutionResultHandlerTests {
assertThat(body).isEqualTo("<p>Hello Foo</p><p>Hello Bar</p>");
}
@Test
void renderSse() {
MockServerHttpRequest request = MockServerHttpRequest.get("/")
.accept(MediaType.TEXT_EVENT_STREAM)
.acceptLanguageAsLocales(Locale.ENGLISH)
.build();
MockServerWebExchange exchange = MockServerWebExchange.from(request);
HandlerResult result = new HandlerResult(
new Handler(),
Flux.just(fragment1, fragment2).subscribeOn(Schedulers.boundedElastic()),
on(Handler.class).resolveReturnType(Flux.class, Fragment.class),
new BindingContext());
String body = initHandler().handleResult(exchange, result)
.then(Mono.defer(() -> exchange.getResponse().getBodyAsString()))
.block(Duration.ofSeconds(60));
assertThat(body).isEqualTo("""
event:fragment1
data:<p>Hello Foo</p>
event:fragment2
data:<p>Hello Bar</p>
""");
}
private ViewResolutionResultHandler initHandler() {
AnnotationConfigApplicationContext context =
@ -98,6 +127,7 @@ public class FragmentViewResolutionResultHandlerTests {
String prefix = "org/springframework/web/reactive/result/view/script/kotlin/";
ScriptTemplateViewResolver viewResolver = new ScriptTemplateViewResolver(prefix, ".kts");
viewResolver.setApplicationContext(context);
viewResolver.setSupportedMediaTypes(List.of(MediaType.TEXT_HTML, MediaType.TEXT_EVENT_STREAM));
RequestedContentTypeResolver contentTypeResolver = new HeaderContentTypeResolver();
return new ViewResolutionResultHandler(List.of(viewResolver), contentTypeResolver);