Avoid collecting Flux elements in KotlinSerializationJsonEncoder

Closes gh-33428
This commit is contained in:
Sébastien Deleuze 2024-09-05 14:47:49 +02:00
parent 907859f2f3
commit 5e3c5d466f
3 changed files with 101 additions and 23 deletions

View File

@ -16,13 +16,13 @@
package org.springframework.http.codec;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import kotlin.text.Charsets;
import kotlinx.serialization.KSerializer;
import kotlinx.serialization.StringFormat;
import org.reactivestreams.Publisher;
@ -52,6 +52,11 @@ import org.springframework.util.MimeType;
public abstract class KotlinSerializationStringEncoder<T extends StringFormat> extends KotlinSerializationSupport<T>
implements Encoder<Object> {
private static final byte[] NEWLINE_SEPARATOR = {'\n'};
protected static final byte[] EMPTY_BYTES = new byte[0];
// CharSequence encoding needed for now, see https://github.com/Kotlin/kotlinx.serialization/issues/204 for more details
private final CharSequenceEncoder charSequenceEncoder = CharSequenceEncoder.allMimeTypes();
private final Set<MimeType> streamingMediaTypes = new HashSet<>();
@ -85,22 +90,40 @@ public abstract class KotlinSerializationStringEncoder<T extends StringFormat> e
return supportedMimeTypes();
}
@Override
public Flux<DataBuffer> encode(Publisher<?> inputStream, DataBufferFactory bufferFactory,
ResolvableType elementType,
@Nullable MimeType mimeType, @Nullable Map<String, Object> hints) {
if (inputStream instanceof Mono) {
return Mono.from(inputStream)
ResolvableType elementType, @Nullable MimeType mimeType, @Nullable Map<String, Object> hints) {
if (inputStream instanceof Mono<?> mono) {
return mono
.map(value -> encodeValue(value, bufferFactory, elementType, mimeType, hints))
.flux();
}
if (mimeType != null && this.streamingMediaTypes.contains(mimeType)) {
return Flux.from(inputStream)
.map(list -> encodeValue(list, bufferFactory, elementType, mimeType, hints)
.write("\n", Charsets.UTF_8));
.map(value -> encodeStreamingValue(value, bufferFactory, elementType, mimeType, hints, EMPTY_BYTES,
NEWLINE_SEPARATOR));
}
return encodeNonStream(inputStream, bufferFactory, elementType, mimeType, hints);
}
protected DataBuffer encodeStreamingValue(Object value, DataBufferFactory bufferFactory,
ResolvableType valueType, @Nullable MimeType mimeType,
@Nullable Map<String, Object> hints, byte[] prefix, byte[] suffix) {
List<DataBuffer> buffers = new ArrayList<>(3);
if (prefix.length > 0) {
buffers.add(bufferFactory.allocateBuffer(prefix.length).write(prefix));
}
buffers.add(encodeValue(value, bufferFactory, valueType, mimeType, hints));
if (suffix.length > 0) {
buffers.add(bufferFactory.allocateBuffer(suffix.length).write(suffix));
}
return bufferFactory.join(buffers);
}
protected Flux<DataBuffer> encodeNonStream(Publisher<?> inputStream, DataBufferFactory bufferFactory,
ResolvableType elementType, @Nullable MimeType mimeType, @Nullable Map<String, Object> hints) {
ResolvableType listType = ResolvableType.forClassWithGenerics(List.class, elementType);
return Flux.from(inputStream)
@ -109,7 +132,6 @@ public abstract class KotlinSerializationStringEncoder<T extends StringFormat> e
.flux();
}
@Override
public DataBuffer encodeValue(Object value, DataBufferFactory bufferFactory,
ResolvableType valueType, @Nullable MimeType mimeType,

View File

@ -17,11 +17,20 @@
package org.springframework.http.codec.json;
import java.util.List;
import java.util.Map;
import kotlinx.serialization.json.Json;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import org.springframework.core.ResolvableType;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.http.MediaType;
import org.springframework.http.codec.KotlinSerializationStringEncoder;
import org.springframework.lang.Nullable;
import org.springframework.util.MimeType;
/**
* Encode from an {@code Object} stream to a byte stream of JSON objects using
@ -49,4 +58,38 @@ public class KotlinSerializationJsonEncoder extends KotlinSerializationStringEnc
setStreamingMediaTypes(List.of(MediaType.APPLICATION_NDJSON));
}
@Override
public Flux<DataBuffer> encodeNonStream(Publisher<?> inputStream, DataBufferFactory bufferFactory,
ResolvableType elementType, @Nullable MimeType mimeType, @Nullable Map<String, Object> hints) {
JsonArrayJoinHelper helper = new JsonArrayJoinHelper();
return Flux.from(inputStream)
.map(value -> encodeStreamingValue(value, bufferFactory, elementType, mimeType, hints,
helper.getPrefix(), EMPTY_BYTES))
.switchIfEmpty(Mono.fromCallable(() -> bufferFactory.wrap(helper.getPrefix())))
.concatWith(Mono.fromCallable(() -> bufferFactory.wrap(helper.getSuffix())));
}
private static class JsonArrayJoinHelper {
private static final byte[] COMMA_SEPARATOR = {','};
private static final byte[] OPEN_BRACKET = {'['};
private static final byte[] CLOSE_BRACKET = {']'};
private boolean firstItemEmitted;
public byte[] getPrefix() {
byte[] prefix = (this.firstItemEmitted ? COMMA_SEPARATOR : OPEN_BRACKET);
this.firstItemEmitted = true;
return prefix;
}
public byte[] getSuffix() {
return CLOSE_BRACKET;
}
}
}

View File

@ -22,14 +22,12 @@ import org.junit.jupiter.api.Test
import org.springframework.core.MethodParameter
import org.springframework.core.Ordered
import org.springframework.core.ResolvableType
import org.springframework.core.io.buffer.DataBuffer
import org.springframework.core.io.buffer.DataBufferUtils
import org.springframework.core.testfixture.codec.AbstractEncoderTests
import org.springframework.http.MediaType
import org.springframework.http.codec.ServerSentEvent
import reactor.core.publisher.Flux
import reactor.core.publisher.Mono
import reactor.test.StepVerifier.FirstStep
import reactor.test.StepVerifier
import java.math.BigDecimal
import java.nio.charset.StandardCharsets
import kotlin.reflect.jvm.javaMethod
@ -72,15 +70,32 @@ class KotlinSerializationJsonEncoderTests : AbstractEncoderTests<KotlinSerializa
Pojo("foofoofoo", "barbarbar")
)
testEncode(input, Pojo::class.java) {
it.consumeNextWith(expectString("[" +
"{\"foo\":\"foo\",\"bar\":\"bar\"}," +
"{\"foo\":\"foofoo\",\"bar\":\"barbar\"}," +
"{\"foo\":\"foofoofoo\",\"bar\":\"barbarbar\"}]")
.andThen { dataBuffer -> DataBufferUtils.release(dataBuffer) })
it.consumeNextWith(expectString("[{\"foo\":\"foo\",\"bar\":\"bar\"}"))
.consumeNextWith(expectString(",{\"foo\":\"foofoo\",\"bar\":\"barbar\"}"))
.consumeNextWith(expectString(",{\"foo\":\"foofoofoo\",\"bar\":\"barbarbar\"}"))
.consumeNextWith(expectString("]"))
.verifyComplete()
}
}
@Test
fun encodeEmpty() {
testEncode(Flux.empty(), Pojo::class.java) {
it
.consumeNextWith(expectString("["))
.consumeNextWith(expectString("]"))
.verifyComplete()
}
}
@Test
fun encodeWithErrorAsFirstSignal() {
val message = "I'm a teapot"
val input = Flux.error<Any>(IllegalStateException(message))
val output = encoder.encode(input, this.bufferFactory, ResolvableType.forClass(Pojo::class.java), null, null)
StepVerifier.create(output).expectErrorMessage(message).verify()
}
@Test
fun encodeStream() {
val input = Flux.just(
@ -105,9 +120,8 @@ class KotlinSerializationJsonEncoderTests : AbstractEncoderTests<KotlinSerializa
fun encodeMono() {
val input = Mono.just(Pojo("foo", "bar"))
testEncode(input, Pojo::class.java) {
it.consumeNextWith(expectString("{\"foo\":\"foo\",\"bar\":\"bar\"}")
.andThen { dataBuffer: DataBuffer? -> DataBufferUtils.release(dataBuffer) })
.verifyComplete()
it.consumeNextWith(expectString("{\"foo\":\"foo\",\"bar\":\"bar\"}"))
.verifyComplete()
}
}
@ -116,8 +130,7 @@ class KotlinSerializationJsonEncoderTests : AbstractEncoderTests<KotlinSerializa
val input = Mono.just(mapOf("value" to null))
val methodParameter = MethodParameter.forExecutable(::handleMapWithNullable::javaMethod.get()!!, -1)
testEncode(input, ResolvableType.forMethodParameter(methodParameter), null, null) {
it.consumeNextWith(expectString("{\"value\":null}")
.andThen { dataBuffer: DataBuffer? -> DataBufferUtils.release(dataBuffer) })
it.consumeNextWith(expectString("{\"value\":null}"))
.verifyComplete()
}
}