diff --git a/spring-web/src/main/java/org/springframework/http/codec/KotlinSerializationSupport.java b/spring-web/src/main/java/org/springframework/http/codec/KotlinSerializationSupport.java index 792b2b5abc1..bded9bb17cb 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/KotlinSerializationSupport.java +++ b/spring-web/src/main/java/org/springframework/http/codec/KotlinSerializationSupport.java @@ -16,6 +16,7 @@ package org.springframework.http.codec; +import java.lang.reflect.Method; import java.lang.reflect.Type; import java.util.Arrays; import java.util.HashSet; @@ -23,14 +24,21 @@ import java.util.List; import java.util.Map; import java.util.Set; +import kotlin.reflect.KFunction; +import kotlin.reflect.KType; +import kotlin.reflect.full.KCallables; +import kotlin.reflect.jvm.ReflectJvmMapping; import kotlinx.serialization.KSerializer; import kotlinx.serialization.SerialFormat; import kotlinx.serialization.SerializersKt; import kotlinx.serialization.descriptors.PolymorphicKind; import kotlinx.serialization.descriptors.SerialDescriptor; +import org.springframework.core.KotlinDetector; +import org.springframework.core.MethodParameter; import org.springframework.core.ResolvableType; import org.springframework.lang.Nullable; +import org.springframework.util.Assert; import org.springframework.util.ConcurrentReferenceHashMap; import org.springframework.util.MimeType; @@ -46,7 +54,10 @@ import org.springframework.util.MimeType; */ public abstract class KotlinSerializationSupport { - private final Map> serializerCache = new ConcurrentReferenceHashMap<>(); + private final Map> typeSerializerCache = new ConcurrentReferenceHashMap<>(); + + private final Map> kTypeSerializerCache = new ConcurrentReferenceHashMap<>(); + private final T format; @@ -117,8 +128,33 @@ public abstract class KotlinSerializationSupport { */ @Nullable protected final KSerializer serializer(ResolvableType resolvableType) { + if (resolvableType.getSource() instanceof MethodParameter parameter) { + Method method = parameter.getMethod(); + Assert.notNull(method, "Method must not be null"); + if (KotlinDetector.isKotlinType(method.getDeclaringClass())) { + KFunction function = ReflectJvmMapping.getKotlinFunction(method); + Assert.notNull(function, "Kotlin function must not be null"); + KType type = (parameter.getParameterIndex() == -1 ? function.getReturnType() : + KCallables.getValueParameters(function).get(parameter.getParameterIndex()).getType()); + KSerializer serializer = this.kTypeSerializerCache.get(type); + if (serializer == null) { + try { + serializer = SerializersKt.serializerOrNull(this.format.getSerializersModule(), type); + } + catch (IllegalArgumentException ignored) { + } + if (serializer != null) { + if (hasPolymorphism(serializer.getDescriptor(), new HashSet<>())) { + return null; + } + this.kTypeSerializerCache.put(type, serializer); + } + } + return serializer; + } + } Type type = resolvableType.getType(); - KSerializer serializer = this.serializerCache.get(type); + KSerializer serializer = this.typeSerializerCache.get(type); if (serializer == null) { try { serializer = SerializersKt.serializerOrNull(this.format.getSerializersModule(), type); @@ -129,7 +165,7 @@ public abstract class KotlinSerializationSupport { if (hasPolymorphism(serializer.getDescriptor(), new HashSet<>())) { return null; } - this.serializerCache.put(type, serializer); + this.typeSerializerCache.put(type, serializer); } } return serializer; diff --git a/spring-web/src/test/kotlin/org/springframework/http/codec/json/KotlinSerializationJsonDecoderTests.kt b/spring-web/src/test/kotlin/org/springframework/http/codec/json/KotlinSerializationJsonDecoderTests.kt index 6c5fe1b5c5b..b7178566c7d 100644 --- a/spring-web/src/test/kotlin/org/springframework/http/codec/json/KotlinSerializationJsonDecoderTests.kt +++ b/spring-web/src/test/kotlin/org/springframework/http/codec/json/KotlinSerializationJsonDecoderTests.kt @@ -19,9 +19,11 @@ package org.springframework.http.codec.json import kotlinx.serialization.Serializable import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.Test +import org.springframework.core.MethodParameter import org.springframework.core.Ordered import org.springframework.core.ResolvableType import org.springframework.core.io.buffer.DataBuffer +import org.springframework.core.io.buffer.DataBufferUtils import org.springframework.core.testfixture.codec.AbstractDecoderTests import org.springframework.http.MediaType import reactor.core.publisher.Flux @@ -32,6 +34,7 @@ import java.lang.UnsupportedOperationException import java.math.BigDecimal import java.nio.charset.Charset import java.nio.charset.StandardCharsets +import kotlin.reflect.jvm.javaMethod /** * Tests for the JSON decoding using kotlinx.serialization. @@ -128,6 +131,22 @@ class KotlinSerializationJsonDecoderTests : AbstractDecoderTests { return stringBuffer(value, StandardCharsets.UTF_8) } @@ -145,4 +164,6 @@ class KotlinSerializationJsonDecoderTests : AbstractDecoderTests) = map + } diff --git a/spring-web/src/test/kotlin/org/springframework/http/codec/json/KotlinSerializationJsonEncoderTests.kt b/spring-web/src/test/kotlin/org/springframework/http/codec/json/KotlinSerializationJsonEncoderTests.kt index 8e00528f8c0..1a366f91df3 100644 --- a/spring-web/src/test/kotlin/org/springframework/http/codec/json/KotlinSerializationJsonEncoderTests.kt +++ b/spring-web/src/test/kotlin/org/springframework/http/codec/json/KotlinSerializationJsonEncoderTests.kt @@ -19,6 +19,7 @@ package org.springframework.http.codec.json import kotlinx.serialization.Serializable import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.Test +import org.springframework.core.MethodParameter import org.springframework.core.Ordered import org.springframework.core.ResolvableType import org.springframework.core.io.buffer.DataBuffer @@ -31,6 +32,7 @@ import reactor.core.publisher.Mono import reactor.test.StepVerifier.FirstStep import java.math.BigDecimal import java.nio.charset.StandardCharsets +import kotlin.reflect.jvm.javaMethod /** * Tests for the JSON encoding using kotlinx.serialization. @@ -109,6 +111,17 @@ class KotlinSerializationJsonEncoderTests : AbstractEncoderTests DataBufferUtils.release(dataBuffer) }) + .verifyComplete() + } + } + @Test fun canNotEncode() { assertThat(encoder.canEncode(ResolvableType.forClass(String::class.java), null)).isFalse() @@ -123,4 +136,6 @@ class KotlinSerializationJsonEncoderTests : AbstractEncoderTests) = map + }