diff --git a/spring-web/src/main/java/org/springframework/http/codec/KotlinSerializationStringEncoder.java b/spring-web/src/main/java/org/springframework/http/codec/KotlinSerializationStringEncoder.java index 82007c3ee7..f9ade37bce 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/KotlinSerializationStringEncoder.java +++ b/spring-web/src/main/java/org/springframework/http/codec/KotlinSerializationStringEncoder.java @@ -16,13 +16,13 @@ package org.springframework.http.codec; +import java.util.ArrayList; import java.util.Collection; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; -import kotlin.text.Charsets; import kotlinx.serialization.KSerializer; import kotlinx.serialization.StringFormat; import org.reactivestreams.Publisher; @@ -52,6 +52,11 @@ import org.springframework.util.MimeType; public abstract class KotlinSerializationStringEncoder extends KotlinSerializationSupport implements Encoder { + private static final byte[] NEWLINE_SEPARATOR = {'\n'}; + + protected static final byte[] EMPTY_BYTES = new byte[0]; + + // CharSequence encoding needed for now, see https://github.com/Kotlin/kotlinx.serialization/issues/204 for more details private final CharSequenceEncoder charSequenceEncoder = CharSequenceEncoder.allMimeTypes(); private final Set streamingMediaTypes = new HashSet<>(); @@ -85,22 +90,40 @@ public abstract class KotlinSerializationStringEncoder e return supportedMimeTypes(); } - @Override public Flux encode(Publisher inputStream, DataBufferFactory bufferFactory, - ResolvableType elementType, - @Nullable MimeType mimeType, @Nullable Map hints) { - if (inputStream instanceof Mono) { - return Mono.from(inputStream) + ResolvableType elementType, @Nullable MimeType mimeType, @Nullable Map hints) { + + if (inputStream instanceof Mono mono) { + return mono .map(value -> encodeValue(value, bufferFactory, elementType, mimeType, hints)) .flux(); } - if (mimeType != null && this.streamingMediaTypes.contains(mimeType)) { return Flux.from(inputStream) - .map(list -> encodeValue(list, bufferFactory, elementType, mimeType, hints) - .write("\n", Charsets.UTF_8)); + .map(value -> encodeStreamingValue(value, bufferFactory, elementType, mimeType, hints, EMPTY_BYTES, + NEWLINE_SEPARATOR)); } + return encodeNonStream(inputStream, bufferFactory, elementType, mimeType, hints); + } + + protected DataBuffer encodeStreamingValue(Object value, DataBufferFactory bufferFactory, + ResolvableType valueType, @Nullable MimeType mimeType, + @Nullable Map hints, byte[] prefix, byte[] suffix) { + + List buffers = new ArrayList<>(3); + if (prefix.length > 0) { + buffers.add(bufferFactory.allocateBuffer(prefix.length).write(prefix)); + } + buffers.add(encodeValue(value, bufferFactory, valueType, mimeType, hints)); + if (suffix.length > 0) { + buffers.add(bufferFactory.allocateBuffer(suffix.length).write(suffix)); + } + return bufferFactory.join(buffers); + } + + protected Flux encodeNonStream(Publisher inputStream, DataBufferFactory bufferFactory, + ResolvableType elementType, @Nullable MimeType mimeType, @Nullable Map hints) { ResolvableType listType = ResolvableType.forClassWithGenerics(List.class, elementType); return Flux.from(inputStream) @@ -109,7 +132,6 @@ public abstract class KotlinSerializationStringEncoder e .flux(); } - @Override public DataBuffer encodeValue(Object value, DataBufferFactory bufferFactory, ResolvableType valueType, @Nullable MimeType mimeType, diff --git a/spring-web/src/main/java/org/springframework/http/codec/json/KotlinSerializationJsonEncoder.java b/spring-web/src/main/java/org/springframework/http/codec/json/KotlinSerializationJsonEncoder.java index 897fbaccfd..92cd20e3b8 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/json/KotlinSerializationJsonEncoder.java +++ b/spring-web/src/main/java/org/springframework/http/codec/json/KotlinSerializationJsonEncoder.java @@ -17,11 +17,20 @@ package org.springframework.http.codec.json; import java.util.List; +import java.util.Map; import kotlinx.serialization.json.Json; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import org.springframework.core.ResolvableType; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferFactory; import org.springframework.http.MediaType; import org.springframework.http.codec.KotlinSerializationStringEncoder; +import org.springframework.lang.Nullable; +import org.springframework.util.MimeType; /** * Encode from an {@code Object} stream to a byte stream of JSON objects using @@ -49,4 +58,38 @@ public class KotlinSerializationJsonEncoder extends KotlinSerializationStringEnc setStreamingMediaTypes(List.of(MediaType.APPLICATION_NDJSON)); } + @Override + public Flux encodeNonStream(Publisher inputStream, DataBufferFactory bufferFactory, + ResolvableType elementType, @Nullable MimeType mimeType, @Nullable Map hints) { + + JsonArrayJoinHelper helper = new JsonArrayJoinHelper(); + return Flux.from(inputStream) + .map(value -> encodeStreamingValue(value, bufferFactory, elementType, mimeType, hints, + helper.getPrefix(), EMPTY_BYTES)) + .switchIfEmpty(Mono.fromCallable(() -> bufferFactory.wrap(helper.getPrefix()))) + .concatWith(Mono.fromCallable(() -> bufferFactory.wrap(helper.getSuffix()))); + } + + + private static class JsonArrayJoinHelper { + + private static final byte[] COMMA_SEPARATOR = {','}; + + private static final byte[] OPEN_BRACKET = {'['}; + + private static final byte[] CLOSE_BRACKET = {']'}; + + private boolean firstItemEmitted; + + public byte[] getPrefix() { + byte[] prefix = (this.firstItemEmitted ? COMMA_SEPARATOR : OPEN_BRACKET); + this.firstItemEmitted = true; + return prefix; + } + + public byte[] getSuffix() { + return CLOSE_BRACKET; + } + } + } diff --git a/spring-web/src/test/kotlin/org/springframework/http/codec/json/KotlinSerializationJsonEncoderTests.kt b/spring-web/src/test/kotlin/org/springframework/http/codec/json/KotlinSerializationJsonEncoderTests.kt index 1a366f91df..86c89a0d16 100644 --- a/spring-web/src/test/kotlin/org/springframework/http/codec/json/KotlinSerializationJsonEncoderTests.kt +++ b/spring-web/src/test/kotlin/org/springframework/http/codec/json/KotlinSerializationJsonEncoderTests.kt @@ -22,14 +22,12 @@ import org.junit.jupiter.api.Test import org.springframework.core.MethodParameter import org.springframework.core.Ordered import org.springframework.core.ResolvableType -import org.springframework.core.io.buffer.DataBuffer -import org.springframework.core.io.buffer.DataBufferUtils import org.springframework.core.testfixture.codec.AbstractEncoderTests import org.springframework.http.MediaType import org.springframework.http.codec.ServerSentEvent import reactor.core.publisher.Flux import reactor.core.publisher.Mono -import reactor.test.StepVerifier.FirstStep +import reactor.test.StepVerifier import java.math.BigDecimal import java.nio.charset.StandardCharsets import kotlin.reflect.jvm.javaMethod @@ -72,15 +70,32 @@ class KotlinSerializationJsonEncoderTests : AbstractEncoderTests DataBufferUtils.release(dataBuffer) }) + it.consumeNextWith(expectString("[{\"foo\":\"foo\",\"bar\":\"bar\"}")) + .consumeNextWith(expectString(",{\"foo\":\"foofoo\",\"bar\":\"barbar\"}")) + .consumeNextWith(expectString(",{\"foo\":\"foofoofoo\",\"bar\":\"barbarbar\"}")) + .consumeNextWith(expectString("]")) .verifyComplete() } } + @Test + fun encodeEmpty() { + testEncode(Flux.empty(), Pojo::class.java) { + it + .consumeNextWith(expectString("[")) + .consumeNextWith(expectString("]")) + .verifyComplete() + } + } + + @Test + fun encodeWithErrorAsFirstSignal() { + val message = "I'm a teapot" + val input = Flux.error(IllegalStateException(message)) + val output = encoder.encode(input, this.bufferFactory, ResolvableType.forClass(Pojo::class.java), null, null) + StepVerifier.create(output).expectErrorMessage(message).verify() + } + @Test fun encodeStream() { val input = Flux.just( @@ -105,9 +120,8 @@ class KotlinSerializationJsonEncoderTests : AbstractEncoderTests DataBufferUtils.release(dataBuffer) }) - .verifyComplete() + it.consumeNextWith(expectString("{\"foo\":\"foo\",\"bar\":\"bar\"}")) + .verifyComplete() } } @@ -116,8 +130,7 @@ class KotlinSerializationJsonEncoderTests : AbstractEncoderTests DataBufferUtils.release(dataBuffer) }) + it.consumeNextWith(expectString("{\"value\":null}")) .verifyComplete() } }