Support Coroutines in HttpServiceProxyFactory
See gh-29527
This commit is contained in:
parent
ce85fdc5c7
commit
1d4bf58e8d
|
@ -21,6 +21,7 @@ import java.lang.reflect.Method;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
|
||||||
import kotlin.Unit;
|
import kotlin.Unit;
|
||||||
|
import kotlin.coroutines.Continuation;
|
||||||
import kotlin.coroutines.CoroutineContext;
|
import kotlin.coroutines.CoroutineContext;
|
||||||
import kotlin.jvm.JvmClassMappingKt;
|
import kotlin.jvm.JvmClassMappingKt;
|
||||||
import kotlin.reflect.KClass;
|
import kotlin.reflect.KClass;
|
||||||
|
@ -67,6 +68,10 @@ public abstract class CoroutinesUtils {
|
||||||
(scope, continuation) -> MonoKt.awaitSingleOrNull(source, continuation));
|
(scope, continuation) -> MonoKt.awaitSingleOrNull(source, continuation));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static <T> Object awaitSingleOrNull(Mono<T> source, Continuation<T> continuation) {
|
||||||
|
return MonoKt.awaitSingleOrNull(source, continuation);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Invoke a suspending function and converts it to {@link Mono} or
|
* Invoke a suspending function and converts it to {@link Mono} or
|
||||||
* {@link Flux}. Uses an {@linkplain Dispatchers#getUnconfined() unconfined}
|
* {@link Flux}. Uses an {@linkplain Dispatchers#getUnconfined() unconfined}
|
||||||
|
|
|
@ -28,6 +28,7 @@ import reactor.core.publisher.Flux;
|
||||||
import reactor.core.publisher.Mono;
|
import reactor.core.publisher.Mono;
|
||||||
|
|
||||||
import org.springframework.core.DefaultParameterNameDiscoverer;
|
import org.springframework.core.DefaultParameterNameDiscoverer;
|
||||||
|
import org.springframework.core.KotlinDetector;
|
||||||
import org.springframework.core.MethodParameter;
|
import org.springframework.core.MethodParameter;
|
||||||
import org.springframework.core.ParameterizedTypeReference;
|
import org.springframework.core.ParameterizedTypeReference;
|
||||||
import org.springframework.core.ReactiveAdapter;
|
import org.springframework.core.ReactiveAdapter;
|
||||||
|
@ -83,6 +84,11 @@ final class HttpServiceMethod {
|
||||||
if (count == 0) {
|
if (count == 0) {
|
||||||
return new MethodParameter[0];
|
return new MethodParameter[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (KotlinDetector.isSuspendingFunction(method)) {
|
||||||
|
count -= 1;
|
||||||
|
}
|
||||||
|
|
||||||
DefaultParameterNameDiscoverer nameDiscoverer = new DefaultParameterNameDiscoverer();
|
DefaultParameterNameDiscoverer nameDiscoverer = new DefaultParameterNameDiscoverer();
|
||||||
MethodParameter[] parameters = new MethodParameter[count];
|
MethodParameter[] parameters = new MethodParameter[count];
|
||||||
for (int i = 0; i < count; i++) {
|
for (int i = 0; i < count; i++) {
|
||||||
|
@ -306,6 +312,10 @@ final class HttpServiceMethod {
|
||||||
|
|
||||||
MethodParameter returnParam = new MethodParameter(method, -1);
|
MethodParameter returnParam = new MethodParameter(method, -1);
|
||||||
Class<?> returnType = returnParam.getParameterType();
|
Class<?> returnType = returnParam.getParameterType();
|
||||||
|
if (KotlinDetector.isSuspendingFunction(method)) {
|
||||||
|
returnType = Mono.class;
|
||||||
|
}
|
||||||
|
|
||||||
ReactiveAdapter reactiveAdapter = reactiveRegistry.getAdapter(returnType);
|
ReactiveAdapter reactiveAdapter = reactiveRegistry.getAdapter(returnType);
|
||||||
|
|
||||||
MethodParameter actualParam = (reactiveAdapter != null ? returnParam.nested() : returnParam.nestedIfOptional());
|
MethodParameter actualParam = (reactiveAdapter != null ? returnParam.nested() : returnParam.nestedIfOptional());
|
||||||
|
|
|
@ -25,11 +25,15 @@ import java.util.Map;
|
||||||
import java.util.function.Function;
|
import java.util.function.Function;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
import kotlin.coroutines.Continuation;
|
||||||
import org.aopalliance.intercept.MethodInterceptor;
|
import org.aopalliance.intercept.MethodInterceptor;
|
||||||
import org.aopalliance.intercept.MethodInvocation;
|
import org.aopalliance.intercept.MethodInvocation;
|
||||||
|
import reactor.core.publisher.Mono;
|
||||||
|
|
||||||
import org.springframework.aop.framework.ProxyFactory;
|
import org.springframework.aop.framework.ProxyFactory;
|
||||||
import org.springframework.aop.framework.ReflectiveMethodInvocation;
|
import org.springframework.aop.framework.ReflectiveMethodInvocation;
|
||||||
|
import org.springframework.core.CoroutinesUtils;
|
||||||
|
import org.springframework.core.KotlinDetector;
|
||||||
import org.springframework.core.MethodIntrospector;
|
import org.springframework.core.MethodIntrospector;
|
||||||
import org.springframework.core.ReactiveAdapterRegistry;
|
import org.springframework.core.ReactiveAdapterRegistry;
|
||||||
import org.springframework.core.annotation.AnnotatedElementUtils;
|
import org.springframework.core.annotation.AnnotatedElementUtils;
|
||||||
|
@ -264,10 +268,18 @@ public final class HttpServiceProxyFactory {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@SuppressWarnings({"unchecked"})
|
||||||
public Object invoke(MethodInvocation invocation) throws Throwable {
|
public Object invoke(MethodInvocation invocation) throws Throwable {
|
||||||
Method method = invocation.getMethod();
|
Method method = invocation.getMethod();
|
||||||
HttpServiceMethod httpServiceMethod = this.httpServiceMethods.get(method);
|
HttpServiceMethod httpServiceMethod = this.httpServiceMethods.get(method);
|
||||||
if (httpServiceMethod != null) {
|
if (httpServiceMethod != null) {
|
||||||
|
if (KotlinDetector.isSuspendingFunction(method)) {
|
||||||
|
Object[] arguments = getSuspendedFunctionArgs(invocation.getArguments());
|
||||||
|
Continuation<Object> continuation = resolveContinuationArgument(invocation.getArguments());
|
||||||
|
Mono<Object> wrapped = (Mono<Object>) httpServiceMethod.invoke(arguments);
|
||||||
|
return CoroutinesUtils.awaitSingleOrNull(wrapped, continuation);
|
||||||
|
}
|
||||||
|
|
||||||
return httpServiceMethod.invoke(invocation.getArguments());
|
return httpServiceMethod.invoke(invocation.getArguments());
|
||||||
}
|
}
|
||||||
if (method.isDefault()) {
|
if (method.isDefault()) {
|
||||||
|
@ -278,6 +290,17 @@ public final class HttpServiceProxyFactory {
|
||||||
}
|
}
|
||||||
throw new IllegalStateException("Unexpected method invocation: " + method);
|
throw new IllegalStateException("Unexpected method invocation: " + method);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings({"unchecked"})
|
||||||
|
private static <T> Continuation<T> resolveContinuationArgument(Object[] args) {
|
||||||
|
return (Continuation<T>) args[args.length - 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Object[] getSuspendedFunctionArgs(Object[] args) {
|
||||||
|
Object[] functionArgs = new Object[args.length - 1];
|
||||||
|
System.arraycopy(args, 0, functionArgs, 0, args.length - 1);
|
||||||
|
return functionArgs;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,158 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2002-2022 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.springframework.web.reactive.function.client.support
|
||||||
|
|
||||||
|
import kotlinx.coroutines.reactor.mono
|
||||||
|
import kotlinx.coroutines.runBlocking
|
||||||
|
import okhttp3.mockwebserver.MockResponse
|
||||||
|
import okhttp3.mockwebserver.MockWebServer
|
||||||
|
import org.assertj.core.api.Assertions
|
||||||
|
import org.junit.jupiter.api.AfterEach
|
||||||
|
import org.junit.jupiter.api.BeforeEach
|
||||||
|
import org.junit.jupiter.api.Test
|
||||||
|
import org.springframework.web.bind.annotation.RequestAttribute
|
||||||
|
import org.springframework.web.reactive.function.client.ClientRequest
|
||||||
|
import org.springframework.web.reactive.function.client.ExchangeFunction
|
||||||
|
import org.springframework.web.reactive.function.client.WebClient
|
||||||
|
import org.springframework.web.service.annotation.GetExchange
|
||||||
|
import org.springframework.web.service.invoker.HttpServiceProxyFactory
|
||||||
|
import reactor.core.publisher.Mono
|
||||||
|
import reactor.test.StepVerifier
|
||||||
|
import java.time.Duration
|
||||||
|
import java.util.function.Consumer
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Integration tests for [HTTP Service proxy][HttpServiceProxyFactory]
|
||||||
|
* using [WebClient] and [MockWebServer].
|
||||||
|
*
|
||||||
|
* @author DongHyeon Kim (wplong11)
|
||||||
|
*/
|
||||||
|
class WebClientHttpServiceProxyKotlinTests {
|
||||||
|
private var server: MockWebServer? = null
|
||||||
|
@BeforeEach
|
||||||
|
fun setUp() {
|
||||||
|
server = MockWebServer()
|
||||||
|
}
|
||||||
|
|
||||||
|
@AfterEach
|
||||||
|
fun shutdown() {
|
||||||
|
server?.shutdown()
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun greeting() {
|
||||||
|
prepareResponse { response: MockResponse ->
|
||||||
|
response.setHeader(
|
||||||
|
"Content-Type",
|
||||||
|
"text/plain"
|
||||||
|
).setBody("Hello Spring!")
|
||||||
|
}
|
||||||
|
StepVerifier.create(mono<String> { initHttpService().getGreeting() })
|
||||||
|
.expectNext("Hello Spring!")
|
||||||
|
.expectComplete()
|
||||||
|
.verify(Duration.ofSeconds(5))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun greetingMono() {
|
||||||
|
prepareResponse { response: MockResponse ->
|
||||||
|
response.setHeader(
|
||||||
|
"Content-Type",
|
||||||
|
"text/plain"
|
||||||
|
).setBody("Hello Spring!")
|
||||||
|
}
|
||||||
|
StepVerifier.create(initHttpService().getGreetingMono())
|
||||||
|
.expectNext("Hello Spring!")
|
||||||
|
.expectComplete()
|
||||||
|
.verify(Duration.ofSeconds(5))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun greetingBlocking() {
|
||||||
|
prepareResponse { response: MockResponse ->
|
||||||
|
response.setHeader(
|
||||||
|
"Content-Type",
|
||||||
|
"text/plain"
|
||||||
|
).setBody("Hello Spring!")
|
||||||
|
}
|
||||||
|
StepVerifier.create(mono<String> { initHttpService().getGreetingBlocking() })
|
||||||
|
.expectNext("Hello Spring!")
|
||||||
|
.expectComplete()
|
||||||
|
.verify(Duration.ofSeconds(5))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun greetingWithRequestAttribute() {
|
||||||
|
val attributes: MutableMap<String, Any> = HashMap()
|
||||||
|
val webClient = WebClient.builder()
|
||||||
|
.baseUrl(server!!.url("/").toString())
|
||||||
|
.filter { request: ClientRequest, next: ExchangeFunction ->
|
||||||
|
attributes.putAll(request.attributes())
|
||||||
|
next.exchange(request)
|
||||||
|
}
|
||||||
|
.build()
|
||||||
|
prepareResponse { response: MockResponse ->
|
||||||
|
response.setHeader(
|
||||||
|
"Content-Type",
|
||||||
|
"text/plain"
|
||||||
|
).setBody("Hello Spring!")
|
||||||
|
}
|
||||||
|
|
||||||
|
val service = initHttpService(webClient)
|
||||||
|
val value = runBlocking {
|
||||||
|
service.getGreetingWithAttribute("myAttributeValue")
|
||||||
|
}
|
||||||
|
StepVerifier.create(mono<String> { value })
|
||||||
|
.expectNext("Hello Spring!")
|
||||||
|
.expectComplete()
|
||||||
|
.verify(Duration.ofSeconds(5))
|
||||||
|
Assertions.assertThat(attributes).containsEntry("myAttribute", "myAttributeValue")
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun initHttpService(): TestHttpService {
|
||||||
|
val webClient = WebClient.builder().baseUrl(
|
||||||
|
server!!.url("/").toString()
|
||||||
|
).build()
|
||||||
|
return initHttpService(webClient)
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun initHttpService(webClient: WebClient): TestHttpService {
|
||||||
|
return HttpServiceProxyFactory.builder()
|
||||||
|
.clientAdapter(WebClientAdapter.forClient(webClient))
|
||||||
|
.build()
|
||||||
|
.createClient(TestHttpService::class.java)
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun prepareResponse(consumer: Consumer<MockResponse>) {
|
||||||
|
val response = MockResponse()
|
||||||
|
consumer.accept(response)
|
||||||
|
server!!.enqueue(response)
|
||||||
|
}
|
||||||
|
|
||||||
|
private interface TestHttpService {
|
||||||
|
@GetExchange("/greeting")
|
||||||
|
suspend fun getGreeting(): String
|
||||||
|
|
||||||
|
@GetExchange("/greeting")
|
||||||
|
fun getGreetingMono(): Mono<String>
|
||||||
|
|
||||||
|
@GetExchange("/greeting")
|
||||||
|
fun getGreetingBlocking(): String
|
||||||
|
|
||||||
|
@GetExchange("/greeting")
|
||||||
|
suspend fun getGreetingWithAttribute(@RequestAttribute myAttribute: String): String
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue