Propagate CoroutineContext in CoWebFilter

This provides an elegant and dynamic way to customize the
CoroutineContext in WebFlux with the annotation programming
model.

Closes gh-27522
This commit is contained in:
Sébastien Deleuze 2023-09-07 12:08:12 +02:00
parent 9d768a89d2
commit b0aa004d9d
3 changed files with 77 additions and 25 deletions

View File

@ -17,6 +17,8 @@
package org.springframework.web.server
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
import kotlinx.coroutines.currentCoroutineContext
import kotlinx.coroutines.reactor.awaitSingleOrNull
import kotlinx.coroutines.reactor.mono
import reactor.core.publisher.Mono
@ -26,6 +28,7 @@ import reactor.core.publisher.Mono
* using coroutines.
*
* @author Arjen Poutsma
* @author Sebastien Deleuze
* @since 6.0.5
*/
abstract class CoWebFilter : WebFilter {
@ -34,6 +37,7 @@ abstract class CoWebFilter : WebFilter {
return mono(Dispatchers.Unconfined) {
filter(exchange, object : CoWebFilterChain {
override suspend fun filter(exchange: ServerWebExchange) {
exchange.attributes[COROUTINE_CONTEXT_ATTRIBUTE] = currentCoroutineContext().minusKey(Job.Key)
chain.filter(exchange).awaitSingleOrNull()
}
})}.then()
@ -47,6 +51,12 @@ abstract class CoWebFilter : WebFilter {
*/
protected abstract suspend fun filter(exchange: ServerWebExchange, chain: CoWebFilterChain)
companion object {
@JvmField
val COROUTINE_CONTEXT_ATTRIBUTE = CoWebFilter::class.java.getName() + ".context"
}
}
/**

View File

@ -16,6 +16,8 @@
package org.springframework.web.server
import kotlinx.coroutines.CoroutineName
import kotlinx.coroutines.withContext
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.Test
import org.mockito.BDDMockito.given
@ -24,9 +26,11 @@ import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRe
import org.springframework.web.testfixture.server.MockServerWebExchange
import reactor.core.publisher.Mono
import reactor.test.StepVerifier
import kotlin.coroutines.CoroutineContext
/**
* @author Arjen Poutsma
* @author Sebastien Deleuze
*/
class CoWebFilterTests {
@ -45,6 +49,26 @@ class CoWebFilterTests {
assertThat(exchange.attributes["foo"]).isEqualTo("bar")
}
@Test
fun filterWithContext() {
val exchange = MockServerWebExchange.from(MockServerHttpRequest.get("https://example.com"))
val chain = Mockito.mock(WebFilterChain::class.java)
given(chain.filter(exchange)).willReturn(Mono.empty())
val filter = MyCoWebFilterWithContext()
val result = filter.filter(exchange, chain)
StepVerifier.create(result).verifyComplete()
val context = exchange.attributes[CoWebFilter.COROUTINE_CONTEXT_ATTRIBUTE] as CoroutineContext
assertThat(context).isNotNull()
val coroutineName = context[CoroutineName.Key] as CoroutineName
assertThat(coroutineName).isNotNull()
assertThat(coroutineName.name).isEqualTo("foo")
}
}
@ -53,4 +77,12 @@ private class MyCoWebFilter : CoWebFilter() {
exchange.attributes["foo"] = "bar"
chain.filter(exchange)
}
}
}
private class MyCoWebFilterWithContext : CoWebFilter() {
override suspend fun filter(exchange: ServerWebExchange, chain: CoWebFilterChain) {
withContext(CoroutineName("foo")) {
chain.filter(exchange)
}
}
}

View File

@ -26,6 +26,7 @@ import java.util.Map;
import java.util.Objects;
import java.util.stream.Stream;
import kotlin.coroutines.CoroutineContext;
import kotlin.reflect.KFunction;
import kotlin.reflect.KParameter;
import kotlin.reflect.jvm.KCallablesJvm;
@ -48,6 +49,7 @@ import org.springframework.validation.method.MethodValidator;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.reactive.BindingContext;
import org.springframework.web.reactive.HandlerResult;
import org.springframework.web.server.CoWebFilter;
import org.springframework.web.server.ServerWebExchange;
/**
@ -152,7 +154,7 @@ public class InvocableHandlerMethod extends HandlerMethod {
* @param providedArgs optional list of argument values to match by type
* @return a Mono with a {@link HandlerResult}
*/
@SuppressWarnings({"KotlinInternalInJava", "unchecked"})
@SuppressWarnings("unchecked")
public Mono<HandlerResult> invoke(
ServerWebExchange exchange, BindingContext bindingContext, Object... providedArgs) {
@ -167,12 +169,7 @@ public class InvocableHandlerMethod extends HandlerMethod {
boolean isSuspendingFunction = KotlinDetector.isSuspendingFunction(method);
try {
if (KotlinDetector.isKotlinReflectPresent() && KotlinDetector.isKotlinType(method.getDeclaringClass())) {
if (isSuspendingFunction) {
value = CoroutinesUtils.invokeSuspendingFunction(method, getBean(), args);
}
else {
value = KotlinDelegate.invokeFunction(method, getBean(), args);
}
value = KotlinDelegate.invokeFunction(method, getBean(), args, isSuspendingFunction, exchange);
}
else {
value = method.invoke(getBean(), args);
@ -297,25 +294,38 @@ public class InvocableHandlerMethod extends HandlerMethod {
@Nullable
@SuppressWarnings("deprecation")
public static Object invokeFunction(Method method, Object target, Object[] args) {
KFunction<?> function = Objects.requireNonNull(ReflectJvmMapping.getKotlinFunction(method));
if (method.isAccessible() && !KCallablesJvm.isAccessible(function)) {
KCallablesJvm.setAccessible(function, true);
}
Map<KParameter, Object> argMap = CollectionUtils.newHashMap(args.length + 1);
int index = 0;
for (KParameter parameter : function.getParameters()) {
switch (parameter.getKind()) {
case INSTANCE -> argMap.put(parameter, target);
case VALUE -> {
if (!parameter.isOptional() || args[index] != null) {
argMap.put(parameter, args[index]);
}
index++;
}
public static Object invokeFunction(Method method, Object target, Object[] args, boolean isSuspendingFunction,
ServerWebExchange exchange) {
if (isSuspendingFunction) {
Object coroutineContext = exchange.getAttribute(CoWebFilter.COROUTINE_CONTEXT_ATTRIBUTE);
if (coroutineContext == null) {
return CoroutinesUtils.invokeSuspendingFunction(method, target, args);
}
else {
return CoroutinesUtils.invokeSuspendingFunction((CoroutineContext) coroutineContext, method, target, args);
}
}
return function.callBy(argMap);
else {
KFunction<?> function = Objects.requireNonNull(ReflectJvmMapping.getKotlinFunction(method));
if (method.isAccessible() && !KCallablesJvm.isAccessible(function)) {
KCallablesJvm.setAccessible(function, true);
}
Map<KParameter, Object> argMap = CollectionUtils.newHashMap(args.length + 1);
int index = 0;
for (KParameter parameter : function.getParameters()) {
switch (parameter.getKind()) {
case INSTANCE -> argMap.put(parameter, target);
case VALUE -> {
if (!parameter.isOptional() || args[index] != null) {
argMap.put(parameter, args[index]);
}
index++;
}
}
}
return function.callBy(argMap);
}
}
}