Propagate the context in Coroutines transactions

This commit ensures that CoroutineContext is properly
propagated in transactional suspending functions. Both
annotation and functional variants are supported.

Closes gh-27308
This commit is contained in:
Sébastien Deleuze 2023-02-02 14:06:29 +01:00
parent 3e2f58cdd2
commit 45ae00fda3
4 changed files with 128 additions and 11 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -22,6 +22,8 @@ import java.util.concurrent.ConcurrentMap;
import io.vavr.control.Try;
import kotlin.coroutines.Continuation;
import kotlin.coroutines.CoroutineContext;
import kotlinx.coroutines.Job;
import kotlinx.coroutines.reactive.AwaitKt;
import kotlinx.coroutines.reactive.ReactiveFlowKt;
import org.apache.commons.logging.Log;
@ -363,7 +365,7 @@ public abstract class TransactionAspectSupport implements BeanFactoryAware, Init
InvocationCallback callback = invocation;
if (corInv != null) {
callback = () -> CoroutinesUtils.invokeSuspendingFunction(method, corInv.getTarget(), corInv.getArguments());
callback = () -> KotlinDelegate.invokeSuspendingFunction(method, corInv);
}
Object result = txSupport.invokeWithinTransaction(method, targetClass, callback, txAttr, (ReactiveTransactionManager) tm);
if (corInv != null) {
@ -883,6 +885,12 @@ public abstract class TransactionAspectSupport implements BeanFactoryAware, Init
private static Object awaitSingleOrNull(Publisher<?> publisher, Object continuation) {
return AwaitKt.awaitSingleOrNull(publisher, (Continuation<Object>) continuation);
}
public static Publisher<?> invokeSuspendingFunction(Method method, CoroutinesInvocationCallback callback) {
CoroutineContext coroutineContext = ((Continuation<?>) callback.getContinuation()).getContext().minusKey(Job.Key);
return CoroutinesUtils.invokeSuspendingFunction(coroutineContext, method, callback.getTarget(), callback.getArguments());
}
}

View File

@ -16,14 +16,17 @@
package org.springframework.transaction.reactive
import java.util.Optional
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
import kotlinx.coroutines.currentCoroutineContext
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.reactive.asFlow
import kotlinx.coroutines.reactive.awaitLast
import kotlinx.coroutines.reactor.asFlux
import kotlinx.coroutines.reactor.mono
import org.springframework.transaction.ReactiveTransaction
import java.util.*
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.EmptyCoroutineContext
/**
* Coroutines variant of [TransactionalOperator.transactional] as a [Flow] extension.
@ -31,8 +34,8 @@ import org.springframework.transaction.ReactiveTransaction
* @author Sebastien Deleuze
* @since 5.2
*/
fun <T : Any> Flow<T>.transactional(operator: TransactionalOperator): Flow<T> =
operator.transactional(asFlux()).asFlow()
fun <T : Any> Flow<T>.transactional(operator: TransactionalOperator, context: CoroutineContext = EmptyCoroutineContext): Flow<T> =
operator.transactional(asFlux(context)).asFlow()
/**
* Coroutines variant of [TransactionalOperator.execute] with a suspending lambda
@ -42,6 +45,8 @@ fun <T : Any> Flow<T>.transactional(operator: TransactionalOperator): Flow<T> =
* @author Mark Paluch
* @since 5.2
*/
suspend fun <T> TransactionalOperator.executeAndAwait(f: suspend (ReactiveTransaction) -> T): T =
execute { status -> mono(Dispatchers.Unconfined) { f(status) } }.map { value -> Optional.ofNullable(value) }
suspend fun <T> TransactionalOperator.executeAndAwait(f: suspend (ReactiveTransaction) -> T): T {
val context = currentCoroutineContext().minusKey(Job.Key)
return execute { status -> mono(context) { f(status) } }.map { value -> Optional.ofNullable(value) }
.defaultIfEmpty(Optional.empty()).awaitLast().orElse(null)
}

View File

