diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala index 30f7fb6cf86..2764de5cd6c 100644 --- a/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala +++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala @@ -819,11 +819,9 @@ class TransactionCoordinator(txnConfig: TransactionConfig, } if (nextState == TransactionState.PREPARE_ABORT && isEpochFence) { - // We should clear the pending state to make way for the transition to PrepareAbort and also bump - // the epoch in the transaction metadata we are about to append. + // We should clear the pending state to make way for the transition to PrepareAbort txnMetadata.pendingState = None - txnMetadata.producerEpoch = producerEpoch - txnMetadata.lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH + // For TV2+, don't manually set the epoch - let prepareAbortOrCommit handle it naturally. } nextProducerIdOrErrors.flatMap { diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala index eea5db86bc6..12c36f61761 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala @@ -1267,6 +1267,142 @@ class TransactionCoordinatorTest { any()) } + @Test + def shouldNotCauseEpochOverflowWhenInitPidDuringOngoingTxnV2(): Unit = { + // When InitProducerId is called with an ongoing transaction at epoch 32766 (Short.MaxValue - 1), + // it should not cause an epoch overflow by incrementing twice. + // The only true increment happens in prepareAbortOrCommit + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, + (Short.MaxValue - 1).toShort, (Short.MaxValue - 2).toShort, txnTimeoutMs, TransactionState.ONGOING, partitions, time.milliseconds(), time.milliseconds(), TV_2) + + when(transactionManager.validateTransactionTimeoutMs(anyBoolean(), anyInt())) + .thenReturn(true) + when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) + .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) + when(transactionManager.transactionVersionLevel()).thenReturn(TV_2) + + // Capture the transition metadata to verify epoch increments + val capturedTxnTransitMetadata: ArgumentCaptor[TxnTransitMetadata] = ArgumentCaptor.forClass(classOf[TxnTransitMetadata]) + when(transactionManager.appendTransactionToLog( + ArgumentMatchers.eq(transactionalId), + ArgumentMatchers.eq(coordinatorEpoch), + capturedTxnTransitMetadata.capture(), + capturedErrorsCallback.capture(), + any(), + any()) + ).thenAnswer(invocation => { + val transitMetadata = invocation.getArgument[TxnTransitMetadata](2) + // Simulate the metadata update that would happen in the real appendTransactionToLog + txnMetadata.completeTransitionTo(transitMetadata) + capturedErrorsCallback.getValue.apply(Errors.NONE) + }) + + // Handle InitProducerId with ongoing transaction at epoch 32766 + coordinator.handleInitProducerId( + transactionalId, + txnTimeoutMs, + enableTwoPCFlag = false, + keepPreparedTxn = false, + None, + initProducerIdMockCallback + ) + + // Verify that the epoch did not overflow (should be Short.MaxValue = 32767, not negative) + assertEquals(Short.MaxValue, txnMetadata.producerEpoch) + assertEquals(TransactionState.PREPARE_ABORT, txnMetadata.state) + + verify(transactionManager).validateTransactionTimeoutMs(anyBoolean(), anyInt()) + verify(transactionManager, times(3)).getTransactionState(ArgumentMatchers.eq(transactionalId)) + verify(transactionManager).appendTransactionToLog( + ArgumentMatchers.eq(transactionalId), + ArgumentMatchers.eq(coordinatorEpoch), + any[TxnTransitMetadata], + any(), + any(), + any()) + } + + @Test + def shouldHandleTimeoutAtEpochOverflowBoundaryCorrectlyTV2(): Unit = { + // Test the scenario where we have an ongoing transaction at epoch 32766 (Short.MaxValue - 1) + // and the producer crashes/times out. This test verifies that the timeout handling + // correctly manages the epoch overflow scenario without causing failures. + + val epochAtMaxBoundary = (Short.MaxValue - 1).toShort // 32766 + val now = time.milliseconds() + + // Create transaction metadata at the epoch boundary that would cause overflow IFF double-incremented + val txnMetadata = new TransactionMetadata( + transactionalId = transactionalId, + producerId = producerId, + prevProducerId = RecordBatch.NO_PRODUCER_ID, + nextProducerId = RecordBatch.NO_PRODUCER_ID, + producerEpoch = epochAtMaxBoundary, + lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, + txnTimeoutMs = txnTimeoutMs, + state = TransactionState.ONGOING, + topicPartitions = partitions, + txnStartTimestamp = now, + txnLastUpdateTimestamp = now, + clientTransactionVersion = TV_2 + ) + assertTrue(txnMetadata.isProducerEpochExhausted) + + // Mock the transaction manager to return our test transaction as timed out + when(transactionManager.timedOutTransactions()) + .thenReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, producerId, epochAtMaxBoundary))) + when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) + .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) + when(transactionManager.transactionVersionLevel()).thenReturn(TV_2) + + // Mock the append operation to simulate successful write and update the metadata + when(transactionManager.appendTransactionToLog( + ArgumentMatchers.eq(transactionalId), + ArgumentMatchers.eq(coordinatorEpoch), + any[TxnTransitMetadata], + capturedErrorsCallback.capture(), + any(), + any()) + ).thenAnswer(invocation => { + val transitMetadata = invocation.getArgument[TxnTransitMetadata](2) + // Simulate the metadata update that would happen in the real appendTransactionToLog + txnMetadata.completeTransitionTo(transitMetadata) + capturedErrorsCallback.getValue.apply(Errors.NONE) + }) + + // Track the actual behavior + var callbackInvoked = false + var resultError: Errors = null + var resultProducerId: Long = -1 + var resultEpoch: Short = -1 + + def checkOnEndTransactionComplete(txnIdAndPidEpoch: TransactionalIdAndProducerIdEpoch) + (error: Errors, newProducerId: Long, newProducerEpoch: Short): Unit = { + callbackInvoked = true + resultError = error + resultProducerId = newProducerId + resultEpoch = newProducerEpoch + } + + // Execute the timeout abort process + coordinator.abortTimedOutTransactions(checkOnEndTransactionComplete) + + assertTrue(callbackInvoked, "Callback should have been invoked") + assertEquals(Errors.NONE, resultError, "Expected no errors in the callback") + assertEquals(producerId, resultProducerId, "Expected producer ID to match") + assertEquals(Short.MaxValue, resultEpoch, "Expected producer epoch to be Short.MaxValue (32767) single epoch bump") + + // Verify the transaction metadata was correctly updated to the final epoch + assertEquals(Short.MaxValue, txnMetadata.producerEpoch, + s"Expected transaction metadata producer epoch to be ${Short.MaxValue} " + + s"after timeout handling, but was ${txnMetadata.producerEpoch}" + ) + + // Verify the basic flow was attempted + verify(transactionManager).timedOutTransactions() + verify(transactionManager, atLeast(1)).getTransactionState(ArgumentMatchers.eq(transactionalId)) + } + @Test def testInitProducerIdWithNoLastProducerData(): Unit = { // If the metadata doesn't include the previous producer data (for example, if it was written to the log by a broker