diff --git a/spring-web/src/main/java/org/springframework/web/service/invoker/HttpServiceMethod.java b/spring-web/src/main/java/org/springframework/web/service/invoker/HttpServiceMethod.java index 6d88bae497f..701d640d897 100644 --- a/spring-web/src/main/java/org/springframework/web/service/invoker/HttpServiceMethod.java +++ b/spring-web/src/main/java/org/springframework/web/service/invoker/HttpServiceMethod.java @@ -52,6 +52,7 @@ import org.springframework.web.service.annotation.HttpExchange; * by delegating to an {@link HttpClientAdapter} to perform actual requests. * * @author Rossen Stoyanchev + * @author Sebastien Deleuze * @since 6.0 */ final class HttpServiceMethod { @@ -311,14 +312,15 @@ final class HttpServiceMethod { MethodParameter returnParam = new MethodParameter(method, -1); Class returnType = returnParam.getParameterType(); - if (KotlinDetector.isSuspendingFunction(method)) { + boolean isSuspending = KotlinDetector.isSuspendingFunction(method); + if (isSuspending) { returnType = Mono.class; } ReactiveAdapter reactiveAdapter = reactiveRegistry.getAdapter(returnType); MethodParameter actualParam = (reactiveAdapter != null ? returnParam.nested() : returnParam.nestedIfOptional()); - Class actualType = actualParam.getNestedParameterType(); + Class actualType = isSuspending ? actualParam.getParameterType() : actualParam.getNestedParameterType(); Function> responseFunction; if (actualType.equals(void.class) || actualType.equals(Void.class)) { @@ -331,18 +333,18 @@ final class HttpServiceMethod { responseFunction = client::requestToHeaders; } else if (actualType.equals(ResponseEntity.class)) { - MethodParameter bodyParam = actualParam.nested(); + MethodParameter bodyParam = isSuspending ? actualParam : actualParam.nested(); Class bodyType = bodyParam.getNestedParameterType(); if (bodyType.equals(Void.class)) { responseFunction = client::requestToBodilessEntity; } else { ReactiveAdapter bodyAdapter = reactiveRegistry.getAdapter(bodyType); - responseFunction = initResponseEntityFunction(client, bodyParam, bodyAdapter); + responseFunction = initResponseEntityFunction(client, bodyParam, bodyAdapter, isSuspending); } } else { - responseFunction = initBodyFunction(client, actualParam, reactiveAdapter); + responseFunction = initBodyFunction(client, actualParam, reactiveAdapter, isSuspending); } boolean blockForOptional = returnType.equals(Optional.class); @@ -350,8 +352,8 @@ final class HttpServiceMethod { } @SuppressWarnings("ConstantConditions") - private static Function> initResponseEntityFunction( - HttpClientAdapter client, MethodParameter methodParam, @Nullable ReactiveAdapter reactiveAdapter) { + private static Function> initResponseEntityFunction(HttpClientAdapter client, + MethodParameter methodParam, @Nullable ReactiveAdapter reactiveAdapter, boolean isSuspending) { if (reactiveAdapter == null) { return request -> client.requestToEntity( @@ -362,7 +364,8 @@ final class HttpServiceMethod { "ResponseEntity body must be a concrete value or a multi-value Publisher"); ParameterizedTypeReference bodyType = - ParameterizedTypeReference.forType(methodParam.nested().getNestedGenericParameterType()); + ParameterizedTypeReference.forType(isSuspending ? methodParam.nested().getGenericParameterType() : + methodParam.nested().getNestedGenericParameterType()); // Shortcut for Flux if (reactiveAdapter.getReactiveType().equals(Flux.class)) { @@ -376,11 +379,12 @@ final class HttpServiceMethod { }); } - private static Function> initBodyFunction( - HttpClientAdapter client, MethodParameter methodParam, @Nullable ReactiveAdapter reactiveAdapter) { + private static Function> initBodyFunction(HttpClientAdapter client, + MethodParameter methodParam, @Nullable ReactiveAdapter reactiveAdapter, boolean isSuspending) { ParameterizedTypeReference bodyType = - ParameterizedTypeReference.forType(methodParam.getNestedGenericParameterType()); + ParameterizedTypeReference.forType(isSuspending ? methodParam.getGenericParameterType() : + methodParam.getNestedGenericParameterType()); return (reactiveAdapter != null && reactiveAdapter.isMultiValue() ? request -> client.requestToBodyFlux(request, bodyType) : diff --git a/spring-web/src/test/kotlin/org/springframework/web/service/invoker/KotlinHttpServiceMethodTests.kt b/spring-web/src/test/kotlin/org/springframework/web/service/invoker/KotlinHttpServiceMethodTests.kt new file mode 100644 index 00000000000..d5a98da00dc --- /dev/null +++ b/spring-web/src/test/kotlin/org/springframework/web/service/invoker/KotlinHttpServiceMethodTests.kt @@ -0,0 +1,94 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.service.invoker + +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.runBlocking +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import org.springframework.core.ParameterizedTypeReference +import org.springframework.http.HttpStatus +import org.springframework.http.ResponseEntity +import org.springframework.lang.Nullable +import org.springframework.web.service.annotation.GetExchange + +/** + * Kotlin tests for [HttpServiceMethod]. + * + * @author Sebastien Deleuze + */ +class KotlinHttpServiceMethodTests { + + private val client = TestHttpClientAdapter() + private val proxyFactory = HttpServiceProxyFactory.builder(client).build() + + @Test + fun coroutinesService(): Unit = runBlocking { + val service = proxyFactory.createClient(CoroutinesService::class.java) + + val stringBody = service.stringBody() + assertThat(stringBody).isEqualTo("requestToBody") + verifyClientInvocation("requestToBody", object : ParameterizedTypeReference() {}) + + service.listBody() + verifyClientInvocation("requestToBody", object : ParameterizedTypeReference>() {}) + + val flowBody = service.flowBody() + assertThat(flowBody.toList()).containsExactly("request", "To", "Body", "Flux") + verifyClientInvocation("requestToBodyFlux", object : ParameterizedTypeReference() {}) + + val stringEntity = service.stringEntity() + assertThat(stringEntity).isEqualTo(ResponseEntity.ok("requestToEntity")) + verifyClientInvocation("requestToEntity", object : ParameterizedTypeReference() {}) + + service.listEntity() + verifyClientInvocation("requestToEntity", object : ParameterizedTypeReference>() {}) + + val flowEntity = service.flowEntity() + assertThat(flowEntity.statusCode).isEqualTo(HttpStatus.OK) + assertThat(flowEntity.body!!.toList()).containsExactly("request", "To", "Entity", "Flux") + verifyClientInvocation("requestToEntityFlux", object : ParameterizedTypeReference() {}) + } + + private fun verifyClientInvocation(methodName: String, @Nullable expectedBodyType: ParameterizedTypeReference<*>) { + assertThat(client.invokedMethodName).isEqualTo(methodName) + assertThat(client.bodyType).isEqualTo(expectedBodyType) + } + + private interface CoroutinesService { + + @GetExchange + suspend fun stringBody(): String + + @GetExchange + suspend fun listBody(): MutableList + + @GetExchange + fun flowBody(): Flow + + @GetExchange + suspend fun stringEntity(): ResponseEntity + + @GetExchange + suspend fun listEntity(): ResponseEntity> + + @GetExchange + fun flowEntity(): ResponseEntity> + } + +}