diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/KafkaProducer.java b/clients/src/main/java/org/apache/kafka/clients/producer/KafkaProducer.java index 71d201b71f9..c4ace64b0e5 100644 --- a/clients/src/main/java/org/apache/kafka/clients/producer/KafkaProducer.java +++ b/clients/src/main/java/org/apache/kafka/clients/producer/KafkaProducer.java @@ -78,6 +78,7 @@ import org.apache.kafka.common.telemetry.internals.ClientTelemetryReporter; import org.apache.kafka.common.telemetry.internals.ClientTelemetryUtils; import org.apache.kafka.common.utils.AppInfoParser; import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.ProducerIdAndEpoch; import org.apache.kafka.common.utils.Time; import org.apache.kafka.common.utils.Timer; import org.apache.kafka.common.utils.Utils; @@ -804,7 +805,8 @@ public class KafkaProducer implements Producer { flush(); transactionManager.prepareTransaction(); producerMetrics.recordPrepareTxn(time.nanoseconds() - now); - return transactionManager.preparedTransactionState(); + ProducerIdAndEpoch producerIdAndEpoch = transactionManager.preparedTransactionState(); + return new PreparedTxnState(producerIdAndEpoch.producerId, producerIdAndEpoch.epoch); } /** @@ -908,7 +910,8 @@ public class KafkaProducer implements Producer { } // Get the current prepared transaction state - PreparedTxnState currentPreparedState = transactionManager.preparedTransactionState(); + ProducerIdAndEpoch currentProducerIdAndEpoch = transactionManager.preparedTransactionState(); + PreparedTxnState currentPreparedState = new PreparedTxnState(currentProducerIdAndEpoch.producerId, currentProducerIdAndEpoch.epoch); // Compare the prepared transaction state token and commit or abort accordingly if (currentPreparedState.equals(preparedTxnState)) { diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java index 5d83cbc0b1b..969085809e6 100644 --- a/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java +++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java @@ -23,7 +23,6 @@ import org.apache.kafka.clients.RequestCompletionHandler; import org.apache.kafka.clients.consumer.CommitFailedException; import org.apache.kafka.clients.consumer.ConsumerGroupMetadata; import org.apache.kafka.clients.consumer.OffsetAndMetadata; -import org.apache.kafka.clients.producer.PreparedTxnState; import org.apache.kafka.clients.producer.Producer; import org.apache.kafka.clients.producer.ProducerRecord; import org.apache.kafka.common.KafkaException; @@ -147,7 +146,7 @@ public class TransactionManager { private volatile long latestFinalizedFeaturesEpoch = -1; private volatile boolean isTransactionV2Enabled = false; private final boolean enable2PC; - private volatile PreparedTxnState preparedTxnState; + private volatile ProducerIdAndEpoch preparedTxnState = ProducerIdAndEpoch.NONE; private enum State { UNINITIALIZED, @@ -230,7 +229,6 @@ public class TransactionManager { this.txnPartitionMap = new TxnPartitionMap(logContext); this.apiVersions = apiVersions; this.enable2PC = enable2PC; - this.preparedTxnState = new PreparedTxnState(); } /** @@ -348,8 +346,8 @@ public class TransactionManager { throwIfPendingState("prepareTransaction"); maybeFailWithError(); transitionTo(State.PREPARED_TRANSACTION); - this.preparedTxnState = new PreparedTxnState( - this.producerIdAndEpoch.producerId + ":" + + this.preparedTxnState = new ProducerIdAndEpoch( + this.producerIdAndEpoch.producerId, this.producerIdAndEpoch.epoch ); } @@ -1343,7 +1341,7 @@ public class TransactionManager { newPartitionsInTransaction.clear(); pendingPartitionsInTransaction.clear(); partitionsInTransaction.clear(); - preparedTxnState = new PreparedTxnState(); + preparedTxnState = ProducerIdAndEpoch.NONE; } abstract class TxnRequestHandler implements RequestCompletionHandler { @@ -1500,8 +1498,21 @@ public class TransactionManager { ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(initProducerIdResponse.data().producerId(), initProducerIdResponse.data().producerEpoch()); setProducerIdAndEpoch(producerIdAndEpoch); - // TO_DO Add code to handle transition to prepared_txn when keepPrepared = true - transitionTo(State.READY); + // If this is a transaction with keepPreparedTxn=true, transition directly + // to PREPARED_TRANSACTION state IFF there is an ongoing transaction. + if (builder.data.keepPreparedTxn() && + initProducerIdResponse.data().ongoingTxnProducerId() != RecordBatch.NO_PRODUCER_ID + ) { + transitionTo(State.PREPARED_TRANSACTION); + // Update the preparedTxnState with the ongoing pid and epoch from the response. + // This will be used to complete the transaction later. + TransactionManager.this.preparedTxnState = new ProducerIdAndEpoch( + initProducerIdResponse.data().ongoingTxnProducerId(), + initProducerIdResponse.data().ongoingTxnProducerEpoch() + ); + } else { + transitionTo(State.READY); + } lastError = null; if (this.isEpochBump) { resetSequenceNumbers(); @@ -1958,13 +1969,13 @@ public class TransactionManager { } /** - * Returns a PreparedTxnState object containing the producer ID and epoch + * Returns a ProducerIdAndEpoch object containing the producer ID and epoch * of the ongoing transaction. * This is used when preparing a transaction for a two-phase commit. * - * @return a PreparedTxnState with the current producer ID and epoch + * @return a ProducerIdAndEpoch with the current producer ID and epoch. */ - public PreparedTxnState preparedTransactionState() { + public ProducerIdAndEpoch preparedTransactionState() { return this.preparedTxnState; } } diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/KafkaProducerTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/KafkaProducerTest.java index 24c1574f3f9..a52d3c6f0b2 100644 --- a/clients/src/test/java/org/apache/kafka/clients/producer/KafkaProducerTest.java +++ b/clients/src/test/java/org/apache/kafka/clients/producer/KafkaProducerTest.java @@ -94,6 +94,7 @@ import org.apache.kafka.common.telemetry.internals.ClientTelemetrySender; import org.apache.kafka.common.utils.LogCaptureAppender; import org.apache.kafka.common.utils.LogContext; import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.ProducerIdAndEpoch; import org.apache.kafka.common.utils.Time; import org.apache.kafka.test.MockMetricsReporter; import org.apache.kafka.test.MockPartitioner; @@ -154,7 +155,6 @@ import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; @@ -1453,12 +1453,15 @@ public class KafkaProducerTest { doNothing().when(ctx.transactionManager).prepareTransaction(); - PreparedTxnState expectedState = mock(PreparedTxnState.class); - when(ctx.transactionManager.preparedTransactionState()).thenReturn(expectedState); + long expectedProducerId = 12345L; + short expectedEpoch = 5; + ProducerIdAndEpoch expectedProducerIdAndEpoch = new ProducerIdAndEpoch(expectedProducerId, expectedEpoch); + when(ctx.transactionManager.preparedTransactionState()).thenReturn(expectedProducerIdAndEpoch); try (KafkaProducer producer = ctx.newKafkaProducer()) { PreparedTxnState returned = producer.prepareTransaction(); - assertSame(expectedState, returned); + assertEquals(expectedProducerId, returned.producerId()); + assertEquals(expectedEpoch, returned.epoch()); verify(ctx.transactionManager).prepareTransaction(); verify(ctx.accumulator).beginFlush(); @@ -1612,11 +1615,11 @@ public class KafkaProducerTest { // Create prepared states with matching values long producerId = 12345L; short epoch = 5; - PreparedTxnState currentState = new PreparedTxnState(producerId, epoch); PreparedTxnState inputState = new PreparedTxnState(producerId, epoch); + ProducerIdAndEpoch currentProducerIdAndEpoch = new ProducerIdAndEpoch(producerId, epoch); // Set up the transaction manager to return the prepared state - when(ctx.transactionManager.preparedTransactionState()).thenReturn(currentState); + when(ctx.transactionManager.preparedTransactionState()).thenReturn(currentProducerIdAndEpoch); // Should trigger commit when states match TransactionalRequestResult commitResult = mock(TransactionalRequestResult.class); @@ -1650,11 +1653,11 @@ public class KafkaProducerTest { // Create txn prepared states with different values long producerId = 12345L; short epoch = 5; - PreparedTxnState currentState = new PreparedTxnState(producerId, epoch); PreparedTxnState inputState = new PreparedTxnState(producerId + 1, epoch); + ProducerIdAndEpoch currentProducerIdAndEpoch = new ProducerIdAndEpoch(producerId, epoch); // Set up the transaction manager to return the prepared state - when(ctx.transactionManager.preparedTransactionState()).thenReturn(currentState); + when(ctx.transactionManager.preparedTransactionState()).thenReturn(currentProducerIdAndEpoch); // Should trigger abort when states don't match TransactionalRequestResult abortResult = mock(TransactionalRequestResult.class); diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java index 4668a91ed04..494c715df79 100644 --- a/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java +++ b/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java @@ -23,7 +23,6 @@ import org.apache.kafka.clients.NodeApiVersions; import org.apache.kafka.clients.consumer.CommitFailedException; import org.apache.kafka.clients.consumer.ConsumerGroupMetadata; import org.apache.kafka.clients.consumer.OffsetAndMetadata; -import org.apache.kafka.clients.producer.PreparedTxnState; import org.apache.kafka.clients.producer.RecordMetadata; import org.apache.kafka.common.KafkaException; import org.apache.kafka.common.Node; @@ -4068,11 +4067,92 @@ public class TransactionManagerTest { transactionManager.prepareTransaction(); assertTrue(transactionManager.isPrepared()); - PreparedTxnState preparedState = transactionManager.preparedTransactionState(); - // Validate the state contains the correct serialized producer ID and epoch - assertEquals(producerId + ":" + epoch, preparedState.toString()); - assertEquals(producerId, preparedState.producerId()); - assertEquals(epoch, preparedState.epoch()); + ProducerIdAndEpoch preparedState = transactionManager.preparedTransactionState(); + // Validate the state contains the correct producer ID and epoch + assertEquals(producerId, preparedState.producerId); + assertEquals(epoch, preparedState.epoch); + } + + @Test + public void testInitPidResponseWithKeepPreparedTrueAndOngoingTransaction() { + // Initialize transaction manager with 2PC enabled + initializeTransactionManager(Optional.of(transactionalId), true, true); + + // Start initializeTransactions with keepPreparedTxn=true + TransactionalRequestResult result = transactionManager.initializeTransactions(true); + + // Prepare coordinator response + prepareFindCoordinatorResponse(Errors.NONE, false, CoordinatorType.TRANSACTION, transactionalId); + runUntil(() -> transactionManager.coordinator(CoordinatorType.TRANSACTION) != null); + + // Simulate InitProducerId response with ongoing transaction + long ongoingPid = 12345L; + short ongoingEpoch = 5; + prepareInitPidResponse( + Errors.NONE, + false, + producerId, + epoch, + true, + true, + ongoingPid, + ongoingEpoch + ); + + runUntil(transactionManager::hasProducerId); + transactionManager.maybeUpdateTransactionV2Enabled(true); + + result.await(); + assertTrue(result.isSuccessful()); + + // Verify transaction manager transitioned to PREPARED_TRANSACTION state + assertTrue(transactionManager.isPrepared()); + + // Verify preparedTxnState was set with ongoing producer ID and epoch + ProducerIdAndEpoch preparedState = transactionManager.preparedTransactionState(); + assertNotNull(preparedState); + assertEquals(ongoingPid, preparedState.producerId); + assertEquals(ongoingEpoch, preparedState.epoch); + } + + @Test + public void testInitPidResponseWithKeepPreparedTrueAndNoOngoingTransaction() { + // Initialize transaction manager without 2PC enabled + // keepPrepared can be true even when enable2Pc is false, and we expect the same behavior + initializeTransactionManager(Optional.of(transactionalId), true, false); + + // Start initializeTransactions with keepPreparedTxn=true + TransactionalRequestResult result = transactionManager.initializeTransactions(true); + + // Prepare coordinator response + prepareFindCoordinatorResponse(Errors.NONE, false, CoordinatorType.TRANSACTION, transactionalId); + runUntil(() -> transactionManager.coordinator(CoordinatorType.TRANSACTION) != null); + + // Simulate InitProducerId response without ongoing transaction + prepareInitPidResponse( + Errors.NONE, + false, + producerId, + epoch, + true, + false, + RecordBatch.NO_PRODUCER_ID, + RecordBatch.NO_PRODUCER_EPOCH + ); + + runUntil(transactionManager::hasProducerId); + transactionManager.maybeUpdateTransactionV2Enabled(true); + + result.await(); + assertTrue(result.isSuccessful()); + + // Verify transaction manager transitioned to READY state (not PREPARED_TRANSACTION) + assertFalse(transactionManager.isPrepared()); + assertTrue(transactionManager.isReady()); + + // Verify preparedTxnState was not set or is empty + ProducerIdAndEpoch preparedState = transactionManager.preparedTransactionState(); + assertEquals(ProducerIdAndEpoch.NONE, preparedState); } private void prepareAddPartitionsToTxn(final Map errors) {