diff --git a/spring-tx/src/main/java/org/springframework/transaction/support/TransactionSynchronizationManager.java b/spring-tx/src/main/java/org/springframework/transaction/support/TransactionSynchronizationManager.java index 06c6150978..51b861a8c2 100644 --- a/spring-tx/src/main/java/org/springframework/transaction/support/TransactionSynchronizationManager.java +++ b/spring-tx/src/main/java/org/springframework/transaction/support/TransactionSynchronizationManager.java @@ -158,13 +158,73 @@ public abstract class TransactionSynchronizationManager { /** * Bind the given resource for the given key to the current thread. + *

Note: Any bound resource needs to get explicitly unbound through + * {@link #unbindResource}. For automatic unbinding after transaction + * completion, use {@link #bindSynchronizedResource} instead. * @param key the key to bind the value to (usually the resource factory) * @param value the value to bind (usually the active resource object) * @throws IllegalStateException if there is already a value bound to the thread * @see ResourceTransactionManager#getResourceFactory() + * @see #bindSynchronizedResource */ public static void bindResource(Object key, Object value) throws IllegalStateException { Object actualKey = TransactionSynchronizationUtils.unwrapResourceIfNecessary(key); + Object oldValue = doBindResource(actualKey, value); + if (oldValue != null) { + throw new IllegalStateException( + "Already value [" + oldValue + "] for key [" + actualKey + "] bound to thread"); + } + } + + /** + * Bind the given resource for the given key to the current thread, + * synchronizing it with the current transaction for automatic unbinding + * after transaction completion. + *

This is effectively a programmatic way to register a transaction-scoped + * resource, similar to the BeanFactory-driven {@link SimpleTransactionScope}. + *

An existing value bound for the given key will be preserved and re-bound + * after transaction completion, restoring the state before this bind call. + * @param key the key to bind the value to (usually the resource factory) + * @param value the value to bind (usually the active resource object) + * @throws IllegalStateException if transaction synchronization is not active + * @since 7.0 + * @see #bindResource + * @see #registerSynchronization + */ + public static void bindSynchronizedResource(Object key, Object value) throws IllegalStateException { + Set synchs = synchronizations.get(); + if (synchs == null) { + throw new IllegalStateException("Transaction synchronization is not active"); + } + Object actualKey = TransactionSynchronizationUtils.unwrapResourceIfNecessary(key); + Object oldValue = doBindResource(actualKey, value); + synchs.add(new TransactionSynchronization() { + @Override + public void suspend() { + doUnbindResource(actualKey); + } + @Override + public void resume() { + Object existingValue = doBindResource(actualKey, value); + if (existingValue != null) { + throw new IllegalStateException( + "Unexpected value [" + existingValue + "] for key [" + actualKey + "] bound on resume"); + } + } + @Override + public void afterCompletion(int status) { + doUnbindResource(actualKey); + if (oldValue != null) { + doBindResource(actualKey, oldValue); + } + } + }); + } + + /** + * Actually bind the given resource for the given key to the current thread. + */ + private static @Nullable Object doBindResource(Object actualKey, Object value) { Assert.notNull(value, "Value must not be null"); Map map = resources.get(); // set ThreadLocal Map if none found @@ -177,18 +237,19 @@ public abstract class TransactionSynchronizationManager { if (oldValue instanceof ResourceHolder resourceHolder && resourceHolder.isVoid()) { oldValue = null; } - if (oldValue != null) { - throw new IllegalStateException( - "Already value [" + oldValue + "] for key [" + actualKey + "] bound to thread"); - } + return oldValue; } /** * Unbind a resource for the given key from the current thread. + *

This explicit step is only necessary with {@link #bindResource}. + * For automatic unbinding, consider {@link #bindSynchronizedResource}. * @param key the key to unbind (usually the resource factory) * @return the previously bound value (usually the active resource object) * @throws IllegalStateException if there is no value bound to the thread * @see ResourceTransactionManager#getResourceFactory() + * @see #bindResource + * @see #unbindResourceIfPossible */ public static Object unbindResource(Object key) throws IllegalStateException { Object actualKey = TransactionSynchronizationUtils.unwrapResourceIfNecessary(key); @@ -201,8 +262,12 @@ public abstract class TransactionSynchronizationManager { /** * Unbind a resource for the given key from the current thread. + *

This explicit step is only necessary with {@link #bindResource}. + * For automatic unbinding, consider {@link #bindSynchronizedResource}. * @param key the key to unbind (usually the resource factory) * @return the previously bound value, or {@code null} if none bound + * @see #bindResource + * @see #unbindResource */ public static @Nullable Object unbindResourceIfPossible(Object key) { Object actualKey = TransactionSynchronizationUtils.unwrapResourceIfNecessary(key); diff --git a/spring-tx/src/test/java/org/springframework/transaction/support/SimpleTransactionScopeTests.java b/spring-tx/src/test/java/org/springframework/transaction/support/SimpleTransactionScopeTests.java index 5e1215846a..24894b460e 100644 --- a/spring-tx/src/test/java/org/springframework/transaction/support/SimpleTransactionScopeTests.java +++ b/spring-tx/src/test/java/org/springframework/transaction/support/SimpleTransactionScopeTests.java @@ -30,6 +30,7 @@ import org.springframework.transaction.testfixture.CallCountingTransactionManage import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; /** * @author Juergen Hoeller @@ -54,13 +55,11 @@ class SimpleTransactionScopeTests { context.refresh(); - assertThatExceptionOfType(BeanCreationException.class).isThrownBy(() -> - context.getBean(TestBean.class)) - .withCauseInstanceOf(IllegalStateException.class); + assertThatExceptionOfType(BeanCreationException.class).isThrownBy(() -> context.getBean(TestBean.class)) + .withCauseInstanceOf(IllegalStateException.class); - assertThatExceptionOfType(BeanCreationException.class).isThrownBy(() -> - context.getBean(DerivedTestBean.class)) - .withCauseInstanceOf(IllegalStateException.class); + assertThatExceptionOfType(BeanCreationException.class).isThrownBy(() -> context.getBean(DerivedTestBean.class)) + .withCauseInstanceOf(IllegalStateException.class); TestBean bean1; DerivedTestBean bean2; @@ -99,13 +98,11 @@ class SimpleTransactionScopeTests { assertThat(bean2b.wasDestroyed()).isTrue(); assertThat(TransactionSynchronizationManager.getResourceMap()).isEmpty(); - assertThatExceptionOfType(BeanCreationException.class).isThrownBy(() -> - context.getBean(TestBean.class)) - .withCauseInstanceOf(IllegalStateException.class); + assertThatExceptionOfType(BeanCreationException.class).isThrownBy(() -> context.getBean(TestBean.class)) + .withCauseInstanceOf(IllegalStateException.class); - assertThatExceptionOfType(BeanCreationException.class).isThrownBy(() -> - context.getBean(DerivedTestBean.class)) - .withCauseInstanceOf(IllegalStateException.class); + assertThatExceptionOfType(BeanCreationException.class).isThrownBy(() -> context.getBean(DerivedTestBean.class)) + .withCauseInstanceOf(IllegalStateException.class); } @Test @@ -175,4 +172,46 @@ class SimpleTransactionScopeTests { } } + @Test + void bindSynchronizedResource() { + CallCountingTransactionManager tm = new CallCountingTransactionManager(); + TransactionTemplate tt = new TransactionTemplate(tm); + + tt.execute(status -> { + TestBean tb = new TestBean(); + TransactionSynchronizationManager.bindSynchronizedResource("tb", tb); + assertThat(TransactionSynchronizationManager.hasResource("tb")).isTrue(); + assertThat(TransactionSynchronizationManager.getResource("tb")).isSameAs(tb); + return null; + }); + assertThat(TransactionSynchronizationManager.hasResource("tb")).isFalse(); + } + + @Test + void bindSynchronizedResourceWithOldValue() { + CallCountingTransactionManager tm = new CallCountingTransactionManager(); + TransactionTemplate tt = new TransactionTemplate(tm); + + TestBean oldValue = new TestBean(); + TransactionSynchronizationManager.bindResource("tb", oldValue); + + tt.execute(status -> { + TestBean tb = new TestBean(); + TransactionSynchronizationManager.bindSynchronizedResource("tb", tb); + assertThat(TransactionSynchronizationManager.hasResource("tb")).isTrue(); + assertThat(TransactionSynchronizationManager.getResource("tb")).isSameAs(tb); + return null; + }); + assertThat(TransactionSynchronizationManager.hasResource("tb")).isTrue(); + assertThat(TransactionSynchronizationManager.getResource("tb")).isSameAs(oldValue); + TransactionSynchronizationManager.unbindResource("tb"); + } + + @Test + void bindSynchronizedResourceWithoutTransaction() { + assertThatIllegalStateException().isThrownBy( + () -> TransactionSynchronizationManager.bindSynchronizedResource("tb", new TestBean())); + assertThat(TransactionSynchronizationManager.hasResource("tb")).isFalse(); + } + }