@ -21,12 +21,15 @@ import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.toList
import kotlinx.coroutines.runBlocking
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.fail
import kotlinx.coroutines.withContext
import org.assertj.core.api.Assertions.*
import org.junit.jupiter.api.Test
import org.springframework.aop.framework.ProxyFactory
import org.springframework.transaction.interceptor.TransactionInterceptor
import org.springframework.transaction.testfixture.ReactiveCallCountingTransactionManager
import kotlin.coroutines.AbstractCoroutineContextElement
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.coroutineContext
/**
* @author Sebastien Deleuze
@ -118,6 +121,36 @@ class CoroutinesAnnotationTransactionInterceptorTests {
assertReactiveGetTransactionAndCommitCount(1)
}
@Test
fun suspendingValueSuccessWithContext() {
val proxyFactory = ProxyFactory()
proxyFactory.setTarget(TestWithCoroutines())
proxyFactory.addAdvice(TransactionInterceptor(rtm, source))
val proxy = proxyFactory.proxy as TestWithCoroutines
assertThat(runBlocking {
withExampleContext("context") {
proxy.suspendingValueSuccessWithContext()
}
}).isEqualTo("context")
assertReactiveGetTransactionAndCommitCount(1)
}
@Test
fun suspendingValueFailureWithContext() {
val proxyFactory = ProxyFactory()
proxyFactory.setTarget(TestWithCoroutines())
proxyFactory.addAdvice(TransactionInterceptor(rtm, source))
val proxy = proxyFactory.proxy as TestWithCoroutines
assertThatIllegalStateException().isThrownBy {
runBlocking {
withExampleContext("context") {
proxy.suspendingValueFailureWithContext()
}
}
}.withMessage("context")
assertReactiveGetTransactionAndRollbackCount(1)
}
private fun assertReactiveGetTransactionAndCommitCount(expectedCount: Int) {
assertThat(rtm.begun).isEqualTo(expectedCount)
assertThat(rtm.commits).isEqualTo(expectedCount)
@ -166,5 +199,27 @@ class CoroutinesAnnotationTransactionInterceptorTests {
emit("foo")
}
}
open suspend fun suspendingValueSuccessWithContext(): String {
delay(10)
return coroutineContext[ExampleContext.Key].toString()
}
open suspend fun suspendingValueFailureWithContext(): String {
delay(10)
throw IllegalStateException(coroutineContext[ExampleContext.Key].toString())
}
}
}
data class ExampleContext(val value: String) : AbstractCoroutineContextElement(ExampleContext) {
companion object Key : CoroutineContext.Key<ExampleContext>
override fun toString(): String = value
}
private suspend fun withExampleContext(inputValue: String, f: suspend () -> String) =
withContext(ExampleContext(inputValue)) {
f()
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,6 +16,7 @@
package org.springframework.transaction.reactive
import kotlinx.coroutines.currentCoroutineContext
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.toList
@ -23,6 +24,8 @@ import kotlinx.coroutines.runBlocking
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.Test
import org.springframework.transaction.support.DefaultTransactionDefinition
import kotlin.coroutines.AbstractCoroutineContextElement
import kotlin.coroutines.CoroutineContext
class TransactionalOperatorExtensionsTests {
@ -107,4 +110,50 @@ class TransactionalOperatorExtensionsTests {
}
}
}
@Test
fun coroutineContextWithSuspendingFunction() {
val operator = TransactionalOperator.create(tm, DefaultTransactionDefinition())
runBlocking(User(role = "admin")) {
try {
operator.executeAndAwait {
delay(1)
val currentUser = currentCoroutineContext()[User]
assertThat(currentUser).isNotNull()
assertThat(currentUser!!.role).isEqualTo("admin")
throw IllegalStateException()
}
} catch (e: IllegalStateException) {
assertThat(tm.commit).isFalse()
assertThat(tm.rollback).isTrue()
return@runBlocking
}
}
}
@Test
fun coroutineContextWithFlow() {
val operator = TransactionalOperator.create(tm, DefaultTransactionDefinition())
val flow = flow<Int> {
delay(1)
val currentUser = currentCoroutineContext()[User]
assertThat(currentUser).isNotNull()
assertThat(currentUser!!.role).isEqualTo("admin")
throw IllegalStateException()
}
runBlocking(User(role = "admin")) {
try {
flow.transactional(operator, coroutineContext).toList()
} catch (e: IllegalStateException) {
assertThat(tm.commit).isFalse()
assertThat(tm.rollback).isTrue()
return@runBlocking
}
}
}
private data class User(val role: String) : AbstractCoroutineContextElement(User) {
companion object Key : CoroutineContext.Key<User>
}
}