diff --git a/clients/src/main/java/org/apache/kafka/common/requests/EndTxnRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/EndTxnRequest.java index c9ea98005fd..5cd346a8fe0 100644 --- a/clients/src/main/java/org/apache/kafka/common/requests/EndTxnRequest.java +++ b/clients/src/main/java/org/apache/kafka/common/requests/EndTxnRequest.java @@ -32,7 +32,11 @@ public class EndTxnRequest extends AbstractRequest { public final EndTxnRequestData data; public Builder(EndTxnRequestData data) { - super(ApiKeys.END_TXN); + this(data, false); + } + + public Builder(EndTxnRequestData data, boolean enableUnstableLastVersion) { + super(ApiKeys.END_TXN, enableUnstableLastVersion); this.data = data; } diff --git a/clients/src/main/resources/common/message/EndTxnRequest.json b/clients/src/main/resources/common/message/EndTxnRequest.json index bc66adcf50a..aea50d6b16d 100644 --- a/clients/src/main/resources/common/message/EndTxnRequest.json +++ b/clients/src/main/resources/common/message/EndTxnRequest.json @@ -25,7 +25,10 @@ // Version 3 enables flexible versions. // // Version 4 adds support for new error code TRANSACTION_ABORTABLE (KIP-890). - "validVersions": "0-4", + // + // Version 5 enables bumping epoch on every transaction (KIP-890 Part 2) + "latestVersionUnstable": true, + "validVersions": "0-5", "flexibleVersions": "3+", "fields": [ { "name": "TransactionalId", "type": "string", "versions": "0+", "entityType": "transactionalId", diff --git a/clients/src/main/resources/common/message/EndTxnResponse.json b/clients/src/main/resources/common/message/EndTxnResponse.json index 08ac6cddd38..fb958a74033 100644 --- a/clients/src/main/resources/common/message/EndTxnResponse.json +++ b/clients/src/main/resources/common/message/EndTxnResponse.json @@ -24,12 +24,18 @@ // Version 3 enables flexible versions. // // Version 4 adds support for new error code TRANSACTION_ABORTABLE (KIP-890). - "validVersions": "0-4", + // + // Version 5 enables bumping epoch on every transaction (KIP-890 Part 2), so producer ID and epoch are included in the response. + "validVersions": "0-5", "flexibleVersions": "3+", "fields": [ { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, { "name": "ErrorCode", "type": "int16", "versions": "0+", - "about": "The error code, or 0 if there was no error." } + "about": "The error code, or 0 if there was no error." }, + { "name": "ProducerId", "type": "int64", "versions": "5+", "entityType": "producerId", "default": "-1", "ignorable": "true", + "about": "The producer ID." }, + { "name": "ProducerEpoch", "type": "int16", "versions": "5+", "default": "-1", "ignorable": "true", + "about": "The current epoch associated with the producer." } ] } diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala index cd268f3a166..72a196eca60 100644 --- a/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala +++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala @@ -29,7 +29,7 @@ import org.apache.kafka.common.protocol.Errors import org.apache.kafka.common.record.RecordBatch import org.apache.kafka.common.requests.{AddPartitionsToTxnResponse, TransactionResult} import org.apache.kafka.common.utils.{LogContext, ProducerIdAndEpoch, Time} -import org.apache.kafka.server.common.RequestLocal +import org.apache.kafka.server.common.{RequestLocal, TransactionVersion} import org.apache.kafka.server.util.Scheduler import scala.jdk.CollectionConverters._ @@ -98,7 +98,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, private type InitProducerIdCallback = InitProducerIdResult => Unit private type AddPartitionsCallback = Errors => Unit private type VerifyPartitionsCallback = AddPartitionsToTxnResult => Unit - private type EndTxnCallback = Errors => Unit + private type EndTxnCallback = (Errors, Long, Short) => Unit private type ApiResult[T] = Either[Errors, T] /* Active flag of the coordinator */ @@ -135,13 +135,15 @@ class TransactionCoordinator(txnConfig: TransactionConfig, case Success(producerId) => val createdMetadata = new TransactionMetadata(transactionalId = transactionalId, producerId = producerId, - lastProducerId = RecordBatch.NO_PRODUCER_ID, + previousProducerId = RecordBatch.NO_PRODUCER_ID, + nextProducerId = RecordBatch.NO_PRODUCER_ID, producerEpoch = RecordBatch.NO_PRODUCER_EPOCH, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = transactionTimeoutMs, state = Empty, topicPartitions = collection.mutable.Set.empty[TopicPartition], - txnLastUpdateTimestamp = time.milliseconds()) + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = TransactionVersion.TV_0) txnManager.putTransactionStateIfNotExists(createdMetadata) case Failure(exception) => @@ -169,7 +171,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, case Right((coordinatorEpoch, newMetadata)) => if (newMetadata.txnState == PrepareEpochFence) { // abort the ongoing transaction and then return CONCURRENT_TRANSACTIONS to let client wait and retry - def sendRetriableErrorCallback(error: Errors): Unit = { + def sendRetriableErrorCallback(error: Errors, newProducerId: Long, newProducerEpoch: Short): Unit = { if (error != Errors.NONE) { responseCallback(initTransactionError(error)) } else { @@ -182,6 +184,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, newMetadata.producerEpoch, TransactionResult.ABORT, isFromClient = false, + clientTransactionVersion = txnManager.transactionVersionLevel(), // Since this is not from client, use server TV sendRetriableErrorCallback, requestLocal) } else { @@ -221,7 +224,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, // could be a retry after a valid epoch bump that the producer never received the response for txnMetadata.producerEpoch == RecordBatch.NO_PRODUCER_EPOCH || producerIdAndEpoch.producerId == txnMetadata.producerId || - (producerIdAndEpoch.producerId == txnMetadata.lastProducerId && TransactionMetadata.isEpochExhausted(producerIdAndEpoch.epoch)) + (producerIdAndEpoch.producerId == txnMetadata.previousProducerId && TransactionMetadata.isEpochExhausted(producerIdAndEpoch.epoch)) } if (txnMetadata.pendingTransitionInProgress) { @@ -487,6 +490,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, producerId: Long, producerEpoch: Short, txnMarkerResult: TransactionResult, + clientTransactionVersion: TransactionVersion, responseCallback: EndTxnCallback, requestLocal: RequestLocal = RequestLocal.noCaching): Unit = { endTransaction(transactionalId, @@ -494,6 +498,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, producerEpoch, txnMarkerResult, isFromClient = true, + clientTransactionVersion, responseCallback, requestLocal) } @@ -503,12 +508,15 @@ class TransactionCoordinator(txnConfig: TransactionConfig, producerEpoch: Short, txnMarkerResult: TransactionResult, isFromClient: Boolean, + clientTransactionVersion: TransactionVersion, responseCallback: EndTxnCallback, requestLocal: RequestLocal): Unit = { var isEpochFence = false if (transactionalId == null || transactionalId.isEmpty) - responseCallback(Errors.INVALID_REQUEST) + responseCallback(Errors.INVALID_REQUEST, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH) else { + var producerIdCopy = RecordBatch.NO_PRODUCER_ID + var producerEpochCopy = RecordBatch.NO_PRODUCER_EPOCH val preAppendResult: ApiResult[(Int, TxnTransitMetadata)] = txnManager.getTransactionState(transactionalId).flatMap { case None => Left(Errors.INVALID_PRODUCER_ID_MAPPING) @@ -518,10 +526,39 @@ class TransactionCoordinator(txnConfig: TransactionConfig, val coordinatorEpoch = epochAndTxnMetadata.coordinatorEpoch txnMetadata.inLock { - if (txnMetadata.producerId != producerId) + producerIdCopy = txnMetadata.producerId + producerEpochCopy = txnMetadata.producerEpoch + // PrepareEpochFence has slightly different epoch bumping logic so don't include it here. + val currentTxnMetadataIsAtLeastTransactionsV2 = !txnMetadata.pendingState.contains(PrepareEpochFence) && txnMetadata.clientTransactionVersion.supportsEpochBump() + // True if the client used TV_2 and retried a request that had overflowed the epoch, and a new producer ID is stored in the txnMetadata + val retryOnOverflow = currentTxnMetadataIsAtLeastTransactionsV2 && + txnMetadata.previousProducerId == producerId && producerEpoch == Short.MaxValue - 1 && txnMetadata.producerEpoch == 0 + // True if the client used TV_2 and retried an endTxn request, and the bumped producer epoch is stored in the txnMetadata. + val retryOnEpochBump = endTxnEpochBumped(txnMetadata, producerEpoch) + + val isValidEpoch = { + if (currentTxnMetadataIsAtLeastTransactionsV2) { + // With transactions V2, state + same epoch is not sufficient to determine if a retry transition is valid. If the epoch is the + // same it actually indicates the next endTransaction call. Instead, we want to check the epoch matches with the epoch in the retry conditions. + // Return producer fenced even in the cases where the epoch is higher and could indicate an invalid state transition. + // Use the following criteria to determine if a v2 retry is valid: + txnMetadata.state match { + case Ongoing | Empty | Dead | PrepareEpochFence => + producerEpoch == txnMetadata.producerEpoch + case PrepareCommit | PrepareAbort => + retryOnEpochBump + case CompleteCommit | CompleteAbort => + retryOnEpochBump || retryOnOverflow + } + } else { + // For transactions V1 strict equality is enforced on the client side requests, as they shouldn't bump the producer epoch without server knowledge. + (!isFromClient || producerEpoch == txnMetadata.producerEpoch) && producerEpoch >= txnMetadata.producerEpoch + } + } + + if (txnMetadata.producerId != producerId && !retryOnOverflow) Left(Errors.INVALID_PRODUCER_ID_MAPPING) - // Strict equality is enforced on the client side requests, as they shouldn't bump the producer epoch. - else if ((isFromClient && producerEpoch != txnMetadata.producerEpoch) || producerEpoch < txnMetadata.producerEpoch) + else if (!isValidEpoch) Left(Errors.PRODUCER_FENCED) else if (txnMetadata.pendingTransitionInProgress && txnMetadata.pendingState.get != PrepareEpochFence) Left(Errors.CONCURRENT_TRANSACTIONS) @@ -532,6 +569,19 @@ class TransactionCoordinator(txnConfig: TransactionConfig, else PrepareAbort + // Maybe allocate new producer ID if we are bumping epoch and epoch is exhausted + val nextProducerIdOrErrors = + if (clientTransactionVersion.supportsEpochBump() && !txnMetadata.pendingState.contains(PrepareEpochFence) && txnMetadata.isProducerEpochExhausted) { + producerIdManager.generateProducerId() match { + case Success(newProducerId) => + Right(newProducerId) + case Failure(exception) => + Left(Errors.forException(exception)) + } + } else { + Right(RecordBatch.NO_PRODUCER_ID) + } + if (nextState == PrepareAbort && txnMetadata.pendingState.contains(PrepareEpochFence)) { // We should clear the pending state to make way for the transition to PrepareAbort and also bump // the epoch in the transaction metadata we are about to append. @@ -541,7 +591,10 @@ class TransactionCoordinator(txnConfig: TransactionConfig, txnMetadata.lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH } - Right(coordinatorEpoch, txnMetadata.prepareAbortOrCommit(nextState, time.milliseconds())) + nextProducerIdOrErrors.flatMap { + nextProducerId => + Right(coordinatorEpoch, txnMetadata.prepareAbortOrCommit(nextState, clientTransactionVersion, nextProducerId, time.milliseconds())) + } case CompleteCommit => if (txnMarkerResult == TransactionResult.COMMIT) Left(Errors.NONE) @@ -576,8 +629,12 @@ class TransactionCoordinator(txnConfig: TransactionConfig, preAppendResult match { case Left(err) => + if (err == Errors.NONE) { + responseCallback(err, producerIdCopy, producerEpochCopy) + } else { debug(s"Aborting append of $txnMarkerResult to transaction log with coordinator and returning $err error to client for $transactionalId's EndTransaction request") - responseCallback(err) + responseCallback(err, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH) + } case Right((coordinatorEpoch, newMetadata)) => def sendTxnMarkersCallback(error: Errors): Unit = { @@ -595,7 +652,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, txnMetadata.inLock { if (txnMetadata.producerId != producerId) Left(Errors.INVALID_PRODUCER_ID_MAPPING) - else if (txnMetadata.producerEpoch != producerEpoch) + else if (txnMetadata.producerEpoch != producerEpoch && !endTxnEpochBumped(txnMetadata, producerEpoch)) Left(Errors.PRODUCER_FENCED) else if (txnMetadata.pendingTransitionInProgress) Left(Errors.CONCURRENT_TRANSACTIONS) @@ -630,12 +687,12 @@ class TransactionCoordinator(txnConfig: TransactionConfig, preSendResult match { case Left(err) => info(s"Aborting sending of transaction markers after appended $txnMarkerResult to transaction log and returning $err error to client for $transactionalId's EndTransaction request") - responseCallback(err) + responseCallback(err, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH) case Right((txnMetadata, newPreSendMetadata)) => // we can respond to the client immediately and continue to write the txn markers if // the log append was successful - responseCallback(Errors.NONE) + responseCallback(Errors.NONE, txnMetadata.producerId, txnMetadata.producerEpoch) txnMarkerChannelManager.addTxnMarkersToSend(coordinatorEpoch, txnMarkerResult, txnMetadata, newPreSendMetadata) } @@ -659,7 +716,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, } } - responseCallback(error) + responseCallback(error, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH) } } @@ -669,11 +726,19 @@ class TransactionCoordinator(txnConfig: TransactionConfig, } } + // When a client and server support V2, every endTransaction call bumps the producer epoch. When checking epoch, we want to + // check epoch + 1. Epoch bumps from PrepareEpochFence state are handled separately, so this method should not be used to check that case. + // Returns true if the transaction state epoch is the specified producer epoch + 1 and epoch bump on every transaction is expected. + private def endTxnEpochBumped(txnMetadata: TransactionMetadata, producerEpoch: Short): Boolean = { + !txnMetadata.pendingState.contains(PrepareEpochFence) && txnMetadata.clientTransactionVersion.supportsEpochBump() && + txnMetadata.producerEpoch == producerEpoch + 1 + } + def transactionTopicConfigs: Properties = txnManager.transactionTopicConfigs def partitionFor(transactionalId: String): Int = txnManager.partitionFor(transactionalId) - private def onEndTransactionComplete(txnIdAndPidEpoch: TransactionalIdAndProducerIdEpoch)(error: Errors): Unit = { + private def onEndTransactionComplete(txnIdAndPidEpoch: TransactionalIdAndProducerIdEpoch)(error: Errors, newProducerId: Long, newProducerEpoch: Short): Unit = { error match { case Errors.NONE => info("Completed rollback of ongoing transaction for transactionalId " + @@ -721,6 +786,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, txnTransitMetadata.producerEpoch, TransactionResult.ABORT, isFromClient = false, + clientTransactionVersion = txnManager.transactionVersionLevel(), // Since this is not from client, use server TV onComplete(txnIdAndPidEpoch), RequestLocal.noCaching) } diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala index 2d4465d9a2c..623f88c7a38 100644 --- a/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala +++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala @@ -25,6 +25,7 @@ import org.apache.kafka.common.protocol.{ByteBufferAccessor, MessageUtil} import org.apache.kafka.common.record.{Record, RecordBatch} import org.apache.kafka.common.{MessageFormatter, TopicPartition} import org.apache.kafka.coordinator.transaction.generated.{TransactionLogKey, TransactionLogValue} +import org.apache.kafka.server.common.TransactionVersion import scala.collection.mutable import scala.jdk.CollectionConverters._ @@ -63,7 +64,7 @@ object TransactionLog { * @return value payload bytes */ private[transaction] def valueToBytes(txnMetadata: TxnTransitMetadata, - usesFlexibleRecords: Boolean): Array[Byte] = { + transactionVersionLevel: TransactionVersion): Array[Byte] = { if (txnMetadata.txnState == Empty && txnMetadata.topicPartitions.nonEmpty) throw new IllegalStateException(s"Transaction is not expected to have any partitions since its state is ${txnMetadata.txnState}: $txnMetadata") @@ -78,9 +79,7 @@ object TransactionLog { // Serialize with version 0 (highest non-flexible version) until transaction.version 1 is enabled // which enables flexible fields in records. - val version: Short = - if (usesFlexibleRecords) 1 else 0 - MessageUtil.toVersionPrefixedBytes(version, + MessageUtil.toVersionPrefixedBytes(transactionVersionLevel.transactionLogValueVersion(), new TransactionLogValue() .setProducerId(txnMetadata.producerId) .setProducerEpoch(txnMetadata.producerEpoch) @@ -88,7 +87,8 @@ object TransactionLog { .setTransactionStatus(txnMetadata.txnState.id) .setTransactionLastUpdateTimestampMs(txnMetadata.txnLastUpdateTimestamp) .setTransactionStartTimestampMs(txnMetadata.txnStartTimestamp) - .setTransactionPartitions(transactionPartitions)) + .setTransactionPartitions(transactionPartitions) + .setClientTransactionVersion(txnMetadata.clientTransactionVersion.featureLevel())) } /** @@ -124,14 +124,16 @@ object TransactionLog { val transactionMetadata = new TransactionMetadata( transactionalId = transactionalId, producerId = value.producerId, - lastProducerId = RecordBatch.NO_PRODUCER_ID, + previousProducerId = value.previousProducerId, + nextProducerId = value.nextProducerId, producerEpoch = value.producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = value.transactionTimeoutMs, state = TransactionState.fromId(value.transactionStatus), topicPartitions = mutable.Set.empty[TopicPartition], txnStartTimestamp = value.transactionStartTimestampMs, - txnLastUpdateTimestamp = value.transactionLastUpdateTimestampMs) + txnLastUpdateTimestamp = value.transactionLastUpdateTimestampMs, + clientTransactionVersion = TransactionVersion.fromFeatureLevel(value.clientTransactionVersion)) if (!transactionMetadata.state.equals(Empty)) value.transactionPartitions.forEach(partitionsSchema => diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala index 9fab77f30b6..ec3fb66e355 100644 --- a/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala +++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala @@ -17,11 +17,11 @@ package kafka.coordinator.transaction import java.util.concurrent.locks.ReentrantLock - import kafka.utils.{CoreUtils, Logging, nonthreadsafe} import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.protocol.Errors import org.apache.kafka.common.record.RecordBatch +import org.apache.kafka.server.common.TransactionVersion import scala.collection.{immutable, mutable} @@ -163,70 +163,64 @@ private[transaction] case object PrepareEpochFence extends TransactionState { } private[transaction] object TransactionMetadata { - def apply(transactionalId: String, producerId: Long, producerEpoch: Short, txnTimeoutMs: Int, timestamp: Long) = - new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, producerEpoch, - RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Empty, collection.mutable.Set.empty[TopicPartition], timestamp, timestamp) - - def apply(transactionalId: String, producerId: Long, producerEpoch: Short, txnTimeoutMs: Int, - state: TransactionState, timestamp: Long) = - new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, producerEpoch, - RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, state, collection.mutable.Set.empty[TopicPartition], timestamp, timestamp) - - def apply(transactionalId: String, producerId: Long, lastProducerId: Long, producerEpoch: Short, - lastProducerEpoch: Short, txnTimeoutMs: Int, state: TransactionState, timestamp: Long) = - new TransactionMetadata(transactionalId, producerId, lastProducerId, producerEpoch, lastProducerEpoch, - txnTimeoutMs, state, collection.mutable.Set.empty[TopicPartition], timestamp, timestamp) - def isEpochExhausted(producerEpoch: Short): Boolean = producerEpoch >= Short.MaxValue - 1 } // this is a immutable object representing the target transition of the transaction metadata private[transaction] case class TxnTransitMetadata(producerId: Long, - lastProducerId: Long, + prevProducerId: Long, + nextProducerId: Long, producerEpoch: Short, lastProducerEpoch: Short, txnTimeoutMs: Int, txnState: TransactionState, topicPartitions: immutable.Set[TopicPartition], txnStartTimestamp: Long, - txnLastUpdateTimestamp: Long) { + txnLastUpdateTimestamp: Long, + clientTransactionVersion: TransactionVersion) { override def toString: String = { "TxnTransitMetadata(" + s"producerId=$producerId, " + - s"lastProducerId=$lastProducerId, " + + s"previousProducerId=$prevProducerId, " + + s"nextProducerId=$nextProducerId, " + s"producerEpoch=$producerEpoch, " + s"lastProducerEpoch=$lastProducerEpoch, " + s"txnTimeoutMs=$txnTimeoutMs, " + s"txnState=$txnState, " + s"topicPartitions=$topicPartitions, " + s"txnStartTimestamp=$txnStartTimestamp, " + - s"txnLastUpdateTimestamp=$txnLastUpdateTimestamp)" + s"txnLastUpdateTimestamp=$txnLastUpdateTimestamp, " + + s"clientTransactionVersion=$clientTransactionVersion)" } } /** * - * @param producerId producer id - * @param lastProducerId last producer id assigned to the producer - * @param producerEpoch current epoch of the producer - * @param lastProducerEpoch last epoch of the producer - * @param txnTimeoutMs timeout to be used to abort long running transactions - * @param state current state of the transaction - * @param topicPartitions current set of partitions that are part of this transaction - * @param txnStartTimestamp time the transaction was started, i.e., when first partition is added - * @param txnLastUpdateTimestamp updated when any operation updates the TransactionMetadata. To be used for expiration + * @param producerId producer id + * @param previousProducerId producer id for the last committed transaction with this transactional ID + * @param nextProducerId Latest producer ID sent to the producer for the given transactional ID + * @param producerEpoch current epoch of the producer + * @param lastProducerEpoch last epoch of the producer + * @param txnTimeoutMs timeout to be used to abort long running transactions + * @param state current state of the transaction + * @param topicPartitions current set of partitions that are part of this transaction + * @param txnStartTimestamp time the transaction was started, i.e., when first partition is added + * @param txnLastUpdateTimestamp updated when any operation updates the TransactionMetadata. To be used for expiration + * @param clientTransactionVersion TransactionVersion used by the client when the state was transitioned */ @nonthreadsafe private[transaction] class TransactionMetadata(val transactionalId: String, var producerId: Long, - var lastProducerId: Long, + var previousProducerId: Long, + var nextProducerId: Long, var producerEpoch: Short, var lastProducerEpoch: Short, var txnTimeoutMs: Int, var state: TransactionState, val topicPartitions: mutable.Set[TopicPartition], @volatile var txnStartTimestamp: Long = -1, - @volatile var txnLastUpdateTimestamp: Long) extends Logging { + @volatile var txnLastUpdateTimestamp: Long, + var clientTransactionVersion: TransactionVersion) extends Logging { // pending state is used to indicate the state that this transaction is going to // transit to, and for blocking future attempts to transit it again if it is not legal; @@ -256,8 +250,8 @@ private[transaction] class TransactionMetadata(val transactionalId: String, // this is visible for test only def prepareNoTransit(): TxnTransitMetadata = { // do not call transitTo as it will set the pending state, a follow-up call to abort the transaction will set its pending state - TxnTransitMetadata(producerId, lastProducerId, producerEpoch, lastProducerEpoch, txnTimeoutMs, state, topicPartitions.toSet, - txnStartTimestamp, txnLastUpdateTimestamp) + TxnTransitMetadata(producerId, previousProducerId, nextProducerId, producerEpoch, lastProducerEpoch, txnTimeoutMs, state, topicPartitions.toSet, + txnStartTimestamp, txnLastUpdateTimestamp, TransactionVersion.TV_0) } def prepareFenceProducerEpoch(): TxnTransitMetadata = { @@ -335,9 +329,16 @@ private[transaction] class TransactionMetadata(val transactionalId: String, (topicPartitions ++ addedTopicPartitions).toSet, newTxnStartTimestamp, updateTimestamp) } - def prepareAbortOrCommit(newState: TransactionState, updateTimestamp: Long): TxnTransitMetadata = { - prepareTransitionTo(newState, producerId, producerEpoch, lastProducerEpoch, txnTimeoutMs, topicPartitions.toSet, - txnStartTimestamp, updateTimestamp) + def prepareAbortOrCommit(newState: TransactionState, clientTransactionVersion: TransactionVersion, nextProducerId: Long, updateTimestamp: Long): TxnTransitMetadata = { + val (updatedProducerEpoch, updatedLastProducerEpoch) = if (clientTransactionVersion.supportsEpochBump()) { + // We already ensured that we do not overflow here. MAX_SHORT is the highest possible value. + ((producerEpoch + 1).toShort, producerEpoch) + } else { + (producerEpoch, lastProducerEpoch) + } + + prepareTransitionTo(newState, producerId, nextProducerId, updatedProducerEpoch, updatedLastProducerEpoch, txnTimeoutMs, topicPartitions.toSet, + txnStartTimestamp, updateTimestamp, clientTransactionVersion) } def prepareComplete(updateTimestamp: Long): TxnTransitMetadata = { @@ -345,8 +346,15 @@ private[transaction] class TransactionMetadata(val transactionalId: String, // Since the state change was successfully written to the log, unset the flag for a failed epoch fence hasFailedEpochFence = false - prepareTransitionTo(newState, producerId, producerEpoch, lastProducerEpoch, txnTimeoutMs, Set.empty[TopicPartition], - txnStartTimestamp, updateTimestamp) + val (updatedProducerId, updatedProducerEpoch) = + // If we overflowed on epoch bump, we have to set it as the producer ID now the marker has been written. + if (clientTransactionVersion.supportsEpochBump() && nextProducerId != RecordBatch.NO_PRODUCER_ID) { + (nextProducerId, 0.toShort) + } else { + (producerId, producerEpoch) + } + prepareTransitionTo(newState, updatedProducerId, RecordBatch.NO_PRODUCER_ID, updatedProducerEpoch, producerEpoch, txnTimeoutMs, Set.empty[TopicPartition], + txnStartTimestamp, updateTimestamp, clientTransactionVersion) } def prepareDead(): TxnTransitMetadata = { @@ -367,37 +375,50 @@ private[transaction] class TransactionMetadata(val transactionalId: String, } } - private def prepareTransitionTo(newState: TransactionState, - newProducerId: Long, - newEpoch: Short, - newLastEpoch: Short, - newTxnTimeoutMs: Int, - newTopicPartitions: immutable.Set[TopicPartition], - newTxnStartTimestamp: Long, + private def prepareTransitionTo(updatedState: TransactionState, + updatedProducerId: Long, + updatedEpoch: Short, + updatedLastEpoch: Short, + updatedTxnTimeoutMs: Int, + updatedTopicPartitions: immutable.Set[TopicPartition], + updatedTxnStartTimestamp: Long, updateTimestamp: Long): TxnTransitMetadata = { + prepareTransitionTo(updatedState, updatedProducerId, RecordBatch.NO_PRODUCER_ID, updatedEpoch, updatedLastEpoch, updatedTxnTimeoutMs, updatedTopicPartitions, updatedTxnStartTimestamp, updateTimestamp, TransactionVersion.TV_0) + } + + private def prepareTransitionTo(updatedState: TransactionState, + updatedProducerId: Long, + nextProducerId: Long, + updatedEpoch: Short, + updatedLastEpoch: Short, + updatedTxnTimeoutMs: Int, + updatedTopicPartitions: immutable.Set[TopicPartition], + updatedTxnStartTimestamp: Long, + updateTimestamp: Long, + clientTransactionVersion: TransactionVersion): TxnTransitMetadata = { if (pendingState.isDefined) - throw new IllegalStateException(s"Preparing transaction state transition to $newState " + + throw new IllegalStateException(s"Preparing transaction state transition to $updatedState " + s"while it already a pending state ${pendingState.get}") - if (newProducerId < 0) - throw new IllegalArgumentException(s"Illegal new producer id $newProducerId") + if (updatedProducerId < 0) + throw new IllegalArgumentException(s"Illegal new producer id $updatedProducerId") // The epoch is initialized to NO_PRODUCER_EPOCH when the TransactionMetadata // is created for the first time and it could stay like this until transitioning // to Dead. - if (newState != Dead && newEpoch < 0) - throw new IllegalArgumentException(s"Illegal new producer epoch $newEpoch") + if (updatedState != Dead && updatedEpoch < 0) + throw new IllegalArgumentException(s"Illegal new producer epoch $updatedEpoch") // check that the new state transition is valid and update the pending state if necessary - if (newState.validPreviousStates.contains(state)) { - val transitMetadata = TxnTransitMetadata(newProducerId, producerId, newEpoch, newLastEpoch, newTxnTimeoutMs, newState, - newTopicPartitions, newTxnStartTimestamp, updateTimestamp) + if (updatedState.validPreviousStates.contains(state)) { + val transitMetadata = TxnTransitMetadata(updatedProducerId, producerId, nextProducerId, updatedEpoch, updatedLastEpoch, updatedTxnTimeoutMs, updatedState, + updatedTopicPartitions, updatedTxnStartTimestamp, updateTimestamp, clientTransactionVersion) debug(s"TransactionalId $transactionalId prepare transition from $state to $transitMetadata") - pendingState = Some(newState) + pendingState = Some(updatedState) transitMetadata } else { - throw new IllegalStateException(s"Preparing transaction state transition to $newState failed since the target state" + - s" $newState is not a valid previous state of the current state $state") + throw new IllegalStateException(s"Preparing transaction state transition to $updatedState failed since the target state" + + s" $updatedState is not a valid previous state of the current state $state") } } @@ -436,7 +457,7 @@ private[transaction] class TransactionMetadata(val transactionalId: String, producerEpoch = transitMetadata.producerEpoch lastProducerEpoch = transitMetadata.lastProducerEpoch producerId = transitMetadata.producerId - lastProducerId = transitMetadata.lastProducerId + previousProducerId = transitMetadata.prevProducerId } case Ongoing => // from addPartitions @@ -457,6 +478,10 @@ private[transaction] class TransactionMetadata(val transactionalId: String, txnStartTimestamp != transitMetadata.txnStartTimestamp) { throwStateTransitionFailure(transitMetadata) + } else if (transitMetadata.clientTransactionVersion.supportsEpochBump()) { + producerEpoch = transitMetadata.producerEpoch + lastProducerEpoch = transitMetadata.lastProducerEpoch + nextProducerId = transitMetadata.nextProducerId } case CompleteAbort | CompleteCommit => // from write markers @@ -468,6 +493,13 @@ private[transaction] class TransactionMetadata(val transactionalId: String, } else { txnStartTimestamp = transitMetadata.txnStartTimestamp topicPartitions.clear() + if (transitMetadata.clientTransactionVersion.supportsEpochBump()) { + producerEpoch = transitMetadata.producerEpoch + lastProducerEpoch = transitMetadata.lastProducerEpoch + previousProducerId = transitMetadata.prevProducerId + producerId = transitMetadata.producerId + nextProducerId = transitMetadata.nextProducerId + } } case PrepareEpochFence => @@ -487,6 +519,7 @@ private[transaction] class TransactionMetadata(val transactionalId: String, } debug(s"TransactionalId $transactionalId complete transition from $state to $transitMetadata") + clientTransactionVersion = transitMetadata.clientTransactionVersion txnLastUpdateTimestamp = transitMetadata.txnLastUpdateTimestamp pendingState = None state = toState @@ -494,8 +527,14 @@ private[transaction] class TransactionMetadata(val transactionalId: String, } private def validProducerEpoch(transitMetadata: TxnTransitMetadata): Boolean = { - val transitEpoch = transitMetadata.producerEpoch - val transitProducerId = transitMetadata.producerId + val isAtLeastTransactionsV2 = transitMetadata.clientTransactionVersion.supportsEpochBump() + val isOverflowComplete = isAtLeastTransactionsV2 && (transitMetadata.txnState == CompleteCommit || transitMetadata.txnState == CompleteAbort) && transitMetadata.producerEpoch == 0 + val transitEpoch = + if (isOverflowComplete || (isAtLeastTransactionsV2 && (transitMetadata.txnState == PrepareCommit || transitMetadata.txnState == PrepareAbort))) + transitMetadata.lastProducerEpoch + else + transitMetadata.producerEpoch + val transitProducerId = if (isOverflowComplete) transitMetadata.prevProducerId else transitMetadata.producerId transitEpoch == producerEpoch && transitProducerId == producerId } @@ -518,6 +557,8 @@ private[transaction] class TransactionMetadata(val transactionalId: String, "TransactionMetadata(" + s"transactionalId=$transactionalId, " + s"producerId=$producerId, " + + s"previousProducerId=$previousProducerId, " + s"nextProducerId=$nextProducerId, " s"producerEpoch=$producerEpoch, " + s"txnTimeoutMs=$txnTimeoutMs, " + s"state=$state, " + diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala index d66bab06f68..01912724587 100644 --- a/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala +++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala @@ -101,8 +101,10 @@ class TransactionStateManager(brokerId: Int, TransactionStateManagerConfig.METRICS_GROUP, "The avg time it took to load the partitions in the last 30sec"), new Avg()) - private[transaction] def usesFlexibleRecords(): Boolean = { - metadataCache.features().finalizedFeatures().getOrDefault(TransactionVersion.FEATURE_NAME, 0.toShort) > 0 + private[transaction] def transactionVersionLevel(): TransactionVersion = { + val version = TransactionVersion.fromFeatureLevel(metadataCache.features().finalizedFeatures().getOrDefault( + TransactionVersion.FEATURE_NAME, 0.toShort)) + version } // visible for testing only @@ -624,7 +626,7 @@ class TransactionStateManager(brokerId: Int, // generate the message for this transaction metadata val keyBytes = TransactionLog.keyToBytes(transactionalId) - val valueBytes = TransactionLog.valueToBytes(newMetadata, usesFlexibleRecords()) + val valueBytes = TransactionLog.valueToBytes(newMetadata, transactionVersionLevel()) val timestamp = time.milliseconds() val records = MemoryRecords.withRecords(TransactionLog.EnforcedCompression, new SimpleRecord(timestamp, keyBytes, valueBytes)) diff --git a/core/src/main/scala/kafka/server/KafkaApis.scala b/core/src/main/scala/kafka/server/KafkaApis.scala index 2586c562cf5..840f163132a 100644 --- a/core/src/main/scala/kafka/server/KafkaApis.scala +++ b/core/src/main/scala/kafka/server/KafkaApis.scala @@ -73,7 +73,7 @@ import org.apache.kafka.coordinator.group.{Group, GroupCoordinator} import org.apache.kafka.coordinator.share.ShareCoordinator import org.apache.kafka.server.ClientMetricsManager import org.apache.kafka.server.authorizer._ -import org.apache.kafka.server.common.{GroupVersion, MetadataVersion, RequestLocal} +import org.apache.kafka.server.common.{GroupVersion, MetadataVersion, RequestLocal, TransactionVersion} import org.apache.kafka.server.common.MetadataVersion.{IBP_0_11_0_IV0, IBP_2_3_IV0} import org.apache.kafka.server.record.BrokerCompressionType import org.apache.kafka.server.share.context.ShareFetchContext @@ -2299,7 +2299,7 @@ class KafkaApis(val requestChannel: RequestChannel, val transactionalId = endTxnRequest.data.transactionalId if (authHelper.authorize(request.context, WRITE, TRANSACTIONAL_ID, transactionalId)) { - def sendResponseCallback(error: Errors): Unit = { + def sendResponseCallback(error: Errors, newProducerId: Long, newProducerEpoch: Short): Unit = { def createResponse(requestThrottleMs: Int): AbstractResponse = { val finalError = if (endTxnRequest.version < 2 && error == Errors.PRODUCER_FENCED) { @@ -2311,6 +2311,8 @@ class KafkaApis(val requestChannel: RequestChannel, } val responseBody = new EndTxnResponse(new EndTxnResponseData() .setErrorCode(finalError.code) + .setProducerId(newProducerId) + .setProducerEpoch(newProducerEpoch) .setThrottleTimeMs(requestThrottleMs)) trace(s"Completed ${endTxnRequest.data.transactionalId}'s EndTxnRequest " + s"with committed: ${endTxnRequest.data.committed}, " + @@ -2320,10 +2322,14 @@ class KafkaApis(val requestChannel: RequestChannel, requestHelper.sendResponseMaybeThrottle(request, createResponse) } + // If the request is version 4, we know the client supports transaction version 2. + val clientTransactionVersion = if (endTxnRequest.version() > 4) TransactionVersion.TV_2 else TransactionVersion.TV_0 + txnCoordinator.handleEndTransaction(endTxnRequest.data.transactionalId, endTxnRequest.data.producerId, endTxnRequest.data.producerEpoch, endTxnRequest.result(), + clientTransactionVersion, sendResponseCallback, requestLocal) } else diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala index 12531d289db..323dbe958e1 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala @@ -464,10 +464,10 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren addPartitionsOp.awaitAndVerify(txn) val txnMetadata = transactionMetadata(txn).getOrElse(throw new IllegalStateException(s"Transaction not found $txn")) - txnRecords += new SimpleRecord(txn.txnMessageKeyBytes, TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), true)) + txnRecords += new SimpleRecord(txn.txnMessageKeyBytes, TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), TransactionVersion.TV_2)) txnMetadata.state = PrepareCommit - txnRecords += new SimpleRecord(txn.txnMessageKeyBytes, TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), true)) + txnRecords += new SimpleRecord(txn.txnMessageKeyBytes, TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), TransactionVersion.TV_2)) prepareTxnLog(partitionId) } @@ -506,13 +506,15 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren private def prepareExhaustedEpochTxnMetadata(txn: Transaction): TransactionMetadata = { new TransactionMetadata(transactionalId = txn.transactionalId, producerId = producerId, - lastProducerId = RecordBatch.NO_PRODUCER_ID, + previousProducerId = RecordBatch.NO_PRODUCER_ID, + nextProducerId = RecordBatch.NO_PRODUCER_ID, producerEpoch = (Short.MaxValue - 1).toShort, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 60000, state = Empty, topicPartitions = collection.mutable.Set.empty[TopicPartition], - txnLastUpdateTimestamp = time.milliseconds()) + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = TransactionVersion.TV_0) } abstract class TxnOperation[R] extends Operation { @@ -562,7 +564,8 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren txnMetadata.producerId, txnMetadata.producerEpoch, transactionResult(txn), - resultCallback, + TransactionVersion.TV_2, + (r, _, _) => resultCallback(r), RequestLocal.withThreadConfinedCaching) } } diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala index 584a49bbc2c..f32a924fb68 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala @@ -23,9 +23,13 @@ import org.apache.kafka.common.record.RecordBatch import org.apache.kafka.common.requests.{AddPartitionsToTxnResponse, TransactionResult} import org.apache.kafka.common.utils.{LogContext, MockTime, ProducerIdAndEpoch} import org.apache.kafka.coordinator.transaction.TransactionStateManagerConfig +import org.apache.kafka.server.common.TransactionVersion +import org.apache.kafka.server.common.TransactionVersion.{TV_0, TV_2} import org.apache.kafka.server.util.MockScheduler import org.junit.jupiter.api.Assertions._ import org.junit.jupiter.api.Test +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ValueSource import org.mockito.{ArgumentCaptor, ArgumentMatchers} import org.mockito.ArgumentMatchers.{any, anyInt} import org.mockito.Mockito.{mock, times, verify, when} @@ -51,6 +55,7 @@ class TransactionCoordinatorTest { private val producerId = 10L private val producerEpoch: Short = 1 private val txnTimeoutMs = 1 + private val producerId2 = 11L private val partitions = mutable.Set[TopicPartition](new TopicPartition("topic1", 0)) private val scheduler = new MockScheduler(time) @@ -66,6 +71,8 @@ class TransactionCoordinatorTest { val transactionStatePartitionCount = 1 var result: InitProducerIdResult = _ var error: Errors = Errors.NONE + var newProducerId: Long = RecordBatch.NO_PRODUCER_ID + var newEpoch: Short = RecordBatch.NO_PRODUCER_EPOCH private def mockPidGenerator(): Unit = { when(pidGenerator.generateProducerId()).thenAnswer(_ => { @@ -155,8 +162,8 @@ class TransactionCoordinatorTest { def shouldGenerateNewProducerIdIfEpochsExhausted(): Unit = { initPidGenericMocks(transactionalId) - val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, (Short.MaxValue - 1).toShort, - (Short.MaxValue - 2).toShort, txnTimeoutMs, Empty, mutable.Set.empty, time.milliseconds(), time.milliseconds()) + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, (Short.MaxValue - 1).toShort, + (Short.MaxValue - 2).toShort, txnTimeoutMs, Empty, mutable.Set.empty, time.milliseconds(), time.milliseconds(), TV_0) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) @@ -245,7 +252,8 @@ class TransactionCoordinatorTest { errors = AddPartitionsToTxnResponse.errorsForTransaction(result.topicResults()).asScala.toMap } // If producer ID is not the same, return INVALID_PRODUCER_ID_MAPPING - val wrongPidTxnMetadata = new TransactionMetadata(transactionalId, 1, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, PrepareCommit, partitions, 0, 0) + val wrongPidTxnMetadata = new TransactionMetadata(transactionalId, 1, 0, RecordBatch.NO_PRODUCER_ID, + 0, RecordBatch.NO_PRODUCER_EPOCH, 0, PrepareCommit, partitions, 0, 0, TV_0) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(new CoordinatorEpochAndTxnMetadata(coordinatorEpoch, wrongPidTxnMetadata)))) @@ -253,10 +261,10 @@ class TransactionCoordinatorTest { errors.foreach { case (_, error) => assertEquals(Errors.INVALID_PRODUCER_ID_MAPPING, error) } - // If producer epoch is not equal, return PRODUCER_FENCED - val oldEpochTxnMetadata = new TransactionMetadata(transactionalId, 0, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, PrepareCommit, partitions, 0, 0) + val oldEpochTxnMetadata = new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID, + 0, RecordBatch.NO_PRODUCER_EPOCH, 0, PrepareCommit, partitions, 0, 0, TV_0) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(new CoordinatorEpochAndTxnMetadata(coordinatorEpoch, oldEpochTxnMetadata)))) @@ -266,7 +274,8 @@ class TransactionCoordinatorTest { } // If the txn state is Prepare or AbortCommit, we return CONCURRENT_TRANSACTIONS - val emptyTxnMetadata = new TransactionMetadata(transactionalId, 0, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, PrepareCommit, partitions, 0, 0) + val emptyTxnMetadata = new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID, + 0, RecordBatch.NO_PRODUCER_EPOCH, 0, PrepareCommit, partitions, 0, 0, TV_0) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(new CoordinatorEpochAndTxnMetadata(coordinatorEpoch, emptyTxnMetadata)))) @@ -276,7 +285,8 @@ class TransactionCoordinatorTest { } // Pending state does not matter, we will just check if the partitions are in the txnMetadata. - val ongoingTxnMetadata = new TransactionMetadata(transactionalId, 0, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, Ongoing, mutable.Set.empty, 0, 0) + val ongoingTxnMetadata = new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID, + 0, RecordBatch.NO_PRODUCER_EPOCH, 0, Ongoing, mutable.Set.empty, 0, 0, TV_0) ongoingTxnMetadata.pendingState = Some(CompleteCommit) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(new CoordinatorEpochAndTxnMetadata(coordinatorEpoch, ongoingTxnMetadata)))) @@ -298,9 +308,11 @@ class TransactionCoordinatorTest { } def validateConcurrentTransactions(state: TransactionState): Unit = { + // Since the clientTransactionVersion doesn't matter, use 2 since the states are PrepareCommit and PrepareAbort. when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, - new TransactionMetadata(transactionalId, 0, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, state, mutable.Set.empty, 0, 0))))) + new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID, + 0, RecordBatch.NO_PRODUCER_EPOCH, 0, state, mutable.Set.empty, 0, 0, TV_2))))) coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 0, partitions, errorsCallback) assertEquals(Errors.CONCURRENT_TRANSACTIONS, error) @@ -308,9 +320,11 @@ class TransactionCoordinatorTest { @Test def shouldRespondWithProducerFencedOnAddPartitionsWhenEpochsAreDifferent(): Unit = { + // Since the clientTransactionVersion doesn't matter, use 2 since the state is PrepareCommit. when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, - new TransactionMetadata(transactionalId, 0, 0, 10, 9, 0, PrepareCommit, mutable.Set.empty, 0, 0))))) + new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID, + 10, 9, 0, PrepareCommit, mutable.Set.empty, 0, 0, TV_2))))) coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 0, partitions, errorsCallback) assertEquals(Errors.PRODUCER_FENCED, error) @@ -318,27 +332,30 @@ class TransactionCoordinatorTest { @Test def shouldAppendNewMetadataToLogOnAddPartitionsWhenPartitionsAdded(): Unit = { - validateSuccessfulAddPartitions(Empty) + validateSuccessfulAddPartitions(Empty, 0) } @Test def shouldRespondWithSuccessOnAddPartitionsWhenStateIsOngoing(): Unit = { - validateSuccessfulAddPartitions(Ongoing) + validateSuccessfulAddPartitions(Ongoing, 0) } - @Test - def shouldRespondWithSuccessOnAddPartitionsWhenStateIsCompleteCommit(): Unit = { - validateSuccessfulAddPartitions(CompleteCommit) + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def shouldRespondWithSuccessOnAddPartitionsWhenStateIsCompleteCommit(clientTransactionVersion: Short): Unit = { + validateSuccessfulAddPartitions(CompleteCommit, clientTransactionVersion) } - @Test - def shouldRespondWithSuccessOnAddPartitionsWhenStateIsCompleteAbort(): Unit = { - validateSuccessfulAddPartitions(CompleteAbort) + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def shouldRespondWithSuccessOnAddPartitionsWhenStateIsCompleteAbort(clientTransactionVersion: Short): Unit = { + validateSuccessfulAddPartitions(CompleteAbort, clientTransactionVersion) } - def validateSuccessfulAddPartitions(previousState: TransactionState): Unit = { - val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, (producerEpoch - 1).toShort, - txnTimeoutMs, previousState, mutable.Set.empty, time.milliseconds(), time.milliseconds()) + def validateSuccessfulAddPartitions(previousState: TransactionState, transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, + producerEpoch, (producerEpoch - 1).toShort, txnTimeoutMs, previousState, mutable.Set.empty, time.milliseconds(), time.milliseconds(), clientTransactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) @@ -360,7 +377,8 @@ class TransactionCoordinatorTest { def shouldRespondWithErrorsNoneOnAddPartitionWhenNoErrorsAndPartitionsTheSame(): Unit = { when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, - new TransactionMetadata(transactionalId, 0, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, Empty, partitions, 0, 0))))) + new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID, + 0, RecordBatch.NO_PRODUCER_EPOCH, 0, Empty, partitions, 0, 0, TV_0))))) coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 0, partitions, errorsCallback) assertEquals(Errors.NONE, error) @@ -376,7 +394,8 @@ class TransactionCoordinatorTest { when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, - new TransactionMetadata(transactionalId, 0, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, Ongoing, partitions, 0, 0))))) + new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID, + 0, RecordBatch.NO_PRODUCER_EPOCH, 0, Ongoing, partitions, 0, 0, TV_0))))) coordinator.handleVerifyPartitionsInTransaction(transactionalId, 0L, 0, partitions, verifyPartitionsInTxnCallback) errors.foreach { case (_, error) => @@ -394,7 +413,8 @@ class TransactionCoordinatorTest { when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, - new TransactionMetadata(transactionalId, 0, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, Empty, partitions, 0, 0))))) + new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID, + 0, RecordBatch.NO_PRODUCER_EPOCH, 0, Empty, partitions, 0, 0, TV_0))))) val extraPartitions = partitions ++ Set(new TopicPartition("topic2", 0)) @@ -404,107 +424,227 @@ class TransactionCoordinatorTest { verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) } - @Test - def shouldReplyWithInvalidPidMappingOnEndTxnWhenTxnIdDoesntExist(): Unit = { + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def shouldReplyWithInvalidPidMappingOnEndTxnWhenTxnIdDoesntExist(transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(None)) - coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, errorsCallback) + coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback) assertEquals(Errors.INVALID_PRODUCER_ID_MAPPING, error) verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) } - @Test - def shouldReplyWithInvalidPidMappingOnEndTxnWhenPidDosentMatchMapped(): Unit = { + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def shouldReplyWithInvalidPidMappingOnEndTxnWhenPidDosentMatchMapped(transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, - new TransactionMetadata(transactionalId, 10, 10, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, Ongoing, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))) + new TransactionMetadata(transactionalId, 10, 10, RecordBatch.NO_PRODUCER_ID, + 0, RecordBatch.NO_PRODUCER_EPOCH, 0, Ongoing, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_0))))) - coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, errorsCallback) + coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback) assertEquals(Errors.INVALID_PRODUCER_ID_MAPPING, error) verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) } - @Test - def shouldReplyWithProducerFencedOnEndTxnWhenEpochIsNotSameAsTransaction(): Unit = { + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def shouldReplyWithProducerFencedOnEndTxnWhenEpochIsNotSameAsTransaction(transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, - new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, (producerEpoch - 1).toShort, 1, Ongoing, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))) + new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, producerEpoch, + (producerEpoch - 1).toShort, 1, Ongoing, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_0))))) - coordinator.handleEndTransaction(transactionalId, producerId, 0, TransactionResult.COMMIT, errorsCallback) + coordinator.handleEndTransaction(transactionalId, producerId, 0, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback) assertEquals(Errors.PRODUCER_FENCED, error) verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) } - @Test - def shouldReturnOkOnEndTxnWhenStatusIsCompleteCommitAndResultIsCommit(): Unit ={ + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def shouldReturnOkOnEndTxnWhenStatusIsCompleteCommitAndResultIsCommit(transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, - new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, (producerEpoch - 1).toShort, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))) + new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, producerEpoch, + (producerEpoch - 1).toShort, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion))))) - coordinator.handleEndTransaction(transactionalId, producerId, 1, TransactionResult.COMMIT, errorsCallback) + coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch(clientTransactionVersion), TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback) assertEquals(Errors.NONE, error) verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) } - @Test - def shouldReturnOkOnEndTxnWhenStatusIsCompleteAbortAndResultIsAbort(): Unit ={ - val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, (producerEpoch - 1).toShort, 1, CompleteAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()) + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def shouldReturnOkOnEndTxnWhenStatusIsCompleteAbortAndResultIsAbort(transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, + producerEpoch, (producerEpoch - 1).toShort, 1, CompleteAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) - coordinator.handleEndTransaction(transactionalId, producerId, 1, TransactionResult.ABORT, errorsCallback) + coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch(clientTransactionVersion), TransactionResult.ABORT, clientTransactionVersion, endTxnCallback) assertEquals(Errors.NONE, error) verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) } - @Test - def shouldReturnInvalidTxnRequestOnEndTxnRequestWhenStatusIsCompleteAbortAndResultIsNotAbort(): Unit = { - val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, (producerEpoch - 1).toShort, 1, CompleteAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()) + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def shouldReturnInvalidTxnRequestOnEndTxnRequestWhenStatusIsCompleteAbortAndResultIsNotAbort(transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, + producerEpoch, (producerEpoch - 1).toShort, 1, CompleteAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) - coordinator.handleEndTransaction(transactionalId, producerId, 1, TransactionResult.COMMIT, errorsCallback) + coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch(clientTransactionVersion), TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback) assertEquals(Errors.INVALID_TXN_STATE, error) verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) } - @Test - def shouldReturnInvalidTxnRequestOnEndTxnRequestWhenStatusIsCompleteCommitAndResultIsNotCommit(): Unit = { - val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, (producerEpoch - 1).toShort,1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()) + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def shouldReturnInvalidTxnRequestOnEndTxnRequestWhenStatusIsCompleteCommitAndResultIsNotCommit(transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, + producerEpoch, (producerEpoch - 1).toShort,1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) - coordinator.handleEndTransaction(transactionalId, producerId, 1, TransactionResult.ABORT, errorsCallback) + coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch(clientTransactionVersion), TransactionResult.ABORT, clientTransactionVersion, endTxnCallback) assertEquals(Errors.INVALID_TXN_STATE, error) verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) } - @Test - def shouldReturnConcurrentTxnRequestOnEndTxnRequestWhenStatusIsPrepareCommit(): Unit = { + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def shouldReturnConcurrentTransactionsOnEndTxnRequestWhenStatusIsPrepareCommit(transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) - .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, (producerEpoch - 1).toShort, 1, PrepareCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))) + .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, + RecordBatch.NO_PRODUCER_ID, producerEpoch, (producerEpoch - 1).toShort, 1, PrepareCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion))))) - coordinator.handleEndTransaction(transactionalId, producerId, 1, TransactionResult.COMMIT, errorsCallback) + coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch(clientTransactionVersion), TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback) assertEquals(Errors.CONCURRENT_TRANSACTIONS, error) verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) } - @Test - def shouldReturnInvalidTxnRequestOnEndTxnRequestWhenStatusIsPrepareAbort(): Unit = { + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def shouldReturnInvalidTxnRequestOnEndTxnRequestWhenStatusIsPrepareAbort(transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) - .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, 1, RecordBatch.NO_PRODUCER_EPOCH, 1, PrepareAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))) + .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, + RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, PrepareAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion))))) - coordinator.handleEndTransaction(transactionalId, producerId, 1, TransactionResult.COMMIT, errorsCallback) + coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch(clientTransactionVersion), TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback) assertEquals(Errors.INVALID_TXN_STATE, error) verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) } @Test - def shouldAppendPrepareCommitToLogOnEndTxnWhenStatusIsOngoingAndResultIsCommit(): Unit = { - mockPrepare(PrepareCommit) + def shouldReturnWhenTransactionVersionDowngraded(): Unit = { + // State was written when transactions V2 + when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) + .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, + RecordBatch.NO_PRODUCER_ID, (producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH, 1, PrepareCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2))))) - coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.COMMIT, errorsCallback) + // Return CONCURRENT_TRANSACTIONS as the transaction is still completing + coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.COMMIT, TV_0, endTxnCallback) + assertEquals(Errors.CONCURRENT_TRANSACTIONS, error) + assertEquals(RecordBatch.NO_PRODUCER_ID, newProducerId) + assertEquals(RecordBatch.NO_PRODUCER_EPOCH, newEpoch) + verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) + + // Recognize the retry and return NONE + when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) + .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, + RecordBatch.NO_PRODUCER_ID, (producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2))))) + coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.COMMIT, TV_0, endTxnCallback) + assertEquals(Errors.NONE, error) + assertEquals(producerId, newProducerId) + assertEquals((producerEpoch + 1).toShort, newEpoch) // epoch is bumped since we started as V2 + verify(transactionManager, times(2)).getTransactionState(ArgumentMatchers.eq(transactionalId)) + } + + @Test + def shouldReturnCorrectlyWhenTransactionVersionUpgraded(): Unit = { + // State was written when transactions V0 + when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) + .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, + RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, PrepareCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_0))))) + + // Transactions V0 throws the concurrent transactions error here. + coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.COMMIT, TV_2, endTxnCallback) + assertEquals(Errors.CONCURRENT_TRANSACTIONS, error) + verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) + + // When the transaction is completed, return and do not throw an error. + when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) + .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, + RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_0))))) + coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.COMMIT, TV_2, endTxnCallback) + assertEquals(Errors.NONE, error) + assertEquals(producerId, newProducerId) + assertEquals(producerEpoch, newEpoch) // epoch is not bumped since this started as V1 + verify(transactionManager, times(2)).getTransactionState(ArgumentMatchers.eq(transactionalId)) + } + + @Test + def shouldReturnInvalidTxnRequestOnEndTxnV2IfNotEndTxnV2Retry(): Unit = { + when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) + .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, + RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, PrepareCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2))))) + + // If producerEpoch is the same, this is not a retry of the EndTxnRequest, but the next EndTxnRequest. Return PRODUCER_FENCED. + coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.COMMIT, TV_2, endTxnCallback) + assertEquals(Errors.PRODUCER_FENCED, error) + verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) + + when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) + .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, + RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2))))) + + // If producerEpoch is the same, this is not a retry of the EndTxnRequest, but the next EndTxnRequest. Return PRODUCER_FENCED. + coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.COMMIT, TV_2, endTxnCallback) + assertEquals(Errors.PRODUCER_FENCED, error) + verify(transactionManager, times(2)).getTransactionState(ArgumentMatchers.eq(transactionalId)) + } + + @Test + def shouldReturnOkOnEndTxnV2IfEndTxnV2RetryEpochOverflow(): Unit = { + when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) + .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, + producerId2, Short.MaxValue, (Short.MaxValue - 1).toShort, 1, PrepareCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2))))) + + // Return CONCURRENT_TRANSACTIONS while transaction is still completing + coordinator.handleEndTransaction(transactionalId, producerId, (Short.MaxValue - 1).toShort, TransactionResult.COMMIT, TV_2, endTxnCallback) + assertEquals(Errors.CONCURRENT_TRANSACTIONS, error) + verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) + + when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) + .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId2, producerId, + RecordBatch.NO_PRODUCER_ID, 0, RecordBatch.NO_PRODUCER_EPOCH, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2))))) + + coordinator.handleEndTransaction(transactionalId, producerId, (Short.MaxValue - 1).toShort, TransactionResult.COMMIT, TV_2, endTxnCallback) + assertEquals(Errors.NONE, error) + assertNotEquals(RecordBatch.NO_PRODUCER_ID, newProducerId) + assertNotEquals(producerId, newProducerId) + assertEquals(0, newEpoch) + verify(transactionManager, times(2)).getTransactionState(ArgumentMatchers.eq(transactionalId)) + } + + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def shouldAppendPrepareCommitToLogOnEndTxnWhenStatusIsOngoingAndResultIsCommit(transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) + mockPrepare(PrepareCommit, clientTransactionVersion) + + coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback) verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) verify(transactionManager).appendTransactionToLog( ArgumentMatchers.eq(transactionalId), @@ -515,11 +655,13 @@ class TransactionCoordinatorTest { any()) } - @Test - def shouldAppendPrepareAbortToLogOnEndTxnWhenStatusIsOngoingAndResultIsAbort(): Unit = { - mockPrepare(PrepareAbort) + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def shouldAppendPrepareAbortToLogOnEndTxnWhenStatusIsOngoingAndResultIsAbort(transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) + mockPrepare(PrepareAbort, clientTransactionVersion) - coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.ABORT, errorsCallback) + coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.ABORT, clientTransactionVersion, endTxnCallback) verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) verify(transactionManager).appendTransactionToLog( ArgumentMatchers.eq(transactionalId), @@ -530,90 +672,106 @@ class TransactionCoordinatorTest { any()) } - @Test - def shouldRespondWithInvalidRequestOnEndTxnWhenTransactionalIdIsNull(): Unit = { - coordinator.handleEndTransaction(null, 0, 0, TransactionResult.COMMIT, errorsCallback) + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def shouldRespondWithInvalidRequestOnEndTxnWhenTransactionalIdIsNull(transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) + coordinator.handleEndTransaction(null, 0, 0, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback) assertEquals(Errors.INVALID_REQUEST, error) } - @Test - def shouldRespondWithInvalidRequestOnEndTxnWhenTransactionalIdIsEmpty(): Unit = { + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def shouldRespondWithInvalidRequestOnEndTxnWhenTransactionalIdIsEmpty(transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Left(Errors.NOT_COORDINATOR)) - coordinator.handleEndTransaction("", 0, 0, TransactionResult.COMMIT, errorsCallback) + coordinator.handleEndTransaction("", 0, 0, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback) assertEquals(Errors.INVALID_REQUEST, error) } - @Test - def shouldRespondWithNotCoordinatorOnEndTxnWhenIsNotCoordinatorForId(): Unit = { + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def shouldRespondWithNotCoordinatorOnEndTxnWhenIsNotCoordinatorForId(transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Left(Errors.NOT_COORDINATOR)) - coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, errorsCallback) + coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback) assertEquals(Errors.NOT_COORDINATOR, error) } - @Test - def shouldRespondWithCoordinatorLoadInProgressOnEndTxnWhenCoordinatorIsLoading(): Unit = { + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def shouldRespondWithCoordinatorLoadInProgressOnEndTxnWhenCoordinatorIsLoading(transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Left(Errors.COORDINATOR_LOAD_IN_PROGRESS)) - coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, errorsCallback) + coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback) assertEquals(Errors.COORDINATOR_LOAD_IN_PROGRESS, error) } - @Test - def shouldReturnInvalidEpochOnEndTxnWhenEpochIsLarger(): Unit = { + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def shouldReturnInvalidEpochOnEndTxnWhenEpochIsLarger(transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) val serverProducerEpoch = 1.toShort - verifyEndTxnEpoch(serverProducerEpoch, (serverProducerEpoch + 1).toShort) + verifyEndTxnEpoch(serverProducerEpoch, (serverProducerEpoch + 1).toShort, clientTransactionVersion) } - @Test - def shouldReturnInvalidEpochOnEndTxnWhenEpochIsSmaller(): Unit = { - val serverProducerEpoch = 1.toShort - verifyEndTxnEpoch(serverProducerEpoch, (serverProducerEpoch - 1).toShort) + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def shouldReturnInvalidEpochOnEndTxnWhenEpochIsSmaller(transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) + val serverProducerEpoch = 2.toShort + // Since we bump epoch in transactionV2 the request should be one producer ID older + verifyEndTxnEpoch(serverProducerEpoch, requestEpoch(clientTransactionVersion), clientTransactionVersion) } - private def verifyEndTxnEpoch(metadataEpoch: Short, requestEpoch: Short): Unit = { + private def verifyEndTxnEpoch(metadataEpoch: Short, requestEpoch: Short, clientTransactionVersion: TransactionVersion): Unit = { when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, - new TransactionMetadata(transactionalId, producerId, producerId, metadataEpoch, 0, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))) + new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, metadataEpoch, 1, + 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion))))) - coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch, TransactionResult.COMMIT, errorsCallback) + coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback) assertEquals(Errors.PRODUCER_FENCED, error) verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) } @Test def shouldIncrementEpochAndUpdateMetadataOnHandleInitPidWhenExistingEmptyTransaction(): Unit = { - validateIncrementEpochAndUpdateMetadata(Empty) + validateIncrementEpochAndUpdateMetadata(Empty, 0) + } + + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def shouldIncrementEpochAndUpdateMetadataOnHandleInitPidWhenExistingCompleteTransaction(clientTransactionVersion: Short): Unit = { + validateIncrementEpochAndUpdateMetadata(CompleteAbort, clientTransactionVersion) + } + + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def shouldIncrementEpochAndUpdateMetadataOnHandleInitPidWhenExistingCompleteCommitTransaction(clientTransactionVersion: Short): Unit = { + validateIncrementEpochAndUpdateMetadata(CompleteCommit, clientTransactionVersion) } @Test - def shouldIncrementEpochAndUpdateMetadataOnHandleInitPidWhenExistingCompleteTransaction(): Unit = { - validateIncrementEpochAndUpdateMetadata(CompleteAbort) - } - - @Test - def shouldIncrementEpochAndUpdateMetadataOnHandleInitPidWhenExistingCompleteCommitTransaction(): Unit = { - validateIncrementEpochAndUpdateMetadata(CompleteCommit) - } - - @Test - def shouldWaitForCommitToCompleteOnHandleInitPidAndExistingTransactionInPrepareCommitState(): Unit ={ + def shouldWaitForCommitToCompleteOnHandleInitPidAndExistingTransactionInPrepareCommitState(): Unit = { validateRespondsWithConcurrentTransactionsOnInitPidWhenInPrepareState(PrepareCommit) } @Test - def shouldWaitForCommitToCompleteOnHandleInitPidAndExistingTransactionInPrepareAbortState(): Unit ={ + def shouldWaitForCommitToCompleteOnHandleInitPidAndExistingTransactionInPrepareAbortState(): Unit = { validateRespondsWithConcurrentTransactionsOnInitPidWhenInPrepareState(PrepareAbort) } @Test def shouldAbortTransactionOnHandleInitPidWhenExistingTransactionInOngoingState(): Unit = { - val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, - (producerEpoch - 1).toShort, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds()) + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, + producerEpoch, (producerEpoch - 1).toShort, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0) when(transactionManager.validateTransactionTimeoutMs(anyInt())) .thenReturn(true) @@ -621,8 +779,10 @@ class TransactionCoordinatorTest { when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) - val originalMetadata = new TransactionMetadata(transactionalId, producerId, producerId, (producerEpoch + 1).toShort, - RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds()) + when(transactionManager.transactionVersionLevel()).thenReturn(TV_0) + + val originalMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, + (producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0) when(transactionManager.appendTransactionToLog( ArgumentMatchers.eq(transactionalId), ArgumentMatchers.eq(coordinatorEpoch), @@ -640,7 +800,7 @@ class TransactionCoordinatorTest { verify(transactionManager).appendTransactionToLog( ArgumentMatchers.eq(transactionalId), ArgumentMatchers.eq(coordinatorEpoch), - ArgumentMatchers.eq(originalMetadata.prepareAbortOrCommit(PrepareAbort, time.milliseconds())), + ArgumentMatchers.eq(originalMetadata.prepareAbortOrCommit(PrepareAbort, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds())), any(), any(), any()) @@ -648,14 +808,14 @@ class TransactionCoordinatorTest { @Test def shouldFailToAbortTransactionOnHandleInitPidWhenProducerEpochIsSmaller(): Unit = { - val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, - (producerEpoch - 1).toShort, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds()) + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, + producerEpoch, (producerEpoch - 1).toShort, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0) when(transactionManager.validateTransactionTimeoutMs(anyInt())) .thenReturn(true) - val bumpedTxnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, (producerEpoch + 2).toShort, - (producerEpoch - 1).toShort, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds()) + val bumpedTxnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, + (producerEpoch + 2).toShort, (producerEpoch - 1).toShort, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) @@ -671,8 +831,8 @@ class TransactionCoordinatorTest { @Test def shouldNotRepeatedlyBumpEpochDueToInitPidDuringOngoingTxnIfAppendToLogFails(): Unit = { - val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, - RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds()) + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, + producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0) when(transactionManager.validateTransactionTimeoutMs(anyInt())) .thenReturn(true) @@ -683,9 +843,11 @@ class TransactionCoordinatorTest { when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenAnswer(_ => Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) - val originalMetadata = new TransactionMetadata(transactionalId, producerId, producerId, (producerEpoch + 1).toShort, - RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds()) - val txnTransitMetadata = originalMetadata.prepareAbortOrCommit(PrepareAbort, time.milliseconds()) + when(transactionManager.transactionVersionLevel()).thenReturn(TV_0) + + val originalMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, + (producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0) + val txnTransitMetadata = originalMetadata.prepareAbortOrCommit(PrepareAbort, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds()) when(transactionManager.appendTransactionToLog( ArgumentMatchers.eq(transactionalId), ArgumentMatchers.eq(coordinatorEpoch), @@ -740,33 +902,38 @@ class TransactionCoordinatorTest { @Test def shouldUseLastEpochToFenceWhenEpochsAreExhausted(): Unit = { - val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, (Short.MaxValue - 1).toShort, - (Short.MaxValue - 2).toShort, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds()) + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, + (Short.MaxValue - 1).toShort, (Short.MaxValue - 2).toShort, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0) assertTrue(txnMetadata.isProducerEpochExhausted) when(transactionManager.validateTransactionTimeoutMs(anyInt())) .thenReturn(true) - val postFenceTxnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, Short.MaxValue, - RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, PrepareAbort, partitions, time.milliseconds(), time.milliseconds()) + val postFenceTxnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, + Short.MaxValue, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, PrepareAbort, partitions, time.milliseconds(), time.milliseconds(), TV_0) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, postFenceTxnMetadata)))) + when(transactionManager.transactionVersionLevel()).thenReturn(TV_0) + + // InitProducerId uses FenceProducerEpoch so clientTransactionVersion is 0. when(transactionManager.appendTransactionToLog( ArgumentMatchers.eq(transactionalId), ArgumentMatchers.eq(coordinatorEpoch), ArgumentMatchers.eq(TxnTransitMetadata( producerId = producerId, - lastProducerId = producerId, + prevProducerId = producerId, + nextProducerId = RecordBatch.NO_PRODUCER_ID, producerEpoch = Short.MaxValue, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = txnTimeoutMs, txnState = PrepareAbort, topicPartitions = partitions.toSet, txnStartTimestamp = time.milliseconds(), - txnLastUpdateTimestamp = time.milliseconds())), + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = TV_0)), capturedErrorsCallback.capture(), any(), any()) @@ -783,14 +950,16 @@ class TransactionCoordinatorTest { ArgumentMatchers.eq(coordinatorEpoch), ArgumentMatchers.eq(TxnTransitMetadata( producerId = producerId, - lastProducerId = producerId, + prevProducerId = producerId, + nextProducerId = RecordBatch.NO_PRODUCER_ID, producerEpoch = Short.MaxValue, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = txnTimeoutMs, txnState = PrepareAbort, topicPartitions = partitions.toSet, txnStartTimestamp = time.milliseconds(), - txnLastUpdateTimestamp = time.milliseconds())), + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = TV_0)), any(), any(), any()) @@ -800,8 +969,8 @@ class TransactionCoordinatorTest { def testInitProducerIdWithNoLastProducerData(): Unit = { // If the metadata doesn't include the previous producer data (for example, if it was written to the log by a broker // on an old version), the retry case should fail - val txnMetadata = new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, (producerEpoch + 1).toShort, - RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds) + val txnMetadata = new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, + RecordBatch.NO_PRODUCER_EPOCH, (producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds, TV_0) when(transactionManager.validateTransactionTimeoutMs(anyInt())) .thenReturn(true) @@ -817,8 +986,8 @@ class TransactionCoordinatorTest { @Test def testFenceProducerWhenMappingExistsWithDifferentProducerId(): Unit = { // Existing transaction ID maps to new producer ID - val txnMetadata = new TransactionMetadata(transactionalId, producerId + 1, producerId, producerEpoch, - (producerEpoch - 1).toShort, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds) + val txnMetadata = new TransactionMetadata(transactionalId, producerId + 1, producerId, + RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, (producerEpoch - 1).toShort, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds, TV_0) when(transactionManager.validateTransactionTimeoutMs(anyInt())) .thenReturn(true) @@ -835,8 +1004,8 @@ class TransactionCoordinatorTest { def testInitProducerIdWithCurrentEpochProvided(): Unit = { mockPidGenerator() - val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, 10, - 9, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds) + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, + RecordBatch.NO_PRODUCER_EPOCH, 10, 9, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds, TV_0) when(transactionManager.validateTransactionTimeoutMs(anyInt())) .thenReturn(true) @@ -870,8 +1039,8 @@ class TransactionCoordinatorTest { def testInitProducerIdStaleCurrentEpochProvided(): Unit = { mockPidGenerator() - val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, 10, - 9, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds) + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, + RecordBatch.NO_PRODUCER_EPOCH, 10, 9, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds, TV_0) when(transactionManager.validateTransactionTimeoutMs(anyInt())) .thenReturn(true) @@ -906,8 +1075,8 @@ class TransactionCoordinatorTest { @Test def testRetryInitProducerIdAfterProducerIdRotation(): Unit = { // Existing transaction ID maps to new producer ID - val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, (Short.MaxValue - 1).toShort, - (Short.MaxValue - 2).toShort, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds) + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, + RecordBatch.NO_PRODUCER_EPOCH, (Short.MaxValue - 1).toShort, (Short.MaxValue - 2).toShort, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds, TV_0) when(pidGenerator.generateProducerId()) .thenReturn(Success(producerId + 1)) @@ -928,7 +1097,7 @@ class TransactionCoordinatorTest { capturedErrorsCallback.getValue.apply(Errors.NONE) txnMetadata.pendingState = None txnMetadata.producerId = capturedTxnTransitMetadata.getValue.producerId - txnMetadata.lastProducerId = capturedTxnTransitMetadata.getValue.lastProducerId + txnMetadata.previousProducerId = capturedTxnTransitMetadata.getValue.prevProducerId txnMetadata.producerEpoch = capturedTxnTransitMetadata.getValue.producerEpoch txnMetadata.lastProducerEpoch = capturedTxnTransitMetadata.getValue.lastProducerEpoch }) @@ -947,8 +1116,8 @@ class TransactionCoordinatorTest { @Test def testInitProducerIdWithInvalidEpochAfterProducerIdRotation(): Unit = { // Existing transaction ID maps to new producer ID - val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, (Short.MaxValue - 1).toShort, - (Short.MaxValue - 2).toShort, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds) + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, + RecordBatch.NO_PRODUCER_EPOCH, (Short.MaxValue - 1).toShort, (Short.MaxValue - 2).toShort, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds, TV_0) when(pidGenerator.generateProducerId()) .thenReturn(Success(producerId + 1)) @@ -969,7 +1138,7 @@ class TransactionCoordinatorTest { capturedErrorsCallback.getValue.apply(Errors.NONE) txnMetadata.pendingState = None txnMetadata.producerId = capturedTxnTransitMetadata.getValue.producerId - txnMetadata.lastProducerId = capturedTxnTransitMetadata.getValue.lastProducerId + txnMetadata.previousProducerId = capturedTxnTransitMetadata.getValue.prevProducerId txnMetadata.producerEpoch = capturedTxnTransitMetadata.getValue.producerEpoch txnMetadata.lastProducerEpoch = capturedTxnTransitMetadata.getValue.lastProducerEpoch }) @@ -995,16 +1164,20 @@ class TransactionCoordinatorTest { @Test def shouldAbortExpiredTransactionsInOngoingStateAndBumpEpoch(): Unit = { val now = time.milliseconds() - val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, - RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now) + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, + RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now, TV_0) when(transactionManager.timedOutTransactions()) .thenReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, producerId, producerEpoch))) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) - val expectedTransition = TxnTransitMetadata(producerId, producerId, (producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH, - txnTimeoutMs, PrepareAbort, partitions.toSet, now, now + TransactionStateManagerConfig.TRANSACTIONS_ABORT_TIMED_OUT_TRANSACTION_CLEANUP_INTERVAL_MS_DEFAULT) + // Transaction timeouts use FenceProducerEpoch so clientTransactionVersion is 0. + val expectedTransition = TxnTransitMetadata(producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, (producerEpoch + 1).toShort, + RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, PrepareAbort, partitions.toSet, now, + now + TransactionStateManagerConfig.TRANSACTIONS_ABORT_TIMED_OUT_TRANSACTION_CLEANUP_INTERVAL_MS_DEFAULT, TV_0) + + when(transactionManager.transactionVersionLevel()).thenReturn(TV_0) when(transactionManager.appendTransactionToLog(ArgumentMatchers.eq(transactionalId), ArgumentMatchers.eq(coordinatorEpoch), @@ -1030,20 +1203,22 @@ class TransactionCoordinatorTest { @Test def shouldNotAcceptSmallerEpochDuringTransactionExpiration(): Unit = { val now = time.milliseconds() - val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, - RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now) + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, + RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now, TV_0) when(transactionManager.timedOutTransactions()) .thenReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, producerId, producerEpoch))) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) - val bumpedTxnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, (producerEpoch + 2).toShort, - RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now) + when(transactionManager.transactionVersionLevel()).thenReturn(TV_0) + + val bumpedTxnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, + RecordBatch.NO_PRODUCER_EPOCH, (producerEpoch + 2).toShort, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now, TV_0) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, bumpedTxnMetadata)))) - def checkOnEndTransactionComplete(txnIdAndPidEpoch: TransactionalIdAndProducerIdEpoch)(error: Errors): Unit = { + def checkOnEndTransactionComplete(txnIdAndPidEpoch: TransactionalIdAndProducerIdEpoch)(error: Errors, producerId: Long, producerEpoch: Short): Unit = { assertEquals(Errors.PRODUCER_FENCED, error) } coordinator.abortTimedOutTransactions(checkOnEndTransactionComplete) @@ -1054,9 +1229,9 @@ class TransactionCoordinatorTest { @Test def shouldNotAbortExpiredTransactionsThatHaveAPendingStateTransition(): Unit = { - val metadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, - RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds()) - metadata.prepareAbortOrCommit(PrepareCommit, time.milliseconds()) + val metadata = new TransactionMetadata(transactionalId, producerId, producerId, + RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0) + metadata.prepareAbortOrCommit(PrepareCommit, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds()) when(transactionManager.timedOutTransactions()) .thenReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, producerId, producerEpoch))) @@ -1073,22 +1248,26 @@ class TransactionCoordinatorTest { @Test def shouldNotBumpEpochWhenAbortingExpiredTransactionIfAppendToLogFails(): Unit = { val now = time.milliseconds() - val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, - RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now) + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, + RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now, TV_0) when(transactionManager.timedOutTransactions()) .thenReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, producerId, producerEpoch))) - val txnMetadataAfterAppendFailure = new TransactionMetadata(transactionalId, producerId, producerId, (producerEpoch + 1).toShort, - RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now) + val txnMetadataAfterAppendFailure = new TransactionMetadata(transactionalId, producerId, producerId, + RecordBatch.NO_PRODUCER_EPOCH, (producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now, TV_0) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadataAfterAppendFailure)))) + // Transaction timeouts use FenceProducerEpoch so clientTransactionVersion is 0. val bumpedEpoch = (producerEpoch + 1).toShort - val expectedTransition = TxnTransitMetadata(producerId, producerId, bumpedEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, - PrepareAbort, partitions.toSet, now, now + TransactionStateManagerConfig.TRANSACTIONS_ABORT_TIMED_OUT_TRANSACTION_CLEANUP_INTERVAL_MS_DEFAULT) + val expectedTransition = TxnTransitMetadata(producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, bumpedEpoch, + RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, PrepareAbort, partitions.toSet, now, + now + TransactionStateManagerConfig.TRANSACTIONS_ABORT_TIMED_OUT_TRANSACTION_CLEANUP_INTERVAL_MS_DEFAULT, TV_0) + + when(transactionManager.transactionVersionLevel()).thenReturn(TV_0) when(transactionManager.appendTransactionToLog(ArgumentMatchers.eq(transactionalId), ArgumentMatchers.eq(coordinatorEpoch), @@ -1117,9 +1296,9 @@ class TransactionCoordinatorTest { @Test def shouldNotBumpEpochWithPendingTransaction(): Unit = { - val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, - RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds()) - txnMetadata.prepareAbortOrCommit(PrepareCommit, time.milliseconds()) + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0) + txnMetadata.prepareAbortOrCommit(PrepareCommit, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds()) when(transactionManager.validateTransactionTimeoutMs(anyInt())) .thenReturn(true) @@ -1146,9 +1325,9 @@ class TransactionCoordinatorTest { def testDescribeTransactionsWithExpiringTransactionalId(): Unit = { coordinator.startup(() => transactionStatePartitionCount, enableTransactionalIdExpiration = false) - val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Dead, mutable.Set.empty, time.milliseconds(), - time.milliseconds()) + time.milliseconds(), TV_0) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) @@ -1172,8 +1351,8 @@ class TransactionCoordinatorTest { @Test def testDescribeTransactions(): Unit = { - val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, - RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds()) + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) @@ -1199,7 +1378,9 @@ class TransactionCoordinatorTest { when(transactionManager.validateTransactionTimeoutMs(anyInt())) .thenReturn(true) - val metadata = new TransactionMetadata(transactionalId, 0, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, state, mutable.Set[TopicPartition](new TopicPartition("topic", 1)), 0, 0) + // Since the clientTransactionVersion doesn't matter, use 2 since the states are PrepareCommit and PrepareAbort. + val metadata = new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, + 0, RecordBatch.NO_PRODUCER_EPOCH, 0, state, mutable.Set[TopicPartition](new TopicPartition("topic", 1)), 0, 0, TV_2) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, metadata)))) @@ -1208,14 +1389,16 @@ class TransactionCoordinatorTest { assertEquals(InitProducerIdResult(-1, -1, Errors.CONCURRENT_TRANSACTIONS), result) } - private def validateIncrementEpochAndUpdateMetadata(state: TransactionState): Unit = { + private def validateIncrementEpochAndUpdateMetadata(state: TransactionState, transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) when(pidGenerator.generateProducerId()) .thenReturn(Success(producerId)) when(transactionManager.validateTransactionTimeoutMs(anyInt())) .thenReturn(true) - val metadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, state, mutable.Set.empty[TopicPartition], time.milliseconds(), time.milliseconds()) + val metadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, + producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, state, mutable.Set.empty[TopicPartition], time.milliseconds(), time.milliseconds(), clientTransactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, metadata)))) @@ -1242,13 +1425,13 @@ class TransactionCoordinatorTest { assertEquals(producerId, metadata.producerId) } - private def mockPrepare(transactionState: TransactionState, runCallback: Boolean = false): TransactionMetadata = { + private def mockPrepare(transactionState: TransactionState, clientTransactionVersion: TransactionVersion, runCallback: Boolean = false): TransactionMetadata = { val now = time.milliseconds() - val originalMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, - txnTimeoutMs, Ongoing, partitions, now, now) + val originalMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, + producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now, TV_0) - val transition = TxnTransitMetadata(producerId, producerId, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, - transactionState, partitions.toSet, now, now) + val transition = TxnTransitMetadata(producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, transactionState, partitions.toSet, now, now, clientTransactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, originalMetadata)))) @@ -1264,8 +1447,8 @@ class TransactionCoordinatorTest { capturedErrorsCallback.getValue.apply(Errors.NONE) }) - new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, - txnTimeoutMs, transactionState, partitions, time.milliseconds(), time.milliseconds()) + new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, transactionState, partitions, time.milliseconds(), time.milliseconds(), clientTransactionVersion) } def initProducerIdMockCallback(ret: InitProducerIdResult): Unit = { @@ -1275,4 +1458,17 @@ class TransactionCoordinatorTest { def errorsCallback(ret: Errors): Unit = { error = ret } + + def endTxnCallback(ret: Errors, producerId: Long, epoch: Short): Unit = { + error = ret + newProducerId = producerId + newEpoch = epoch + } + + def requestEpoch(clientTransactionVersion: TransactionVersion): Short = { + if (clientTransactionVersion.supportsEpochBump()) + (producerEpoch - 1).toShort + else + producerEpoch + } } diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala index 3fdf42c6b33..fd5f1e37a65 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala @@ -23,8 +23,9 @@ import org.apache.kafka.common.compress.Compression import org.apache.kafka.common.protocol.{ByteBufferAccessor, MessageUtil} import org.apache.kafka.common.protocol.types.Field.TaggedFieldsSection import org.apache.kafka.common.protocol.types.{CompactArrayOf, Field, Schema, Struct, Type} -import org.apache.kafka.common.record.{MemoryRecords, SimpleRecord} +import org.apache.kafka.common.record.{MemoryRecords, RecordBatch, SimpleRecord} import org.apache.kafka.coordinator.transaction.generated.{TransactionLogKey, TransactionLogValue} +import org.apache.kafka.server.common.TransactionVersion.{TV_0, TV_2} import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows, assertTrue} import org.junit.jupiter.api.Test @@ -48,10 +49,11 @@ class TransactionLogTest { val transactionalId = "transactionalId" val producerId = 23423L - val txnMetadata = TransactionMetadata(transactionalId, producerId, producerEpoch, transactionTimeoutMs, 0) + val txnMetadata = new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_ID, producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, transactionTimeoutMs, Empty, collection.mutable.Set.empty[TopicPartition], 0, 0, TV_0) txnMetadata.addPartitions(topicPartitions) - assertThrows(classOf[IllegalStateException], () => TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), true)) + assertThrows(classOf[IllegalStateException], () => TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), TV_2)) } @Test @@ -72,14 +74,14 @@ class TransactionLogTest { // generate transaction log messages val txnRecords = pidMappings.map { case (transactionalId, producerId) => - val txnMetadata = TransactionMetadata(transactionalId, producerId, producerEpoch, transactionTimeoutMs, - transactionStates(producerId), 0) + val txnMetadata = new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_ID, producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, transactionTimeoutMs, transactionStates(producerId), collection.mutable.Set.empty[TopicPartition], 0, 0, TV_0) if (!txnMetadata.state.equals(Empty)) txnMetadata.addPartitions(topicPartitions) val keyBytes = TransactionLog.keyToBytes(transactionalId) - val valueBytes = TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), true) + val valueBytes = TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), TV_2) new SimpleRecord(keyBytes, valueBytes) }.toSeq @@ -114,12 +116,12 @@ class TransactionLogTest { val producerId = 1334L val topicPartition = new TopicPartition("topic", 0) - val txnMetadata = TransactionMetadata(transactionalId, producerId, producerEpoch, - transactionTimeoutMs, Ongoing, 0) + val txnMetadata = new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_ID, producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, transactionTimeoutMs, Ongoing, collection.mutable.Set.empty[TopicPartition], 0, 0, TV_0) txnMetadata.addPartitions(Set(topicPartition)) val keyBytes = TransactionLog.keyToBytes(transactionalId) - val valueBytes = TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), true) + val valueBytes = TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), TV_2) val transactionMetadataRecord = TestUtils.records(Seq( new SimpleRecord(keyBytes, valueBytes) )).records.asScala.head @@ -144,15 +146,15 @@ class TransactionLogTest { @Test def testSerializeTransactionLogValueToHighestNonFlexibleVersion(): Unit = { - val txnTransitMetadata = TxnTransitMetadata(1, 1, 1, 1, 1000, CompleteCommit, Set.empty, 500, 500) - val txnLogValueBuffer = ByteBuffer.wrap(TransactionLog.valueToBytes(txnTransitMetadata, false)) + val txnTransitMetadata = TxnTransitMetadata(1, 1, 1, 1, 1, 1000, CompleteCommit, Set.empty, 500, 500, TV_0) + val txnLogValueBuffer = ByteBuffer.wrap(TransactionLog.valueToBytes(txnTransitMetadata, TV_0)) assertEquals(0, txnLogValueBuffer.getShort) } @Test def testSerializeTransactionLogValueToFlexibleVersion(): Unit = { - val txnTransitMetadata = TxnTransitMetadata(1, 1, 1, 1, 1000, CompleteCommit, Set.empty, 500, 500) - val txnLogValueBuffer = ByteBuffer.wrap(TransactionLog.valueToBytes(txnTransitMetadata, true)) + val txnTransitMetadata = TxnTransitMetadata(1, 1, 1, 1, 1, 1000, CompleteCommit, Set.empty, 500, 500, TV_2) + val txnLogValueBuffer = ByteBuffer.wrap(TransactionLog.valueToBytes(txnTransitMetadata, TV_2)) assertEquals(TransactionLogValue.HIGHEST_SUPPORTED_VERSION, txnLogValueBuffer.getShort) } @@ -194,8 +196,8 @@ class TransactionLogTest { new Field("topic", Type.COMPACT_STRING, ""), new Field("partition_ids", new CompactArrayOf(Type.INT32), ""), TaggedFieldsSection.of( - Int.box(0), new Field("partition_foo", Type.STRING, ""), - Int.box(1), new Field("partition_foo", Type.INT32, "") + Int.box(100), new Field("partition_foo", Type.STRING, ""), + Int.box(101), new Field("partition_foo", Type.INT32, "") ) ) @@ -204,8 +206,8 @@ class TransactionLogTest { txnPartitions.set("topic", "topic") txnPartitions.set("partition_ids", Array(Integer.valueOf(1))) val txnPartitionsTaggedFields = new java.util.TreeMap[Integer, Any]() - txnPartitionsTaggedFields.put(0, "foo") - txnPartitionsTaggedFields.put(1, 4000) + txnPartitionsTaggedFields.put(100, "foo") + txnPartitionsTaggedFields.put(101, 4000) txnPartitions.set("_tagged_fields", txnPartitionsTaggedFields) // Copy of TransactionLogValue.SCHEMA_1 with a few @@ -219,8 +221,8 @@ class TransactionLogTest { new Field("transaction_last_update_timestamp_ms", Type.INT64, ""), new Field("transaction_start_timestamp_ms", Type.INT64, ""), TaggedFieldsSection.of( - Int.box(0), new Field("txn_foo", Type.STRING, ""), - Int.box(1), new Field("txn_bar", Type.INT32, "") + Int.box(100), new Field("txn_foo", Type.STRING, ""), + Int.box(101), new Field("txn_bar", Type.INT32, "") ) ) @@ -234,8 +236,8 @@ class TransactionLogTest { transactionLogValue.set("transaction_last_update_timestamp_ms", 2000L) transactionLogValue.set("transaction_start_timestamp_ms", 3000L) val txnLogValueTaggedFields = new java.util.TreeMap[Integer, Any]() - txnLogValueTaggedFields.put(0, "foo") - txnLogValueTaggedFields.put(1, 4000) + txnLogValueTaggedFields.put(100, "foo") + txnLogValueTaggedFields.put(101, 4000) transactionLogValue.set("_tagged_fields", txnLogValueTaggedFields) // Prepare the buffer. @@ -249,8 +251,8 @@ class TransactionLogTest { // fields were read but ignored. buffer.getShort() // Skip version. val value = new TransactionLogValue(new ByteBufferAccessor(buffer), 1.toShort) - assertEquals(Seq(0, 1), value.unknownTaggedFields().asScala.map(_.tag)) - assertEquals(Seq(0, 1), value.transactionPartitions().get(0).unknownTaggedFields().asScala.map(_.tag)) + assertEquals(Seq(100, 101), value.unknownTaggedFields().asScala.map(_.tag)) + assertEquals(Seq(100, 101), value.transactionPartitions().get(0).unknownTaggedFields().asScala.map(_.tag)) // Read the buffer with readTxnRecordValue. buffer.rewind() diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala index d40932f3226..3f1c4c67a06 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala @@ -28,7 +28,7 @@ import org.apache.kafka.common.record.RecordBatch import org.apache.kafka.common.requests.{RequestHeader, TransactionResult, WriteTxnMarkersRequest, WriteTxnMarkersResponse} import org.apache.kafka.common.utils.MockTime import org.apache.kafka.common.{Node, TopicPartition} -import org.apache.kafka.server.common.MetadataVersion +import org.apache.kafka.server.common.{MetadataVersion, TransactionVersion} import org.apache.kafka.server.metrics.{KafkaMetricsGroup, KafkaYammerMetrics} import org.apache.kafka.server.util.RequestAndCompletionHandler import org.junit.jupiter.api.Assertions._ @@ -63,10 +63,10 @@ class TransactionMarkerChannelManagerTest { private val coordinatorEpoch2 = 1 private val txnTimeoutMs = 0 private val txnResult = TransactionResult.COMMIT - private val txnMetadata1 = new TransactionMetadata(transactionalId1, producerId1, producerId1, producerEpoch, lastProducerEpoch, - txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](partition1, partition2), 0L, 0L) - private val txnMetadata2 = new TransactionMetadata(transactionalId2, producerId2, producerId2, producerEpoch, lastProducerEpoch, - txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](partition1), 0L, 0L) + private val txnMetadata1 = new TransactionMetadata(transactionalId1, producerId1, producerId1, RecordBatch.NO_PRODUCER_ID, + producerEpoch, lastProducerEpoch, txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](partition1, partition2), 0L, 0L, TransactionVersion.TV_2) + private val txnMetadata2 = new TransactionMetadata(transactionalId2, producerId2, producerId2, RecordBatch.NO_PRODUCER_ID, + producerEpoch, lastProducerEpoch, txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](partition1), 0L, 0L, TransactionVersion.TV_2) private val capturedErrorsCallback: ArgumentCaptor[Errors => Unit] = ArgumentCaptor.forClass(classOf[Errors => Unit]) private val time = new MockTime diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala index 1004915f46c..72ffa5629c0 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala @@ -23,6 +23,7 @@ import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.protocol.{ApiKeys, Errors} import org.apache.kafka.common.record.RecordBatch import org.apache.kafka.common.requests.{RequestHeader, TransactionResult, WriteTxnMarkersRequest, WriteTxnMarkersResponse} +import org.apache.kafka.server.common.TransactionVersion import org.junit.jupiter.api.Assertions._ import org.junit.jupiter.api.Test import org.mockito.ArgumentMatchers @@ -42,8 +43,8 @@ class TransactionMarkerRequestCompletionHandlerTest { private val coordinatorEpoch = 0 private val txnResult = TransactionResult.COMMIT private val topicPartition = new TopicPartition("topic1", 0) - private val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, lastProducerEpoch, - txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](topicPartition), 0L, 0L) + private val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, + producerEpoch, lastProducerEpoch, txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](topicPartition), 0L, 0L, TransactionVersion.TV_2) private val pendingCompleteTxnAndMarkers = asList( PendingCompleteTxnAndMarkerEntry( PendingCompleteTxn(transactionalId, coordinatorEpoch, txnMetadata, txnMetadata.prepareComplete(42)), diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMetadataTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMetadataTest.scala index 5d40c90a4b3..4c56e639d34 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMetadataTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMetadataTest.scala @@ -19,9 +19,13 @@ package kafka.coordinator.transaction import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.protocol.Errors import org.apache.kafka.common.record.RecordBatch +import org.apache.kafka.server.common.TransactionVersion +import org.apache.kafka.server.common.TransactionVersion.{TV_0, TV_2} import org.apache.kafka.server.util.MockTime import org.junit.jupiter.api.Assertions._ import org.junit.jupiter.api.Test +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ValueSource import scala.collection.mutable @@ -38,13 +42,15 @@ class TransactionMetadataTest { val txnMetadata = new TransactionMetadata( transactionalId = transactionalId, producerId = producerId, - lastProducerId = RecordBatch.NO_PRODUCER_ID, + previousProducerId = RecordBatch.NO_PRODUCER_ID, + nextProducerId = RecordBatch.NO_PRODUCER_ID, producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, state = Empty, topicPartitions = mutable.Set.empty, - txnLastUpdateTimestamp = time.milliseconds()) + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = TV_0) val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, None) txnMetadata.completeTransitionTo(transitMetadata) @@ -60,13 +66,15 @@ class TransactionMetadataTest { val txnMetadata = new TransactionMetadata( transactionalId = transactionalId, producerId = producerId, - lastProducerId = RecordBatch.NO_PRODUCER_ID, + previousProducerId = RecordBatch.NO_PRODUCER_ID, + nextProducerId = RecordBatch.NO_PRODUCER_ID, producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, state = Empty, topicPartitions = mutable.Set.empty, - txnLastUpdateTimestamp = time.milliseconds()) + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = TV_0) val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, None) txnMetadata.completeTransitionTo(transitMetadata) @@ -82,13 +90,15 @@ class TransactionMetadataTest { val txnMetadata = new TransactionMetadata( transactionalId = transactionalId, producerId = producerId, - lastProducerId = RecordBatch.NO_PRODUCER_ID, + previousProducerId = RecordBatch.NO_PRODUCER_ID, + nextProducerId = RecordBatch.NO_PRODUCER_ID, producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, state = Empty, topicPartitions = mutable.Set.empty, - txnLastUpdateTimestamp = time.milliseconds()) + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = TV_0) assertTrue(txnMetadata.isProducerEpochExhausted) assertThrows(classOf[IllegalStateException], () => txnMetadata.prepareIncrementProducerEpoch(30000, @@ -101,14 +111,16 @@ class TransactionMetadataTest { val txnMetadata = new TransactionMetadata( transactionalId = transactionalId, producerId = producerId, - lastProducerId = RecordBatch.NO_PRODUCER_ID, + previousProducerId = RecordBatch.NO_PRODUCER_ID, + nextProducerId = RecordBatch.NO_PRODUCER_ID, producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, state = Empty, topicPartitions = mutable.Set.empty, txnStartTimestamp = 1L, - txnLastUpdateTimestamp = time.milliseconds()) + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = TV_0) // let new time be smaller val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, Option(producerEpoch), @@ -127,14 +139,16 @@ class TransactionMetadataTest { val txnMetadata = new TransactionMetadata( transactionalId = transactionalId, producerId = producerId, - lastProducerId = RecordBatch.NO_PRODUCER_ID, + previousProducerId = RecordBatch.NO_PRODUCER_ID, + nextProducerId = RecordBatch.NO_PRODUCER_ID, producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, state = Empty, topicPartitions = mutable.Set.empty, txnStartTimestamp = 1L, - txnLastUpdateTimestamp = time.milliseconds()) + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = TV_0) // let new time be smaller val transitMetadata = txnMetadata.prepareProducerIdRotation(producerId + 1, 30000, time.milliseconds() - 1, recordLastEpoch = true) @@ -152,14 +166,16 @@ class TransactionMetadataTest { val txnMetadata = new TransactionMetadata( transactionalId = transactionalId, producerId = producerId, - lastProducerId = RecordBatch.NO_PRODUCER_ID, + previousProducerId = RecordBatch.NO_PRODUCER_ID, + nextProducerId = RecordBatch.NO_PRODUCER_ID, producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, state = Empty, topicPartitions = mutable.Set.empty, txnStartTimestamp = time.milliseconds(), - txnLastUpdateTimestamp = time.milliseconds()) + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = TV_0) // let new time be smaller; when transiting from Empty the start time would be updated to the update-time var transitMetadata = txnMetadata.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic1", 0)), time.milliseconds() - 1) @@ -188,17 +204,19 @@ class TransactionMetadataTest { val txnMetadata = new TransactionMetadata( transactionalId = transactionalId, producerId = producerId, - lastProducerId = RecordBatch.NO_PRODUCER_ID, + previousProducerId = RecordBatch.NO_PRODUCER_ID, + nextProducerId = RecordBatch.NO_PRODUCER_ID, producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, state = Ongoing, topicPartitions = mutable.Set.empty, txnStartTimestamp = 1L, - txnLastUpdateTimestamp = time.milliseconds()) + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = TV_0) // let new time be smaller - val transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareCommit, time.milliseconds() - 1) + val transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareCommit, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds() - 1) txnMetadata.completeTransitionTo(transitMetadata) assertEquals(PrepareCommit, txnMetadata.state) assertEquals(producerId, txnMetadata.producerId) @@ -214,17 +232,19 @@ class TransactionMetadataTest { val txnMetadata = new TransactionMetadata( transactionalId = transactionalId, producerId = producerId, - lastProducerId = RecordBatch.NO_PRODUCER_ID, + previousProducerId = RecordBatch.NO_PRODUCER_ID, + nextProducerId = RecordBatch.NO_PRODUCER_ID, producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, state = Ongoing, topicPartitions = mutable.Set.empty, txnStartTimestamp = 1L, - txnLastUpdateTimestamp = time.milliseconds()) + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = TV_0) // let new time be smaller - val transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareAbort, time.milliseconds() - 1) + val transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareAbort, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds() - 1) txnMetadata.completeTransitionTo(transitMetadata) assertEquals(PrepareAbort, txnMetadata.state) assertEquals(producerId, txnMetadata.producerId) @@ -234,53 +254,65 @@ class TransactionMetadataTest { assertEquals(time.milliseconds() - 1, txnMetadata.txnLastUpdateTimestamp) } - @Test - def testTolerateTimeShiftDuringCompleteCommit(): Unit = { + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def testTolerateTimeShiftDuringCompleteCommit(transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) val producerEpoch: Short = 1 val txnMetadata = new TransactionMetadata( transactionalId = transactionalId, producerId = producerId, - lastProducerId = RecordBatch.NO_PRODUCER_ID, + previousProducerId = RecordBatch.NO_PRODUCER_ID, + nextProducerId = RecordBatch.NO_PRODUCER_ID, producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, state = PrepareCommit, topicPartitions = mutable.Set.empty, txnStartTimestamp = 1L, - txnLastUpdateTimestamp = time.milliseconds()) + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = clientTransactionVersion) // let new time be smaller val transitMetadata = txnMetadata.prepareComplete(time.milliseconds() - 1) txnMetadata.completeTransitionTo(transitMetadata) + + val lastEpoch = if (clientTransactionVersion.supportsEpochBump()) producerEpoch else RecordBatch.NO_PRODUCER_EPOCH assertEquals(CompleteCommit, txnMetadata.state) assertEquals(producerId, txnMetadata.producerId) - assertEquals(RecordBatch.NO_PRODUCER_EPOCH, txnMetadata.lastProducerEpoch) + assertEquals(lastEpoch, txnMetadata.lastProducerEpoch) assertEquals(producerEpoch, txnMetadata.producerEpoch) assertEquals(1L, txnMetadata.txnStartTimestamp) assertEquals(time.milliseconds() - 1, txnMetadata.txnLastUpdateTimestamp) } - @Test - def testTolerateTimeShiftDuringCompleteAbort(): Unit = { + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def testTolerateTimeShiftDuringCompleteAbort(transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) val producerEpoch: Short = 1 val txnMetadata = new TransactionMetadata( transactionalId = transactionalId, producerId = producerId, - lastProducerId = RecordBatch.NO_PRODUCER_ID, + previousProducerId = RecordBatch.NO_PRODUCER_ID, + nextProducerId = RecordBatch.NO_PRODUCER_ID, producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, state = PrepareAbort, topicPartitions = mutable.Set.empty, txnStartTimestamp = 1L, - txnLastUpdateTimestamp = time.milliseconds()) + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = clientTransactionVersion) // let new time be smaller val transitMetadata = txnMetadata.prepareComplete(time.milliseconds() - 1) txnMetadata.completeTransitionTo(transitMetadata) + + val lastEpoch = if (clientTransactionVersion.supportsEpochBump()) producerEpoch else RecordBatch.NO_PRODUCER_EPOCH assertEquals(CompleteAbort, txnMetadata.state) assertEquals(producerId, txnMetadata.producerId) - assertEquals(RecordBatch.NO_PRODUCER_EPOCH, txnMetadata.lastProducerEpoch) + assertEquals(lastEpoch, txnMetadata.lastProducerEpoch) assertEquals(producerEpoch, txnMetadata.producerEpoch) assertEquals(1L, txnMetadata.txnStartTimestamp) assertEquals(time.milliseconds() - 1, txnMetadata.txnLastUpdateTimestamp) @@ -293,13 +325,15 @@ class TransactionMetadataTest { val txnMetadata = new TransactionMetadata( transactionalId = transactionalId, producerId = producerId, - lastProducerId = RecordBatch.NO_PRODUCER_ID, + previousProducerId = RecordBatch.NO_PRODUCER_ID, + nextProducerId = RecordBatch.NO_PRODUCER_ID, producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, state = Ongoing, topicPartitions = mutable.Set.empty, - txnLastUpdateTimestamp = time.milliseconds()) + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = TV_0) assertTrue(txnMetadata.isProducerEpochExhausted) val fencingTransitMetadata = txnMetadata.prepareFenceProducerEpoch() @@ -310,7 +344,7 @@ class TransactionMetadataTest { // We should reset the pending state to make way for the abort transition. txnMetadata.pendingState = None - val transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareAbort, time.milliseconds()) + val transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareAbort, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds()) txnMetadata.completeTransitionTo(transitMetadata) assertEquals(producerId, transitMetadata.producerId) } @@ -322,13 +356,15 @@ class TransactionMetadataTest { val txnMetadata = new TransactionMetadata( transactionalId = transactionalId, producerId = producerId, - lastProducerId = RecordBatch.NO_PRODUCER_ID, + previousProducerId = RecordBatch.NO_PRODUCER_ID, + nextProducerId = RecordBatch.NO_PRODUCER_ID, producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, state = Ongoing, topicPartitions = mutable.Set.empty, - txnLastUpdateTimestamp = time.milliseconds()) + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = TV_0) assertTrue(txnMetadata.isProducerEpochExhausted) assertThrows(classOf[IllegalStateException], () => txnMetadata.prepareFenceProducerEpoch()) } @@ -340,36 +376,108 @@ class TransactionMetadataTest { val txnMetadata = new TransactionMetadata( transactionalId = transactionalId, producerId = producerId, - lastProducerId = RecordBatch.NO_PRODUCER_ID, + previousProducerId = RecordBatch.NO_PRODUCER_ID, + nextProducerId = RecordBatch.NO_PRODUCER_ID, producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, state = Empty, topicPartitions = mutable.Set.empty, - txnLastUpdateTimestamp = time.milliseconds()) + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = TV_0) val newProducerId = 9893L val transitMetadata = txnMetadata.prepareProducerIdRotation(newProducerId, 30000, time.milliseconds(), recordLastEpoch = true) txnMetadata.completeTransitionTo(transitMetadata) assertEquals(newProducerId, txnMetadata.producerId) - assertEquals(producerId, txnMetadata.lastProducerId) + assertEquals(producerId, txnMetadata.previousProducerId) assertEquals(0, txnMetadata.producerEpoch) assertEquals(producerEpoch, txnMetadata.lastProducerEpoch) } + @Test + def testEpochBumpOnEndTxn(): Unit = { + time.sleep(100) + val producerEpoch = 10.toShort + + val txnMetadata = new TransactionMetadata( + transactionalId = transactionalId, + producerId = producerId, + previousProducerId = RecordBatch.NO_PRODUCER_ID, + nextProducerId = RecordBatch.NO_PRODUCER_ID, + producerEpoch = producerEpoch, + lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, + txnTimeoutMs = 30000, + state = Ongoing, + topicPartitions = mutable.Set.empty, + txnStartTimestamp = time.milliseconds(), + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = TV_0) + + var transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareCommit, TV_2, RecordBatch.NO_PRODUCER_ID, time.milliseconds() - 1) + txnMetadata.completeTransitionTo(transitMetadata) + assertEquals(producerId, txnMetadata.producerId) + assertEquals((producerEpoch + 1).toShort, txnMetadata.producerEpoch) + assertEquals(TV_2, txnMetadata.clientTransactionVersion) + + transitMetadata = txnMetadata.prepareComplete(time.milliseconds()) + txnMetadata.completeTransitionTo(transitMetadata) + assertEquals(producerId, txnMetadata.producerId) + assertEquals((producerEpoch + 1).toShort, txnMetadata.producerEpoch) + assertEquals(TV_2, txnMetadata.clientTransactionVersion) + } + + @Test + def testEpochBumpOnEndTxnOverflow(): Unit = { + time.sleep(100) + val producerEpoch = (Short.MaxValue - 1).toShort + + val txnMetadata = new TransactionMetadata( + transactionalId = transactionalId, + producerId = producerId, + previousProducerId = RecordBatch.NO_PRODUCER_ID, + nextProducerId = RecordBatch.NO_PRODUCER_ID, + producerEpoch = producerEpoch, + lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, + txnTimeoutMs = 30000, + state = Ongoing, + topicPartitions = mutable.Set.empty, + txnStartTimestamp = time.milliseconds(), + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = TV_0) + assertTrue(txnMetadata.isProducerEpochExhausted) + + val newProducerId = 9893L + var transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareCommit, TV_2, newProducerId, time.milliseconds() - 1) + txnMetadata.completeTransitionTo(transitMetadata) + assertEquals(producerId, txnMetadata.producerId) + assertEquals(Short.MaxValue, txnMetadata.producerEpoch) + assertEquals(TV_2, txnMetadata.clientTransactionVersion) + + transitMetadata = txnMetadata.prepareComplete(time.milliseconds()) + txnMetadata.completeTransitionTo(transitMetadata) + assertEquals(newProducerId, txnMetadata.producerId) + assertEquals(0, txnMetadata.producerEpoch) + assertEquals(TV_2, txnMetadata.clientTransactionVersion) + } + @Test def testRotateProducerIdInOngoingState(): Unit = { - assertThrows(classOf[IllegalStateException], () => testRotateProducerIdInOngoingState(Ongoing)) + assertThrows(classOf[IllegalStateException], () => testRotateProducerIdInOngoingState(Ongoing, TV_0)) } - @Test - def testRotateProducerIdInPrepareAbortState(): Unit = { - assertThrows(classOf[IllegalStateException], () => testRotateProducerIdInOngoingState(PrepareAbort)) + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def testRotateProducerIdInPrepareAbortState(transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) + assertThrows(classOf[IllegalStateException], () => testRotateProducerIdInOngoingState(PrepareAbort, clientTransactionVersion)) } - @Test - def testRotateProducerIdInPrepareCommitState(): Unit = { - assertThrows(classOf[IllegalStateException], () => testRotateProducerIdInOngoingState(PrepareCommit)) + @ParameterizedTest + @ValueSource(shorts = Array(0, 2)) + def testRotateProducerIdInPrepareCommitState(transactionVersion: Short): Unit = { + val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) + assertThrows(classOf[IllegalStateException], () => testRotateProducerIdInOngoingState(PrepareCommit, clientTransactionVersion)) } @Test @@ -379,13 +487,15 @@ class TransactionMetadataTest { val txnMetadata = new TransactionMetadata( transactionalId = transactionalId, producerId = producerId, - lastProducerId = RecordBatch.NO_PRODUCER_ID, + previousProducerId = RecordBatch.NO_PRODUCER_ID, + nextProducerId = RecordBatch.NO_PRODUCER_ID, producerEpoch = RecordBatch.NO_PRODUCER_EPOCH, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, state = Empty, topicPartitions = mutable.Set.empty, - txnLastUpdateTimestamp = time.milliseconds()) + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = TV_0) val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, Some(producerEpoch)) txnMetadata.completeTransitionTo(transitMetadata) @@ -401,13 +511,15 @@ class TransactionMetadataTest { val txnMetadata = new TransactionMetadata( transactionalId = transactionalId, producerId = producerId, - lastProducerId = RecordBatch.NO_PRODUCER_ID, + previousProducerId = RecordBatch.NO_PRODUCER_ID, + nextProducerId = RecordBatch.NO_PRODUCER_ID, producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, state = Empty, topicPartitions = mutable.Set.empty, - txnLastUpdateTimestamp = time.milliseconds()) + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = TV_0) val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, Some(producerEpoch)) txnMetadata.completeTransitionTo(transitMetadata) @@ -424,13 +536,15 @@ class TransactionMetadataTest { val txnMetadata = new TransactionMetadata( transactionalId = transactionalId, producerId = producerId, - lastProducerId = RecordBatch.NO_PRODUCER_ID, + previousProducerId = RecordBatch.NO_PRODUCER_ID, + nextProducerId = RecordBatch.NO_PRODUCER_ID, producerEpoch = producerEpoch, lastProducerEpoch = lastProducerEpoch, txnTimeoutMs = 30000, state = Empty, topicPartitions = mutable.Set.empty, - txnLastUpdateTimestamp = time.milliseconds()) + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = TV_0) val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, Some(lastProducerEpoch)) txnMetadata.completeTransitionTo(transitMetadata) @@ -447,13 +561,15 @@ class TransactionMetadataTest { val txnMetadata = new TransactionMetadata( transactionalId = transactionalId, producerId = producerId, - lastProducerId = producerId, + previousProducerId = producerId, + nextProducerId = RecordBatch.NO_PRODUCER_ID, producerEpoch = producerEpoch, lastProducerEpoch = lastProducerEpoch, txnTimeoutMs = 30000, state = Empty, topicPartitions = mutable.Set.empty, - txnLastUpdateTimestamp = time.milliseconds()) + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = TV_0) val result = txnMetadata.prepareIncrementProducerEpoch(30000, Some((lastProducerEpoch - 1).toShort), time.milliseconds()) @@ -503,19 +619,21 @@ class TransactionMetadataTest { assertEquals(Set.empty, unmatchedStates) } - private def testRotateProducerIdInOngoingState(state: TransactionState): Unit = { + private def testRotateProducerIdInOngoingState(state: TransactionState, clientTransactionVersion: TransactionVersion): Unit = { val producerEpoch = (Short.MaxValue - 1).toShort val txnMetadata = new TransactionMetadata( transactionalId = transactionalId, producerId = producerId, - lastProducerId = producerId, + previousProducerId = producerId, + nextProducerId = RecordBatch.NO_PRODUCER_ID, producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, state = state, topicPartitions = mutable.Set.empty, - txnLastUpdateTimestamp = time.milliseconds()) + txnLastUpdateTimestamp = time.milliseconds(), + clientTransactionVersion = clientTransactionVersion) val newProducerId = 9893L txnMetadata.prepareProducerIdRotation(newProducerId, 30000, time.milliseconds(), recordLastEpoch = false) } diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala index acaffea536f..a9783684c13 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala @@ -35,6 +35,7 @@ import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse import org.apache.kafka.common.requests.TransactionResult import org.apache.kafka.common.utils.MockTime import org.apache.kafka.server.common.{FinalizedFeatures, MetadataVersion, RequestLocal, TransactionVersion} +import org.apache.kafka.server.common.TransactionVersion.{TV_0, TV_2} import org.apache.kafka.coordinator.transaction.generated.TransactionLogKey import org.apache.kafka.server.util.MockScheduler import org.apache.kafka.storage.internals.log.{AppendOrigin, FetchDataInfo, FetchIsolation, LogConfig, LogOffsetMetadata} @@ -181,7 +182,7 @@ class TransactionStateManagerTest { new TopicPartition("topic1", 0), new TopicPartition("topic1", 1))) val records = MemoryRecords.withRecords(startOffset, Compression.NONE, - new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), true))) + new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2))) // We create a latch which is awaited while the log is loading. This ensures that the deletion // is triggered before the loading returns @@ -225,19 +226,19 @@ class TransactionStateManagerTest { txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0), new TopicPartition("topic1", 1))) - txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), true)) + txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2)) // pid1's transaction adds three more partitions txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic2", 0), new TopicPartition("topic2", 1), new TopicPartition("topic2", 2))) - txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), true)) + txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2)) // pid1's transaction is preparing to commit txnMetadata1.state = PrepareCommit - txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), true)) + txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2)) // pid2's transaction started with three partitions txnMetadata2.state = Ongoing @@ -245,23 +246,23 @@ class TransactionStateManagerTest { new TopicPartition("topic3", 1), new TopicPartition("topic3", 2))) - txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), true)) + txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), TV_2)) // pid2's transaction is preparing to abort txnMetadata2.state = PrepareAbort - txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), true)) + txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), TV_2)) // pid2's transaction has aborted txnMetadata2.state = CompleteAbort - txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), true)) + txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), TV_2)) // pid2's epoch has advanced, with no ongoing transaction yet txnMetadata2.state = Empty txnMetadata2.topicPartitions.clear() - txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), true)) + txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), TV_2)) val startOffset = 15L // it should work for any start offset val records = MemoryRecords.withRecords(startOffset, Compression.NONE, txnRecords.toArray: _*) @@ -796,14 +797,9 @@ class TransactionStateManagerTest { // write the change. If the write fails (e.g. under min isr), the TransactionMetadata // is left at it is. If the transactional id is never reused, the TransactionMetadata // will be expired and it should succeed. - val txnMetadata = TransactionMetadata( - transactionalId = transactionalId, - producerId = 1, - producerEpoch = RecordBatch.NO_PRODUCER_EPOCH, - txnTimeoutMs = transactionTimeoutMs, - state = Empty, - timestamp = time.milliseconds() - ) + val timestamp = time.milliseconds() + val txnMetadata = new TransactionMetadata(transactionalId, 1, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, + RecordBatch.NO_PRODUCER_EPOCH, transactionTimeoutMs, Empty, collection.mutable.Set.empty[TopicPartition], timestamp, timestamp, TV_0) transactionManager.putTransactionStateIfNotExists(txnMetadata) time.sleep(txnConfig.transactionalIdExpirationMs + 1) @@ -890,7 +886,7 @@ class TransactionStateManagerTest { txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0), new TopicPartition("topic1", 1))) - txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), true)) + txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2)) val startOffset = 0L val records = MemoryRecords.withRecords(startOffset, Compression.NONE, txnRecords.toArray: _*) @@ -1053,7 +1049,7 @@ class TransactionStateManagerTest { txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0), new TopicPartition("topic1", 1))) - txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), true)) + txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2)) val startOffset = 0L val records = MemoryRecords.withRecords(startOffset, Compression.NONE, txnRecords.toArray: _*) @@ -1081,7 +1077,9 @@ class TransactionStateManagerTest { producerId: Long, state: TransactionState = Empty, txnTimeout: Int = transactionTimeoutMs): TransactionMetadata = { - TransactionMetadata(transactionalId, producerId, 0.toShort, txnTimeout, state, time.milliseconds()) + val timestamp = time.milliseconds() + new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_ID, 0.toShort, + RecordBatch.NO_PRODUCER_EPOCH, txnTimeout, state, collection.mutable.Set.empty[TopicPartition], timestamp, timestamp, TV_0) } private def prepareTxnLog(topicPartition: TopicPartition, @@ -1159,7 +1157,7 @@ class TransactionStateManagerTest { txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 1), new TopicPartition("topic1", 1))) - txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), true)) + txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2)) val startOffset = 15L val records = MemoryRecords.withRecords(startOffset, Compression.NONE, txnRecords.toArray: _*) @@ -1178,7 +1176,7 @@ class TransactionStateManagerTest { txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0), new TopicPartition("topic1", 1))) - txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), true)) + txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2)) val startOffset = 0L val unknownKey = new TransactionLogKey() @@ -1199,7 +1197,7 @@ class TransactionStateManagerTest { val txnMetadata = txnMetadataPool.get(transactionalId1) assertEquals(txnMetadata1.transactionalId, txnMetadata.transactionalId) assertEquals(txnMetadata1.producerId, txnMetadata.producerId) - assertEquals(txnMetadata1.lastProducerId, txnMetadata.lastProducerId) + assertEquals(txnMetadata1.previousProducerId, txnMetadata.previousProducerId) assertEquals(txnMetadata1.producerEpoch, txnMetadata.producerEpoch) assertEquals(txnMetadata1.lastProducerEpoch, txnMetadata.lastProducerEpoch) assertEquals(txnMetadata1.txnTimeoutMs, txnMetadata.txnTimeoutMs) @@ -1210,7 +1208,7 @@ class TransactionStateManagerTest { @ParameterizedTest @EnumSource(classOf[TransactionVersion]) - def testUsesFlexibleRecords(transactionVersion: TransactionVersion): Unit = { + def testTransactionVersionInTransactionManager(transactionVersion: TransactionVersion): Unit = { val metadataCache = mock(classOf[MetadataCache]) when(metadataCache.features()).thenReturn { new FinalizedFeatures( @@ -1223,7 +1221,6 @@ class TransactionStateManagerTest { val transactionManager = new TransactionStateManager(0, scheduler, replicaManager, metadataCache, txnConfig, time, metrics) - val expectFlexibleRecords = transactionVersion.featureLevel > 0 - assertEquals(expectFlexibleRecords, transactionManager.usesFlexibleRecords()) + assertEquals(transactionVersion, transactionManager.transactionVersionLevel()) } } diff --git a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala index 7972e9c5355..371bc0384a7 100644 --- a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala +++ b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala @@ -85,7 +85,7 @@ import org.apache.kafka.security.authorizer.AclEntry import org.apache.kafka.server.ClientMetricsManager import org.apache.kafka.server.authorizer.{Action, AuthorizationResult, Authorizer} import org.apache.kafka.server.common.MetadataVersion.{IBP_0_10_2_IV0, IBP_2_2_IV1} -import org.apache.kafka.server.common.{FeatureVersion, FinalizedFeatures, GroupVersion, KRaftVersion, MetadataVersion, RequestLocal} +import org.apache.kafka.server.common.{FeatureVersion, FinalizedFeatures, GroupVersion, KRaftVersion, MetadataVersion, RequestLocal, TransactionVersion} import org.apache.kafka.server.config.{ConfigType, KRaftConfigs, ReplicationConfigs, ServerConfigs, ServerLogConfigs, ShareGroupConfig} import org.apache.kafka.server.metrics.ClientMetricsTestUtils import org.apache.kafka.server.share.{CachedSharePartition, ErroneousAndValidPartitionData} @@ -2572,7 +2572,7 @@ class KafkaApisTest extends Logging { reset(replicaManager, clientRequestQuotaManager, requestChannel, txnCoordinator) val capturedResponse: ArgumentCaptor[EndTxnResponse] = ArgumentCaptor.forClass(classOf[EndTxnResponse]) - val responseCallback: ArgumentCaptor[Errors => Unit] = ArgumentCaptor.forClass(classOf[Errors => Unit]) + val responseCallback: ArgumentCaptor[(Errors, Long, Short) => Unit] = ArgumentCaptor.forClass(classOf[(Errors, Long, Short) => Unit]) val transactionalId = "txnId" val producerId = 15L @@ -2587,15 +2587,18 @@ class KafkaApisTest extends Logging { ).build(version.toShort) val request = buildRequest(endTxnRequest) + val clientTransactionVersion = if (version > 4) TransactionVersion.TV_2 else TransactionVersion.TV_0 + val requestLocal = RequestLocal.withThreadConfinedCaching when(txnCoordinator.handleEndTransaction( ArgumentMatchers.eq(transactionalId), ArgumentMatchers.eq(producerId), ArgumentMatchers.eq(epoch), ArgumentMatchers.eq(TransactionResult.COMMIT), + ArgumentMatchers.eq(clientTransactionVersion), responseCallback.capture(), ArgumentMatchers.eq(requestLocal) - )).thenAnswer(_ => responseCallback.getValue.apply(Errors.PRODUCER_FENCED)) + )).thenAnswer(_ => responseCallback.getValue.apply(Errors.PRODUCER_FENCED, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH)) val kafkaApis = createKafkaApis() try { kafkaApis.handleEndTxnRequest(request, requestLocal) diff --git a/server-common/src/main/java/org/apache/kafka/server/common/TransactionVersion.java b/server-common/src/main/java/org/apache/kafka/server/common/TransactionVersion.java index 64163d5a0b4..36dadb5cf11 100644 --- a/server-common/src/main/java/org/apache/kafka/server/common/TransactionVersion.java +++ b/server-common/src/main/java/org/apache/kafka/server/common/TransactionVersion.java @@ -49,6 +49,10 @@ public enum TransactionVersion implements FeatureVersion { return featureLevel; } + public static TransactionVersion fromFeatureLevel(short version) { + return (TransactionVersion) Features.TRANSACTION_VERSION.fromFeatureLevel(version, true); + } + @Override public String featureName() { return FEATURE_NAME; @@ -63,4 +67,14 @@ public enum TransactionVersion implements FeatureVersion { public Map dependencies() { return dependencies; } + + // Transactions V1 enables log version 0 (flexible fields) + public short transactionLogValueVersion() { + return (short) (featureLevel >= 1 ? 1 : 0); + } + + // Transactions V2 enables epoch bump on commit/abort. + public boolean supportsEpochBump() { + return featureLevel >= 2; + } } diff --git a/transaction-coordinator/src/main/resources/common/message/TransactionLogValue.json b/transaction-coordinator/src/main/resources/common/message/TransactionLogValue.json index c6efc772d58..c9a1cbf66e3 100644 --- a/transaction-coordinator/src/main/resources/common/message/TransactionLogValue.json +++ b/transaction-coordinator/src/main/resources/common/message/TransactionLogValue.json @@ -24,6 +24,10 @@ "fields": [ { "name": "ProducerId", "type": "int64", "versions": "0+", "about": "Producer id in use by the transactional id"}, + { "name": "PreviousProducerId", "type": "int64", "taggedVersions": "1+", "tag": 0, "default": -1, + "about": "Producer id used by the last committed transaction"}, + { "name": "NextProducerId", "type": "int64", "taggedVersions": "1+", "tag": 1, "default": -1, + "about": "Latest producer ID sent to the producer for the given transactional ID"}, { "name": "ProducerEpoch", "type": "int16", "versions": "0+", "about": "Epoch associated with the producer id"}, { "name": "TransactionTimeoutMs", "type": "int32", "versions": "0+", @@ -37,6 +41,8 @@ { "name": "TransactionLastUpdateTimestampMs", "type": "int64", "versions": "0+", "about": "Time the transaction was last updated"}, { "name": "TransactionStartTimestampMs", "type": "int64", "versions": "0+", - "about": "Time the transaction was started"} + "about": "Time the transaction was started"}, + { "name": "ClientTransactionVersion", "type": "int16", "default": 0, "taggedVersions": "1+", "tag": 2, + "about": "The transaction version used by the client"} ] }