diff --git a/spring-web/src/main/kotlin/org/springframework/web/server/CoWebFilter.kt b/spring-web/src/main/kotlin/org/springframework/web/server/CoWebFilter.kt index 3e8181d8bb..9f7bc9df94 100644 --- a/spring-web/src/main/kotlin/org/springframework/web/server/CoWebFilter.kt +++ b/spring-web/src/main/kotlin/org/springframework/web/server/CoWebFilter.kt @@ -17,6 +17,8 @@ package org.springframework.web.server import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.Job +import kotlinx.coroutines.currentCoroutineContext import kotlinx.coroutines.reactor.awaitSingleOrNull import kotlinx.coroutines.reactor.mono import reactor.core.publisher.Mono @@ -26,6 +28,7 @@ import reactor.core.publisher.Mono * using coroutines. * * @author Arjen Poutsma + * @author Sebastien Deleuze * @since 6.0.5 */ abstract class CoWebFilter : WebFilter { @@ -34,6 +37,7 @@ abstract class CoWebFilter : WebFilter { return mono(Dispatchers.Unconfined) { filter(exchange, object : CoWebFilterChain { override suspend fun filter(exchange: ServerWebExchange) { + exchange.attributes[COROUTINE_CONTEXT_ATTRIBUTE] = currentCoroutineContext().minusKey(Job.Key) chain.filter(exchange).awaitSingleOrNull() } })}.then() @@ -47,6 +51,12 @@ abstract class CoWebFilter : WebFilter { */ protected abstract suspend fun filter(exchange: ServerWebExchange, chain: CoWebFilterChain) + companion object { + + @JvmField + val COROUTINE_CONTEXT_ATTRIBUTE = CoWebFilter::class.java.getName() + ".context" + } + } /** diff --git a/spring-web/src/test/kotlin/org/springframework/web/server/CoWebFilterTests.kt b/spring-web/src/test/kotlin/org/springframework/web/server/CoWebFilterTests.kt index 85c3b0c5d0..872f3a4bc0 100644 --- a/spring-web/src/test/kotlin/org/springframework/web/server/CoWebFilterTests.kt +++ b/spring-web/src/test/kotlin/org/springframework/web/server/CoWebFilterTests.kt @@ -16,6 +16,8 @@ package org.springframework.web.server +import kotlinx.coroutines.CoroutineName +import kotlinx.coroutines.withContext import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.Test import org.mockito.BDDMockito.given @@ -24,9 +26,11 @@ import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRe import org.springframework.web.testfixture.server.MockServerWebExchange import reactor.core.publisher.Mono import reactor.test.StepVerifier +import kotlin.coroutines.CoroutineContext /** * @author Arjen Poutsma + * @author Sebastien Deleuze */ class CoWebFilterTests { @@ -45,6 +49,26 @@ class CoWebFilterTests { assertThat(exchange.attributes["foo"]).isEqualTo("bar") } + + @Test + fun filterWithContext() { + val exchange = MockServerWebExchange.from(MockServerHttpRequest.get("https://example.com")) + + val chain = Mockito.mock(WebFilterChain::class.java) + given(chain.filter(exchange)).willReturn(Mono.empty()) + + val filter = MyCoWebFilterWithContext() + val result = filter.filter(exchange, chain) + + StepVerifier.create(result).verifyComplete() + + val context = exchange.attributes[CoWebFilter.COROUTINE_CONTEXT_ATTRIBUTE] as CoroutineContext + assertThat(context).isNotNull() + val coroutineName = context[CoroutineName.Key] as CoroutineName + assertThat(coroutineName).isNotNull() + assertThat(coroutineName.name).isEqualTo("foo") + } + } @@ -53,4 +77,12 @@ private class MyCoWebFilter : CoWebFilter() { exchange.attributes["foo"] = "bar" chain.filter(exchange) } -} \ No newline at end of file +} + +private class MyCoWebFilterWithContext : CoWebFilter() { + override suspend fun filter(exchange: ServerWebExchange, chain: CoWebFilterChain) { + withContext(CoroutineName("foo")) { + chain.filter(exchange) + } + } +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java index 278713647c..66fb38ecb7 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/InvocableHandlerMethod.java @@ -26,6 +26,7 @@ import java.util.Map; import java.util.Objects; import java.util.stream.Stream; +import kotlin.coroutines.CoroutineContext; import kotlin.reflect.KFunction; import kotlin.reflect.KParameter; import kotlin.reflect.jvm.KCallablesJvm; @@ -48,6 +49,7 @@ import org.springframework.validation.method.MethodValidator; import org.springframework.web.method.HandlerMethod; import org.springframework.web.reactive.BindingContext; import org.springframework.web.reactive.HandlerResult; +import org.springframework.web.server.CoWebFilter; import org.springframework.web.server.ServerWebExchange; /** @@ -152,7 +154,7 @@ public class InvocableHandlerMethod extends HandlerMethod { * @param providedArgs optional list of argument values to match by type * @return a Mono with a {@link HandlerResult} */ - @SuppressWarnings({"KotlinInternalInJava", "unchecked"}) + @SuppressWarnings("unchecked") public Mono invoke( ServerWebExchange exchange, BindingContext bindingContext, Object... providedArgs) { @@ -167,12 +169,7 @@ public class InvocableHandlerMethod extends HandlerMethod { boolean isSuspendingFunction = KotlinDetector.isSuspendingFunction(method); try { if (KotlinDetector.isKotlinReflectPresent() && KotlinDetector.isKotlinType(method.getDeclaringClass())) { - if (isSuspendingFunction) { - value = CoroutinesUtils.invokeSuspendingFunction(method, getBean(), args); - } - else { - value = KotlinDelegate.invokeFunction(method, getBean(), args); - } + value = KotlinDelegate.invokeFunction(method, getBean(), args, isSuspendingFunction, exchange); } else { value = method.invoke(getBean(), args); @@ -297,25 +294,38 @@ public class InvocableHandlerMethod extends HandlerMethod { @Nullable @SuppressWarnings("deprecation") - public static Object invokeFunction(Method method, Object target, Object[] args) { - KFunction function = Objects.requireNonNull(ReflectJvmMapping.getKotlinFunction(method)); - if (method.isAccessible() && !KCallablesJvm.isAccessible(function)) { - KCallablesJvm.setAccessible(function, true); - } - Map argMap = CollectionUtils.newHashMap(args.length + 1); - int index = 0; - for (KParameter parameter : function.getParameters()) { - switch (parameter.getKind()) { - case INSTANCE -> argMap.put(parameter, target); - case VALUE -> { - if (!parameter.isOptional() || args[index] != null) { - argMap.put(parameter, args[index]); - } - index++; - } + public static Object invokeFunction(Method method, Object target, Object[] args, boolean isSuspendingFunction, + ServerWebExchange exchange) { + + if (isSuspendingFunction) { + Object coroutineContext = exchange.getAttribute(CoWebFilter.COROUTINE_CONTEXT_ATTRIBUTE); + if (coroutineContext == null) { + return CoroutinesUtils.invokeSuspendingFunction(method, target, args); + } + else { + return CoroutinesUtils.invokeSuspendingFunction((CoroutineContext) coroutineContext, method, target, args); } } - return function.callBy(argMap); + else { + KFunction function = Objects.requireNonNull(ReflectJvmMapping.getKotlinFunction(method)); + if (method.isAccessible() && !KCallablesJvm.isAccessible(function)) { + KCallablesJvm.setAccessible(function, true); + } + Map argMap = CollectionUtils.newHashMap(args.length + 1); + int index = 0; + for (KParameter parameter : function.getParameters()) { + switch (parameter.getKind()) { + case INSTANCE -> argMap.put(parameter, target); + case VALUE -> { + if (!parameter.isOptional() || args[index] != null) { + argMap.put(parameter, args[index]); + } + index++; + } + } + } + return function.callBy(argMap); + } } }