From ede0c94aaae3bbc9f3e0ac6bb0b0e6c8b9df39d7 Mon Sep 17 00:00:00 2001 From: Justine Olshan Date: Thu, 26 Sep 2024 09:37:11 -0700 Subject: [PATCH] KAFKA-14562 [1/2]: Implement epoch bump after every transaction (#16719) Implement server side changes for epoch bump but keep EndTxn as an unstable API until the client side changes are implemented. EndTxnResponse will return the producer ID and epoch for the transaction. Introduces new tagged fields to the TransactionLogValue to persist the clientTransactionVersion, previousProducerId, and nextProducerId to the log so that the state can be reloaded. See KIP-890 for more details. Small updates to naming of lastProducerId -> PreviousProducerId. Also cleans up the many TransactionMetadata constructors. Reviewers: Artem Livshits , David Jacot --- .../kafka/common/requests/EndTxnRequest.java | 6 +- .../common/message/EndTxnRequest.json | 5 +- .../common/message/EndTxnResponse.json | 10 +- .../transaction/TransactionCoordinator.scala | 100 +++- .../transaction/TransactionLog.scala | 16 +- .../transaction/TransactionMetadata.scala | 157 +++-- .../transaction/TransactionStateManager.scala | 8 +- .../main/scala/kafka/server/KafkaApis.scala | 10 +- ...ransactionCoordinatorConcurrencyTest.scala | 13 +- .../TransactionCoordinatorTest.scala | 538 ++++++++++++------ .../transaction/TransactionLogTest.scala | 48 +- .../TransactionMarkerChannelManagerTest.scala | 10 +- ...onMarkerRequestCompletionHandlerTest.scala | 5 +- .../transaction/TransactionMetadataTest.scala | 226 ++++++-- .../TransactionStateManagerTest.scala | 47 +- .../unit/kafka/server/KafkaApisTest.scala | 9 +- .../server/common/TransactionVersion.java | 14 + .../common/message/TransactionLogValue.json | 8 +- 18 files changed, 850 insertions(+), 380 deletions(-) diff --git a/clients/src/main/java/org/apache/kafka/common/requests/EndTxnRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/EndTxnRequest.java index c9ea98005fd..5cd346a8fe0 100644 --- a/clients/src/main/java/org/apache/kafka/common/requests/EndTxnRequest.java +++ b/clients/src/main/java/org/apache/kafka/common/requests/EndTxnRequest.java @@ -32,7 +32,11 @@ public class EndTxnRequest extends AbstractRequest { public final EndTxnRequestData data; public Builder(EndTxnRequestData data) { - super(ApiKeys.END_TXN); + this(data, false); + } + + public Builder(EndTxnRequestData data, boolean enableUnstableLastVersion) { + super(ApiKeys.END_TXN, enableUnstableLastVersion); this.data = data; } diff --git a/clients/src/main/resources/common/message/EndTxnRequest.json b/clients/src/main/resources/common/message/EndTxnRequest.json index bc66adcf50a..aea50d6b16d 100644 --- a/clients/src/main/resources/common/message/EndTxnRequest.json +++ b/clients/src/main/resources/common/message/EndTxnRequest.json @@ -25,7 +25,10 @@ // Version 3 enables flexible versions. // // Version 4 adds support for new error code TRANSACTION_ABORTABLE (KIP-890). - "validVersions": "0-4", + // + // Version 5 enables bumping epoch on every transaction (KIP-890 Part 2) + "latestVersionUnstable": true, + "validVersions": "0-5", "flexibleVersions": "3+", "fields": [ { "name": "TransactionalId", "type": "string", "versions": "0+", "entityType": "transactionalId", diff --git a/clients/src/main/resources/common/message/EndTxnResponse.json b/clients/src/main/resources/common/message/EndTxnResponse.json index 08ac6cddd38..fb958a74033 100644 --- a/clients/src/main/resources/common/message/EndTxnResponse.json +++ b/clients/src/main/resources/common/message/EndTxnResponse.json @@ -24,12 +24,18 @@ // Version 3 enables flexible versions. // // Version 4 adds support for new error code TRANSACTION_ABORTABLE (KIP-890). - "validVersions": "0-4", + // + // Version 5 enables bumping epoch on every transaction (KIP-890 Part 2), so producer ID and epoch are included in the response. + "validVersions": "0-5", "flexibleVersions": "3+", "fields": [ { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, { "name": "ErrorCode", "type": "int16", "versions": "0+", - "about": "The error code, or 0 if there was no error." } + "about": "The error code, or 0 if there was no error." }, + { "name": "ProducerId", "type": "int64", "versions": "5+", "entityType": "producerId", "default": "-1", "ignorable": "true", + "about": "The producer ID." }, + { "name": "ProducerEpoch", "type": "int16", "versions": "5+", "default": "-1", "ignorable": "true", + "about": "The current epoch associated with the producer." } ] } diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala index cd268f3a166..72a196eca60 100644 --- a/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala +++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala @@ -29,7 +29,7 @@ import org.apache.kafka.common.protocol.Errors import org.apache.kafka.common.record.RecordBatch import org.apache.kafka.common.requests.{AddPartitionsToTxnResponse, TransactionResult} import org.apache.kafka.common.utils.{LogContext, ProducerIdAndEpoch, Time} -import org.apache.kafka.server.common.RequestLocal +import org.apache.kafka.server.common.{RequestLocal, TransactionVersion} import org.apache.kafka.server.util.Scheduler import scala.jdk.CollectionConverters._ @@ -98,7 +98,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, private type InitProducerIdCallback = InitProducerIdResult => Unit private type AddPartitionsCallback = Errors => Unit private type VerifyPartitionsCallback = AddPartitionsToTxnResult => Unit - private type EndTxnCallback = Errors => Unit + private type EndTxnCallback = (Errors, Long, Short) => Unit private type ApiResult[T] = Either[Errors, T] /* Active flag of the coordinator */ @@ -135,13 +135,15 @@ class TransactionCoordinator(txnConfig: TransactionConfig, case Success(producerId) => val createdMetadata = new TransactionMetadata(transactionalId = transactionalId, producerId = producerId, - lastProducerId = RecordBatch.NO_PRODUCER_ID, + previousProducerId = RecordBatch.NO_PRODUCER_ID, + nextProducerId = RecordBatch.NO_PRODUCER_ID, producerEpoch = RecordBatch.NO_PRODUCER_EPOCH, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = transactionTimeoutMs, state = Empty, topicPartitions = collection.mutable.Set.empty[TopicPartition], - txnLastUpdateTimestamp = time.milliseconds()) + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = TransactionVersion.TV_0) txnManager.putTransactionStateIfNotExists(createdMetadata) case Failure(exception) => @@ -169,7 +171,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, case Right((coordinatorEpoch, newMetadata)) => if (newMetadata.txnState == PrepareEpochFence) { // abort the ongoing transaction and then return CONCURRENT_TRANSACTIONS to let client wait and retry - def sendRetriableErrorCallback(error: Errors): Unit = { + def sendRetriableErrorCallback(error: Errors, newProducerId: Long, newProducerEpoch: Short): Unit = { if (error != Errors.NONE) { responseCallback(initTransactionError(error)) } else { @@ -182,6 +184,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, newMetadata.producerEpoch, TransactionResult.ABORT, isFromClient = false, + clientTransactionVersion = txnManager.transactionVersionLevel(), // Since this is not from client, use server TV sendRetriableErrorCallback, requestLocal) } else { @@ -221,7 +224,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, // could be a retry after a valid epoch bump that the producer never received the response for txnMetadata.producerEpoch == RecordBatch.NO_PRODUCER_EPOCH || producerIdAndEpoch.producerId == txnMetadata.producerId || - (producerIdAndEpoch.producerId == txnMetadata.lastProducerId && TransactionMetadata.isEpochExhausted(producerIdAndEpoch.epoch)) + (producerIdAndEpoch.producerId == txnMetadata.previousProducerId && TransactionMetadata.isEpochExhausted(producerIdAndEpoch.epoch)) } if (txnMetadata.pendingTransitionInProgress) { @@ -487,6 +490,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, producerId: Long, producerEpoch: Short, txnMarkerResult: TransactionResult, + clientTransactionVersion: TransactionVersion, responseCallback: EndTxnCallback, requestLocal: RequestLocal = RequestLocal.noCaching): Unit = { endTransaction(transactionalId, @@ -494,6 +498,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, producerEpoch, txnMarkerResult, isFromClient = true, + clientTransactionVersion, responseCallback, requestLocal) } @@ -503,12 +508,15 @@ class TransactionCoordinator(txnConfig: TransactionConfig, producerEpoch: Short, txnMarkerResult: TransactionResult, isFromClient: Boolean, + clientTransactionVersion: TransactionVersion, responseCallback: EndTxnCallback, requestLocal: RequestLocal): Unit = { var isEpochFence = false if (transactionalId == null || transactionalId.isEmpty) - responseCallback(Errors.INVALID_REQUEST) + responseCallback(Errors.INVALID_REQUEST, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH) else { + var producerIdCopy = RecordBatch.NO_PRODUCER_ID + var producerEpochCopy = RecordBatch.NO_PRODUCER_EPOCH val preAppendResult: ApiResult[(Int, TxnTransitMetadata)] = txnManager.getTransactionState(transactionalId).flatMap { case None => Left(Errors.INVALID_PRODUCER_ID_MAPPING) @@ -518,10 +526,39 @@ class TransactionCoordinator(txnConfig: TransactionConfig, val coordinatorEpoch = epochAndTxnMetadata.coordinatorEpoch txnMetadata.inLock { - if (txnMetadata.producerId != producerId) + producerIdCopy = txnMetadata.producerId + producerEpochCopy = txnMetadata.producerEpoch + // PrepareEpochFence has slightly different epoch bumping logic so don't include it here. + val currentTxnMetadataIsAtLeastTransactionsV2 = !txnMetadata.pendingState.contains(PrepareEpochFence) && txnMetadata.clientTransactionVersion.supportsEpochBump() + // True if the client used TV_2 and retried a request that had overflowed the epoch, and a new producer ID is stored in the txnMetadata + val retryOnOverflow = currentTxnMetadataIsAtLeastTransactionsV2 && + txnMetadata.previousProducerId == producerId && producerEpoch == Short.MaxValue - 1 && txnMetadata.producerEpoch == 0 + // True if the client used TV_2 and retried an endTxn request, and the bumped producer epoch is stored in the txnMetadata. + val retryOnEpochBump = endTxnEpochBumped(txnMetadata, producerEpoch) + + val isValidEpoch = { + if (currentTxnMetadataIsAtLeastTransactionsV2) { + // With transactions V2, state + same epoch is not sufficient to determine if a retry transition is valid. If the epoch is the + // same it actually indicates the next endTransaction call. Instead, we want to check the epoch matches with the epoch in the retry conditions. + // Return producer fenced even in the cases where the epoch is higher and could indicate an invalid state transition. + // Use the following criteria to determine if a v2 retry is valid: + txnMetadata.state match { + case Ongoing | Empty | Dead | PrepareEpochFence => + producerEpoch == txnMetadata.producerEpoch + case PrepareCommit | PrepareAbort => + retryOnEpochBump + case CompleteCommit | CompleteAbort => + retryOnEpochBump || retryOnOverflow + } + } else { + // For transactions V1 strict equality is enforced on the client side requests, as they shouldn't bump the producer epoch without server knowledge. + (!isFromClient || producerEpoch == txnMetadata.producerEpoch) && producerEpoch >= txnMetadata.producerEpoch + } + } + + if (txnMetadata.producerId != producerId && !retryOnOverflow) Left(Errors.INVALID_PRODUCER_ID_MAPPING) - // Strict equality is enforced on the client side requests, as they shouldn't bump the producer epoch. - else if ((isFromClient && producerEpoch != txnMetadata.producerEpoch) || producerEpoch < txnMetadata.producerEpoch) + else if (!isValidEpoch) Left(Errors.PRODUCER_FENCED) else if (txnMetadata.pendingTransitionInProgress && txnMetadata.pendingState.get != PrepareEpochFence) Left(Errors.CONCURRENT_TRANSACTIONS) @@ -532,6 +569,19 @@ class TransactionCoordinator(txnConfig: TransactionConfig, else PrepareAbort + // Maybe allocate new producer ID if we are bumping epoch and epoch is exhausted + val nextProducerIdOrErrors = + if (clientTransactionVersion.supportsEpochBump() && !txnMetadata.pendingState.contains(PrepareEpochFence) && txnMetadata.isProducerEpochExhausted) { + producerIdManager.generateProducerId() match { + case Success(newProducerId) => + Right(newProducerId) + case Failure(exception) => + Left(Errors.forException(exception)) + } + } else { + Right(RecordBatch.NO_PRODUCER_ID) + } + if (nextState == PrepareAbort && txnMetadata.pendingState.contains(PrepareEpochFence)) { // We should clear the pending state to make way for the transition to PrepareAbort and also bump // the epoch in the transaction metadata we are about to append. @@ -541,7 +591,10 @@ class TransactionCoordinator(txnConfig: TransactionConfig, txnMetadata.lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH } - Right(coordinatorEpoch, txnMetadata.prepareAbortOrCommit(nextState, time.milliseconds())) + nextProducerIdOrErrors.flatMap { + nextProducerId => + Right(coordinatorEpoch, txnMetadata.prepareAbortOrCommit(nextState, clientTransactionVersion, nextProducerId, time.milliseconds())) + } case CompleteCommit => if (txnMarkerResult == TransactionResult.COMMIT) Left(Errors.NONE) @@ -576,8 +629,12 @@ class TransactionCoordinator(txnConfig: TransactionConfig, preAppendResult match { case Left(err) => + if (err == Errors.NONE) { + responseCallback(err, producerIdCopy, producerEpochCopy) + } else { debug(s"Aborting append of $txnMarkerResult to transaction log with coordinator and returning $err error to client for $transactionalId's EndTransaction request") - responseCallback(err) + responseCallback(err, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH) + } case Right((coordinatorEpoch, newMetadata)) => def sendTxnMarkersCallback(error: Errors): Unit = { @@ -595,7 +652,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, txnMetadata.inLock { if (txnMetadata.producerId != producerId) Left(Errors.INVALID_PRODUCER_ID_MAPPING) - else if (txnMetadata.producerEpoch != producerEpoch) + else if (txnMetadata.producerEpoch != producerEpoch && !endTxnEpochBumped(txnMetadata, producerEpoch)) Left(Errors.PRODUCER_FENCED) else if (txnMetadata.pendingTransitionInProgress) Left(Errors.CONCURRENT_TRANSACTIONS) @@ -630,12 +687,12 @@ class TransactionCoordinator(txnConfig: TransactionConfig, preSendResult match { case Left(err) => info(s"Aborting sending of transaction markers after appended $txnMarkerResult to transaction log and returning $err error to client for $transactionalId's EndTransaction request") - responseCallback(err) + responseCallback(err, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH) case Right((txnMetadata, newPreSendMetadata)) => // we can respond to the client immediately and continue to write the txn markers if // the log append was successful - responseCallback(Errors.NONE) + responseCallback(Errors.NONE, txnMetadata.producerId, txnMetadata.producerEpoch) txnMarkerChannelManager.addTxnMarkersToSend(coordinatorEpoch, txnMarkerResult, txnMetadata, newPreSendMetadata) } @@ -659,7 +716,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, } } - responseCallback(error) + responseCallback(error, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH) } } @@ -669,11 +726,19 @@ class TransactionCoordinator(txnConfig: TransactionConfig, } } + // When a client and server support V2, every endTransaction call bumps the producer epoch. When checking epoch, we want to + // check epoch + 1. Epoch bumps from PrepareEpochFence state are handled separately, so this method should not be used to check that case. + // Returns true if the transaction state epoch is the specified producer epoch + 1 and epoch bump on every transaction is expected. + private def endTxnEpochBumped(txnMetadata: TransactionMetadata, producerEpoch: Short): Boolean = { + !txnMetadata.pendingState.contains(PrepareEpochFence) && txnMetadata.clientTransactionVersion.supportsEpochBump() && + txnMetadata.producerEpoch == producerEpoch + 1 + } + def transactionTopicConfigs: Properties = txnManager.transactionTopicConfigs def partitionFor(transactionalId: String): Int = txnManager.partitionFor(transactionalId) - private def onEndTransactionComplete(txnIdAndPidEpoch: TransactionalIdAndProducerIdEpoch)(error: Errors): Unit = { + private def onEndTransactionComplete(txnIdAndPidEpoch: TransactionalIdAndProducerIdEpoch)(error: Errors, newProducerId: Long, newProducerEpoch: Short): Unit = { error match { case Errors.NONE => info("Completed rollback of ongoing transaction for transactionalId " + @@ -721,6 +786,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, txnTransitMetadata.producerEpoch, TransactionResult.ABORT, isFromClient = false, + clientTransactionVersion = txnManager.transactionVersionLevel(), // Since this is not from client, use server TV onComplete(txnIdAndPidEpoch), RequestLocal.noCaching) } diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala index 2d4465d9a2c..623f88c7a38 100644 --- a/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala +++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala @@ -25,6 +25,7 @@ import org.apache.kafka.common.protocol.{ByteBufferAccessor, MessageUtil} import org.apache.kafka.common.record.{Record, RecordBatch} import org.apache.kafka.common.{MessageFormatter, TopicPartition} import org.apache.kafka.coordinator.transaction.generated.{TransactionLogKey, TransactionLogValue} +import org.apache.kafka.server.common.TransactionVersion import scala.collection.mutable import scala.jdk.CollectionConverters._ @@ -63,7 +64,7 @@ object TransactionLog { * @return value payload bytes */ private[transaction] def valueToBytes(txnMetadata: TxnTransitMetadata, - usesFlexibleRecords: Boolean): Array[Byte] = { + transactionVersionLevel: TransactionVersion): Array[Byte] = { if (txnMetadata.txnState == Empty && txnMetadata.topicPartitions.nonEmpty) throw new IllegalStateException(s"Transaction is not expected to have any partitions since its state is ${txnMetadata.txnState}: $txnMetadata") @@ -78,9 +79,7 @@ object TransactionLog { // Serialize with version 0 (highest non-flexible version) until transaction.version 1 is enabled // which enables flexible fields in records. - val version: Short = - if (usesFlexibleRecords) 1 else 0 - MessageUtil.toVersionPrefixedBytes(version, + MessageUtil.toVersionPrefixedBytes(transactionVersionLevel.transactionLogValueVersion(), new TransactionLogValue() .setProducerId(txnMetadata.producerId) .setProducerEpoch(txnMetadata.producerEpoch) @@ -88,7 +87,8 @@ object TransactionLog { .setTransactionStatus(txnMetadata.txnState.id) .setTransactionLastUpdateTimestampMs(txnMetadata.txnLastUpdateTimestamp) .setTransactionStartTimestampMs(txnMetadata.txnStartTimestamp) - .setTransactionPartitions(transactionPartitions)) + .setTransactionPartitions(transactionPartitions) + .setClientTransactionVersion(txnMetadata.clientTransactionVersion.featureLevel())) } /** @@ -124,14 +124,16 @@ object TransactionLog { val transactionMetadata = new TransactionMetadata( transactionalId = transactionalId, producerId = value.producerId, - lastProducerId = RecordBatch.NO_PRODUCER_ID, + previousProducerId = value.previousProducerId, + nextProducerId = value.nextProducerId, producerEpoch = value.producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = value.transactionTimeoutMs, state = TransactionState.fromId(value.transactionStatus), topicPartitions = mutable.Set.empty[TopicPartition], txnStartTimestamp = value.transactionStartTimestampMs, - txnLastUpdateTimestamp = value.transactionLastUpdateTimestampMs) + txnLastUpdateTimestamp = value.transactionLastUpdateTimestampMs, + clientTransactionVersion = TransactionVersion.fromFeatureLevel(value.clientTransactionVersion)) if (!transactionMetadata.state.equals(Empty)) value.transactionPartitions.forEach(partitionsSchema => diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala index 9fab77f30b6..ec3fb66e355 100644 --- a/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala +++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala @@ -17,11 +17,11 @@ package kafka.coordinator.transaction import java.util.concurrent.locks.ReentrantLock - import kafka.utils.{CoreUtils, Logging, nonthreadsafe} import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.protocol.Errors import org.apache.kafka.common.record.RecordBatch +import org.apache.kafka.server.common.TransactionVersion import scala.collection.{immutable, mutable} @@ -163,70 +163,64 @@ private[transaction] case object PrepareEpochFence extends TransactionState { } private[transaction] object TransactionMetadata { - def apply(transactionalId: String, producerId: Long, producerEpoch: Short, txnTimeoutMs: Int, timestamp: Long) = - new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, producerEpoch, - RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Empty, collection.mutable.Set.empty[TopicPartition], timestamp, timestamp) - - def apply(transactionalId: String, producerId: Long, producerEpoch: Short, txnTimeoutMs: Int, - state: TransactionState, timestamp: Long) = - new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, producerEpoch, - RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, state, collection.mutable.Set.empty[TopicPartition], timestamp, timestamp) - - def apply(transactionalId: String, producerId: Long, lastProducerId: Long, producerEpoch: Short, - lastProducerEpoch: Short, txnTimeoutMs: Int, state: TransactionState, timestamp: Long) = - new TransactionMetadata(transactionalId, producerId, lastProducerId, producerEpoch, lastProducerEpoch, - txnTimeoutMs, state, collection.mutable.Set.empty[TopicPartition], timestamp, timestamp) - def isEpochExhausted(producerEpoch: Short): Boolean = producerEpoch >= Short.MaxValue - 1 } // this is a immutable object representing the target transition of the transaction metadata private[transaction] case class TxnTransitMetadata(producerId: Long, - lastProducerId: Long, + prevProducerId: Long, + nextProducerId: Long, producerEpoch: Short, lastProducerEpoch: Short, txnTimeoutMs: Int, txnState: TransactionState, topicPartitions: immutable.Set[TopicPartition], txnStartTimestamp: Long, - txnLastUpdateTimestamp: Long) { + txnLastUpdateTimestamp: Long, + clientTransactionVersion: TransactionVersion) { override def toString: String = { "TxnTransitMetadata(" + s"producerId=$producerId, " + - s"lastProducerId=$lastProducerId, " + + s"previousProducerId=$prevProducerId, " + + s"nextProducerId=$nextProducerId, " + s"producerEpoch=$producerEpoch, " + s"lastProducerEpoch=$lastProducerEpoch, " + s"txnTimeoutMs=$txnTimeoutMs, " + s"txnState=$txnState, " + s"topicPartitions=$topicPartitions, " + s"txnStartTimestamp=$txnStartTimestamp, " + - s"txnLastUpdateTimestamp=$txnLastUpdateTimestamp)" + s"txnLastUpdateTimestamp=$txnLastUpdateTimestamp, " + + s"clientTransactionVersion=$clientTransactionVersion)" } } /** * - * @param producerId producer id - * @param lastProducerId last producer id assigned to the producer - * @param producerEpoch current epoch of the producer - * @param lastProducerEpoch last epoch of the producer - * @param txnTimeoutMs timeout to be used to abort long running transactions - * @param state current state of the transaction - * @param topicPartitions current set of partitions that are part of this transaction - * @param txnStartTimestamp time the transaction was started, i.e., when first partition is added - * @param txnLastUpdateTimestamp updated when any operation updates the TransactionMetadata. To be used for expiration + * @param producerId producer id + * @param previousProducerId producer id for the last committed transaction with this transactional ID + * @param nextProducerId Latest producer ID sent to the producer for the given transactional ID + * @param producerEpoch current epoch of the producer + * @param lastProducerEpoch last epoch of the producer + * @param txnTimeoutMs timeout to be used to abort long running transactions + * @param state current state of the transaction + * @param topicPartitions current set of partitions that are part of this transaction + * @param txnStartTimestamp time the transaction was started, i.e., when first partition is added + * @param txnLastUpdateTimestamp updated when any operation updates the TransactionMetadata. To be used for expiration + * @param clientTransactionVersion TransactionVersion used by the client when the state was transitioned */ @nonthreadsafe private[transaction] class TransactionMetadata(val transactionalId: String, var producerId: Long, - var lastProducerId: Long, + var previousProducerId: Long, + var nextProducerId: Long, var producerEpoch: Short, var lastProducerEpoch: Short, var txnTimeoutMs: Int, var state: TransactionState, val topicPartitions: mutable.Set[TopicPartition], @volatile var txnStartTimestamp: Long = -1, - @volatile var txnLastUpdateTimestamp: Long) extends Logging { + @volatile var txnLastUpdateTimestamp: Long, + var clientTransactionVersion: TransactionVersion) extends Logging { // pending state is used to indicate the state that this transaction is going to // transit to, and for blocking future attempts to transit it again if it is not legal; @@ -256,8 +250,8 @@ private[transaction] class TransactionMetadata(val transactionalId: String, // this is visible for test only def prepareNoTransit(): TxnTransitMetadata = { // do not call transitTo as it will set the pending state, a follow-up call to abort the transaction will set its pending state - TxnTransitMetadata(producerId, lastProducerId, producerEpoch, lastProducerEpoch, txnTimeoutMs, state, topicPartitions.toSet, - txnStartTimestamp, txnLastUpdateTimestamp) + TxnTransitMetadata(producerId, previousProducerId, nextProducerId, producerEpoch, lastProducerEpoch, txnTimeoutMs, state, topicPartitions.toSet, + txnStartTimestamp, txnLastUpdateTimestamp, TransactionVersion.TV_0) } def prepareFenceProducerEpoch(): TxnTransitMetadata = { @@ -335,9 +329,16 @@ private[transaction] class TransactionMetadata(val transactionalId: String, (topicPartitions ++ addedTopicPartitions).toSet, newTxnStartTimestamp, updateTimestamp) } - def prepareAbortOrCommit(newState: TransactionState, updateTimestamp: Long): TxnTransitMetadata = { - prepareTransitionTo(newState, producerId, producerEpoch, lastProducerEpoch, txnTimeoutMs, topicPartitions.toSet, - txnStartTimestamp, updateTimestamp) + def prepareAbortOrCommit(newState: TransactionState, clientTransactionVersion: TransactionVersion, nextProducerId: Long, updateTimestamp: Long): TxnTransitMetadata = { + val (updatedProducerEpoch, updatedLastProducerEpoch) = if (clientTransactionVersion.supportsEpochBump()) { + // We already ensured that we do not overflow here. MAX_SHORT is the highest possible value. + ((producerEpoch + 1).toShort, producerEpoch) + } else { + (producerEpoch, lastProducerEpoch) + } + + prepareTransitionTo(newState, producerId, nextProducerId, updatedProducerEpoch, updatedLastProducerEpoch, txnTimeoutMs, topicPartitions.toSet, + txnStartTimestamp, updateTimestamp, clientTransactionVersion) } def prepareComplete(updateTimestamp: Long): TxnTransitMetadata = { @@ -345,8 +346,15 @@ private[transaction] class TransactionMetadata(val transactionalId: String, // Since the state change was successfully written to the log, unset the flag for a failed epoch fence hasFailedEpochFence = false - prepareTransitionTo(newState, producerId, producerEpoch, lastProducerEpoch, txnTimeoutMs, Set.empty[TopicPartition], - txnStartTimestamp, updateTimestamp) + val (updatedProducerId, updatedProducerEpoch) = + // If we overflowed on epoch bump, we have to set it as the producer ID now the marker has been written. + if (clientTransactionVersion.supportsEpochBump() && nextProducerId != RecordBatch.NO_PRODUCER_ID) { + (nextProducerId, 0.toShort) + } else { + (producerId, producerEpoch) + } + prepareTransitionTo(newState, updatedProducerId, RecordBatch.NO_PRODUCER_ID, updatedProducerEpoch, producerEpoch, txnTimeoutMs, Set.empty[TopicPartition], + txnStartTimestamp, updateTimestamp, clientTransactionVersion) } def prepareDead(): TxnTransitMetadata = { @@ -367,37 +375,50 @@ private[transaction] class TransactionMetadata(val transactionalId: String, } } - private def prepareTransitionTo(newState: TransactionState, - newProducerId: Long, - newEpoch: Short, - newLastEpoch: Short, - newTxnTimeoutMs: Int, - newTopicPartitions: immutable.Set[TopicPartition], - newTxnStartTimestamp: Long, + private def prepareTransitionTo(updatedState: TransactionState, + updatedProducerId: Long, + updatedEpoch: Short, + updatedLastEpoch: Short, + updatedTxnTimeoutMs: Int, + updatedTopicPartitions: immutable.Set[TopicPartition], + updatedTxnStartTimestamp: Long, updateTimestamp: Long): TxnTransitMetadata = { + prepareTransitionTo(updatedState, updatedProducerId, RecordBatch.NO_PRODUCER_ID, updatedEpoch, updatedLastEpoch, updatedTxnTimeoutMs, updatedTopicPartitions, updatedTxnStartTimestamp, updateTimestamp, TransactionVersion.TV_0) + } + + private def prepareTransitionTo(updatedState: TransactionState, + updatedProducerId: Long, + nextProducerId: Long, + updatedEpoch: Short, + updatedLastEpoch: Short, + updatedTxnTimeoutMs: Int, + updatedTopicPartitions: immutable.Set[TopicPartition], + updatedTxnStartTimestamp: Long, + updateTimestamp: Long, + clientTransactionVersion: TransactionVersion): TxnTransitMetadata = { if (pendingState.isDefined) - throw new IllegalStateException(s"Preparing transaction state transition to $newState " + + throw new IllegalStateException(s"Preparing transaction state transition to $updatedState " + s"while it already a pending state ${pendingState.get}") - if (newProducerId < 0) - throw new IllegalArgumentException(s"Illegal new producer id $newProducerId") + if (updatedProducerId < 0) + throw new IllegalArgumentException(s"Illegal new producer id $updatedProducerId") // The epoch is initialized to NO_PRODUCER_EPOCH when the TransactionMetadata // is created for the first time and it could stay like this until transitioning // to Dead. - if (newState != Dead && newEpoch < 0) - throw new IllegalArgumentException(s"Illegal new producer epoch $newEpoch") + if (updatedState != Dead && updatedEpoch < 0) + throw new IllegalArgumentException(s"Illegal new producer epoch $updatedEpoch") // check that the new state transition is valid and update the pending state if necessary - if (newState.validPreviousStates.contains(state)) { - val transitMetadata = TxnTransitMetadata(newProducerId, producerId, newEpoch, newLastEpoch, newTxnTimeoutMs, newState, - newTopicPartitions, newTxnStartTimestamp, updateTimestamp) + if (updatedState.validPreviousStates.contains(state)) { + val transitMetadata = TxnTransitMetadata(updatedProducerId, producerId, nextProducerId, updatedEpoch, updatedLastEpoch, updatedTxnTimeoutMs, updatedState, + updatedTopicPartitions, updatedTxnStartTimestamp, updateTimestamp, clientTransactionVersion) debug(s"TransactionalId $transactionalId prepare transition from $state to $transitMetadata") - pendingState = Some(newState) + pendingState = Some(updatedState) transitMetadata } else { - throw new IllegalStateException(s"Preparing transaction state transition to $newState failed since the target state" + - s" $newState is not a valid previous state of the current state $state") + throw new IllegalStateException(s"Preparing transaction state transition to $updatedState failed since the target state" + + s" $updatedState is not a valid previous state of the current state $state") } } @@ -436,7 +457,7 @@ private[transaction] class TransactionMetadata(val transactionalId: String, producerEpoch = transitMetadata.producerEpoch lastProducerEpoch = transitMetadata.lastProducerEpoch producerId = transitMetadata.producerId - lastProducerId = transitMetadata.lastProducerId + previousProducerId = transitMetadata.prevProducerId } case Ongoing => // from addPartitions @@ -457,6 +478,10 @@ private[transaction] class TransactionMetadata(val transactionalId: String, txnStartTimestamp != transitMetadata.txnStartTimestamp) { throwStateTransitionFailure(transitMetadata) + } else if (transitMetadata.clientTransactionVersion.supportsEpochBump()) { + producerEpoch = transitMetadata.producerEpoch + lastProducerEpoch = transitMetadata.lastProducerEpoch + nextProducerId = transitMetadata.nextProducerId } case CompleteAbort | CompleteCommit => // from write markers @@ -468,6 +493,13 @@ private[transaction] class TransactionMetadata(val transactionalId: String, } else { txnStartTimestamp = transitMetadata.txnStartTimestamp topicPartitions.clear() + if (transitMetadata.clientTransactionVersion.supportsEpochBump()) { + producerEpoch = transitMetadata.producerEpoch + lastProducerEpoch = transitMetadata.lastProducerEpoch + previousProducerId = transitMetadata.prevProducerId + producerId = transitMetadata.producerId + nextProducerId = transitMetadata.nextProducerId + } } case PrepareEpochFence => @@ -487,6 +519,7 @@ private[transaction] class TransactionMetadata(val transactionalId: String, } debug(s"TransactionalId $transactionalId complete transition from $state to $transitMetadata") + clientTransactionVersion = transitMetadata.clientTransactionVersion txnLastUpdateTimestamp = transitMetadata.txnLastUpdateTimestamp pendingState = None state = toState @@ -494,8 +527,14 @@ private[transaction] class TransactionMetadata(val transactionalId: String, } private def validProducerEpoch(transitMetadata: TxnTransitMetadata): Boolean = { - val transitEpoch = transitMetadata.producerEpoch - val transitProducerId = transitMetadata.producerId + val isAtLeastTransactionsV2 = transitMetadata.clientTransactionVersion.supportsEpochBump() + val isOverflowComplete = isAtLeastTransactionsV2 && (transitMetadata.txnState == CompleteCommit || transitMetadata.txnState == CompleteAbort) && transitMetadata.producerEpoch == 0 + val transitEpoch = + if (isOverflowComplete || (isAtLeastTransactionsV2 && (transitMetadata.txnState == PrepareCommit || transitMetadata.txnState == PrepareAbort))) + transitMetadata.lastProducerEpoch + else + transitMetadata.producerEpoch + val transitProducerId = if (isOverflowComplete) transitMetadata.prevProducerId else transitMetadata.producerId transitEpoch == producerEpoch && transitProducerId == producerId } @@ -518,6 +557,8 @@ private[transaction] class TransactionMetadata(val transactionalId: String, "TransactionMetadata(" + s"transactionalId=$transactionalId, " + s"producerId=$producerId, " + + s"previousProducerId=$previousProducerId, " + s"nextProducerId=$nextProducerId, " s"producerEpoch=$producerEpoch, " + s"txnTimeoutMs=$txnTimeoutMs, " + s"state=$state, " + diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala index d66bab06f68..01912724587 100644 --- a/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala +++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala @@ -101,8 +101,10 @@ class TransactionStateManager(brokerId: Int, TransactionStateManagerConfig.METRICS_GROUP, "The avg time it took to load the partitions in the last 30sec"), new Avg()) - private[transaction] def usesFlexibleRecords(): Boolean = { - metadataCache.features().finalizedFeatures().getOrDefault(TransactionVersion.FEATURE_NAME, 0.toShort) > 0 + private[transaction] def transactionVersionLevel(): TransactionVersion = { + val version = TransactionVersion.fromFeatureLevel(metadataCache.features().finalizedFeatures().getOrDefault( + TransactionVersion.FEATURE_NAME, 0.toShort)) + version } // visible for testing only @@ -624,7 +626,7 @@ class TransactionStateManager(brokerId: Int, // generate the message for this transaction metadata val keyBytes = TransactionLog.keyToBytes(transactionalId) - val valueBytes = TransactionLog.valueToBytes(newMetadata, usesFlexibleRecords()) + val valueBytes = TransactionLog.valueToBytes(newMetadata, transactionVersionLevel()) val timestamp = time.milliseconds() val records = MemoryRecords.withRecords(TransactionLog.EnforcedCompression, new SimpleRecord(timestamp, keyBytes, valueBytes)) diff --git a/core/src/main/scala/kafka/server/KafkaApis.scala b/core/src/main/scala/kafka/server/KafkaApis.scala index 2586c562cf5..840f163132a 100644 --- a/core/src/main/scala/kafka/server/KafkaApis.scala +++ b/core/src/main/scala/kafka/server/KafkaApis.scala @@ -73,7 +73,7 @@ import org.apache.kafka.coordinator.group.{Group, GroupCoordinator} import org.apache.kafka.coordinator.share.ShareCoordinator import org.apache.kafka.server.ClientMetricsManager import org.apache.kafka.server.authorizer._ -import org.apache.kafka.server.common.{GroupVersion, MetadataVersion, RequestLocal} +import org.apache.kafka.server.common.{GroupVersion, MetadataVersion, RequestLocal, TransactionVersion} import org.apache.kafka.server.common.MetadataVersion.{IBP_0_11_0_IV0, IBP_2_3_IV0} import org.apache.kafka.server.record.BrokerCompressionType import org.apache.kafka.server.share.context.ShareFetchContext @@ -2299,7 +2299,7 @@ class KafkaApis(val requestChannel: RequestChannel, val transactionalId = endTxnRequest.data.transactionalId if (authHelper.authorize(request.context, WRITE, TRANSACTIONAL_ID, transactionalId)) { - def sendResponseCallback(error: Errors): Unit = { + def sendResponseCallback(error: Errors, newProducerId: Long, newProducerEpoch: Short): Unit = { def createResponse(requestThrottleMs: Int): AbstractResponse = { val finalError = if (endTxnRequest.version < 2 && error == Errors.PRODUCER_FENCED) { @@ -2311,6 +2311,8 @@ class KafkaApis(val requestChannel: RequestChannel, } val responseBody = new EndTxnResponse(new EndTxnResponseData() .setErrorCode(finalError.code) + .setProducerId(newProducerId) + .setProducerEpoch(newProducerEpoch) .setThrottleTimeMs(requestThrottleMs)) trace(s"Completed ${endTxnRequest.data.transactionalId}'s EndTxnRequest " + s"with committed: ${endTxnRequest.data.committed}, " + @@ -2320,10 +2322,14 @@ class KafkaApis(val requestChannel: RequestChannel, requestHelper.sendResponseMaybeThrottle(request, createResponse) } + // If the request is version 4, we know the client supports transaction version 2. + val clientTransactionVersion = if (endTxnRequest.version() > 4) TransactionVersion.TV_2 else TransactionVersion.TV_0 + txnCoordinator.handleEndTransaction(endTxnRequest.data.transactionalId, endTxnRequest.data.producerId, endTxnRequest.data.producerEpoch, endTxnRequest.result(), + clientTransactionVersion, sendResponseCallback, requestLocal) } else diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala index 12531d289db..323dbe958e1 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala @@ -464,10 +464,10 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren addPartitionsOp.awaitAndVerify(txn) val txnMetadata = transactionMetadata(txn).getOrElse(throw new IllegalStateException(s"Transaction not found $txn")) - txnRecords += new SimpleRecord(txn.txnMessageKeyBytes, TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), true)) + txnRecords += new SimpleRecord(txn.txnMessageKeyBytes, TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), TransactionVersion.TV_2)) txnMetadata.state = PrepareCommit - txnRecords += new SimpleRecord(txn.txnMessageKeyBytes, TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), true)) + txnRecords += new SimpleRecord(txn.txnMessageKeyBytes, TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), TransactionVersion.TV_2)) prepareTxnLog(partitionId) } @@ -506,13 +506,15 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren private def prepareExhaustedEpochTxnMetadata(txn: Transaction): TransactionMetadata = { new TransactionMetadata(transactionalId = txn.transactionalId, producerId = producerId, - lastProducerId = RecordBatch.NO_PRODUCER_ID, + previousProducerId = RecordBatch.NO_PRODUCER_ID, + nextProducerId = RecordBatch.NO_PRODUCER_ID, producerEpoch = (Short.MaxValue - 1).toShort, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 60000, state = Empty, topicPartitions = collection.mutable.Set.empty[TopicPartition], - txnLastUpdateTimestamp = time.milliseconds()) + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = TransactionVersion.TV_0) } abstract class TxnOperation[R] extends Operation { @@ -562,7 +564,8 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren txnMetadata.producerId, txnMetadata.producerEpoch, transactionResult(txn), - resultCallback, + TransactionVersion.TV_2, + (r, _, _) => resultCallback(r), RequestLocal.withThreadConfinedCaching) } } diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala index 584a49bbc2c..f32a924fb68 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala @@ -23,9 +23,13 @@ import org.apache.kafka.common.record.RecordBatch import org.apache.kafka.common.requests.{AddPartitionsToTxnResponse, TransactionResult} import org.apache.kafka.common.utils.{LogContext, MockTime, ProducerIdAndEpoch} import org.apache.kafka.coordinator.transaction.TransactionStateManagerConfig +import org.apache.kafka.server.common.TransactionVersion +import org.apache.kafka.server.common.TransactionVersion.{TV_0, TV_2} import org.apache.kafka.server.util.MockScheduler import org.junit.jupiter.api.Assertions._ import org.junit.jupiter.api.Test +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ValueSource import org.mockito.{ArgumentCaptor, ArgumentMatchers} import org.mockito.ArgumentMatchers.{any, anyInt} import org.mockito.Mockito.{mock, times, verify, when} @@ -51,6 +55,7 @@ class TransactionCoordinatorTest { private val producerId = 10L private val producerEpoch: Short = 1 private val txnTimeoutMs = 1 + private val producerId2 = 11L private val partitions = mutable.Set[TopicPartition](new TopicPartition("topic1", 0)) private val scheduler = new MockScheduler(time) @@ -66,6 +71,8 @@ class TransactionCoordinatorTest { val transactionStatePartitionCount = 1 var result: InitProducerIdResult = _ var error: Errors = Errors.NONE + var newProducerId: Long = RecordBatch.NO_PRODUCER_ID + var newEpoch: Short = RecordBatch.NO_PRODUCER_EPOCH private def mockPidGenerator(): Unit = { when(pidGenerator.generateProducerId()).thenAnswer(_ => { @@ -155,8 +162,8 @@ class TransactionCoordinatorTest { def shouldGenerateNewProducerIdIfEpochsExhausted(): Unit = { initPidGenericMocks(transactionalId) - val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, (Short.MaxValue - 1).toShort, - (Short.MaxValue - 2).toShort, txnTimeoutMs, Empty, mutable.Set.empty, time.milliseconds(), time.milliseconds()) + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, (Short.MaxValue - 1).toShort, + (Short.MaxValue - 2).toShort, txnTimeoutMs, Empty, mutable.Set.empty, time.milliseconds(), time.milliseconds(), TV_0) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) @@ -245,7 +252,8 @@ class TransactionCoordinatorTest { errors = AddPartitionsToTxnResponse.errorsForTransaction(result.topicResults()).asScala.toMap } // If producer ID is not the same, return INVALID_PRODUCER_ID_MAPPING - val wrongPidTxnMetadata = new TransactionMetadata(transactionalId, 1, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, PrepareCommit, partitions, 0, 0) + val wrongPidTxnMetadata = new TransactionMetadata(transactionalId, 1, 0, RecordBatch.NO_PRODUCER_ID, + 0, RecordBatch.NO_PRODUCER_EPOCH, 0, PrepareCommit, partitions, 0, 0, TV_0) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(new CoordinatorEpochAndTxnMetadata(coordinatorEpoch, wrongPidTxnMetadata)))) @@ -253,10 +261,10 @@ class TransactionCoordinatorTest { errors.foreach { case (_, error) => assertEquals(Errors.INVALID_PRODUCER_ID_MAPPING, error) } - // If producer epoch is not equal, return PRODUCER_FENCED - val oldEpochTxnMetadata = new TransactionMetadata(transactionalId, 0, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, PrepareCommit, partitions, 0, 0) + val oldEpochTxnMetadata = new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID, + 0, RecordBatch.NO_PRODUCER_EPOCH, 0, PrepareCommit, partitions, 0, 0, TV_0) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(new CoordinatorEpochAndTxnMetadata(coordinatorEpoch, oldEpochTxnMetadata)))) @@ -266,7 +274,8 @@ class TransactionCoordinatorTest { } // If the txn state is Prepare or AbortCommit, we return CONCURRENT_TRANSACTIONS - val emptyTxnMetadata = new TransactionMetadata(transactionalId, 0, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, PrepareCommit, partitions, 0, 0) + val emptyTxnMetadata = new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID, + 0, RecordBatch.NO_PRODUCER_EPOCH, 0, PrepareCommit, partitions, 0, 0, TV_0) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(new CoordinatorEpochAndTxnMetadata(coordinatorEpoch, emptyTxnMetadata)))) @@ -276,7 +285,8 @@ class TransactionCoordinatorTest { } // Pending state does not matter, we will just check if the partitions are in the txnMetadata. - val ongoingTxnMetadata = new TransactionMetadata(transactionalId, 0, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, Ongoing, mutable.Set.empty, 0, 0) + val ongoingTxnMetadata = new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID, + 0, RecordBatch.NO_PRODUCER_EPOCH, 0, Ongoing, mutable.Set.empty, 0, 0, TV_0) ongoingTxnMetadata.pendingState = Some(CompleteCommit) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(new CoordinatorEpochAndTxnMetadata(coordinatorEpoch, ongoingTxnMetadata)))) @@ -298,9 +308,11 @@ class TransactionCoordinatorTest { } def validateConcurrentTransactions(state: TransactionState): Unit = { + // Since the clientTransactionVersion doesn't matter, use 2 since the states are PrepareCommit and PrepareAbort. when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, - new TransactionMetadata(transactionalId, 0, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, state, mutable.Set.empty, 0, 0))))) + new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID, + 0, RecordBatch.NO_PRODUCER_EPOCH, 0, state, mutable.Set.empty, 0, 0, TV_2))))) coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 0, partitions, errorsCallback) assertEquals(Errors.CONCURRENT_TRANSACTIONS, error) @@ -308,9 +320,11 @@ class TransactionCoordinatorTest { @Test def shouldRespondWithProducerFencedOnAddPartitionsWhenEpochsAreDifferent(): Unit = { + // Since the clientTransactionVersion doesn't matter, use 2 since the state is PrepareCommit. when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, - new TransactionMetadata(transactionalId, 0, 0, 10, 9, 0, PrepareCommit, mutable.Set.empty, 0, 0))))) + new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID, + 10, 9, 0, PrepareCommit, mutable.Set.empty, 0, 0, TV_2))))) coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 0, partitions, errorsCallback) assertEquals(Errors.PRODUCER_FENCED, error) @@ -318,27 +332,30 @@ class TransactionCoordinatorTest { @Test def shouldAppendNewMetadataToLogOnAddPartitionsWhenPartitionsAdded(): Unit = { - validateSuccessfulAddPartitions(Empty) + validateSuccessfulAddPartitions(Empty, 0) } @Test def shouldRespondWithSuccessOnAddPartitionsWhenStateIsOngoing(): Unit = { - validateSuccessfulAddPartitions(Ongoing) + validateSuccessfulAddPartitions(Ongoing, 0) } - @Test - def shouldRespondWithSuccessOnAddPartitionsWhenStateIsCompleteCommit(): Unit = { - validateSuccessfulAddPartitions(CompleteCommit) + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def shouldRespondWithSuccessOnAddPartitionsWhenStateIsCompleteCommit(clientTransactionVersion: Short): Unit = { + validateSuccessfulAddPartitions(CompleteCommit, clientTransactionVersion) } - @Test - def shouldRespondWithSuccessOnAddPartitionsWhenStateIsCompleteAbort(): Unit = { - validateSuccessfulAddPartitions(CompleteAbort) + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def shouldRespondWithSuccessOnAddPartitionsWhenStateIsCompleteAbort(clientTransactionVersion: Short): Unit = { + validateSuccessfulAddPartitions(CompleteAbort, clientTransactionVersion) } - def validateSuccessfulAddPartitions(previousState: TransactionState): Unit = { - val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, (producerEpoch - 1).toShort, - txnTimeoutMs, previousState, mutable.Set.empty, time.milliseconds(), time.milliseconds()) + def validateSuccessfulAddPartitions(previousState: TransactionState, transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, + producerEpoch, (producerEpoch - 1).toShort, txnTimeoutMs, previousState, mutable.Set.empty, time.milliseconds(), time.milliseconds(), clientTransactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) @@ -360,7 +377,8 @@ class TransactionCoordinatorTest { def shouldRespondWithErrorsNoneOnAddPartitionWhenNoErrorsAndPartitionsTheSame(): Unit = { when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, - new TransactionMetadata(transactionalId, 0, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, Empty, partitions, 0, 0))))) + new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID, + 0, RecordBatch.NO_PRODUCER_EPOCH, 0, Empty, partitions, 0, 0, TV_0))))) coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 0, partitions, errorsCallback) assertEquals(Errors.NONE, error) @@ -376,7 +394,8 @@ class TransactionCoordinatorTest { when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, - new TransactionMetadata(transactionalId, 0, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, Ongoing, partitions, 0, 0))))) + new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID, + 0, RecordBatch.NO_PRODUCER_EPOCH, 0, Ongoing, partitions, 0, 0, TV_0))))) coordinator.handleVerifyPartitionsInTransaction(transactionalId, 0L, 0, partitions, verifyPartitionsInTxnCallback) errors.foreach { case (_, error) => @@ -394,7 +413,8 @@ class TransactionCoordinatorTest { when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, - new TransactionMetadata(transactionalId, 0, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, Empty, partitions, 0, 0))))) + new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID, + 0, RecordBatch.NO_PRODUCER_EPOCH, 0, Empty, partitions, 0, 0, TV_0))))) val extraPartitions = partitions ++ Set(new TopicPartition("topic2", 0)) @@ -404,107 +424,227 @@ class TransactionCoordinatorTest { verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) } - @Test - def shouldReplyWithInvalidPidMappingOnEndTxnWhenTxnIdDoesntExist(): Unit = { + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def shouldReplyWithInvalidPidMappingOnEndTxnWhenTxnIdDoesntExist(transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(None)) - coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, errorsCallback) + coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback) assertEquals(Errors.INVALID_PRODUCER_ID_MAPPING, error) verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) } - @Test - def shouldReplyWithInvalidPidMappingOnEndTxnWhenPidDosentMatchMapped(): Unit = { + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def shouldReplyWithInvalidPidMappingOnEndTxnWhenPidDosentMatchMapped(transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, - new TransactionMetadata(transactionalId, 10, 10, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, Ongoing, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))) + new TransactionMetadata(transactionalId, 10, 10, RecordBatch.NO_PRODUCER_ID, + 0, RecordBatch.NO_PRODUCER_EPOCH, 0, Ongoing, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_0))))) - coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, errorsCallback) + coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback) assertEquals(Errors.INVALID_PRODUCER_ID_MAPPING, error) verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) } - @Test - def shouldReplyWithProducerFencedOnEndTxnWhenEpochIsNotSameAsTransaction(): Unit = { + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def shouldReplyWithProducerFencedOnEndTxnWhenEpochIsNotSameAsTransaction(transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, - new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, (producerEpoch - 1).toShort, 1, Ongoing, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))) + new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, producerEpoch, + (producerEpoch - 1).toShort, 1, Ongoing, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_0))))) - coordinator.handleEndTransaction(transactionalId, producerId, 0, TransactionResult.COMMIT, errorsCallback) + coordinator.handleEndTransaction(transactionalId, producerId, 0, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback) assertEquals(Errors.PRODUCER_FENCED, error) verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) } - @Test - def shouldReturnOkOnEndTxnWhenStatusIsCompleteCommitAndResultIsCommit(): Unit ={ + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def shouldReturnOkOnEndTxnWhenStatusIsCompleteCommitAndResultIsCommit(transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, - new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, (producerEpoch - 1).toShort, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))) + new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, producerEpoch, + (producerEpoch - 1).toShort, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion))))) - coordinator.handleEndTransaction(transactionalId, producerId, 1, TransactionResult.COMMIT, errorsCallback) + coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch(clientTransactionVersion), TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback) assertEquals(Errors.NONE, error) verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) } - @Test - def shouldReturnOkOnEndTxnWhenStatusIsCompleteAbortAndResultIsAbort(): Unit ={ - val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, (producerEpoch - 1).toShort, 1, CompleteAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()) + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def shouldReturnOkOnEndTxnWhenStatusIsCompleteAbortAndResultIsAbort(transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, + producerEpoch, (producerEpoch - 1).toShort, 1, CompleteAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) - coordinator.handleEndTransaction(transactionalId, producerId, 1, TransactionResult.ABORT, errorsCallback) + coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch(clientTransactionVersion), TransactionResult.ABORT, clientTransactionVersion, endTxnCallback) assertEquals(Errors.NONE, error) verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) } - @Test - def shouldReturnInvalidTxnRequestOnEndTxnRequestWhenStatusIsCompleteAbortAndResultIsNotAbort(): Unit = { - val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, (producerEpoch - 1).toShort, 1, CompleteAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()) + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def shouldReturnInvalidTxnRequestOnEndTxnRequestWhenStatusIsCompleteAbortAndResultIsNotAbort(transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, + producerEpoch, (producerEpoch - 1).toShort, 1, CompleteAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) - coordinator.handleEndTransaction(transactionalId, producerId, 1, TransactionResult.COMMIT, errorsCallback) + coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch(clientTransactionVersion), TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback) assertEquals(Errors.INVALID_TXN_STATE, error) verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) } - @Test - def shouldReturnInvalidTxnRequestOnEndTxnRequestWhenStatusIsCompleteCommitAndResultIsNotCommit(): Unit = { - val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, (producerEpoch - 1).toShort,1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()) + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def shouldReturnInvalidTxnRequestOnEndTxnRequestWhenStatusIsCompleteCommitAndResultIsNotCommit(transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, + producerEpoch, (producerEpoch - 1).toShort,1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) - coordinator.handleEndTransaction(transactionalId, producerId, 1, TransactionResult.ABORT, errorsCallback) + coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch(clientTransactionVersion), TransactionResult.ABORT, clientTransactionVersion, endTxnCallback) assertEquals(Errors.INVALID_TXN_STATE, error) verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) } - @Test - def shouldReturnConcurrentTxnRequestOnEndTxnRequestWhenStatusIsPrepareCommit(): Unit = { + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def shouldReturnConcurrentTransactionsOnEndTxnRequestWhenStatusIsPrepareCommit(transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) - .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, (producerEpoch - 1).toShort, 1, PrepareCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))) + .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, + RecordBatch.NO_PRODUCER_ID, producerEpoch, (producerEpoch - 1).toShort, 1, PrepareCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion))))) - coordinator.handleEndTransaction(transactionalId, producerId, 1, TransactionResult.COMMIT, errorsCallback) + coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch(clientTransactionVersion), TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback) assertEquals(Errors.CONCURRENT_TRANSACTIONS, error) verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) } - @Test - def shouldReturnInvalidTxnRequestOnEndTxnRequestWhenStatusIsPrepareAbort(): Unit = { + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def shouldReturnInvalidTxnRequestOnEndTxnRequestWhenStatusIsPrepareAbort(transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) - .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, 1, RecordBatch.NO_PRODUCER_EPOCH, 1, PrepareAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))) + .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, + RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, PrepareAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion))))) - coordinator.handleEndTransaction(transactionalId, producerId, 1, TransactionResult.COMMIT, errorsCallback) + coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch(clientTransactionVersion), TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback) assertEquals(Errors.INVALID_TXN_STATE, error) verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) } @Test - def shouldAppendPrepareCommitToLogOnEndTxnWhenStatusIsOngoingAndResultIsCommit(): Unit = { - mockPrepare(PrepareCommit) + def shouldReturnWhenTransactionVersionDowngraded(): Unit = { + // State was written when transactions V2 + when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) + .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, + RecordBatch.NO_PRODUCER_ID, (producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH, 1, PrepareCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2))))) - coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.COMMIT, errorsCallback) + // Return CONCURRENT_TRANSACTIONS as the transaction is still completing + coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.COMMIT, TV_0, endTxnCallback) + assertEquals(Errors.CONCURRENT_TRANSACTIONS, error) + assertEquals(RecordBatch.NO_PRODUCER_ID, newProducerId) + assertEquals(RecordBatch.NO_PRODUCER_EPOCH, newEpoch) + verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) + + // Recognize the retry and return NONE + when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) + .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, + RecordBatch.NO_PRODUCER_ID, (producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2))))) + coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.COMMIT, TV_0, endTxnCallback) + assertEquals(Errors.NONE, error) + assertEquals(producerId, newProducerId) + assertEquals((producerEpoch + 1).toShort, newEpoch) // epoch is bumped since we started as V2 + verify(transactionManager, times(2)).getTransactionState(ArgumentMatchers.eq(transactionalId)) + } + + @Test + def shouldReturnCorrectlyWhenTransactionVersionUpgraded(): Unit = { + // State was written when transactions V0 + when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) + .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, + RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, PrepareCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_0))))) + + // Transactions V0 throws the concurrent transactions error here. + coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.COMMIT, TV_2, endTxnCallback) + assertEquals(Errors.CONCURRENT_TRANSACTIONS, error) + verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) + + // When the transaction is completed, return and do not throw an error. + when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) + .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, + RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_0))))) + coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.COMMIT, TV_2, endTxnCallback) + assertEquals(Errors.NONE, error) + assertEquals(producerId, newProducerId) + assertEquals(producerEpoch, newEpoch) // epoch is not bumped since this started as V1 + verify(transactionManager, times(2)).getTransactionState(ArgumentMatchers.eq(transactionalId)) + } + + @Test + def shouldReturnInvalidTxnRequestOnEndTxnV2IfNotEndTxnV2Retry(): Unit = { + when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) + .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, + RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, PrepareCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2))))) + + // If producerEpoch is the same, this is not a retry of the EndTxnRequest, but the next EndTxnRequest. Return PRODUCER_FENCED. + coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.COMMIT, TV_2, endTxnCallback) + assertEquals(Errors.PRODUCER_FENCED, error) + verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) + + when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) + .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, + RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2))))) + + // If producerEpoch is the same, this is not a retry of the EndTxnRequest, but the next EndTxnRequest. Return PRODUCER_FENCED. + coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.COMMIT, TV_2, endTxnCallback) + assertEquals(Errors.PRODUCER_FENCED, error) + verify(transactionManager, times(2)).getTransactionState(ArgumentMatchers.eq(transactionalId)) + } + + @Test + def shouldReturnOkOnEndTxnV2IfEndTxnV2RetryEpochOverflow(): Unit = { + when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) + .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, + producerId2, Short.MaxValue, (Short.MaxValue - 1).toShort, 1, PrepareCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2))))) + + // Return CONCURRENT_TRANSACTIONS while transaction is still completing + coordinator.handleEndTransaction(transactionalId, producerId, (Short.MaxValue - 1).toShort, TransactionResult.COMMIT, TV_2, endTxnCallback) + assertEquals(Errors.CONCURRENT_TRANSACTIONS, error) + verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) + + when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) + .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId2, producerId, + RecordBatch.NO_PRODUCER_ID, 0, RecordBatch.NO_PRODUCER_EPOCH, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2))))) + + coordinator.handleEndTransaction(transactionalId, producerId, (Short.MaxValue - 1).toShort, TransactionResult.COMMIT, TV_2, endTxnCallback) + assertEquals(Errors.NONE, error) + assertNotEquals(RecordBatch.NO_PRODUCER_ID, newProducerId) + assertNotEquals(producerId, newProducerId) + assertEquals(0, newEpoch) + verify(transactionManager, times(2)).getTransactionState(ArgumentMatchers.eq(transactionalId)) + } + + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def shouldAppendPrepareCommitToLogOnEndTxnWhenStatusIsOngoingAndResultIsCommit(transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) + mockPrepare(PrepareCommit, clientTransactionVersion) + + coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback) verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) verify(transactionManager).appendTransactionToLog( ArgumentMatchers.eq(transactionalId), @@ -515,11 +655,13 @@ class TransactionCoordinatorTest { any()) } - @Test - def shouldAppendPrepareAbortToLogOnEndTxnWhenStatusIsOngoingAndResultIsAbort(): Unit = { - mockPrepare(PrepareAbort) + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def shouldAppendPrepareAbortToLogOnEndTxnWhenStatusIsOngoingAndResultIsAbort(transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) + mockPrepare(PrepareAbort, clientTransactionVersion) - coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.ABORT, errorsCallback) + coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.ABORT, clientTransactionVersion, endTxnCallback) verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) verify(transactionManager).appendTransactionToLog( ArgumentMatchers.eq(transactionalId), @@ -530,90 +672,106 @@ class TransactionCoordinatorTest { any()) } - @Test - def shouldRespondWithInvalidRequestOnEndTxnWhenTransactionalIdIsNull(): Unit = { - coordinator.handleEndTransaction(null, 0, 0, TransactionResult.COMMIT, errorsCallback) + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def shouldRespondWithInvalidRequestOnEndTxnWhenTransactionalIdIsNull(transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) + coordinator.handleEndTransaction(null, 0, 0, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback) assertEquals(Errors.INVALID_REQUEST, error) } - @Test - def shouldRespondWithInvalidRequestOnEndTxnWhenTransactionalIdIsEmpty(): Unit = { + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def shouldRespondWithInvalidRequestOnEndTxnWhenTransactionalIdIsEmpty(transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Left(Errors.NOT_COORDINATOR)) - coordinator.handleEndTransaction("", 0, 0, TransactionResult.COMMIT, errorsCallback) + coordinator.handleEndTransaction("", 0, 0, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback) assertEquals(Errors.INVALID_REQUEST, error) } - @Test - def shouldRespondWithNotCoordinatorOnEndTxnWhenIsNotCoordinatorForId(): Unit = { + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def shouldRespondWithNotCoordinatorOnEndTxnWhenIsNotCoordinatorForId(transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Left(Errors.NOT_COORDINATOR)) - coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, errorsCallback) + coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback) assertEquals(Errors.NOT_COORDINATOR, error) } - @Test - def shouldRespondWithCoordinatorLoadInProgressOnEndTxnWhenCoordinatorIsLoading(): Unit = { + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def shouldRespondWithCoordinatorLoadInProgressOnEndTxnWhenCoordinatorIsLoading(transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Left(Errors.COORDINATOR_LOAD_IN_PROGRESS)) - coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, errorsCallback) + coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback) assertEquals(Errors.COORDINATOR_LOAD_IN_PROGRESS, error) } - @Test - def shouldReturnInvalidEpochOnEndTxnWhenEpochIsLarger(): Unit = { + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def shouldReturnInvalidEpochOnEndTxnWhenEpochIsLarger(transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) val serverProducerEpoch = 1.toShort - verifyEndTxnEpoch(serverProducerEpoch, (serverProducerEpoch + 1).toShort) + verifyEndTxnEpoch(serverProducerEpoch, (serverProducerEpoch + 1).toShort, clientTransactionVersion) } - @Test - def shouldReturnInvalidEpochOnEndTxnWhenEpochIsSmaller(): Unit = { - val serverProducerEpoch = 1.toShort - verifyEndTxnEpoch(serverProducerEpoch, (serverProducerEpoch - 1).toShort) + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def shouldReturnInvalidEpochOnEndTxnWhenEpochIsSmaller(transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) + val serverProducerEpoch = 2.toShort + // Since we bump epoch in transactionV2 the request should be one producer ID older + verifyEndTxnEpoch(serverProducerEpoch, requestEpoch(clientTransactionVersion), clientTransactionVersion) } - private def verifyEndTxnEpoch(metadataEpoch: Short, requestEpoch: Short): Unit = { + private def verifyEndTxnEpoch(metadataEpoch: Short, requestEpoch: Short, clientTransactionVersion: TransactionVersion): Unit = { when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, - new TransactionMetadata(transactionalId, producerId, producerId, metadataEpoch, 0, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))) + new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, metadataEpoch, 1, + 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion))))) - coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch, TransactionResult.COMMIT, errorsCallback) + coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback) assertEquals(Errors.PRODUCER_FENCED, error) verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) } @Test def shouldIncrementEpochAndUpdateMetadataOnHandleInitPidWhenExistingEmptyTransaction(): Unit = { - validateIncrementEpochAndUpdateMetadata(Empty) + validateIncrementEpochAndUpdateMetadata(Empty, 0) + } + + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def shouldIncrementEpochAndUpdateMetadataOnHandleInitPidWhenExistingCompleteTransaction(clientTransactionVersion: Short): Unit = { + validateIncrementEpochAndUpdateMetadata(CompleteAbort, clientTransactionVersion) + } + + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def shouldIncrementEpochAndUpdateMetadataOnHandleInitPidWhenExistingCompleteCommitTransaction(clientTransactionVersion: Short): Unit = { + validateIncrementEpochAndUpdateMetadata(CompleteCommit, clientTransactionVersion) } @Test - def shouldIncrementEpochAndUpdateMetadataOnHandleInitPidWhenExistingCompleteTransaction(): Unit = { - validateIncrementEpochAndUpdateMetadata(CompleteAbort) - } - - @Test - def shouldIncrementEpochAndUpdateMetadataOnHandleInitPidWhenExistingCompleteCommitTransaction(): Unit = { - validateIncrementEpochAndUpdateMetadata(CompleteCommit) - } - - @Test - def shouldWaitForCommitToCompleteOnHandleInitPidAndExistingTransactionInPrepareCommitState(): Unit ={ + def shouldWaitForCommitToCompleteOnHandleInitPidAndExistingTransactionInPrepareCommitState(): Unit = { validateRespondsWithConcurrentTransactionsOnInitPidWhenInPrepareState(PrepareCommit) } @Test - def shouldWaitForCommitToCompleteOnHandleInitPidAndExistingTransactionInPrepareAbortState(): Unit ={ + def shouldWaitForCommitToCompleteOnHandleInitPidAndExistingTransactionInPrepareAbortState(): Unit = { validateRespondsWithConcurrentTransactionsOnInitPidWhenInPrepareState(PrepareAbort) } @Test def shouldAbortTransactionOnHandleInitPidWhenExistingTransactionInOngoingState(): Unit = { - val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, - (producerEpoch - 1).toShort, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds()) + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, + producerEpoch, (producerEpoch - 1).toShort, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0) when(transactionManager.validateTransactionTimeoutMs(anyInt())) .thenReturn(true) @@ -621,8 +779,10 @@ class TransactionCoordinatorTest { when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) - val originalMetadata = new TransactionMetadata(transactionalId, producerId, producerId, (producerEpoch + 1).toShort, - RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds()) + when(transactionManager.transactionVersionLevel()).thenReturn(TV_0) + + val originalMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, + (producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0) when(transactionManager.appendTransactionToLog( ArgumentMatchers.eq(transactionalId), ArgumentMatchers.eq(coordinatorEpoch), @@ -640,7 +800,7 @@ class TransactionCoordinatorTest { verify(transactionManager).appendTransactionToLog( ArgumentMatchers.eq(transactionalId), ArgumentMatchers.eq(coordinatorEpoch), - ArgumentMatchers.eq(originalMetadata.prepareAbortOrCommit(PrepareAbort, time.milliseconds())), + ArgumentMatchers.eq(originalMetadata.prepareAbortOrCommit(PrepareAbort, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds())), any(), any(), any()) @@ -648,14 +808,14 @@ class TransactionCoordinatorTest { @Test def shouldFailToAbortTransactionOnHandleInitPidWhenProducerEpochIsSmaller(): Unit = { - val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, - (producerEpoch - 1).toShort, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds()) + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, + producerEpoch, (producerEpoch - 1).toShort, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0) when(transactionManager.validateTransactionTimeoutMs(anyInt())) .thenReturn(true) - val bumpedTxnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, (producerEpoch + 2).toShort, - (producerEpoch - 1).toShort, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds()) + val bumpedTxnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, + (producerEpoch + 2).toShort, (producerEpoch - 1).toShort, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) @@ -671,8 +831,8 @@ class TransactionCoordinatorTest { @Test def shouldNotRepeatedlyBumpEpochDueToInitPidDuringOngoingTxnIfAppendToLogFails(): Unit = { - val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, - RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds()) + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, + producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0) when(transactionManager.validateTransactionTimeoutMs(anyInt())) .thenReturn(true) @@ -683,9 +843,11 @@ class TransactionCoordinatorTest { when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenAnswer(_ => Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) - val originalMetadata = new TransactionMetadata(transactionalId, producerId, producerId, (producerEpoch + 1).toShort, - RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds()) - val txnTransitMetadata = originalMetadata.prepareAbortOrCommit(PrepareAbort, time.milliseconds()) + when(transactionManager.transactionVersionLevel()).thenReturn(TV_0) + + val originalMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, + (producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0) + val txnTransitMetadata = originalMetadata.prepareAbortOrCommit(PrepareAbort, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds()) when(transactionManager.appendTransactionToLog( ArgumentMatchers.eq(transactionalId), ArgumentMatchers.eq(coordinatorEpoch), @@ -740,33 +902,38 @@ class TransactionCoordinatorTest { @Test def shouldUseLastEpochToFenceWhenEpochsAreExhausted(): Unit = { - val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, (Short.MaxValue - 1).toShort, - (Short.MaxValue - 2).toShort, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds()) + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, + (Short.MaxValue - 1).toShort, (Short.MaxValue - 2).toShort, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0) assertTrue(txnMetadata.isProducerEpochExhausted) when(transactionManager.validateTransactionTimeoutMs(anyInt())) .thenReturn(true) - val postFenceTxnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, Short.MaxValue, - RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, PrepareAbort, partitions, time.milliseconds(), time.milliseconds()) + val postFenceTxnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, + Short.MaxValue, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, PrepareAbort, partitions, time.milliseconds(), time.milliseconds(), TV_0) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, postFenceTxnMetadata)))) + when(transactionManager.transactionVersionLevel()).thenReturn(TV_0) + + // InitProducerId uses FenceProducerEpoch so clientTransactionVersion is 0. when(transactionManager.appendTransactionToLog( ArgumentMatchers.eq(transactionalId), ArgumentMatchers.eq(coordinatorEpoch), ArgumentMatchers.eq(TxnTransitMetadata( producerId = producerId, - lastProducerId = producerId, + prevProducerId = producerId, + nextProducerId = RecordBatch.NO_PRODUCER_ID, producerEpoch = Short.MaxValue, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = txnTimeoutMs, txnState = PrepareAbort, topicPartitions = partitions.toSet, txnStartTimestamp = time.milliseconds(), - txnLastUpdateTimestamp = time.milliseconds())), + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = TV_0)), capturedErrorsCallback.capture(), any(), any()) @@ -783,14 +950,16 @@ class TransactionCoordinatorTest { ArgumentMatchers.eq(coordinatorEpoch), ArgumentMatchers.eq(TxnTransitMetadata( producerId = producerId, - lastProducerId = producerId, + prevProducerId = producerId, + nextProducerId = RecordBatch.NO_PRODUCER_ID, producerEpoch = Short.MaxValue, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = txnTimeoutMs, txnState = PrepareAbort, topicPartitions = partitions.toSet, txnStartTimestamp = time.milliseconds(), - txnLastUpdateTimestamp = time.milliseconds())), + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = TV_0)), any(), any(), any()) @@ -800,8 +969,8 @@ class TransactionCoordinatorTest { def testInitProducerIdWithNoLastProducerData(): Unit = { // If the metadata doesn't include the previous producer data (for example, if it was written to the log by a broker // on an old version), the retry case should fail - val txnMetadata = new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, (producerEpoch + 1).toShort, - RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds) + val txnMetadata = new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, + RecordBatch.NO_PRODUCER_EPOCH, (producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds, TV_0) when(transactionManager.validateTransactionTimeoutMs(anyInt())) .thenReturn(true) @@ -817,8 +986,8 @@ class TransactionCoordinatorTest { @Test def testFenceProducerWhenMappingExistsWithDifferentProducerId(): Unit = { // Existing transaction ID maps to new producer ID - val txnMetadata = new TransactionMetadata(transactionalId, producerId + 1, producerId, producerEpoch, - (producerEpoch - 1).toShort, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds) + val txnMetadata = new TransactionMetadata(transactionalId, producerId + 1, producerId, + RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, (producerEpoch - 1).toShort, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds, TV_0) when(transactionManager.validateTransactionTimeoutMs(anyInt())) .thenReturn(true) @@ -835,8 +1004,8 @@ class TransactionCoordinatorTest { def testInitProducerIdWithCurrentEpochProvided(): Unit = { mockPidGenerator() - val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, 10, - 9, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds) + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, + RecordBatch.NO_PRODUCER_EPOCH, 10, 9, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds, TV_0) when(transactionManager.validateTransactionTimeoutMs(anyInt())) .thenReturn(true) @@ -870,8 +1039,8 @@ class TransactionCoordinatorTest { def testInitProducerIdStaleCurrentEpochProvided(): Unit = { mockPidGenerator() - val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, 10, - 9, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds) + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, + RecordBatch.NO_PRODUCER_EPOCH, 10, 9, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds, TV_0) when(transactionManager.validateTransactionTimeoutMs(anyInt())) .thenReturn(true) @@ -906,8 +1075,8 @@ class TransactionCoordinatorTest { @Test def testRetryInitProducerIdAfterProducerIdRotation(): Unit = { // Existing transaction ID maps to new producer ID - val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, (Short.MaxValue - 1).toShort, - (Short.MaxValue - 2).toShort, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds) + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, + RecordBatch.NO_PRODUCER_EPOCH, (Short.MaxValue - 1).toShort, (Short.MaxValue - 2).toShort, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds, TV_0) when(pidGenerator.generateProducerId()) .thenReturn(Success(producerId + 1)) @@ -928,7 +1097,7 @@ class TransactionCoordinatorTest { capturedErrorsCallback.getValue.apply(Errors.NONE) txnMetadata.pendingState = None txnMetadata.producerId = capturedTxnTransitMetadata.getValue.producerId - txnMetadata.lastProducerId = capturedTxnTransitMetadata.getValue.lastProducerId + txnMetadata.previousProducerId = capturedTxnTransitMetadata.getValue.prevProducerId txnMetadata.producerEpoch = capturedTxnTransitMetadata.getValue.producerEpoch txnMetadata.lastProducerEpoch = capturedTxnTransitMetadata.getValue.lastProducerEpoch }) @@ -947,8 +1116,8 @@ class TransactionCoordinatorTest { @Test def testInitProducerIdWithInvalidEpochAfterProducerIdRotation(): Unit = { // Existing transaction ID maps to new producer ID - val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, (Short.MaxValue - 1).toShort, - (Short.MaxValue - 2).toShort, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds) + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, + RecordBatch.NO_PRODUCER_EPOCH, (Short.MaxValue - 1).toShort, (Short.MaxValue - 2).toShort, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds, TV_0) when(pidGenerator.generateProducerId()) .thenReturn(Success(producerId + 1)) @@ -969,7 +1138,7 @@ class TransactionCoordinatorTest { capturedErrorsCallback.getValue.apply(Errors.NONE) txnMetadata.pendingState = None txnMetadata.producerId = capturedTxnTransitMetadata.getValue.producerId - txnMetadata.lastProducerId = capturedTxnTransitMetadata.getValue.lastProducerId + txnMetadata.previousProducerId = capturedTxnTransitMetadata.getValue.prevProducerId txnMetadata.producerEpoch = capturedTxnTransitMetadata.getValue.producerEpoch txnMetadata.lastProducerEpoch = capturedTxnTransitMetadata.getValue.lastProducerEpoch }) @@ -995,16 +1164,20 @@ class TransactionCoordinatorTest { @Test def shouldAbortExpiredTransactionsInOngoingStateAndBumpEpoch(): Unit = { val now = time.milliseconds() - val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, - RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now) + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, + RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now, TV_0) when(transactionManager.timedOutTransactions()) .thenReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, producerId, producerEpoch))) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) - val expectedTransition = TxnTransitMetadata(producerId, producerId, (producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH, - txnTimeoutMs, PrepareAbort, partitions.toSet, now, now + TransactionStateManagerConfig.TRANSACTIONS_ABORT_TIMED_OUT_TRANSACTION_CLEANUP_INTERVAL_MS_DEFAULT) + // Transaction timeouts use FenceProducerEpoch so clientTransactionVersion is 0. + val expectedTransition = TxnTransitMetadata(producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, (producerEpoch + 1).toShort, + RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, PrepareAbort, partitions.toSet, now, + now + TransactionStateManagerConfig.TRANSACTIONS_ABORT_TIMED_OUT_TRANSACTION_CLEANUP_INTERVAL_MS_DEFAULT, TV_0) + + when(transactionManager.transactionVersionLevel()).thenReturn(TV_0) when(transactionManager.appendTransactionToLog(ArgumentMatchers.eq(transactionalId), ArgumentMatchers.eq(coordinatorEpoch), @@ -1030,20 +1203,22 @@ class TransactionCoordinatorTest { @Test def shouldNotAcceptSmallerEpochDuringTransactionExpiration(): Unit = { val now = time.milliseconds() - val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, - RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now) + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, + RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now, TV_0) when(transactionManager.timedOutTransactions()) .thenReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, producerId, producerEpoch))) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) - val bumpedTxnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, (producerEpoch + 2).toShort, - RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now) + when(transactionManager.transactionVersionLevel()).thenReturn(TV_0) + + val bumpedTxnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, + RecordBatch.NO_PRODUCER_EPOCH, (producerEpoch + 2).toShort, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now, TV_0) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, bumpedTxnMetadata)))) - def checkOnEndTransactionComplete(txnIdAndPidEpoch: TransactionalIdAndProducerIdEpoch)(error: Errors): Unit = { + def checkOnEndTransactionComplete(txnIdAndPidEpoch: TransactionalIdAndProducerIdEpoch)(error: Errors, producerId: Long, producerEpoch: Short): Unit = { assertEquals(Errors.PRODUCER_FENCED, error) } coordinator.abortTimedOutTransactions(checkOnEndTransactionComplete) @@ -1054,9 +1229,9 @@ class TransactionCoordinatorTest { @Test def shouldNotAbortExpiredTransactionsThatHaveAPendingStateTransition(): Unit = { - val metadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, - RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds()) - metadata.prepareAbortOrCommit(PrepareCommit, time.milliseconds()) + val metadata = new TransactionMetadata(transactionalId, producerId, producerId, + RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0) + metadata.prepareAbortOrCommit(PrepareCommit, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds()) when(transactionManager.timedOutTransactions()) .thenReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, producerId, producerEpoch))) @@ -1073,22 +1248,26 @@ class TransactionCoordinatorTest { @Test def shouldNotBumpEpochWhenAbortingExpiredTransactionIfAppendToLogFails(): Unit = { val now = time.milliseconds() - val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, - RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now) + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, + RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now, TV_0) when(transactionManager.timedOutTransactions()) .thenReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, producerId, producerEpoch))) - val txnMetadataAfterAppendFailure = new TransactionMetadata(transactionalId, producerId, producerId, (producerEpoch + 1).toShort, - RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now) + val txnMetadataAfterAppendFailure = new TransactionMetadata(transactionalId, producerId, producerId, + RecordBatch.NO_PRODUCER_EPOCH, (producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now, TV_0) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadataAfterAppendFailure)))) + // Transaction timeouts use FenceProducerEpoch so clientTransactionVersion is 0. val bumpedEpoch = (producerEpoch + 1).toShort - val expectedTransition = TxnTransitMetadata(producerId, producerId, bumpedEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, - PrepareAbort, partitions.toSet, now, now + TransactionStateManagerConfig.TRANSACTIONS_ABORT_TIMED_OUT_TRANSACTION_CLEANUP_INTERVAL_MS_DEFAULT) + val expectedTransition = TxnTransitMetadata(producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, bumpedEpoch, + RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, PrepareAbort, partitions.toSet, now, + now + TransactionStateManagerConfig.TRANSACTIONS_ABORT_TIMED_OUT_TRANSACTION_CLEANUP_INTERVAL_MS_DEFAULT, TV_0) + + when(transactionManager.transactionVersionLevel()).thenReturn(TV_0) when(transactionManager.appendTransactionToLog(ArgumentMatchers.eq(transactionalId), ArgumentMatchers.eq(coordinatorEpoch), @@ -1117,9 +1296,9 @@ class TransactionCoordinatorTest { @Test def shouldNotBumpEpochWithPendingTransaction(): Unit = { - val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, - RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds()) - txnMetadata.prepareAbortOrCommit(PrepareCommit, time.milliseconds()) + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0) + txnMetadata.prepareAbortOrCommit(PrepareCommit, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds()) when(transactionManager.validateTransactionTimeoutMs(anyInt())) .thenReturn(true) @@ -1146,9 +1325,9 @@ class TransactionCoordinatorTest { def testDescribeTransactionsWithExpiringTransactionalId(): Unit = { coordinator.startup(() => transactionStatePartitionCount, enableTransactionalIdExpiration = false) - val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Dead, mutable.Set.empty, time.milliseconds(), - time.milliseconds()) + time.milliseconds(), TV_0) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) @@ -1172,8 +1351,8 @@ class TransactionCoordinatorTest { @Test def testDescribeTransactions(): Unit = { - val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, - RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds()) + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) @@ -1199,7 +1378,9 @@ class TransactionCoordinatorTest { when(transactionManager.validateTransactionTimeoutMs(anyInt())) .thenReturn(true) - val metadata = new TransactionMetadata(transactionalId, 0, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, state, mutable.Set[TopicPartition](new TopicPartition("topic", 1)), 0, 0) + // Since the clientTransactionVersion doesn't matter, use 2 since the states are PrepareCommit and PrepareAbort. + val metadata = new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, + 0, RecordBatch.NO_PRODUCER_EPOCH, 0, state, mutable.Set[TopicPartition](new TopicPartition("topic", 1)), 0, 0, TV_2) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, metadata)))) @@ -1208,14 +1389,16 @@ class TransactionCoordinatorTest { assertEquals(InitProducerIdResult(-1, -1, Errors.CONCURRENT_TRANSACTIONS), result) } - private def validateIncrementEpochAndUpdateMetadata(state: TransactionState): Unit = { + private def validateIncrementEpochAndUpdateMetadata(state: TransactionState, transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) when(pidGenerator.generateProducerId()) .thenReturn(Success(producerId)) when(transactionManager.validateTransactionTimeoutMs(anyInt())) .thenReturn(true) - val metadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, state, mutable.Set.empty[TopicPartition], time.milliseconds(), time.milliseconds()) + val metadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, + producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, state, mutable.Set.empty[TopicPartition], time.milliseconds(), time.milliseconds(), clientTransactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, metadata)))) @@ -1242,13 +1425,13 @@ class TransactionCoordinatorTest { assertEquals(producerId, metadata.producerId) } - private def mockPrepare(transactionState: TransactionState, runCallback: Boolean = false): TransactionMetadata = { + private def mockPrepare(transactionState: TransactionState, clientTransactionVersion: TransactionVersion, runCallback: Boolean = false): TransactionMetadata = { val now = time.milliseconds() - val originalMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, - txnTimeoutMs, Ongoing, partitions, now, now) + val originalMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, + producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now, TV_0) - val transition = TxnTransitMetadata(producerId, producerId, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, - transactionState, partitions.toSet, now, now) + val transition = TxnTransitMetadata(producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, transactionState, partitions.toSet, now, now, clientTransactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, originalMetadata)))) @@ -1264,8 +1447,8 @@ class TransactionCoordinatorTest { capturedErrorsCallback.getValue.apply(Errors.NONE) }) - new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, - txnTimeoutMs, transactionState, partitions, time.milliseconds(), time.milliseconds()) + new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, transactionState, partitions, time.milliseconds(), time.milliseconds(), clientTransactionVersion) } def initProducerIdMockCallback(ret: InitProducerIdResult): Unit = { @@ -1275,4 +1458,17 @@ class TransactionCoordinatorTest { def errorsCallback(ret: Errors): Unit = { error = ret } + + def endTxnCallback(ret: Errors, producerId: Long, epoch: Short): Unit = { + error = ret + newProducerId = producerId + newEpoch = epoch + } + + def requestEpoch(clientTransactionVersion: TransactionVersion): Short = { + if (clientTransactionVersion.supportsEpochBump()) + (producerEpoch - 1).toShort + else + producerEpoch + } } diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala index 3fdf42c6b33..fd5f1e37a65 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala @@ -23,8 +23,9 @@ import org.apache.kafka.common.compress.Compression import org.apache.kafka.common.protocol.{ByteBufferAccessor, MessageUtil} import org.apache.kafka.common.protocol.types.Field.TaggedFieldsSection import org.apache.kafka.common.protocol.types.{CompactArrayOf, Field, Schema, Struct, Type} -import org.apache.kafka.common.record.{MemoryRecords, SimpleRecord} +import org.apache.kafka.common.record.{MemoryRecords, RecordBatch, SimpleRecord} import org.apache.kafka.coordinator.transaction.generated.{TransactionLogKey, TransactionLogValue} +import org.apache.kafka.server.common.TransactionVersion.{TV_0, TV_2} import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows, assertTrue} import org.junit.jupiter.api.Test @@ -48,10 +49,11 @@ class TransactionLogTest { val transactionalId = "transactionalId" val producerId = 23423L - val txnMetadata = TransactionMetadata(transactionalId, producerId, producerEpoch, transactionTimeoutMs, 0) + val txnMetadata = new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_ID, producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, transactionTimeoutMs, Empty, collection.mutable.Set.empty[TopicPartition], 0, 0, TV_0) txnMetadata.addPartitions(topicPartitions) - assertThrows(classOf[IllegalStateException], () => TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), true)) + assertThrows(classOf[IllegalStateException], () => TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), TV_2)) } @Test @@ -72,14 +74,14 @@ class TransactionLogTest { // generate transaction log messages val txnRecords = pidMappings.map { case (transactionalId, producerId) => - val txnMetadata = TransactionMetadata(transactionalId, producerId, producerEpoch, transactionTimeoutMs, - transactionStates(producerId), 0) + val txnMetadata = new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_ID, producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, transactionTimeoutMs, transactionStates(producerId), collection.mutable.Set.empty[TopicPartition], 0, 0, TV_0) if (!txnMetadata.state.equals(Empty)) txnMetadata.addPartitions(topicPartitions) val keyBytes = TransactionLog.keyToBytes(transactionalId) - val valueBytes = TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), true) + val valueBytes = TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), TV_2) new SimpleRecord(keyBytes, valueBytes) }.toSeq @@ -114,12 +116,12 @@ class TransactionLogTest { val producerId = 1334L val topicPartition = new TopicPartition("topic", 0) - val txnMetadata = TransactionMetadata(transactionalId, producerId, producerEpoch, - transactionTimeoutMs, Ongoing, 0) + val txnMetadata = new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_ID, producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, transactionTimeoutMs, Ongoing, collection.mutable.Set.empty[TopicPartition], 0, 0, TV_0) txnMetadata.addPartitions(Set(topicPartition)) val keyBytes = TransactionLog.keyToBytes(transactionalId) - val valueBytes = TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), true) + val valueBytes = TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), TV_2) val transactionMetadataRecord = TestUtils.records(Seq( new SimpleRecord(keyBytes, valueBytes) )).records.asScala.head @@ -144,15 +146,15 @@ class TransactionLogTest { @Test def testSerializeTransactionLogValueToHighestNonFlexibleVersion(): Unit = { - val txnTransitMetadata = TxnTransitMetadata(1, 1, 1, 1, 1000, CompleteCommit, Set.empty, 500, 500) - val txnLogValueBuffer = ByteBuffer.wrap(TransactionLog.valueToBytes(txnTransitMetadata, false)) + val txnTransitMetadata = TxnTransitMetadata(1, 1, 1, 1, 1, 1000, CompleteCommit, Set.empty, 500, 500, TV_0) + val txnLogValueBuffer = ByteBuffer.wrap(TransactionLog.valueToBytes(txnTransitMetadata, TV_0)) assertEquals(0, txnLogValueBuffer.getShort) } @Test def testSerializeTransactionLogValueToFlexibleVersion(): Unit = { - val txnTransitMetadata = TxnTransitMetadata(1, 1, 1, 1, 1000, CompleteCommit, Set.empty, 500, 500) - val txnLogValueBuffer = ByteBuffer.wrap(TransactionLog.valueToBytes(txnTransitMetadata, true)) + val txnTransitMetadata = TxnTransitMetadata(1, 1, 1, 1, 1, 1000, CompleteCommit, Set.empty, 500, 500, TV_2) + val txnLogValueBuffer = ByteBuffer.wrap(TransactionLog.valueToBytes(txnTransitMetadata, TV_2)) assertEquals(TransactionLogValue.HIGHEST_SUPPORTED_VERSION, txnLogValueBuffer.getShort) } @@ -194,8 +196,8 @@ class TransactionLogTest { new Field("topic", Type.COMPACT_STRING, ""), new Field("partition_ids", new CompactArrayOf(Type.INT32), ""), TaggedFieldsSection.of( - Int.box(0), new Field("partition_foo", Type.STRING, ""), - Int.box(1), new Field("partition_foo", Type.INT32, "") + Int.box(100), new Field("partition_foo", Type.STRING, ""), + Int.box(101), new Field("partition_foo", Type.INT32, "") ) ) @@ -204,8 +206,8 @@ class TransactionLogTest { txnPartitions.set("topic", "topic") txnPartitions.set("partition_ids", Array(Integer.valueOf(1))) val txnPartitionsTaggedFields = new java.util.TreeMap[Integer, Any]() - txnPartitionsTaggedFields.put(0, "foo") - txnPartitionsTaggedFields.put(1, 4000) + txnPartitionsTaggedFields.put(100, "foo") + txnPartitionsTaggedFields.put(101, 4000) txnPartitions.set("_tagged_fields", txnPartitionsTaggedFields) // Copy of TransactionLogValue.SCHEMA_1 with a few @@ -219,8 +221,8 @@ class TransactionLogTest { new Field("transaction_last_update_timestamp_ms", Type.INT64, ""), new Field("transaction_start_timestamp_ms", Type.INT64, ""), TaggedFieldsSection.of( - Int.box(0), new Field("txn_foo", Type.STRING, ""), - Int.box(1), new Field("txn_bar", Type.INT32, "") + Int.box(100), new Field("txn_foo", Type.STRING, ""), + Int.box(101), new Field("txn_bar", Type.INT32, "") ) ) @@ -234,8 +236,8 @@ class TransactionLogTest { transactionLogValue.set("transaction_last_update_timestamp_ms", 2000L) transactionLogValue.set("transaction_start_timestamp_ms", 3000L) val txnLogValueTaggedFields = new java.util.TreeMap[Integer, Any]() - txnLogValueTaggedFields.put(0, "foo") - txnLogValueTaggedFields.put(1, 4000) + txnLogValueTaggedFields.put(100, "foo") + txnLogValueTaggedFields.put(101, 4000) transactionLogValue.set("_tagged_fields", txnLogValueTaggedFields) // Prepare the buffer. @@ -249,8 +251,8 @@ class TransactionLogTest { // fields were read but ignored. buffer.getShort() // Skip version. val value = new TransactionLogValue(new ByteBufferAccessor(buffer), 1.toShort) - assertEquals(Seq(0, 1), value.unknownTaggedFields().asScala.map(_.tag)) - assertEquals(Seq(0, 1), value.transactionPartitions().get(0).unknownTaggedFields().asScala.map(_.tag)) + assertEquals(Seq(100, 101), value.unknownTaggedFields().asScala.map(_.tag)) + assertEquals(Seq(100, 101), value.transactionPartitions().get(0).unknownTaggedFields().asScala.map(_.tag)) // Read the buffer with readTxnRecordValue. buffer.rewind() diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala index d40932f3226..3f1c4c67a06 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala @@ -28,7 +28,7 @@ import org.apache.kafka.common.record.RecordBatch import org.apache.kafka.common.requests.{RequestHeader, TransactionResult, WriteTxnMarkersRequest, WriteTxnMarkersResponse} import org.apache.kafka.common.utils.MockTime import org.apache.kafka.common.{Node, TopicPartition} -import org.apache.kafka.server.common.MetadataVersion +import org.apache.kafka.server.common.{MetadataVersion, TransactionVersion} import org.apache.kafka.server.metrics.{KafkaMetricsGroup, KafkaYammerMetrics} import org.apache.kafka.server.util.RequestAndCompletionHandler import org.junit.jupiter.api.Assertions._ @@ -63,10 +63,10 @@ class TransactionMarkerChannelManagerTest { private val coordinatorEpoch2 = 1 private val txnTimeoutMs = 0 private val txnResult = TransactionResult.COMMIT - private val txnMetadata1 = new TransactionMetadata(transactionalId1, producerId1, producerId1, producerEpoch, lastProducerEpoch, - txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](partition1, partition2), 0L, 0L) - private val txnMetadata2 = new TransactionMetadata(transactionalId2, producerId2, producerId2, producerEpoch, lastProducerEpoch, - txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](partition1), 0L, 0L) + private val txnMetadata1 = new TransactionMetadata(transactionalId1, producerId1, producerId1, RecordBatch.NO_PRODUCER_ID, + producerEpoch, lastProducerEpoch, txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](partition1, partition2), 0L, 0L, TransactionVersion.TV_2) + private val txnMetadata2 = new TransactionMetadata(transactionalId2, producerId2, producerId2, RecordBatch.NO_PRODUCER_ID, + producerEpoch, lastProducerEpoch, txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](partition1), 0L, 0L, TransactionVersion.TV_2) private val capturedErrorsCallback: ArgumentCaptor[Errors => Unit] = ArgumentCaptor.forClass(classOf[Errors => Unit]) private val time = new MockTime diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala index 1004915f46c..72ffa5629c0 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala @@ -23,6 +23,7 @@ import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.protocol.{ApiKeys, Errors} import org.apache.kafka.common.record.RecordBatch import org.apache.kafka.common.requests.{RequestHeader, TransactionResult, WriteTxnMarkersRequest, WriteTxnMarkersResponse} +import org.apache.kafka.server.common.TransactionVersion import org.junit.jupiter.api.Assertions._ import org.junit.jupiter.api.Test import org.mockito.ArgumentMatchers @@ -42,8 +43,8 @@ class TransactionMarkerRequestCompletionHandlerTest { private val coordinatorEpoch = 0 private val txnResult = TransactionResult.COMMIT private val topicPartition = new TopicPartition("topic1", 0) - private val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, lastProducerEpoch, - txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](topicPartition), 0L, 0L) + private val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, + producerEpoch, lastProducerEpoch, txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](topicPartition), 0L, 0L, TransactionVersion.TV_2) private val pendingCompleteTxnAndMarkers = asList( PendingCompleteTxnAndMarkerEntry( PendingCompleteTxn(transactionalId, coordinatorEpoch, txnMetadata, txnMetadata.prepareComplete(42)), diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMetadataTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMetadataTest.scala index 5d40c90a4b3..4c56e639d34 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMetadataTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMetadataTest.scala @@ -19,9 +19,13 @@ package kafka.coordinator.transaction import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.protocol.Errors import org.apache.kafka.common.record.RecordBatch +import org.apache.kafka.server.common.TransactionVersion +import org.apache.kafka.server.common.TransactionVersion.{TV_0, TV_2} import org.apache.kafka.server.util.MockTime import org.junit.jupiter.api.Assertions._ import org.junit.jupiter.api.Test +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ValueSource import scala.collection.mutable @@ -38,13 +42,15 @@ class TransactionMetadataTest { val txnMetadata = new TransactionMetadata( transactionalId = transactionalId, producerId = producerId, - lastProducerId = RecordBatch.NO_PRODUCER_ID, + previousProducerId = RecordBatch.NO_PRODUCER_ID, + nextProducerId = RecordBatch.NO_PRODUCER_ID, producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, state = Empty, topicPartitions = mutable.Set.empty, - txnLastUpdateTimestamp = time.milliseconds()) + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = TV_0) val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, None) txnMetadata.completeTransitionTo(transitMetadata) @@ -60,13 +66,15 @@ class TransactionMetadataTest { val txnMetadata = new TransactionMetadata( transactionalId = transactionalId, producerId = producerId, - lastProducerId = RecordBatch.NO_PRODUCER_ID, + previousProducerId = RecordBatch.NO_PRODUCER_ID, + nextProducerId = RecordBatch.NO_PRODUCER_ID, producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, state = Empty, topicPartitions = mutable.Set.empty, - txnLastUpdateTimestamp = time.milliseconds()) + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = TV_0) val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, None) txnMetadata.completeTransitionTo(transitMetadata) @@ -82,13 +90,15 @@ class TransactionMetadataTest { val txnMetadata = new TransactionMetadata( transactionalId = transactionalId, producerId = producerId, - lastProducerId = RecordBatch.NO_PRODUCER_ID, + previousProducerId = RecordBatch.NO_PRODUCER_ID, + nextProducerId = RecordBatch.NO_PRODUCER_ID, producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, state = Empty, topicPartitions = mutable.Set.empty, - txnLastUpdateTimestamp = time.milliseconds()) + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = TV_0) assertTrue(txnMetadata.isProducerEpochExhausted) assertThrows(classOf[IllegalStateException], () => txnMetadata.prepareIncrementProducerEpoch(30000, @@ -101,14 +111,16 @@ class TransactionMetadataTest { val txnMetadata = new TransactionMetadata( transactionalId = transactionalId, producerId = producerId, - lastProducerId = RecordBatch.NO_PRODUCER_ID, + previousProducerId = RecordBatch.NO_PRODUCER_ID, + nextProducerId = RecordBatch.NO_PRODUCER_ID, producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, state = Empty, topicPartitions = mutable.Set.empty, txnStartTimestamp = 1L, - txnLastUpdateTimestamp = time.milliseconds()) + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = TV_0) // let new time be smaller val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, Option(producerEpoch), @@ -127,14 +139,16 @@ class TransactionMetadataTest { val txnMetadata = new TransactionMetadata( transactionalId = transactionalId, producerId = producerId, - lastProducerId = RecordBatch.NO_PRODUCER_ID, + previousProducerId = RecordBatch.NO_PRODUCER_ID, + nextProducerId = RecordBatch.NO_PRODUCER_ID, producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, state = Empty, topicPartitions = mutable.Set.empty, txnStartTimestamp = 1L, - txnLastUpdateTimestamp = time.milliseconds()) + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = TV_0) // let new time be smaller val transitMetadata = txnMetadata.prepareProducerIdRotation(producerId + 1, 30000, time.milliseconds() - 1, recordLastEpoch = true) @@ -152,14 +166,16 @@ class TransactionMetadataTest { val txnMetadata = new TransactionMetadata( transactionalId = transactionalId, producerId = producerId, - lastProducerId = RecordBatch.NO_PRODUCER_ID, + previousProducerId = RecordBatch.NO_PRODUCER_ID, + nextProducerId = RecordBatch.NO_PRODUCER_ID, producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, state = Empty, topicPartitions = mutable.Set.empty, txnStartTimestamp = time.milliseconds(), - txnLastUpdateTimestamp = time.milliseconds()) + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = TV_0) // let new time be smaller; when transiting from Empty the start time would be updated to the update-time var transitMetadata = txnMetadata.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic1", 0)), time.milliseconds() - 1) @@ -188,17 +204,19 @@ class TransactionMetadataTest { val txnMetadata = new TransactionMetadata( transactionalId = transactionalId, producerId = producerId, - lastProducerId = RecordBatch.NO_PRODUCER_ID, + previousProducerId = RecordBatch.NO_PRODUCER_ID, + nextProducerId = RecordBatch.NO_PRODUCER_ID, producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, state = Ongoing, topicPartitions = mutable.Set.empty, txnStartTimestamp = 1L, - txnLastUpdateTimestamp = time.milliseconds()) + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = TV_0) // let new time be smaller - val transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareCommit, time.milliseconds() - 1) + val transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareCommit, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds() - 1) txnMetadata.completeTransitionTo(transitMetadata) assertEquals(PrepareCommit, txnMetadata.state) assertEquals(producerId, txnMetadata.producerId) @@ -214,17 +232,19 @@ class TransactionMetadataTest { val txnMetadata = new TransactionMetadata( transactionalId = transactionalId, producerId = producerId, - lastProducerId = RecordBatch.NO_PRODUCER_ID, + previousProducerId = RecordBatch.NO_PRODUCER_ID, + nextProducerId = RecordBatch.NO_PRODUCER_ID, producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, state = Ongoing, topicPartitions = mutable.Set.empty, txnStartTimestamp = 1L, - txnLastUpdateTimestamp = time.milliseconds()) + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = TV_0) // let new time be smaller - val transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareAbort, time.milliseconds() - 1) + val transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareAbort, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds() - 1) txnMetadata.completeTransitionTo(transitMetadata) assertEquals(PrepareAbort, txnMetadata.state) assertEquals(producerId, txnMetadata.producerId) @@ -234,53 +254,65 @@ class TransactionMetadataTest { assertEquals(time.milliseconds() - 1, txnMetadata.txnLastUpdateTimestamp) } - @Test - def testTolerateTimeShiftDuringCompleteCommit(): Unit = { + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def testTolerateTimeShiftDuringCompleteCommit(transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) val producerEpoch: Short = 1 val txnMetadata = new TransactionMetadata( transactionalId = transactionalId, producerId = producerId, - lastProducerId = RecordBatch.NO_PRODUCER_ID, + previousProducerId = RecordBatch.NO_PRODUCER_ID, + nextProducerId = RecordBatch.NO_PRODUCER_ID, producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, state = PrepareCommit, topicPartitions = mutable.Set.empty, txnStartTimestamp = 1L, - txnLastUpdateTimestamp = time.milliseconds()) + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = clientTransactionVersion) // let new time be smaller val transitMetadata = txnMetadata.prepareComplete(time.milliseconds() - 1) txnMetadata.completeTransitionTo(transitMetadata) + + val lastEpoch = if (clientTransactionVersion.supportsEpochBump()) producerEpoch else RecordBatch.NO_PRODUCER_EPOCH assertEquals(CompleteCommit, txnMetadata.state) assertEquals(producerId, txnMetadata.producerId) - assertEquals(RecordBatch.NO_PRODUCER_EPOCH, txnMetadata.lastProducerEpoch) + assertEquals(lastEpoch, txnMetadata.lastProducerEpoch) assertEquals(producerEpoch, txnMetadata.producerEpoch) assertEquals(1L, txnMetadata.txnStartTimestamp) assertEquals(time.milliseconds() - 1, txnMetadata.txnLastUpdateTimestamp) } - @Test - def testTolerateTimeShiftDuringCompleteAbort(): Unit = { + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def testTolerateTimeShiftDuringCompleteAbort(transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) val producerEpoch: Short = 1 val txnMetadata = new TransactionMetadata( transactionalId = transactionalId, producerId = producerId, - lastProducerId = RecordBatch.NO_PRODUCER_ID, + previousProducerId = RecordBatch.NO_PRODUCER_ID, + nextProducerId = RecordBatch.NO_PRODUCER_ID, producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, state = PrepareAbort, topicPartitions = mutable.Set.empty, txnStartTimestamp = 1L, - txnLastUpdateTimestamp = time.milliseconds()) + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = clientTransactionVersion) // let new time be smaller val transitMetadata = txnMetadata.prepareComplete(time.milliseconds() - 1) txnMetadata.completeTransitionTo(transitMetadata) + + val lastEpoch = if (clientTransactionVersion.supportsEpochBump()) producerEpoch else RecordBatch.NO_PRODUCER_EPOCH assertEquals(CompleteAbort, txnMetadata.state) assertEquals(producerId, txnMetadata.producerId) - assertEquals(RecordBatch.NO_PRODUCER_EPOCH, txnMetadata.lastProducerEpoch) + assertEquals(lastEpoch, txnMetadata.lastProducerEpoch) assertEquals(producerEpoch, txnMetadata.producerEpoch) assertEquals(1L, txnMetadata.txnStartTimestamp) assertEquals(time.milliseconds() - 1, txnMetadata.txnLastUpdateTimestamp) @@ -293,13 +325,15 @@ class TransactionMetadataTest { val txnMetadata = new TransactionMetadata( transactionalId = transactionalId, producerId = producerId, - lastProducerId = RecordBatch.NO_PRODUCER_ID, + previousProducerId = RecordBatch.NO_PRODUCER_ID, + nextProducerId = RecordBatch.NO_PRODUCER_ID, producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, state = Ongoing, topicPartitions = mutable.Set.empty, - txnLastUpdateTimestamp = time.milliseconds()) + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = TV_0) assertTrue(txnMetadata.isProducerEpochExhausted) val fencingTransitMetadata = txnMetadata.prepareFenceProducerEpoch() @@ -310,7 +344,7 @@ class TransactionMetadataTest { // We should reset the pending state to make way for the abort transition. txnMetadata.pendingState = None - val transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareAbort, time.milliseconds()) + val transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareAbort, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds()) txnMetadata.completeTransitionTo(transitMetadata) assertEquals(producerId, transitMetadata.producerId) } @@ -322,13 +356,15 @@ class TransactionMetadataTest { val txnMetadata = new TransactionMetadata( transactionalId = transactionalId, producerId = producerId, - lastProducerId = RecordBatch.NO_PRODUCER_ID, + previousProducerId = RecordBatch.NO_PRODUCER_ID, + nextProducerId = RecordBatch.NO_PRODUCER_ID, producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, state = Ongoing, topicPartitions = mutable.Set.empty, - txnLastUpdateTimestamp = time.milliseconds()) + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = TV_0) assertTrue(txnMetadata.isProducerEpochExhausted) assertThrows(classOf[IllegalStateException], () => txnMetadata.prepareFenceProducerEpoch()) } @@ -340,36 +376,108 @@ class TransactionMetadataTest { val txnMetadata = new TransactionMetadata( transactionalId = transactionalId, producerId = producerId, - lastProducerId = RecordBatch.NO_PRODUCER_ID, + previousProducerId = RecordBatch.NO_PRODUCER_ID, + nextProducerId = RecordBatch.NO_PRODUCER_ID, producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, state = Empty, topicPartitions = mutable.Set.empty, - txnLastUpdateTimestamp = time.milliseconds()) + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = TV_0) val newProducerId = 9893L val transitMetadata = txnMetadata.prepareProducerIdRotation(newProducerId, 30000, time.milliseconds(), recordLastEpoch = true) txnMetadata.completeTransitionTo(transitMetadata) assertEquals(newProducerId, txnMetadata.producerId) - assertEquals(producerId, txnMetadata.lastProducerId) + assertEquals(producerId, txnMetadata.previousProducerId) assertEquals(0, txnMetadata.producerEpoch) assertEquals(producerEpoch, txnMetadata.lastProducerEpoch) } + @Test + def testEpochBumpOnEndTxn(): Unit = { + time.sleep(100) + val producerEpoch = 10.toShort + + val txnMetadata = new TransactionMetadata( + transactionalId = transactionalId, + producerId = producerId, + previousProducerId = RecordBatch.NO_PRODUCER_ID, + nextProducerId = RecordBatch.NO_PRODUCER_ID, + producerEpoch = producerEpoch, + lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, + txnTimeoutMs = 30000, + state = Ongoing, + topicPartitions = mutable.Set.empty, + txnStartTimestamp = time.milliseconds(), + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = TV_0) + + var transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareCommit, TV_2, RecordBatch.NO_PRODUCER_ID, time.milliseconds() - 1) + txnMetadata.completeTransitionTo(transitMetadata) + assertEquals(producerId, txnMetadata.producerId) + assertEquals((producerEpoch + 1).toShort, txnMetadata.producerEpoch) + assertEquals(TV_2, txnMetadata.clientTransactionVersion) + + transitMetadata = txnMetadata.prepareComplete(time.milliseconds()) + txnMetadata.completeTransitionTo(transitMetadata) + assertEquals(producerId, txnMetadata.producerId) + assertEquals((producerEpoch + 1).toShort, txnMetadata.producerEpoch) + assertEquals(TV_2, txnMetadata.clientTransactionVersion) + } + + @Test + def testEpochBumpOnEndTxnOverflow(): Unit = { + time.sleep(100) + val producerEpoch = (Short.MaxValue - 1).toShort + + val txnMetadata = new TransactionMetadata( + transactionalId = transactionalId, + producerId = producerId, + previousProducerId = RecordBatch.NO_PRODUCER_ID, + nextProducerId = RecordBatch.NO_PRODUCER_ID, + producerEpoch = producerEpoch, + lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, + txnTimeoutMs = 30000, + state = Ongoing, + topicPartitions = mutable.Set.empty, + txnStartTimestamp = time.milliseconds(), + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = TV_0) + assertTrue(txnMetadata.isProducerEpochExhausted) + + val newProducerId = 9893L + var transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareCommit, TV_2, newProducerId, time.milliseconds() - 1) + txnMetadata.completeTransitionTo(transitMetadata) + assertEquals(producerId, txnMetadata.producerId) + assertEquals(Short.MaxValue, txnMetadata.producerEpoch) + assertEquals(TV_2, txnMetadata.clientTransactionVersion) + + transitMetadata = txnMetadata.prepareComplete(time.milliseconds()) + txnMetadata.completeTransitionTo(transitMetadata) + assertEquals(newProducerId, txnMetadata.producerId) + assertEquals(0, txnMetadata.producerEpoch) + assertEquals(TV_2, txnMetadata.clientTransactionVersion) + } + @Test def testRotateProducerIdInOngoingState(): Unit = { - assertThrows(classOf[IllegalStateException], () => testRotateProducerIdInOngoingState(Ongoing)) + assertThrows(classOf[IllegalStateException], () => testRotateProducerIdInOngoingState(Ongoing, TV_0)) } - @Test - def testRotateProducerIdInPrepareAbortState(): Unit = { - assertThrows(classOf[IllegalStateException], () => testRotateProducerIdInOngoingState(PrepareAbort)) + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def testRotateProducerIdInPrepareAbortState(transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) + assertThrows(classOf[IllegalStateException], () => testRotateProducerIdInOngoingState(PrepareAbort, clientTransactionVersion)) } - @Test - def testRotateProducerIdInPrepareCommitState(): Unit = { - assertThrows(classOf[IllegalStateException], () => testRotateProducerIdInOngoingState(PrepareCommit)) + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def testRotateProducerIdInPrepareCommitState(transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) + assertThrows(classOf[IllegalStateException], () => testRotateProducerIdInOngoingState(PrepareCommit, clientTransactionVersion)) } @Test @@ -379,13 +487,15 @@ class TransactionMetadataTest { val txnMetadata = new TransactionMetadata( transactionalId = transactionalId, producerId = producerId, - lastProducerId = RecordBatch.NO_PRODUCER_ID, + previousProducerId = RecordBatch.NO_PRODUCER_ID, + nextProducerId = RecordBatch.NO_PRODUCER_ID, producerEpoch = RecordBatch.NO_PRODUCER_EPOCH, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, state = Empty, topicPartitions = mutable.Set.empty, - txnLastUpdateTimestamp = time.milliseconds()) + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = TV_0) val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, Some(producerEpoch)) txnMetadata.completeTransitionTo(transitMetadata) @@ -401,13 +511,15 @@ class TransactionMetadataTest { val txnMetadata = new TransactionMetadata( transactionalId = transactionalId, producerId = producerId, - lastProducerId = RecordBatch.NO_PRODUCER_ID, + previousProducerId = RecordBatch.NO_PRODUCER_ID, + nextProducerId = RecordBatch.NO_PRODUCER_ID, producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, state = Empty, topicPartitions = mutable.Set.empty, - txnLastUpdateTimestamp = time.milliseconds()) + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = TV_0) val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, Some(producerEpoch)) txnMetadata.completeTransitionTo(transitMetadata) @@ -424,13 +536,15 @@ class TransactionMetadataTest { val txnMetadata = new TransactionMetadata( transactionalId = transactionalId, producerId = producerId, - lastProducerId = RecordBatch.NO_PRODUCER_ID, + previousProducerId = RecordBatch.NO_PRODUCER_ID, + nextProducerId = RecordBatch.NO_PRODUCER_ID, producerEpoch = producerEpoch, lastProducerEpoch = lastProducerEpoch, txnTimeoutMs = 30000, state = Empty, topicPartitions = mutable.Set.empty, - txnLastUpdateTimestamp = time.milliseconds()) + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = TV_0) val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, Some(lastProducerEpoch)) txnMetadata.completeTransitionTo(transitMetadata) @@ -447,13 +561,15 @@ class TransactionMetadataTest { val txnMetadata = new TransactionMetadata( transactionalId = transactionalId, producerId = producerId, - lastProducerId = producerId, + previousProducerId = producerId, + nextProducerId = RecordBatch.NO_PRODUCER_ID, producerEpoch = producerEpoch, lastProducerEpoch = lastProducerEpoch, txnTimeoutMs = 30000, state = Empty, topicPartitions = mutable.Set.empty, - txnLastUpdateTimestamp = time.milliseconds()) + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = TV_0) val result = txnMetadata.prepareIncrementProducerEpoch(30000, Some((lastProducerEpoch - 1).toShort), time.milliseconds()) @@ -503,19 +619,21 @@ class TransactionMetadataTest { assertEquals(Set.empty, unmatchedStates) } - private def testRotateProducerIdInOngoingState(state: TransactionState): Unit = { + private def testRotateProducerIdInOngoingState(state: TransactionState, clientTransactionVersion: TransactionVersion): Unit = { val producerEpoch = (Short.MaxValue - 1).toShort val txnMetadata = new TransactionMetadata( transactionalId = transactionalId, producerId = producerId, - lastProducerId = producerId, + previousProducerId = producerId, + nextProducerId = RecordBatch.NO_PRODUCER_ID, producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, state = state, topicPartitions = mutable.Set.empty, - txnLastUpdateTimestamp = time.milliseconds()) + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = clientTransactionVersion) val newProducerId = 9893L txnMetadata.prepareProducerIdRotation(newProducerId, 30000, time.milliseconds(), recordLastEpoch = false) } diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala index acaffea536f..a9783684c13 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala @@ -35,6 +35,7 @@ import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse import org.apache.kafka.common.requests.TransactionResult import org.apache.kafka.common.utils.MockTime import org.apache.kafka.server.common.{FinalizedFeatures, MetadataVersion, RequestLocal, TransactionVersion} +import org.apache.kafka.server.common.TransactionVersion.{TV_0, TV_2} import org.apache.kafka.coordinator.transaction.generated.TransactionLogKey import org.apache.kafka.server.util.MockScheduler import org.apache.kafka.storage.internals.log.{AppendOrigin, FetchDataInfo, FetchIsolation, LogConfig, LogOffsetMetadata} @@ -181,7 +182,7 @@ class TransactionStateManagerTest { new TopicPartition("topic1", 0), new TopicPartition("topic1", 1))) val records = MemoryRecords.withRecords(startOffset, Compression.NONE, - new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), true))) + new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2))) // We create a latch which is awaited while the log is loading. This ensures that the deletion // is triggered before the loading returns @@ -225,19 +226,19 @@ class TransactionStateManagerTest { txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0), new TopicPartition("topic1", 1))) - txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), true)) + txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2)) // pid1's transaction adds three more partitions txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic2", 0), new TopicPartition("topic2", 1), new TopicPartition("topic2", 2))) - txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), true)) + txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2)) // pid1's transaction is preparing to commit txnMetadata1.state = PrepareCommit - txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), true)) + txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2)) // pid2's transaction started with three partitions txnMetadata2.state = Ongoing @@ -245,23 +246,23 @@ class TransactionStateManagerTest { new TopicPartition("topic3", 1), new TopicPartition("topic3", 2))) - txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), true)) + txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), TV_2)) // pid2's transaction is preparing to abort txnMetadata2.state = PrepareAbort - txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), true)) + txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), TV_2)) // pid2's transaction has aborted txnMetadata2.state = CompleteAbort - txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), true)) + txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), TV_2)) // pid2's epoch has advanced, with no ongoing transaction yet txnMetadata2.state = Empty txnMetadata2.topicPartitions.clear() - txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), true)) + txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), TV_2)) val startOffset = 15L // it should work for any start offset val records = MemoryRecords.withRecords(startOffset, Compression.NONE, txnRecords.toArray: _*) @@ -796,14 +797,9 @@ class TransactionStateManagerTest { // write the change. If the write fails (e.g. under min isr), the TransactionMetadata // is left at it is. If the transactional id is never reused, the TransactionMetadata // will be expired and it should succeed. - val txnMetadata = TransactionMetadata( - transactionalId = transactionalId, - producerId = 1, - producerEpoch = RecordBatch.NO_PRODUCER_EPOCH, - txnTimeoutMs = transactionTimeoutMs, - state = Empty, - timestamp = time.milliseconds() - ) + val timestamp = time.milliseconds() + val txnMetadata = new TransactionMetadata(transactionalId, 1, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, + RecordBatch.NO_PRODUCER_EPOCH, transactionTimeoutMs, Empty, collection.mutable.Set.empty[TopicPartition], timestamp, timestamp, TV_0) transactionManager.putTransactionStateIfNotExists(txnMetadata) time.sleep(txnConfig.transactionalIdExpirationMs + 1) @@ -890,7 +886,7 @@ class TransactionStateManagerTest { txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0), new TopicPartition("topic1", 1))) - txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), true)) + txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2)) val startOffset = 0L val records = MemoryRecords.withRecords(startOffset, Compression.NONE, txnRecords.toArray: _*) @@ -1053,7 +1049,7 @@ class TransactionStateManagerTest { txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0), new TopicPartition("topic1", 1))) - txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), true)) + txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2)) val startOffset = 0L val records = MemoryRecords.withRecords(startOffset, Compression.NONE, txnRecords.toArray: _*) @@ -1081,7 +1077,9 @@ class TransactionStateManagerTest { producerId: Long, state: TransactionState = Empty, txnTimeout: Int = transactionTimeoutMs): TransactionMetadata = { - TransactionMetadata(transactionalId, producerId, 0.toShort, txnTimeout, state, time.milliseconds()) + val timestamp = time.milliseconds() + new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_ID, 0.toShort, + RecordBatch.NO_PRODUCER_EPOCH, txnTimeout, state, collection.mutable.Set.empty[TopicPartition], timestamp, timestamp, TV_0) } private def prepareTxnLog(topicPartition: TopicPartition, @@ -1159,7 +1157,7 @@ class TransactionStateManagerTest { txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 1), new TopicPartition("topic1", 1))) - txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), true)) + txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2)) val startOffset = 15L val records = MemoryRecords.withRecords(startOffset, Compression.NONE, txnRecords.toArray: _*) @@ -1178,7 +1176,7 @@ class TransactionStateManagerTest { txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0), new TopicPartition("topic1", 1))) - txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), true)) + txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2)) val startOffset = 0L val unknownKey = new TransactionLogKey() @@ -1199,7 +1197,7 @@ class TransactionStateManagerTest { val txnMetadata = txnMetadataPool.get(transactionalId1) assertEquals(txnMetadata1.transactionalId, txnMetadata.transactionalId) assertEquals(txnMetadata1.producerId, txnMetadata.producerId) - assertEquals(txnMetadata1.lastProducerId, txnMetadata.lastProducerId) + assertEquals(txnMetadata1.previousProducerId, txnMetadata.previousProducerId) assertEquals(txnMetadata1.producerEpoch, txnMetadata.producerEpoch) assertEquals(txnMetadata1.lastProducerEpoch, txnMetadata.lastProducerEpoch) assertEquals(txnMetadata1.txnTimeoutMs, txnMetadata.txnTimeoutMs) @@ -1210,7 +1208,7 @@ class TransactionStateManagerTest { @ParameterizedTest @EnumSource(classOf[TransactionVersion]) - def testUsesFlexibleRecords(transactionVersion: TransactionVersion): Unit = { + def testTransactionVersionInTransactionManager(transactionVersion: TransactionVersion): Unit = { val metadataCache = mock(classOf[MetadataCache]) when(metadataCache.features()).thenReturn { new FinalizedFeatures( @@ -1223,7 +1221,6 @@ class TransactionStateManagerTest { val transactionManager = new TransactionStateManager(0, scheduler, replicaManager, metadataCache, txnConfig, time, metrics) - val expectFlexibleRecords = transactionVersion.featureLevel > 0 - assertEquals(expectFlexibleRecords, transactionManager.usesFlexibleRecords()) + assertEquals(transactionVersion, transactionManager.transactionVersionLevel()) } } diff --git a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala index 7972e9c5355..371bc0384a7 100644 --- a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala +++ b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala @@ -85,7 +85,7 @@ import org.apache.kafka.security.authorizer.AclEntry import org.apache.kafka.server.ClientMetricsManager import org.apache.kafka.server.authorizer.{Action, AuthorizationResult, Authorizer} import org.apache.kafka.server.common.MetadataVersion.{IBP_0_10_2_IV0, IBP_2_2_IV1} -import org.apache.kafka.server.common.{FeatureVersion, FinalizedFeatures, GroupVersion, KRaftVersion, MetadataVersion, RequestLocal} +import org.apache.kafka.server.common.{FeatureVersion, FinalizedFeatures, GroupVersion, KRaftVersion, MetadataVersion, RequestLocal, TransactionVersion} import org.apache.kafka.server.config.{ConfigType, KRaftConfigs, ReplicationConfigs, ServerConfigs, ServerLogConfigs, ShareGroupConfig} import org.apache.kafka.server.metrics.ClientMetricsTestUtils import org.apache.kafka.server.share.{CachedSharePartition, ErroneousAndValidPartitionData} @@ -2572,7 +2572,7 @@ class KafkaApisTest extends Logging { reset(replicaManager, clientRequestQuotaManager, requestChannel, txnCoordinator) val capturedResponse: ArgumentCaptor[EndTxnResponse] = ArgumentCaptor.forClass(classOf[EndTxnResponse]) - val responseCallback: ArgumentCaptor[Errors => Unit] = ArgumentCaptor.forClass(classOf[Errors => Unit]) + val responseCallback: ArgumentCaptor[(Errors, Long, Short) => Unit] = ArgumentCaptor.forClass(classOf[(Errors, Long, Short) => Unit]) val transactionalId = "txnId" val producerId = 15L @@ -2587,15 +2587,18 @@ class KafkaApisTest extends Logging { ).build(version.toShort) val request = buildRequest(endTxnRequest) + val clientTransactionVersion = if (version > 4) TransactionVersion.TV_2 else TransactionVersion.TV_0 + val requestLocal = RequestLocal.withThreadConfinedCaching when(txnCoordinator.handleEndTransaction( ArgumentMatchers.eq(transactionalId), ArgumentMatchers.eq(producerId), ArgumentMatchers.eq(epoch), ArgumentMatchers.eq(TransactionResult.COMMIT), + ArgumentMatchers.eq(clientTransactionVersion), responseCallback.capture(), ArgumentMatchers.eq(requestLocal) - )).thenAnswer(_ => responseCallback.getValue.apply(Errors.PRODUCER_FENCED)) + )).thenAnswer(_ => responseCallback.getValue.apply(Errors.PRODUCER_FENCED, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH)) val kafkaApis = createKafkaApis() try { kafkaApis.handleEndTxnRequest(request, requestLocal) diff --git a/server-common/src/main/java/org/apache/kafka/server/common/TransactionVersion.java b/server-common/src/main/java/org/apache/kafka/server/common/TransactionVersion.java index 64163d5a0b4..36dadb5cf11 100644 --- a/server-common/src/main/java/org/apache/kafka/server/common/TransactionVersion.java +++ b/server-common/src/main/java/org/apache/kafka/server/common/TransactionVersion.java @@ -49,6 +49,10 @@ public enum TransactionVersion implements FeatureVersion { return featureLevel; } + public static TransactionVersion fromFeatureLevel(short version) { + return (TransactionVersion) Features.TRANSACTION_VERSION.fromFeatureLevel(version, true); + } + @Override public String featureName() { return FEATURE_NAME; @@ -63,4 +67,14 @@ public enum TransactionVersion implements FeatureVersion { public Map dependencies() { return dependencies; } + + // Transactions V1 enables log version 0 (flexible fields) + public short transactionLogValueVersion() { + return (short) (featureLevel >= 1 ? 1 : 0); + } + + // Transactions V2 enables epoch bump on commit/abort. + public boolean supportsEpochBump() { + return featureLevel >= 2; + } } diff --git a/transaction-coordinator/src/main/resources/common/message/TransactionLogValue.json b/transaction-coordinator/src/main/resources/common/message/TransactionLogValue.json index c6efc772d58..c9a1cbf66e3 100644 --- a/transaction-coordinator/src/main/resources/common/message/TransactionLogValue.json +++ b/transaction-coordinator/src/main/resources/common/message/TransactionLogValue.json @@ -24,6 +24,10 @@ "fields": [ { "name": "ProducerId", "type": "int64", "versions": "0+", "about": "Producer id in use by the transactional id"}, + { "name": "PreviousProducerId", "type": "int64", "taggedVersions": "1+", "tag": 0, "default": -1, + "about": "Producer id used by the last committed transaction"}, + { "name": "NextProducerId", "type": "int64", "taggedVersions": "1+", "tag": 1, "default": -1, + "about": "Latest producer ID sent to the producer for the given transactional ID"}, { "name": "ProducerEpoch", "type": "int16", "versions": "0+", "about": "Epoch associated with the producer id"}, { "name": "TransactionTimeoutMs", "type": "int32", "versions": "0+", @@ -37,6 +41,8 @@ { "name": "TransactionLastUpdateTimestampMs", "type": "int64", "versions": "0+", "about": "Time the transaction was last updated"}, { "name": "TransactionStartTimestampMs", "type": "int64", "versions": "0+", - "about": "Time the transaction was started"} + "about": "Time the transaction was started"}, + { "name": "ClientTransactionVersion", "type": "int16", "default": 0, "taggedVersions": "1+", "tag": 2, + "about": "The transaction version used by the client"} ] }