Support Coroutines in HttpServiceProxyFactory

See gh-29527
This commit is contained in:
Donghyeon Kim 2022-11-19 22:52:18 +09:00 committed by Sébastien Deleuze
parent ce85fdc5c7
commit 1d4bf58e8d
4 changed files with 196 additions and 0 deletions

View File

@ -21,6 +21,7 @@ import java.lang.reflect.Method;
import java.util.Objects;
import kotlin.Unit;
import kotlin.coroutines.Continuation;
import kotlin.coroutines.CoroutineContext;
import kotlin.jvm.JvmClassMappingKt;
import kotlin.reflect.KClass;
@ -67,6 +68,10 @@ public abstract class CoroutinesUtils {
(scope, continuation) -> MonoKt.awaitSingleOrNull(source, continuation));
}
public static <T> Object awaitSingleOrNull(Mono<T> source, Continuation<T> continuation) {
return MonoKt.awaitSingleOrNull(source, continuation);
}
/**
* Invoke a suspending function and converts it to {@link Mono} or
* {@link Flux}. Uses an {@linkplain Dispatchers#getUnconfined() unconfined}

View File

@ -28,6 +28,7 @@ import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import org.springframework.core.DefaultParameterNameDiscoverer;
import org.springframework.core.KotlinDetector;
import org.springframework.core.MethodParameter;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.ReactiveAdapter;
@ -83,6 +84,11 @@ final class HttpServiceMethod {
if (count == 0) {
return new MethodParameter[0];
}
if (KotlinDetector.isSuspendingFunction(method)) {
count -= 1;
}
DefaultParameterNameDiscoverer nameDiscoverer = new DefaultParameterNameDiscoverer();
MethodParameter[] parameters = new MethodParameter[count];
for (int i = 0; i < count; i++) {
@ -306,6 +312,10 @@ final class HttpServiceMethod {
MethodParameter returnParam = new MethodParameter(method, -1);
Class<?> returnType = returnParam.getParameterType();
if (KotlinDetector.isSuspendingFunction(method)) {
returnType = Mono.class;
}
ReactiveAdapter reactiveAdapter = reactiveRegistry.getAdapter(returnType);
MethodParameter actualParam = (reactiveAdapter != null ? returnParam.nested() : returnParam.nestedIfOptional());

View File

@ -25,11 +25,15 @@ import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import kotlin.coroutines.Continuation;
import org.aopalliance.intercept.MethodInterceptor;
import org.aopalliance.intercept.MethodInvocation;
import reactor.core.publisher.Mono;
import org.springframework.aop.framework.ProxyFactory;
import org.springframework.aop.framework.ReflectiveMethodInvocation;
import org.springframework.core.CoroutinesUtils;
import org.springframework.core.KotlinDetector;
import org.springframework.core.MethodIntrospector;
import org.springframework.core.ReactiveAdapterRegistry;
import org.springframework.core.annotation.AnnotatedElementUtils;
@ -264,10 +268,18 @@ public final class HttpServiceProxyFactory {
}
@Override
@SuppressWarnings({"unchecked"})
public Object invoke(MethodInvocation invocation) throws Throwable {
Method method = invocation.getMethod();
HttpServiceMethod httpServiceMethod = this.httpServiceMethods.get(method);
if (httpServiceMethod != null) {
if (KotlinDetector.isSuspendingFunction(method)) {
Object[] arguments = getSuspendedFunctionArgs(invocation.getArguments());
Continuation<Object> continuation = resolveContinuationArgument(invocation.getArguments());
Mono<Object> wrapped = (Mono<Object>) httpServiceMethod.invoke(arguments);
return CoroutinesUtils.awaitSingleOrNull(wrapped, continuation);
}
return httpServiceMethod.invoke(invocation.getArguments());
}
if (method.isDefault()) {
@ -278,6 +290,17 @@ public final class HttpServiceProxyFactory {
}
throw new IllegalStateException("Unexpected method invocation: " + method);
}
@SuppressWarnings({"unchecked"})
private static <T> Continuation<T> resolveContinuationArgument(Object[] args) {
return (Continuation<T>) args[args.length - 1];
}
private static Object[] getSuspendedFunctionArgs(Object[] args) {
Object[] functionArgs = new Object[args.length - 1];
System.arraycopy(args, 0, functionArgs, 0, args.length - 1);
return functionArgs;
}
}
}

View File

@ -0,0 +1,158 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.reactive.function.client.support
import kotlinx.coroutines.reactor.mono
import kotlinx.coroutines.runBlocking
import okhttp3.mockwebserver.MockResponse
import okhttp3.mockwebserver.MockWebServer
import org.assertj.core.api.Assertions
import org.junit.jupiter.api.AfterEach
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.springframework.web.bind.annotation.RequestAttribute
import org.springframework.web.reactive.function.client.ClientRequest
import org.springframework.web.reactive.function.client.ExchangeFunction
import org.springframework.web.reactive.function.client.WebClient
import org.springframework.web.service.annotation.GetExchange
import org.springframework.web.service.invoker.HttpServiceProxyFactory
import reactor.core.publisher.Mono
import reactor.test.StepVerifier
import java.time.Duration
import java.util.function.Consumer
/**
* Integration tests for [HTTP Service proxy][HttpServiceProxyFactory]
* using [WebClient] and [MockWebServer].
*
* @author DongHyeon Kim (wplong11)
*/
class WebClientHttpServiceProxyKotlinTests {
private var server: MockWebServer? = null
@BeforeEach
fun setUp() {
server = MockWebServer()
}
@AfterEach
fun shutdown() {
server?.shutdown()
}
@Test
fun greeting() {
prepareResponse { response: MockResponse ->
response.setHeader(
"Content-Type",
"text/plain"
).setBody("Hello Spring!")
}
StepVerifier.create(mono<String> { initHttpService().getGreeting() })
.expectNext("Hello Spring!")
.expectComplete()
.verify(Duration.ofSeconds(5))
}
@Test
fun greetingMono() {
prepareResponse { response: MockResponse ->
response.setHeader(
"Content-Type",
"text/plain"
).setBody("Hello Spring!")
}
StepVerifier.create(initHttpService().getGreetingMono())
.expectNext("Hello Spring!")
.expectComplete()
.verify(Duration.ofSeconds(5))
}
@Test
fun greetingBlocking() {
prepareResponse { response: MockResponse ->
response.setHeader(
"Content-Type",
"text/plain"
).setBody("Hello Spring!")
}
StepVerifier.create(mono<String> { initHttpService().getGreetingBlocking() })
.expectNext("Hello Spring!")
.expectComplete()
.verify(Duration.ofSeconds(5))
}
@Test
fun greetingWithRequestAttribute() {
val attributes: MutableMap<String, Any> = HashMap()
val webClient = WebClient.builder()
.baseUrl(server!!.url("/").toString())
.filter { request: ClientRequest, next: ExchangeFunction ->
attributes.putAll(request.attributes())
next.exchange(request)
}
.build()
prepareResponse { response: MockResponse ->
response.setHeader(
"Content-Type",
"text/plain"
).setBody("Hello Spring!")
}
val service = initHttpService(webClient)
val value = runBlocking {
service.getGreetingWithAttribute("myAttributeValue")
}
StepVerifier.create(mono<String> { value })
.expectNext("Hello Spring!")
.expectComplete()
.verify(Duration.ofSeconds(5))
Assertions.assertThat(attributes).containsEntry("myAttribute", "myAttributeValue")
}
private fun initHttpService(): TestHttpService {
val webClient = WebClient.builder().baseUrl(
server!!.url("/").toString()
).build()
return initHttpService(webClient)
}
private fun initHttpService(webClient: WebClient): TestHttpService {
return HttpServiceProxyFactory.builder()
.clientAdapter(WebClientAdapter.forClient(webClient))
.build()
.createClient(TestHttpService::class.java)
}
private fun prepareResponse(consumer: Consumer<MockResponse>) {
val response = MockResponse()
consumer.accept(response)
server!!.enqueue(response)
}
private interface TestHttpService {
@GetExchange("/greeting")
suspend fun getGreeting(): String
@GetExchange("/greeting")
fun getGreetingMono(): Mono<String>
@GetExchange("/greeting")
fun getGreetingBlocking(): String
@GetExchange("/greeting")
suspend fun getGreetingWithAttribute(@RequestAttribute myAttribute: String): String
}
}