diff --git a/spring-core/src/main/java/org/springframework/core/CoroutinesUtils.java b/spring-core/src/main/java/org/springframework/core/CoroutinesUtils.java index 66917a8e75..7c6e8cab0a 100644 --- a/spring-core/src/main/java/org/springframework/core/CoroutinesUtils.java +++ b/spring-core/src/main/java/org/springframework/core/CoroutinesUtils.java @@ -21,6 +21,7 @@ import java.lang.reflect.Method; import java.util.Objects; import kotlin.Unit; +import kotlin.coroutines.Continuation; import kotlin.coroutines.CoroutineContext; import kotlin.jvm.JvmClassMappingKt; import kotlin.reflect.KClass; @@ -67,6 +68,10 @@ public abstract class CoroutinesUtils { (scope, continuation) -> MonoKt.awaitSingleOrNull(source, continuation)); } + public static Object awaitSingleOrNull(Mono source, Continuation continuation) { + return MonoKt.awaitSingleOrNull(source, continuation); + } + /** * Invoke a suspending function and converts it to {@link Mono} or * {@link Flux}. Uses an {@linkplain Dispatchers#getUnconfined() unconfined} diff --git a/spring-web/src/main/java/org/springframework/web/service/invoker/HttpServiceMethod.java b/spring-web/src/main/java/org/springframework/web/service/invoker/HttpServiceMethod.java index b0095c0028..ea54fe1423 100644 --- a/spring-web/src/main/java/org/springframework/web/service/invoker/HttpServiceMethod.java +++ b/spring-web/src/main/java/org/springframework/web/service/invoker/HttpServiceMethod.java @@ -28,6 +28,7 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import org.springframework.core.DefaultParameterNameDiscoverer; +import org.springframework.core.KotlinDetector; import org.springframework.core.MethodParameter; import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.ReactiveAdapter; @@ -83,6 +84,11 @@ final class HttpServiceMethod { if (count == 0) { return new MethodParameter[0]; } + + if (KotlinDetector.isSuspendingFunction(method)) { + count -= 1; + } + DefaultParameterNameDiscoverer nameDiscoverer = new DefaultParameterNameDiscoverer(); MethodParameter[] parameters = new MethodParameter[count]; for (int i = 0; i < count; i++) { @@ -306,6 +312,10 @@ final class HttpServiceMethod { MethodParameter returnParam = new MethodParameter(method, -1); Class returnType = returnParam.getParameterType(); + if (KotlinDetector.isSuspendingFunction(method)) { + returnType = Mono.class; + } + ReactiveAdapter reactiveAdapter = reactiveRegistry.getAdapter(returnType); MethodParameter actualParam = (reactiveAdapter != null ? returnParam.nested() : returnParam.nestedIfOptional()); diff --git a/spring-web/src/main/java/org/springframework/web/service/invoker/HttpServiceProxyFactory.java b/spring-web/src/main/java/org/springframework/web/service/invoker/HttpServiceProxyFactory.java index af509524ba..6b1b4f2268 100644 --- a/spring-web/src/main/java/org/springframework/web/service/invoker/HttpServiceProxyFactory.java +++ b/spring-web/src/main/java/org/springframework/web/service/invoker/HttpServiceProxyFactory.java @@ -25,11 +25,15 @@ import java.util.Map; import java.util.function.Function; import java.util.stream.Collectors; +import kotlin.coroutines.Continuation; import org.aopalliance.intercept.MethodInterceptor; import org.aopalliance.intercept.MethodInvocation; +import reactor.core.publisher.Mono; import org.springframework.aop.framework.ProxyFactory; import org.springframework.aop.framework.ReflectiveMethodInvocation; +import org.springframework.core.CoroutinesUtils; +import org.springframework.core.KotlinDetector; import org.springframework.core.MethodIntrospector; import org.springframework.core.ReactiveAdapterRegistry; import org.springframework.core.annotation.AnnotatedElementUtils; @@ -264,10 +268,18 @@ public final class HttpServiceProxyFactory { } @Override + @SuppressWarnings({"unchecked"}) public Object invoke(MethodInvocation invocation) throws Throwable { Method method = invocation.getMethod(); HttpServiceMethod httpServiceMethod = this.httpServiceMethods.get(method); if (httpServiceMethod != null) { + if (KotlinDetector.isSuspendingFunction(method)) { + Object[] arguments = getSuspendedFunctionArgs(invocation.getArguments()); + Continuation continuation = resolveContinuationArgument(invocation.getArguments()); + Mono wrapped = (Mono) httpServiceMethod.invoke(arguments); + return CoroutinesUtils.awaitSingleOrNull(wrapped, continuation); + } + return httpServiceMethod.invoke(invocation.getArguments()); } if (method.isDefault()) { @@ -278,6 +290,17 @@ public final class HttpServiceProxyFactory { } throw new IllegalStateException("Unexpected method invocation: " + method); } + + @SuppressWarnings({"unchecked"}) + private static Continuation resolveContinuationArgument(Object[] args) { + return (Continuation) args[args.length - 1]; + } + + private static Object[] getSuspendedFunctionArgs(Object[] args) { + Object[] functionArgs = new Object[args.length - 1]; + System.arraycopy(args, 0, functionArgs, 0, args.length - 1); + return functionArgs; + } } } diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/support/WebClientHttpServiceProxyKotlinTests.kt b/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/support/WebClientHttpServiceProxyKotlinTests.kt new file mode 100644 index 0000000000..6706d4b16b --- /dev/null +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/support/WebClientHttpServiceProxyKotlinTests.kt @@ -0,0 +1,158 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.web.reactive.function.client.support + +import kotlinx.coroutines.reactor.mono +import kotlinx.coroutines.runBlocking +import okhttp3.mockwebserver.MockResponse +import okhttp3.mockwebserver.MockWebServer +import org.assertj.core.api.Assertions +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.springframework.web.bind.annotation.RequestAttribute +import org.springframework.web.reactive.function.client.ClientRequest +import org.springframework.web.reactive.function.client.ExchangeFunction +import org.springframework.web.reactive.function.client.WebClient +import org.springframework.web.service.annotation.GetExchange +import org.springframework.web.service.invoker.HttpServiceProxyFactory +import reactor.core.publisher.Mono +import reactor.test.StepVerifier +import java.time.Duration +import java.util.function.Consumer + +/** + * Integration tests for [HTTP Service proxy][HttpServiceProxyFactory] + * using [WebClient] and [MockWebServer]. + * + * @author DongHyeon Kim (wplong11) + */ +class WebClientHttpServiceProxyKotlinTests { + private var server: MockWebServer? = null + @BeforeEach + fun setUp() { + server = MockWebServer() + } + + @AfterEach + fun shutdown() { + server?.shutdown() + } + + @Test + fun greeting() { + prepareResponse { response: MockResponse -> + response.setHeader( + "Content-Type", + "text/plain" + ).setBody("Hello Spring!") + } + StepVerifier.create(mono { initHttpService().getGreeting() }) + .expectNext("Hello Spring!") + .expectComplete() + .verify(Duration.ofSeconds(5)) + } + + @Test + fun greetingMono() { + prepareResponse { response: MockResponse -> + response.setHeader( + "Content-Type", + "text/plain" + ).setBody("Hello Spring!") + } + StepVerifier.create(initHttpService().getGreetingMono()) + .expectNext("Hello Spring!") + .expectComplete() + .verify(Duration.ofSeconds(5)) + } + + @Test + fun greetingBlocking() { + prepareResponse { response: MockResponse -> + response.setHeader( + "Content-Type", + "text/plain" + ).setBody("Hello Spring!") + } + StepVerifier.create(mono { initHttpService().getGreetingBlocking() }) + .expectNext("Hello Spring!") + .expectComplete() + .verify(Duration.ofSeconds(5)) + } + + @Test + fun greetingWithRequestAttribute() { + val attributes: MutableMap = HashMap() + val webClient = WebClient.builder() + .baseUrl(server!!.url("/").toString()) + .filter { request: ClientRequest, next: ExchangeFunction -> + attributes.putAll(request.attributes()) + next.exchange(request) + } + .build() + prepareResponse { response: MockResponse -> + response.setHeader( + "Content-Type", + "text/plain" + ).setBody("Hello Spring!") + } + + val service = initHttpService(webClient) + val value = runBlocking { + service.getGreetingWithAttribute("myAttributeValue") + } + StepVerifier.create(mono { value }) + .expectNext("Hello Spring!") + .expectComplete() + .verify(Duration.ofSeconds(5)) + Assertions.assertThat(attributes).containsEntry("myAttribute", "myAttributeValue") + } + + private fun initHttpService(): TestHttpService { + val webClient = WebClient.builder().baseUrl( + server!!.url("/").toString() + ).build() + return initHttpService(webClient) + } + + private fun initHttpService(webClient: WebClient): TestHttpService { + return HttpServiceProxyFactory.builder() + .clientAdapter(WebClientAdapter.forClient(webClient)) + .build() + .createClient(TestHttpService::class.java) + } + + private fun prepareResponse(consumer: Consumer) { + val response = MockResponse() + consumer.accept(response) + server!!.enqueue(response) + } + + private interface TestHttpService { + @GetExchange("/greeting") + suspend fun getGreeting(): String + + @GetExchange("/greeting") + fun getGreetingMono(): Mono + + @GetExchange("/greeting") + fun getGreetingBlocking(): String + + @GetExchange("/greeting") + suspend fun getGreetingWithAttribute(@RequestAttribute myAttribute: String): String + } +}