diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/R2dbcTransactionManager.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/R2dbcTransactionManager.java index f3edbf16b9a..38ef488cae5 100644 --- a/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/R2dbcTransactionManager.java +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/R2dbcTransactionManager.java @@ -209,7 +209,7 @@ public class R2dbcTransactionManager extends AbstractReactiveTransactionManager connectionMono = Mono.just(txObject.getConnectionHolder().getConnection()); } - return connectionMono.flatMap(con -> doBegin(definition, con) + return connectionMono.flatMap(con -> doBegin(con, txObject, definition) .then(prepareTransactionalConnection(con, definition)) .doOnSuccess(v -> { txObject.getConnectionHolder().setTransactionActive(true); @@ -233,7 +233,10 @@ public class R2dbcTransactionManager extends AbstractReactiveTransactionManager }).then(); } - private Mono doBegin(TransactionDefinition definition, Connection con) { + private Mono doBegin( + Connection con, ConnectionFactoryTransactionObject transaction, TransactionDefinition definition) { + + transaction.setMustRestoreAutoCommit(con.isAutoCommit()); io.r2dbc.spi.TransactionDefinition transactionDefinition = createTransactionDefinition(definition); if (logger.isDebugEnabled()) { logger.debug("Starting R2DBC transaction on Connection [" + con + "] using [" + transactionDefinition + "]"); @@ -354,12 +357,22 @@ public class R2dbcTransactionManager extends AbstractReactiveTransactionManager if (logger.isDebugEnabled()) { logger.debug("Releasing R2DBC Connection [" + con + "] after transaction"); } + Mono restoreMono = Mono.empty(); + if (txObject.isMustRestoreAutoCommit() && !con.isAutoCommit()) { + restoreMono = Mono.from(con.setAutoCommit(true)); + if (logger.isDebugEnabled()) { + restoreMono = restoreMono.doOnError(ex -> + logger.debug(String.format("Error ignored during auto-commit restore: %s", ex))); + } + restoreMono = restoreMono.onErrorComplete(); + } Mono releaseMono = ConnectionFactoryUtils.releaseConnection(con, obtainConnectionFactory()); if (logger.isDebugEnabled()) { - releaseMono = releaseMono.doOnError( - ex -> logger.debug(String.format("Error ignored during cleanup: %s", ex))); + releaseMono = releaseMono.doOnError(ex -> + logger.debug(String.format("Error ignored during connection release: %s", ex))); } - return releaseMono.onErrorComplete(); + releaseMono = releaseMono.onErrorComplete(); + return restoreMono.then(releaseMono); } } finally { @@ -482,6 +495,8 @@ public class R2dbcTransactionManager extends AbstractReactiveTransactionManager private boolean newConnectionHolder; + private boolean mustRestoreAutoCommit; + @Nullable private String savepointName; @@ -507,6 +522,14 @@ public class R2dbcTransactionManager extends AbstractReactiveTransactionManager return (this.connectionHolder != null); } + public void setMustRestoreAutoCommit(boolean mustRestoreAutoCommit) { + this.mustRestoreAutoCommit = mustRestoreAutoCommit; + } + + public boolean isMustRestoreAutoCommit() { + return this.mustRestoreAutoCommit; + } + public boolean isTransactionActive() { return (this.connectionHolder != null && this.connectionHolder.isTransactionActive()); } diff --git a/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/R2dbcTransactionManagerUnitTests.java b/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/R2dbcTransactionManagerUnitTests.java index 05cc75cfc04..2dd2341cc71 100644 --- a/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/R2dbcTransactionManagerUnitTests.java +++ b/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/R2dbcTransactionManagerUnitTests.java @@ -81,11 +81,12 @@ class R2dbcTransactionManagerUnitTests { @Test void testSimpleTransaction() { - TestTransactionSynchronization sync = new TestTransactionSynchronization( - TransactionSynchronization.STATUS_COMMITTED); + when(connectionMock.isAutoCommit()).thenReturn(false); AtomicInteger commits = new AtomicInteger(); when(connectionMock.commitTransaction()).thenReturn( Mono.fromRunnable(commits::incrementAndGet)); + TestTransactionSynchronization sync = new TestTransactionSynchronization( + TransactionSynchronization.STATUS_COMMITTED); TransactionalOperator operator = TransactionalOperator.create(tm); @@ -98,6 +99,7 @@ class R2dbcTransactionManagerUnitTests { .verifyComplete(); assertThat(commits).hasValue(1); + verify(connectionMock).isAutoCommit(); verify(connectionMock).beginTransaction(any(io.r2dbc.spi.TransactionDefinition.class)); verify(connectionMock).commitTransaction(); verify(connectionMock).close(); @@ -131,8 +133,10 @@ class R2dbcTransactionManagerUnitTests { } @Test - void appliesTransactionDefinition() { + void appliesTransactionDefinitionAndAutoCommit() { + when(connectionMock.isAutoCommit()).thenReturn(true, false); when(connectionMock.commitTransaction()).thenReturn(Mono.empty()); + when(connectionMock.setAutoCommit(true)).thenReturn(Mono.empty()); DefaultTransactionDefinition definition = new DefaultTransactionDefinition(); definition.setName("my-transaction"); @@ -152,6 +156,7 @@ class R2dbcTransactionManagerUnitTests { verify(connectionMock).beginTransaction(txCaptor.capture()); verify(connectionMock, never()).setTransactionIsolationLevel(any()); verify(connectionMock).commitTransaction(); + verify(connectionMock).setAutoCommit(true); verify(connectionMock).close(); io.r2dbc.spi.TransactionDefinition def = txCaptor.getValue(); @@ -162,29 +167,8 @@ class R2dbcTransactionManagerUnitTests { } @Test - void doesNotSetIsolationLevelIfMatch() { - when(connectionMock.getTransactionIsolationLevel()).thenReturn( - IsolationLevel.READ_COMMITTED); - when(connectionMock.commitTransaction()).thenReturn(Mono.empty()); - - DefaultTransactionDefinition definition = new DefaultTransactionDefinition(); - definition.setIsolationLevel(TransactionDefinition.ISOLATION_READ_COMMITTED); - - TransactionalOperator operator = TransactionalOperator.create(tm, definition); - - ConnectionFactoryUtils.getConnection(connectionFactoryMock) - .as(operator::transactional) - .as(StepVerifier::create) - .expectNextCount(1) - .verifyComplete(); - - verify(connectionMock).beginTransaction(any(io.r2dbc.spi.TransactionDefinition.class)); - verify(connectionMock, never()).setTransactionIsolationLevel(any()); - verify(connectionMock).commitTransaction(); - } - - @Test - void doesNotSetAutoCommitDisabled() { + void doesNotSetAutoCommitIfRestoredByDriver() { + when(connectionMock.isAutoCommit()).thenReturn(true, true); when(connectionMock.commitTransaction()).thenReturn(Mono.empty()); DefaultTransactionDefinition definition = new DefaultTransactionDefinition(); @@ -204,6 +188,7 @@ class R2dbcTransactionManagerUnitTests { @Test void appliesReadOnly() { + when(connectionMock.isAutoCommit()).thenReturn(false); when(connectionMock.commitTransaction()).thenReturn(Mono.empty()); when(connectionMock.setTransactionIsolationLevel(any())).thenReturn(Mono.empty()); Statement statement = mock(); @@ -222,6 +207,7 @@ class R2dbcTransactionManagerUnitTests { .expectNextCount(1) .verifyComplete(); + verify(connectionMock).isAutoCommit(); verify(connectionMock).beginTransaction(any(io.r2dbc.spi.TransactionDefinition.class)); verify(connectionMock).createStatement("SET TRANSACTION READ ONLY"); verify(connectionMock).commitTransaction(); @@ -231,7 +217,9 @@ class R2dbcTransactionManagerUnitTests { @Test void testCommitFails() { - when(connectionMock.commitTransaction()).thenReturn(Mono.defer(() -> Mono.error(new R2dbcBadGrammarException("Commit should fail")))); + when(connectionMock.isAutoCommit()).thenReturn(false); + when(connectionMock.commitTransaction()).thenReturn(Mono.defer(() -> + Mono.error(new R2dbcBadGrammarException("Commit should fail")))); when(connectionMock.rollbackTransaction()).thenReturn(Mono.empty()); TransactionalOperator operator = TransactionalOperator.create(tm); @@ -242,6 +230,7 @@ class R2dbcTransactionManagerUnitTests { .as(StepVerifier::create) .verifyError(BadSqlGrammarException.class); + verify(connectionMock).isAutoCommit(); verify(connectionMock).beginTransaction(any(io.r2dbc.spi.TransactionDefinition.class)); verify(connectionMock).createStatement("foo"); verify(connectionMock).commitTransaction(); @@ -252,6 +241,7 @@ class R2dbcTransactionManagerUnitTests { @Test void testRollback() { + when(connectionMock.isAutoCommit()).thenReturn(false); AtomicInteger commits = new AtomicInteger(); when(connectionMock.commitTransaction()).thenReturn( Mono.fromRunnable(commits::incrementAndGet)); @@ -269,6 +259,7 @@ class R2dbcTransactionManagerUnitTests { assertThat(commits).hasValue(0); assertThat(rollbacks).hasValue(1); + verify(connectionMock).isAutoCommit(); verify(connectionMock).beginTransaction(any(io.r2dbc.spi.TransactionDefinition.class)); verify(connectionMock).rollbackTransaction(); verify(connectionMock).close(); @@ -278,7 +269,8 @@ class R2dbcTransactionManagerUnitTests { @Test @SuppressWarnings("unchecked") void testRollbackFails() { - when(connectionMock.rollbackTransaction()).thenReturn(Mono.defer(() -> Mono.error(new R2dbcBadGrammarException("Commit should fail"))), Mono.empty()); + when(connectionMock.rollbackTransaction()).thenReturn(Mono.defer(() -> + Mono.error(new R2dbcBadGrammarException("Commit should fail"))), Mono.empty()); TransactionalOperator operator = TransactionalOperator.create(tm); operator.execute(reactiveTransaction -> { @@ -287,6 +279,7 @@ class R2dbcTransactionManagerUnitTests { .doOnNext(connection -> connection.createStatement("foo")).then(); }).as(StepVerifier::create).verifyError(BadSqlGrammarException.class); + verify(connectionMock).isAutoCommit(); verify(connectionMock).beginTransaction(any(io.r2dbc.spi.TransactionDefinition.class)); verify(connectionMock).createStatement("foo"); verify(connectionMock, never()).commitTransaction(); @@ -298,7 +291,8 @@ class R2dbcTransactionManagerUnitTests { @Test @SuppressWarnings("unchecked") void testConnectionReleasedWhenRollbackFails() { - when(connectionMock.rollbackTransaction()).thenReturn(Mono.defer(() -> Mono.error(new R2dbcBadGrammarException("Rollback should fail"))), Mono.empty()); + when(connectionMock.rollbackTransaction()).thenReturn(Mono.defer(() -> + Mono.error(new R2dbcBadGrammarException("Rollback should fail"))), Mono.empty()); when(connectionMock.setTransactionIsolationLevel(any())).thenReturn(Mono.empty()); TransactionalOperator operator = TransactionalOperator.create(tm); @@ -319,6 +313,7 @@ class R2dbcTransactionManagerUnitTests { @Test void testTransactionSetRollbackOnly() { + when(connectionMock.isAutoCommit()).thenReturn(false); when(connectionMock.rollbackTransaction()).thenReturn(Mono.empty()); TestTransactionSynchronization sync = new TestTransactionSynchronization( TransactionSynchronization.STATUS_ROLLED_BACK); @@ -334,6 +329,7 @@ class R2dbcTransactionManagerUnitTests { }).then(); }).as(StepVerifier::create).verifyComplete(); + verify(connectionMock).isAutoCommit(); verify(connectionMock).beginTransaction(any(io.r2dbc.spi.TransactionDefinition.class)); verify(connectionMock).rollbackTransaction(); verify(connectionMock).close();