Leverage KType in Kotlin Serialization WebFlux support

In order to take in account properly Kotlin null-safety with the
annotation programming model.

Closes gh-33016
This commit is contained in:
Sébastien Deleuze 2024-07-01 14:55:25 +02:00
parent 23dccc5977
commit 98e89d8fba
3 changed files with 75 additions and 3 deletions

View File

@ -16,6 +16,7 @@
package org.springframework.http.codec;
import java.lang.reflect.Method;
import java.lang.reflect.Type;
import java.util.Arrays;
import java.util.HashSet;
@ -23,14 +24,21 @@ import java.util.List;
import java.util.Map;
import java.util.Set;
import kotlin.reflect.KFunction;
import kotlin.reflect.KType;
import kotlin.reflect.full.KCallables;
import kotlin.reflect.jvm.ReflectJvmMapping;
import kotlinx.serialization.KSerializer;
import kotlinx.serialization.SerialFormat;
import kotlinx.serialization.SerializersKt;
import kotlinx.serialization.descriptors.PolymorphicKind;
import kotlinx.serialization.descriptors.SerialDescriptor;
import org.springframework.core.KotlinDetector;
import org.springframework.core.MethodParameter;
import org.springframework.core.ResolvableType;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.ConcurrentReferenceHashMap;
import org.springframework.util.MimeType;
@ -46,7 +54,10 @@ import org.springframework.util.MimeType;
*/
public abstract class KotlinSerializationSupport<T extends SerialFormat> {
private final Map<Type, KSerializer<Object>> serializerCache = new ConcurrentReferenceHashMap<>();
private final Map<Type, KSerializer<Object>> typeSerializerCache = new ConcurrentReferenceHashMap<>();
private final Map<KType, KSerializer<Object>> kTypeSerializerCache = new ConcurrentReferenceHashMap<>();
private final T format;
@ -117,8 +128,33 @@ public abstract class KotlinSerializationSupport<T extends SerialFormat> {
*/
@Nullable
protected final KSerializer<Object> serializer(ResolvableType resolvableType) {
if (resolvableType.getSource() instanceof MethodParameter parameter) {
Method method = parameter.getMethod();
Assert.notNull(method, "Method must not be null");
if (KotlinDetector.isKotlinType(method.getDeclaringClass())) {
KFunction<?> function = ReflectJvmMapping.getKotlinFunction(method);
Assert.notNull(function, "Kotlin function must not be null");
KType type = (parameter.getParameterIndex() == -1 ? function.getReturnType() :
KCallables.getValueParameters(function).get(parameter.getParameterIndex()).getType());
KSerializer<Object> serializer = this.kTypeSerializerCache.get(type);
if (serializer == null) {
try {
serializer = SerializersKt.serializerOrNull(this.format.getSerializersModule(), type);
}
catch (IllegalArgumentException ignored) {
}
if (serializer != null) {
if (hasPolymorphism(serializer.getDescriptor(), new HashSet<>())) {
return null;
}
this.kTypeSerializerCache.put(type, serializer);
}
}
return serializer;
}
}
Type type = resolvableType.getType();
KSerializer<Object> serializer = this.serializerCache.get(type);
KSerializer<Object> serializer = this.typeSerializerCache.get(type);
if (serializer == null) {
try {
serializer = SerializersKt.serializerOrNull(this.format.getSerializersModule(), type);
@ -129,7 +165,7 @@ public abstract class KotlinSerializationSupport<T extends SerialFormat> {
if (hasPolymorphism(serializer.getDescriptor(), new HashSet<>())) {
return null;
}
this.serializerCache.put(type, serializer);
this.typeSerializerCache.put(type, serializer);
}
}
return serializer;

View File

@ -19,9 +19,11 @@ package org.springframework.http.codec.json
import kotlinx.serialization.Serializable
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.Test
import org.springframework.core.MethodParameter
import org.springframework.core.Ordered
import org.springframework.core.ResolvableType
import org.springframework.core.io.buffer.DataBuffer
import org.springframework.core.io.buffer.DataBufferUtils
import org.springframework.core.testfixture.codec.AbstractDecoderTests
import org.springframework.http.MediaType
import reactor.core.publisher.Flux
@ -32,6 +34,7 @@ import java.lang.UnsupportedOperationException
import java.math.BigDecimal
import java.nio.charset.Charset
import java.nio.charset.StandardCharsets
import kotlin.reflect.jvm.javaMethod
/**
* Tests for the JSON decoding using kotlinx.serialization.
@ -128,6 +131,22 @@ class KotlinSerializationJsonDecoderTests : AbstractDecoderTests<KotlinSerializa
}, null, null)
}
@Test
fun decodeToMonoWithNullableWithNull() {
val input = Flux.concat(
stringBuffer("{\"value\":null}\n"),
)
val methodParameter = MethodParameter.forExecutable(::handleMapWithNullable::javaMethod.get()!!, -1)
val elementType = ResolvableType.forMethodParameter(methodParameter)
testDecodeToMonoAll(input, elementType, {
it.expectNext(mapOf("value" to null))
.expectComplete()
.verify()
}, null, null)
}
private fun stringBuffer(value: String): Mono<DataBuffer> {
return stringBuffer(value, StandardCharsets.UTF_8)
}
@ -145,4 +164,6 @@ class KotlinSerializationJsonDecoderTests : AbstractDecoderTests<KotlinSerializa
@Serializable
data class Pojo(val foo: String, val bar: String, val pojo: Pojo? = null)
fun handleMapWithNullable(map: Map<String, String?>) = map
}

View File

@ -19,6 +19,7 @@ package org.springframework.http.codec.json
import kotlinx.serialization.Serializable
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.Test
import org.springframework.core.MethodParameter
import org.springframework.core.Ordered
import org.springframework.core.ResolvableType
import org.springframework.core.io.buffer.DataBuffer
@ -31,6 +32,7 @@ import reactor.core.publisher.Mono
import reactor.test.StepVerifier.FirstStep
import java.math.BigDecimal
import java.nio.charset.StandardCharsets
import kotlin.reflect.jvm.javaMethod
/**
* Tests for the JSON encoding using kotlinx.serialization.
@ -109,6 +111,17 @@ class KotlinSerializationJsonEncoderTests : AbstractEncoderTests<KotlinSerializa
}
}
@Test
fun encodeMonoWithNullableWithNull() {
val input = Mono.just(mapOf("value" to null))
val methodParameter = MethodParameter.forExecutable(::handleMapWithNullable::javaMethod.get()!!, -1)
testEncode(input, ResolvableType.forMethodParameter(methodParameter), null, null) {
it.consumeNextWith(expectString("{\"value\":null}")
.andThen { dataBuffer: DataBuffer? -> DataBufferUtils.release(dataBuffer) })
.verifyComplete()
}
}
@Test
fun canNotEncode() {
assertThat(encoder.canEncode(ResolvableType.forClass(String::class.java), null)).isFalse()
@ -123,4 +136,6 @@ class KotlinSerializationJsonEncoderTests : AbstractEncoderTests<KotlinSerializa
@Serializable
data class Pojo(val foo: String, val bar: String, val pojo: Pojo? = null)
fun handleMapWithNullable(map: Map<String, String?>) = map
}