Propagate CoroutineContext in CoWebFilter
This provides an elegant and dynamic way to customize the CoroutineContext in WebFlux with the annotation programming model. Closes gh-27522
This commit is contained in:
parent
9d768a89d2
commit
b0aa004d9d
|
@ -17,6 +17,8 @@
|
|||
package org.springframework.web.server
|
||||
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.Job
|
||||
import kotlinx.coroutines.currentCoroutineContext
|
||||
import kotlinx.coroutines.reactor.awaitSingleOrNull
|
||||
import kotlinx.coroutines.reactor.mono
|
||||
import reactor.core.publisher.Mono
|
||||
|
@ -26,6 +28,7 @@ import reactor.core.publisher.Mono
|
|||
* using coroutines.
|
||||
*
|
||||
* @author Arjen Poutsma
|
||||
* @author Sebastien Deleuze
|
||||
* @since 6.0.5
|
||||
*/
|
||||
abstract class CoWebFilter : WebFilter {
|
||||
|
@ -34,6 +37,7 @@ abstract class CoWebFilter : WebFilter {
|
|||
return mono(Dispatchers.Unconfined) {
|
||||
filter(exchange, object : CoWebFilterChain {
|
||||
override suspend fun filter(exchange: ServerWebExchange) {
|
||||
exchange.attributes[COROUTINE_CONTEXT_ATTRIBUTE] = currentCoroutineContext().minusKey(Job.Key)
|
||||
chain.filter(exchange).awaitSingleOrNull()
|
||||
}
|
||||
})}.then()
|
||||
|
@ -47,6 +51,12 @@ abstract class CoWebFilter : WebFilter {
|
|||
*/
|
||||
protected abstract suspend fun filter(exchange: ServerWebExchange, chain: CoWebFilterChain)
|
||||
|
||||
companion object {
|
||||
|
||||
@JvmField
|
||||
val COROUTINE_CONTEXT_ATTRIBUTE = CoWebFilter::class.java.getName() + ".context"
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
package org.springframework.web.server
|
||||
|
||||
import kotlinx.coroutines.CoroutineName
|
||||
import kotlinx.coroutines.withContext
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import org.junit.jupiter.api.Test
|
||||
import org.mockito.BDDMockito.given
|
||||
|
@ -24,9 +26,11 @@ import org.springframework.web.testfixture.http.server.reactive.MockServerHttpRe
|
|||
import org.springframework.web.testfixture.server.MockServerWebExchange
|
||||
import reactor.core.publisher.Mono
|
||||
import reactor.test.StepVerifier
|
||||
import kotlin.coroutines.CoroutineContext
|
||||
|
||||
/**
|
||||
* @author Arjen Poutsma
|
||||
* @author Sebastien Deleuze
|
||||
*/
|
||||
class CoWebFilterTests {
|
||||
|
||||
|
@ -45,6 +49,26 @@ class CoWebFilterTests {
|
|||
|
||||
assertThat(exchange.attributes["foo"]).isEqualTo("bar")
|
||||
}
|
||||
|
||||
@Test
|
||||
fun filterWithContext() {
|
||||
val exchange = MockServerWebExchange.from(MockServerHttpRequest.get("https://example.com"))
|
||||
|
||||
val chain = Mockito.mock(WebFilterChain::class.java)
|
||||
given(chain.filter(exchange)).willReturn(Mono.empty())
|
||||
|
||||
val filter = MyCoWebFilterWithContext()
|
||||
val result = filter.filter(exchange, chain)
|
||||
|
||||
StepVerifier.create(result).verifyComplete()
|
||||
|
||||
val context = exchange.attributes[CoWebFilter.COROUTINE_CONTEXT_ATTRIBUTE] as CoroutineContext
|
||||
assertThat(context).isNotNull()
|
||||
val coroutineName = context[CoroutineName.Key] as CoroutineName
|
||||
assertThat(coroutineName).isNotNull()
|
||||
assertThat(coroutineName.name).isEqualTo("foo")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -53,4 +77,12 @@ private class MyCoWebFilter : CoWebFilter() {
|
|||
exchange.attributes["foo"] = "bar"
|
||||
chain.filter(exchange)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private class MyCoWebFilterWithContext : CoWebFilter() {
|
||||
override suspend fun filter(exchange: ServerWebExchange, chain: CoWebFilterChain) {
|
||||
withContext(CoroutineName("foo")) {
|
||||
chain.filter(exchange)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -26,6 +26,7 @@ import java.util.Map;
|
|||
import java.util.Objects;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import kotlin.coroutines.CoroutineContext;
|
||||
import kotlin.reflect.KFunction;
|
||||
import kotlin.reflect.KParameter;
|
||||
import kotlin.reflect.jvm.KCallablesJvm;
|
||||
|
@ -48,6 +49,7 @@ import org.springframework.validation.method.MethodValidator;
|
|||
import org.springframework.web.method.HandlerMethod;
|
||||
import org.springframework.web.reactive.BindingContext;
|
||||
import org.springframework.web.reactive.HandlerResult;
|
||||
import org.springframework.web.server.CoWebFilter;
|
||||
import org.springframework.web.server.ServerWebExchange;
|
||||
|
||||
/**
|
||||
|
@ -152,7 +154,7 @@ public class InvocableHandlerMethod extends HandlerMethod {
|
|||
* @param providedArgs optional list of argument values to match by type
|
||||
* @return a Mono with a {@link HandlerResult}
|
||||
*/
|
||||
@SuppressWarnings({"KotlinInternalInJava", "unchecked"})
|
||||
@SuppressWarnings("unchecked")
|
||||
public Mono<HandlerResult> invoke(
|
||||
ServerWebExchange exchange, BindingContext bindingContext, Object... providedArgs) {
|
||||
|
||||
|
@ -167,12 +169,7 @@ public class InvocableHandlerMethod extends HandlerMethod {
|
|||
boolean isSuspendingFunction = KotlinDetector.isSuspendingFunction(method);
|
||||
try {
|
||||
if (KotlinDetector.isKotlinReflectPresent() && KotlinDetector.isKotlinType(method.getDeclaringClass())) {
|
||||
if (isSuspendingFunction) {
|
||||
value = CoroutinesUtils.invokeSuspendingFunction(method, getBean(), args);
|
||||
}
|
||||
else {
|
||||
value = KotlinDelegate.invokeFunction(method, getBean(), args);
|
||||
}
|
||||
value = KotlinDelegate.invokeFunction(method, getBean(), args, isSuspendingFunction, exchange);
|
||||
}
|
||||
else {
|
||||
value = method.invoke(getBean(), args);
|
||||
|
@ -297,25 +294,38 @@ public class InvocableHandlerMethod extends HandlerMethod {
|
|||
|
||||
@Nullable
|
||||
@SuppressWarnings("deprecation")
|
||||
public static Object invokeFunction(Method method, Object target, Object[] args) {
|
||||
KFunction<?> function = Objects.requireNonNull(ReflectJvmMapping.getKotlinFunction(method));
|
||||
if (method.isAccessible() && !KCallablesJvm.isAccessible(function)) {
|
||||
KCallablesJvm.setAccessible(function, true);
|
||||
}
|
||||
Map<KParameter, Object> argMap = CollectionUtils.newHashMap(args.length + 1);
|
||||
int index = 0;
|
||||
for (KParameter parameter : function.getParameters()) {
|
||||
switch (parameter.getKind()) {
|
||||
case INSTANCE -> argMap.put(parameter, target);
|
||||
case VALUE -> {
|
||||
if (!parameter.isOptional() || args[index] != null) {
|
||||
argMap.put(parameter, args[index]);
|
||||
}
|
||||
index++;
|
||||
}
|
||||
public static Object invokeFunction(Method method, Object target, Object[] args, boolean isSuspendingFunction,
|
||||
ServerWebExchange exchange) {
|
||||
|
||||
if (isSuspendingFunction) {
|
||||
Object coroutineContext = exchange.getAttribute(CoWebFilter.COROUTINE_CONTEXT_ATTRIBUTE);
|
||||
if (coroutineContext == null) {
|
||||
return CoroutinesUtils.invokeSuspendingFunction(method, target, args);
|
||||
}
|
||||
else {
|
||||
return CoroutinesUtils.invokeSuspendingFunction((CoroutineContext) coroutineContext, method, target, args);
|
||||
}
|
||||
}
|
||||
return function.callBy(argMap);
|
||||
else {
|
||||
KFunction<?> function = Objects.requireNonNull(ReflectJvmMapping.getKotlinFunction(method));
|
||||
if (method.isAccessible() && !KCallablesJvm.isAccessible(function)) {
|
||||
KCallablesJvm.setAccessible(function, true);
|
||||
}
|
||||
Map<KParameter, Object> argMap = CollectionUtils.newHashMap(args.length + 1);
|
||||
int index = 0;
|
||||
for (KParameter parameter : function.getParameters()) {
|
||||
switch (parameter.getKind()) {
|
||||
case INSTANCE -> argMap.put(parameter, target);
|
||||
case VALUE -> {
|
||||
if (!parameter.isOptional() || args[index] != null) {
|
||||
argMap.put(parameter, args[index]);
|
||||
}
|
||||
index++;
|
||||
}
|
||||
}
|
||||
}
|
||||
return function.callBy(argMap);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue