Support Coroutines in HttpServiceProxyFactory
See gh-29527
This commit is contained in:
parent
ce85fdc5c7
commit
1d4bf58e8d
|
@ -21,6 +21,7 @@ import java.lang.reflect.Method;
|
|||
import java.util.Objects;
|
||||
|
||||
import kotlin.Unit;
|
||||
import kotlin.coroutines.Continuation;
|
||||
import kotlin.coroutines.CoroutineContext;
|
||||
import kotlin.jvm.JvmClassMappingKt;
|
||||
import kotlin.reflect.KClass;
|
||||
|
@ -67,6 +68,10 @@ public abstract class CoroutinesUtils {
|
|||
(scope, continuation) -> MonoKt.awaitSingleOrNull(source, continuation));
|
||||
}
|
||||
|
||||
public static <T> Object awaitSingleOrNull(Mono<T> source, Continuation<T> continuation) {
|
||||
return MonoKt.awaitSingleOrNull(source, continuation);
|
||||
}
|
||||
|
||||
/**
|
||||
* Invoke a suspending function and converts it to {@link Mono} or
|
||||
* {@link Flux}. Uses an {@linkplain Dispatchers#getUnconfined() unconfined}
|
||||
|
|
|
@ -28,6 +28,7 @@ import reactor.core.publisher.Flux;
|
|||
import reactor.core.publisher.Mono;
|
||||
|
||||
import org.springframework.core.DefaultParameterNameDiscoverer;
|
||||
import org.springframework.core.KotlinDetector;
|
||||
import org.springframework.core.MethodParameter;
|
||||
import org.springframework.core.ParameterizedTypeReference;
|
||||
import org.springframework.core.ReactiveAdapter;
|
||||
|
@ -83,6 +84,11 @@ final class HttpServiceMethod {
|
|||
if (count == 0) {
|
||||
return new MethodParameter[0];
|
||||
}
|
||||
|
||||
if (KotlinDetector.isSuspendingFunction(method)) {
|
||||
count -= 1;
|
||||
}
|
||||
|
||||
DefaultParameterNameDiscoverer nameDiscoverer = new DefaultParameterNameDiscoverer();
|
||||
MethodParameter[] parameters = new MethodParameter[count];
|
||||
for (int i = 0; i < count; i++) {
|
||||
|
@ -306,6 +312,10 @@ final class HttpServiceMethod {
|
|||
|
||||
MethodParameter returnParam = new MethodParameter(method, -1);
|
||||
Class<?> returnType = returnParam.getParameterType();
|
||||
if (KotlinDetector.isSuspendingFunction(method)) {
|
||||
returnType = Mono.class;
|
||||
}
|
||||
|
||||
ReactiveAdapter reactiveAdapter = reactiveRegistry.getAdapter(returnType);
|
||||
|
||||
MethodParameter actualParam = (reactiveAdapter != null ? returnParam.nested() : returnParam.nestedIfOptional());
|
||||
|
|
|
@ -25,11 +25,15 @@ import java.util.Map;
|
|||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import kotlin.coroutines.Continuation;
|
||||
import org.aopalliance.intercept.MethodInterceptor;
|
||||
import org.aopalliance.intercept.MethodInvocation;
|
||||
import reactor.core.publisher.Mono;
|
||||
|
||||
import org.springframework.aop.framework.ProxyFactory;
|
||||
import org.springframework.aop.framework.ReflectiveMethodInvocation;
|
||||
import org.springframework.core.CoroutinesUtils;
|
||||
import org.springframework.core.KotlinDetector;
|
||||
import org.springframework.core.MethodIntrospector;
|
||||
import org.springframework.core.ReactiveAdapterRegistry;
|
||||
import org.springframework.core.annotation.AnnotatedElementUtils;
|
||||
|
@ -264,10 +268,18 @@ public final class HttpServiceProxyFactory {
|
|||
}
|
||||
|
||||
@Override
|
||||
@SuppressWarnings({"unchecked"})
|
||||
public Object invoke(MethodInvocation invocation) throws Throwable {
|
||||
Method method = invocation.getMethod();
|
||||
HttpServiceMethod httpServiceMethod = this.httpServiceMethods.get(method);
|
||||
if (httpServiceMethod != null) {
|
||||
if (KotlinDetector.isSuspendingFunction(method)) {
|
||||
Object[] arguments = getSuspendedFunctionArgs(invocation.getArguments());
|
||||
Continuation<Object> continuation = resolveContinuationArgument(invocation.getArguments());
|
||||
Mono<Object> wrapped = (Mono<Object>) httpServiceMethod.invoke(arguments);
|
||||
return CoroutinesUtils.awaitSingleOrNull(wrapped, continuation);
|
||||
}
|
||||
|
||||
return httpServiceMethod.invoke(invocation.getArguments());
|
||||
}
|
||||
if (method.isDefault()) {
|
||||
|
@ -278,6 +290,17 @@ public final class HttpServiceProxyFactory {
|
|||
}
|
||||
throw new IllegalStateException("Unexpected method invocation: " + method);
|
||||
}
|
||||
|
||||
@SuppressWarnings({"unchecked"})
|
||||
private static <T> Continuation<T> resolveContinuationArgument(Object[] args) {
|
||||
return (Continuation<T>) args[args.length - 1];
|
||||
}
|
||||
|
||||
private static Object[] getSuspendedFunctionArgs(Object[] args) {
|
||||
Object[] functionArgs = new Object[args.length - 1];
|
||||
System.arraycopy(args, 0, functionArgs, 0, args.length - 1);
|
||||
return functionArgs;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,158 @@
|
|||
/*
|
||||
* Copyright 2002-2022 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.springframework.web.reactive.function.client.support
|
||||
|
||||
import kotlinx.coroutines.reactor.mono
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import okhttp3.mockwebserver.MockResponse
|
||||
import okhttp3.mockwebserver.MockWebServer
|
||||
import org.assertj.core.api.Assertions
|
||||
import org.junit.jupiter.api.AfterEach
|
||||
import org.junit.jupiter.api.BeforeEach
|
||||
import org.junit.jupiter.api.Test
|
||||
import org.springframework.web.bind.annotation.RequestAttribute
|
||||
import org.springframework.web.reactive.function.client.ClientRequest
|
||||
import org.springframework.web.reactive.function.client.ExchangeFunction
|
||||
import org.springframework.web.reactive.function.client.WebClient
|
||||
import org.springframework.web.service.annotation.GetExchange
|
||||
import org.springframework.web.service.invoker.HttpServiceProxyFactory
|
||||
import reactor.core.publisher.Mono
|
||||
import reactor.test.StepVerifier
|
||||
import java.time.Duration
|
||||
import java.util.function.Consumer
|
||||
|
||||
/**
|
||||
* Integration tests for [HTTP Service proxy][HttpServiceProxyFactory]
|
||||
* using [WebClient] and [MockWebServer].
|
||||
*
|
||||
* @author DongHyeon Kim (wplong11)
|
||||
*/
|
||||
class WebClientHttpServiceProxyKotlinTests {
|
||||
private var server: MockWebServer? = null
|
||||
@BeforeEach
|
||||
fun setUp() {
|
||||
server = MockWebServer()
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
fun shutdown() {
|
||||
server?.shutdown()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun greeting() {
|
||||
prepareResponse { response: MockResponse ->
|
||||
response.setHeader(
|
||||
"Content-Type",
|
||||
"text/plain"
|
||||
).setBody("Hello Spring!")
|
||||
}
|
||||
StepVerifier.create(mono<String> { initHttpService().getGreeting() })
|
||||
.expectNext("Hello Spring!")
|
||||
.expectComplete()
|
||||
.verify(Duration.ofSeconds(5))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun greetingMono() {
|
||||
prepareResponse { response: MockResponse ->
|
||||
response.setHeader(
|
||||
"Content-Type",
|
||||
"text/plain"
|
||||
).setBody("Hello Spring!")
|
||||
}
|
||||
StepVerifier.create(initHttpService().getGreetingMono())
|
||||
.expectNext("Hello Spring!")
|
||||
.expectComplete()
|
||||
.verify(Duration.ofSeconds(5))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun greetingBlocking() {
|
||||
prepareResponse { response: MockResponse ->
|
||||
response.setHeader(
|
||||
"Content-Type",
|
||||
"text/plain"
|
||||
).setBody("Hello Spring!")
|
||||
}
|
||||
StepVerifier.create(mono<String> { initHttpService().getGreetingBlocking() })
|
||||
.expectNext("Hello Spring!")
|
||||
.expectComplete()
|
||||
.verify(Duration.ofSeconds(5))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun greetingWithRequestAttribute() {
|
||||
val attributes: MutableMap<String, Any> = HashMap()
|
||||
val webClient = WebClient.builder()
|
||||
.baseUrl(server!!.url("/").toString())
|
||||
.filter { request: ClientRequest, next: ExchangeFunction ->
|
||||
attributes.putAll(request.attributes())
|
||||
next.exchange(request)
|
||||
}
|
||||
.build()
|
||||
prepareResponse { response: MockResponse ->
|
||||
response.setHeader(
|
||||
"Content-Type",
|
||||
"text/plain"
|
||||
).setBody("Hello Spring!")
|
||||
}
|
||||
|
||||
val service = initHttpService(webClient)
|
||||
val value = runBlocking {
|
||||
service.getGreetingWithAttribute("myAttributeValue")
|
||||
}
|
||||
StepVerifier.create(mono<String> { value })
|
||||
.expectNext("Hello Spring!")
|
||||
.expectComplete()
|
||||
.verify(Duration.ofSeconds(5))
|
||||
Assertions.assertThat(attributes).containsEntry("myAttribute", "myAttributeValue")
|
||||
}
|
||||
|
||||
private fun initHttpService(): TestHttpService {
|
||||
val webClient = WebClient.builder().baseUrl(
|
||||
server!!.url("/").toString()
|
||||
).build()
|
||||
return initHttpService(webClient)
|
||||
}
|
||||
|
||||
private fun initHttpService(webClient: WebClient): TestHttpService {
|
||||
return HttpServiceProxyFactory.builder()
|
||||
.clientAdapter(WebClientAdapter.forClient(webClient))
|
||||
.build()
|
||||
.createClient(TestHttpService::class.java)
|
||||
}
|
||||
|
||||
private fun prepareResponse(consumer: Consumer<MockResponse>) {
|
||||
val response = MockResponse()
|
||||
consumer.accept(response)
|
||||
server!!.enqueue(response)
|
||||
}
|
||||
|
||||
private interface TestHttpService {
|
||||
@GetExchange("/greeting")
|
||||
suspend fun getGreeting(): String
|
||||
|
||||
@GetExchange("/greeting")
|
||||
fun getGreetingMono(): Mono<String>
|
||||
|
||||
@GetExchange("/greeting")
|
||||
fun getGreetingBlocking(): String
|
||||
|
||||
@GetExchange("/greeting")
|
||||
suspend fun getGreetingWithAttribute(@RequestAttribute myAttribute: String): String
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue