diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala index 7bc3c03391c..67d3d3d3624 100644 --- a/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala +++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala @@ -391,6 +391,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, producerEpoch: Short, partitions: collection.Set[TopicPartition], responseCallback: AddPartitionsCallback, + clientTransactionVersion: TransactionVersion, requestLocal: RequestLocal = RequestLocal.noCaching): Unit = { if (transactionalId == null || transactionalId.isEmpty) { debug(s"Returning ${Errors.INVALID_REQUEST} error code to client for $transactionalId's AddPartitions request") @@ -420,7 +421,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, // this is an optimization: if the partitions are already in the metadata reply OK immediately Left(Errors.NONE) } else { - Right(coordinatorEpoch, txnMetadata.prepareAddPartitions(partitions.toSet, time.milliseconds())) + Right(coordinatorEpoch, txnMetadata.prepareAddPartitions(partitions.toSet, time.milliseconds(), clientTransactionVersion)) } } } diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala index dc52ea13402..31daebac763 100644 --- a/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala +++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala @@ -255,7 +255,7 @@ private[transaction] class TransactionMetadata(val transactionalId: String, def prepareNoTransit(): TxnTransitMetadata = { // do not call transitTo as it will set the pending state, a follow-up call to abort the transaction will set its pending state TxnTransitMetadata(producerId, previousProducerId, nextProducerId, producerEpoch, lastProducerEpoch, txnTimeoutMs, state, topicPartitions.toSet, - txnStartTimestamp, txnLastUpdateTimestamp, TransactionVersion.TV_0) + txnStartTimestamp, txnLastUpdateTimestamp, clientTransactionVersion) } def prepareFenceProducerEpoch(): TxnTransitMetadata = { @@ -267,7 +267,7 @@ private[transaction] class TransactionMetadata(val transactionalId: String, val bumpedEpoch = if (hasFailedEpochFence) producerEpoch else (producerEpoch + 1).toShort prepareTransitionTo(PrepareEpochFence, producerId, bumpedEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, - topicPartitions.toSet, txnStartTimestamp, txnLastUpdateTimestamp) + topicPartitions.toSet, txnStartTimestamp, txnLastUpdateTimestamp, clientTransactionVersion) } def prepareIncrementProducerEpoch(newTxnTimeoutMs: Int, @@ -306,7 +306,7 @@ private[transaction] class TransactionMetadata(val transactionalId: String, epochBumpResult match { case Right((nextEpoch, lastEpoch)) => Right(prepareTransitionTo(Empty, producerId, nextEpoch, lastEpoch, newTxnTimeoutMs, - immutable.Set.empty[TopicPartition], -1, updateTimestamp)) + immutable.Set.empty[TopicPartition], -1, updateTimestamp, clientTransactionVersion)) case Left(err) => Left(err) } @@ -320,17 +320,17 @@ private[transaction] class TransactionMetadata(val transactionalId: String, throw new IllegalStateException("Cannot rotate producer ids while a transaction is still pending") prepareTransitionTo(Empty, newProducerId, 0, if (recordLastEpoch) producerEpoch else RecordBatch.NO_PRODUCER_EPOCH, - newTxnTimeoutMs, immutable.Set.empty[TopicPartition], -1, updateTimestamp) + newTxnTimeoutMs, immutable.Set.empty[TopicPartition], -1, updateTimestamp, clientTransactionVersion) } - def prepareAddPartitions(addedTopicPartitions: immutable.Set[TopicPartition], updateTimestamp: Long): TxnTransitMetadata = { + def prepareAddPartitions(addedTopicPartitions: immutable.Set[TopicPartition], updateTimestamp: Long, clientTransactionVersion: TransactionVersion): TxnTransitMetadata = { val newTxnStartTimestamp = state match { case Empty | CompleteAbort | CompleteCommit => updateTimestamp case _ => txnStartTimestamp } prepareTransitionTo(Ongoing, producerId, producerEpoch, lastProducerEpoch, txnTimeoutMs, - (topicPartitions ++ addedTopicPartitions).toSet, newTxnStartTimestamp, updateTimestamp) + (topicPartitions ++ addedTopicPartitions).toSet, newTxnStartTimestamp, updateTimestamp, clientTransactionVersion) } def prepareAbortOrCommit(newState: TransactionState, clientTransactionVersion: TransactionVersion, nextProducerId: Long, updateTimestamp: Long, noPartitionAdded: Boolean): TxnTransitMetadata = { @@ -371,7 +371,7 @@ private[transaction] class TransactionMetadata(val transactionalId: String, def prepareDead(): TxnTransitMetadata = { prepareTransitionTo(Dead, producerId, producerEpoch, lastProducerEpoch, txnTimeoutMs, Set.empty[TopicPartition], - txnStartTimestamp, txnLastUpdateTimestamp) + txnStartTimestamp, txnLastUpdateTimestamp, clientTransactionVersion) } /** @@ -394,8 +394,9 @@ private[transaction] class TransactionMetadata(val transactionalId: String, updatedTxnTimeoutMs: Int, updatedTopicPartitions: immutable.Set[TopicPartition], updatedTxnStartTimestamp: Long, - updateTimestamp: Long): TxnTransitMetadata = { - prepareTransitionTo(updatedState, updatedProducerId, RecordBatch.NO_PRODUCER_ID, updatedEpoch, updatedLastEpoch, updatedTxnTimeoutMs, updatedTopicPartitions, updatedTxnStartTimestamp, updateTimestamp, TransactionVersion.TV_0) + updateTimestamp: Long, + clientTransactionVersion: TransactionVersion): TxnTransitMetadata = { + prepareTransitionTo(updatedState, updatedProducerId, RecordBatch.NO_PRODUCER_ID, updatedEpoch, updatedLastEpoch, updatedTxnTimeoutMs, updatedTopicPartitions, updatedTxnStartTimestamp, updateTimestamp, clientTransactionVersion) } private def prepareTransitionTo(updatedState: TransactionState, @@ -613,7 +614,8 @@ private[transaction] class TransactionMetadata(val transactionalId: String, s"pendingState=$pendingState, " + s"topicPartitions=$topicPartitions, " + s"txnStartTimestamp=$txnStartTimestamp, " + - s"txnLastUpdateTimestamp=$txnLastUpdateTimestamp)" + s"txnLastUpdateTimestamp=$txnLastUpdateTimestamp, " + + s"clientTransactionVersion=$clientTransactionVersion)" } override def equals(that: Any): Boolean = that match { @@ -626,13 +628,14 @@ private[transaction] class TransactionMetadata(val transactionalId: String, state.equals(other.state) && topicPartitions.equals(other.topicPartitions) && txnStartTimestamp == other.txnStartTimestamp && - txnLastUpdateTimestamp == other.txnLastUpdateTimestamp + txnLastUpdateTimestamp == other.txnLastUpdateTimestamp && + clientTransactionVersion == other.clientTransactionVersion case _ => false } override def hashCode(): Int = { val fields = Seq(transactionalId, producerId, producerEpoch, txnTimeoutMs, state, topicPartitions, - txnStartTimestamp, txnLastUpdateTimestamp) + txnStartTimestamp, txnLastUpdateTimestamp, clientTransactionVersion) fields.map(_.hashCode()).foldLeft(0)((a, b) => 31 * a + b) } } diff --git a/core/src/main/scala/kafka/server/KafkaApis.scala b/core/src/main/scala/kafka/server/KafkaApis.scala index 7a5ac37f0a3..ba5eef40e5c 100644 --- a/core/src/main/scala/kafka/server/KafkaApis.scala +++ b/core/src/main/scala/kafka/server/KafkaApis.scala @@ -2328,14 +2328,11 @@ class KafkaApis(val requestChannel: RequestChannel, requestHelper.sendResponseMaybeThrottle(request, createResponse) } - // If the request is greater than version 4, we know the client supports transaction version 2. - val clientTransactionVersion = if (endTxnRequest.version() > 4) TransactionVersion.TV_2 else TransactionVersion.TV_0 - txnCoordinator.handleEndTransaction(endTxnRequest.data.transactionalId, endTxnRequest.data.producerId, endTxnRequest.data.producerEpoch, endTxnRequest.result(), - clientTransactionVersion, + TransactionVersion.transactionVersionForEndTxn(endTxnRequest), sendResponseCallback, requestLocal) } else @@ -2614,6 +2611,7 @@ class KafkaApis(val requestChannel: RequestChannel, transaction.producerEpoch, authorizedPartitions, sendResponseCallback, + TransactionVersion.transactionVersionForAddPartitionsToTxn(addPartitionsToTxnRequest), requestLocal) } else { txnCoordinator.handleVerifyPartitionsInTransaction(transactionalId, @@ -2673,6 +2671,7 @@ class KafkaApis(val requestChannel: RequestChannel, addOffsetsToTxnRequest.data.producerEpoch, Set(offsetTopicPartition), sendResponseCallback, + TransactionVersion.TV_0, // This request will always come from the client not using TV 2. requestLocal) } } diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala index f446eb2bfb2..28019efc0c6 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala @@ -547,6 +547,7 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren txnMetadata.producerEpoch, partitions, resultCallback, + TransactionVersion.TV_2, RequestLocal.withThreadConfinedCaching) replicaManager.tryCompleteActions() } diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala index f3302b12935..ab5ff72cd98 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala @@ -209,19 +209,19 @@ class TransactionCoordinatorTest { when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(None)) - coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 1, partitions, errorsCallback) + coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 1, partitions, errorsCallback, TV_0) assertEquals(Errors.INVALID_PRODUCER_ID_MAPPING, error) } @Test def shouldRespondWithInvalidRequestAddPartitionsToTransactionWhenTransactionalIdIsEmpty(): Unit = { - coordinator.handleAddPartitionsToTransaction("", 0L, 1, partitions, errorsCallback) + coordinator.handleAddPartitionsToTransaction("", 0L, 1, partitions, errorsCallback, TV_0) assertEquals(Errors.INVALID_REQUEST, error) } @Test def shouldRespondWithInvalidRequestAddPartitionsToTransactionWhenTransactionalIdIsNull(): Unit = { - coordinator.handleAddPartitionsToTransaction(null, 0L, 1, partitions, errorsCallback) + coordinator.handleAddPartitionsToTransaction(null, 0L, 1, partitions, errorsCallback, TV_0) assertEquals(Errors.INVALID_REQUEST, error) } @@ -230,7 +230,7 @@ class TransactionCoordinatorTest { when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Left(Errors.NOT_COORDINATOR)) - coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 1, partitions, errorsCallback) + coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 1, partitions, errorsCallback, TV_0) assertEquals(Errors.NOT_COORDINATOR, error) } @@ -239,7 +239,7 @@ class TransactionCoordinatorTest { when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Left(Errors.COORDINATOR_LOAD_IN_PROGRESS)) - coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 1, partitions, errorsCallback) + coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 1, partitions, errorsCallback, TV_0) assertEquals(Errors.COORDINATOR_LOAD_IN_PROGRESS, error) } @@ -313,7 +313,7 @@ class TransactionCoordinatorTest { new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, state, mutable.Set.empty, 0, 0, TV_2))))) - coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 0, partitions, errorsCallback) + coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 0, partitions, errorsCallback, TV_2) assertEquals(Errors.CONCURRENT_TRANSACTIONS, error) } @@ -325,7 +325,7 @@ class TransactionCoordinatorTest { new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID, 10, 9, 0, PrepareCommit, mutable.Set.empty, 0, 0, TV_2))))) - coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 0, partitions, errorsCallback) + coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 0, partitions, errorsCallback, TV_2) assertEquals(Errors.PRODUCER_FENCED, error) } @@ -359,7 +359,7 @@ class TransactionCoordinatorTest { when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) - coordinator.handleAddPartitionsToTransaction(transactionalId, producerId, producerEpoch, partitions, errorsCallback) + coordinator.handleAddPartitionsToTransaction(transactionalId, producerId, producerEpoch, partitions, errorsCallback, clientTransactionVersion) verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) verify(transactionManager).appendTransactionToLog( @@ -379,7 +379,7 @@ class TransactionCoordinatorTest { new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, Empty, partitions, 0, 0, TV_0))))) - coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 0, partitions, errorsCallback) + coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 0, partitions, errorsCallback, TV_0) assertEquals(Errors.NONE, error) verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) } diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMetadataTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMetadataTest.scala index 2da9c96fa20..6b2d20e69eb 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMetadataTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMetadataTest.scala @@ -253,7 +253,7 @@ class TransactionMetadataTest { clientTransactionVersion = TV_0) // let new time be smaller; when transiting from Empty the start time would be updated to the update-time - var transitMetadata = txnMetadata.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic1", 0)), time.milliseconds() - 1) + var transitMetadata = txnMetadata.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic1", 0)), time.milliseconds() - 1, TV_0) txnMetadata.completeTransitionTo(transitMetadata) assertEquals(Set[TopicPartition](new TopicPartition("topic1", 0)), txnMetadata.topicPartitions) assertEquals(producerId, txnMetadata.producerId) @@ -263,7 +263,7 @@ class TransactionMetadataTest { assertEquals(time.milliseconds() - 1, txnMetadata.txnLastUpdateTimestamp) // add another partition, check that in Ongoing state the start timestamp would not change to update time - transitMetadata = txnMetadata.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds() - 2) + transitMetadata = txnMetadata.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds() - 2, TV_0) txnMetadata.completeTransitionTo(transitMetadata) assertEquals(Set[TopicPartition](new TopicPartition("topic1", 0), new TopicPartition("topic2", 0)), txnMetadata.topicPartitions) assertEquals(producerId, txnMetadata.producerId) diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala index 78da50f782b..d12df190c8b 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala @@ -389,7 +389,7 @@ class TransactionStateManagerTest { // update the metadata to ongoing with two partitions val newMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic1", 0), - new TopicPartition("topic1", 1)), time.milliseconds()) + new TopicPartition("topic1", 1)), time.milliseconds(), TV_0) // append the new metadata into log transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch, newMetadata, assertCallback, requestLocal = RequestLocal.withThreadConfinedCaching) @@ -404,7 +404,7 @@ class TransactionStateManagerTest { transactionManager.putTransactionStateIfNotExists(txnMetadata1) expectedError = Errors.COORDINATOR_NOT_AVAILABLE - var failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds()) + var failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds(), TV_0) prepareForTxnMessageAppend(Errors.UNKNOWN_TOPIC_OR_PARTITION) val requestLocal = RequestLocal.withThreadConfinedCaching @@ -412,19 +412,19 @@ class TransactionStateManagerTest { assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))), transactionManager.getTransactionState(transactionalId1)) assertTrue(txnMetadata1.pendingState.isEmpty) - failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds()) + failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds(), TV_0) prepareForTxnMessageAppend(Errors.NOT_ENOUGH_REPLICAS) transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback, requestLocal = requestLocal) assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))), transactionManager.getTransactionState(transactionalId1)) assertTrue(txnMetadata1.pendingState.isEmpty) - failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds()) + failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds(), TV_0) prepareForTxnMessageAppend(Errors.NOT_ENOUGH_REPLICAS_AFTER_APPEND) transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback, requestLocal = requestLocal) assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))), transactionManager.getTransactionState(transactionalId1)) assertTrue(txnMetadata1.pendingState.isEmpty) - failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds()) + failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds(), TV_0) prepareForTxnMessageAppend(Errors.REQUEST_TIMED_OUT) transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback, requestLocal = requestLocal) assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))), transactionManager.getTransactionState(transactionalId1)) @@ -437,7 +437,7 @@ class TransactionStateManagerTest { transactionManager.putTransactionStateIfNotExists(txnMetadata1) expectedError = Errors.NOT_COORDINATOR - var failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds()) + var failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds(), TV_0) prepareForTxnMessageAppend(Errors.NOT_LEADER_OR_FOLLOWER) val requestLocal = RequestLocal.withThreadConfinedCaching @@ -445,7 +445,7 @@ class TransactionStateManagerTest { assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))), transactionManager.getTransactionState(transactionalId1)) assertTrue(txnMetadata1.pendingState.isEmpty) - failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds()) + failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds(), TV_0) prepareForTxnMessageAppend(Errors.NONE) transactionManager.removeTransactionsForTxnTopicPartition(partitionId, coordinatorEpoch) transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback, requestLocal = requestLocal) @@ -468,7 +468,7 @@ class TransactionStateManagerTest { transactionManager.putTransactionStateIfNotExists(txnMetadata1) expectedError = Errors.COORDINATOR_LOAD_IN_PROGRESS - val failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds()) + val failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds(), TV_0) prepareForTxnMessageAppend(Errors.NONE) transactionManager.removeTransactionsForTxnTopicPartition(partitionId, coordinatorEpoch) @@ -482,7 +482,7 @@ class TransactionStateManagerTest { transactionManager.putTransactionStateIfNotExists(txnMetadata1) expectedError = Errors.UNKNOWN_SERVER_ERROR - var failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds()) + var failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds(), TV_0) prepareForTxnMessageAppend(Errors.MESSAGE_TOO_LARGE) val requestLocal = RequestLocal.withThreadConfinedCaching @@ -490,7 +490,7 @@ class TransactionStateManagerTest { assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))), transactionManager.getTransactionState(transactionalId1)) assertTrue(txnMetadata1.pendingState.isEmpty) - failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds()) + failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds(), TV_0) prepareForTxnMessageAppend(Errors.RECORD_LIST_TOO_LARGE) transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback, requestLocal = requestLocal) assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))), transactionManager.getTransactionState(transactionalId1)) @@ -503,7 +503,7 @@ class TransactionStateManagerTest { transactionManager.putTransactionStateIfNotExists(txnMetadata1) expectedError = Errors.COORDINATOR_NOT_AVAILABLE - val failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds()) + val failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds(), TV_0) prepareForTxnMessageAppend(Errors.UNKNOWN_TOPIC_OR_PARTITION) transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback, _ => true, RequestLocal.withThreadConfinedCaching) @@ -522,7 +522,7 @@ class TransactionStateManagerTest { expectedError = Errors.NOT_COORDINATOR val newMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic1", 0), - new TopicPartition("topic1", 1)), time.milliseconds()) + new TopicPartition("topic1", 1)), time.milliseconds(), TV_0) // modify the cache while trying to append the new metadata txnMetadata1.producerEpoch = (txnMetadata1.producerEpoch + 1).toShort @@ -541,7 +541,7 @@ class TransactionStateManagerTest { expectedError = Errors.INVALID_PRODUCER_EPOCH val newMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic1", 0), - new TopicPartition("topic1", 1)), time.milliseconds()) + new TopicPartition("topic1", 1)), time.milliseconds(), TV_0) // modify the cache while trying to append the new metadata txnMetadata1.pendingState = None diff --git a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala index a009b2f10a2..a36598a5eeb 100644 --- a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala +++ b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala @@ -2301,6 +2301,7 @@ class KafkaApisTest extends Logging { ArgumentMatchers.eq(epoch), ArgumentMatchers.eq(Set(new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, partition))), responseCallback.capture(), + ArgumentMatchers.eq(TransactionVersion.TV_0), ArgumentMatchers.eq(requestLocal) )).thenAnswer(_ => responseCallback.getValue.apply(Errors.PRODUCER_FENCED)) val kafkaApis = createKafkaApis() @@ -2359,6 +2360,7 @@ class KafkaApisTest extends Logging { ArgumentMatchers.eq(epoch), ArgumentMatchers.eq(Set(topicPartition)), responseCallback.capture(), + ArgumentMatchers.eq(TransactionVersion.TV_0), ArgumentMatchers.eq(requestLocal) )).thenAnswer(_ => responseCallback.getValue.apply(Errors.PRODUCER_FENCED)) val kafkaApis = createKafkaApis() @@ -2434,6 +2436,7 @@ class KafkaApisTest extends Logging { ArgumentMatchers.eq(epoch), ArgumentMatchers.eq(Set(tp0)), responseCallback.capture(), + any[TransactionVersion], ArgumentMatchers.eq(requestLocal) )).thenAnswer(_ => responseCallback.getValue.apply(Errors.NONE)) diff --git a/server-common/src/main/java/org/apache/kafka/server/common/TransactionVersion.java b/server-common/src/main/java/org/apache/kafka/server/common/TransactionVersion.java index 069440d35c9..45546c447b0 100644 --- a/server-common/src/main/java/org/apache/kafka/server/common/TransactionVersion.java +++ b/server-common/src/main/java/org/apache/kafka/server/common/TransactionVersion.java @@ -16,6 +16,9 @@ */ package org.apache.kafka.server.common; +import org.apache.kafka.common.requests.AddPartitionsToTxnRequest; +import org.apache.kafka.common.requests.EndTxnRequest; + import java.util.Collections; import java.util.Map; @@ -55,6 +58,16 @@ public enum TransactionVersion implements FeatureVersion { return (TransactionVersion) Feature.TRANSACTION_VERSION.fromFeatureLevel(version, true); } + public static TransactionVersion transactionVersionForAddPartitionsToTxn(AddPartitionsToTxnRequest request) { + // If the request is greater than version 3, we know the client supports transaction version 2. + return request.version() > 3 ? TV_2 : TV_0; + } + + public static TransactionVersion transactionVersionForEndTxn(EndTxnRequest request) { + // If the request is greater than version 4, we know the client supports transaction version 2. + return request.version() > 4 ? TV_2 : TV_0; + } + @Override public String featureName() { return FEATURE_NAME;