KAFKA-14562 [1/2]: Implement epoch bump after every transaction (#16719)

Implement server side changes for epoch bump but keep EndTxn as an unstable API until the client side changes are implemented. EndTxnResponse will return the producer ID and epoch for the transaction. Introduces new tagged fields to the TransactionLogValue to persist the clientTransactionVersion, previousProducerId, and nextProducerId to the log so that the state can be reloaded. See KIP-890 for more details.

Small updates to naming of lastProducerId -> PreviousProducerId. Also cleans up the many TransactionMetadata constructors.

Reviewers: Artem Livshits <alivshits@confluent.io>, David Jacot <djacot@confluent.io>
This commit is contained in:
Justine Olshan 2024-09-26 09:37:11 -07:00 committed by GitHub
parent 3a1465e14c
commit ede0c94aaa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 850 additions and 380 deletions

View File

@ -32,7 +32,11 @@ public class EndTxnRequest extends AbstractRequest {
public final EndTxnRequestData data; public final EndTxnRequestData data;
public Builder(EndTxnRequestData data) { public Builder(EndTxnRequestData data) {
super(ApiKeys.END_TXN); this(data, false);
}
public Builder(EndTxnRequestData data, boolean enableUnstableLastVersion) {
super(ApiKeys.END_TXN, enableUnstableLastVersion);
this.data = data; this.data = data;
} }

View File

@ -25,7 +25,10 @@
// Version 3 enables flexible versions. // Version 3 enables flexible versions.
// //
// Version 4 adds support for new error code TRANSACTION_ABORTABLE (KIP-890). // Version 4 adds support for new error code TRANSACTION_ABORTABLE (KIP-890).
"validVersions": "0-4", //
// Version 5 enables bumping epoch on every transaction (KIP-890 Part 2)
"latestVersionUnstable": true,
"validVersions": "0-5",
"flexibleVersions": "3+", "flexibleVersions": "3+",
"fields": [ "fields": [
{ "name": "TransactionalId", "type": "string", "versions": "0+", "entityType": "transactionalId", { "name": "TransactionalId", "type": "string", "versions": "0+", "entityType": "transactionalId",

View File

@ -24,12 +24,18 @@
// Version 3 enables flexible versions. // Version 3 enables flexible versions.
// //
// Version 4 adds support for new error code TRANSACTION_ABORTABLE (KIP-890). // Version 4 adds support for new error code TRANSACTION_ABORTABLE (KIP-890).
"validVersions": "0-4", //
// Version 5 enables bumping epoch on every transaction (KIP-890 Part 2), so producer ID and epoch are included in the response.
"validVersions": "0-5",
"flexibleVersions": "3+", "flexibleVersions": "3+",
"fields": [ "fields": [
{ "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+",
"about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." },
{ "name": "ErrorCode", "type": "int16", "versions": "0+", { "name": "ErrorCode", "type": "int16", "versions": "0+",
"about": "The error code, or 0 if there was no error." } "about": "The error code, or 0 if there was no error." },
{ "name": "ProducerId", "type": "int64", "versions": "5+", "entityType": "producerId", "default": "-1", "ignorable": "true",
"about": "The producer ID." },
{ "name": "ProducerEpoch", "type": "int16", "versions": "5+", "default": "-1", "ignorable": "true",
"about": "The current epoch associated with the producer." }
] ]
} }

View File

@ -29,7 +29,7 @@ import org.apache.kafka.common.protocol.Errors
import org.apache.kafka.common.record.RecordBatch import org.apache.kafka.common.record.RecordBatch
import org.apache.kafka.common.requests.{AddPartitionsToTxnResponse, TransactionResult} import org.apache.kafka.common.requests.{AddPartitionsToTxnResponse, TransactionResult}
import org.apache.kafka.common.utils.{LogContext, ProducerIdAndEpoch, Time} import org.apache.kafka.common.utils.{LogContext, ProducerIdAndEpoch, Time}
import org.apache.kafka.server.common.RequestLocal import org.apache.kafka.server.common.{RequestLocal, TransactionVersion}
import org.apache.kafka.server.util.Scheduler import org.apache.kafka.server.util.Scheduler
import scala.jdk.CollectionConverters._ import scala.jdk.CollectionConverters._
@ -98,7 +98,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
private type InitProducerIdCallback = InitProducerIdResult => Unit private type InitProducerIdCallback = InitProducerIdResult => Unit
private type AddPartitionsCallback = Errors => Unit private type AddPartitionsCallback = Errors => Unit
private type VerifyPartitionsCallback = AddPartitionsToTxnResult => Unit private type VerifyPartitionsCallback = AddPartitionsToTxnResult => Unit
private type EndTxnCallback = Errors => Unit private type EndTxnCallback = (Errors, Long, Short) => Unit
private type ApiResult[T] = Either[Errors, T] private type ApiResult[T] = Either[Errors, T]
/* Active flag of the coordinator */ /* Active flag of the coordinator */
@ -135,13 +135,15 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
case Success(producerId) => case Success(producerId) =>
val createdMetadata = new TransactionMetadata(transactionalId = transactionalId, val createdMetadata = new TransactionMetadata(transactionalId = transactionalId,
producerId = producerId, producerId = producerId,
lastProducerId = RecordBatch.NO_PRODUCER_ID, previousProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = RecordBatch.NO_PRODUCER_EPOCH, producerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = transactionTimeoutMs, txnTimeoutMs = transactionTimeoutMs,
state = Empty, state = Empty,
topicPartitions = collection.mutable.Set.empty[TopicPartition], topicPartitions = collection.mutable.Set.empty[TopicPartition],
txnLastUpdateTimestamp = time.milliseconds()) txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TransactionVersion.TV_0)
txnManager.putTransactionStateIfNotExists(createdMetadata) txnManager.putTransactionStateIfNotExists(createdMetadata)
case Failure(exception) => case Failure(exception) =>
@ -169,7 +171,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
case Right((coordinatorEpoch, newMetadata)) => case Right((coordinatorEpoch, newMetadata)) =>
if (newMetadata.txnState == PrepareEpochFence) { if (newMetadata.txnState == PrepareEpochFence) {
// abort the ongoing transaction and then return CONCURRENT_TRANSACTIONS to let client wait and retry // abort the ongoing transaction and then return CONCURRENT_TRANSACTIONS to let client wait and retry
def sendRetriableErrorCallback(error: Errors): Unit = { def sendRetriableErrorCallback(error: Errors, newProducerId: Long, newProducerEpoch: Short): Unit = {
if (error != Errors.NONE) { if (error != Errors.NONE) {
responseCallback(initTransactionError(error)) responseCallback(initTransactionError(error))
} else { } else {
@ -182,6 +184,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
newMetadata.producerEpoch, newMetadata.producerEpoch,
TransactionResult.ABORT, TransactionResult.ABORT,
isFromClient = false, isFromClient = false,
clientTransactionVersion = txnManager.transactionVersionLevel(), // Since this is not from client, use server TV
sendRetriableErrorCallback, sendRetriableErrorCallback,
requestLocal) requestLocal)
} else { } else {
@ -221,7 +224,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
// could be a retry after a valid epoch bump that the producer never received the response for // could be a retry after a valid epoch bump that the producer never received the response for
txnMetadata.producerEpoch == RecordBatch.NO_PRODUCER_EPOCH || txnMetadata.producerEpoch == RecordBatch.NO_PRODUCER_EPOCH ||
producerIdAndEpoch.producerId == txnMetadata.producerId || producerIdAndEpoch.producerId == txnMetadata.producerId ||
(producerIdAndEpoch.producerId == txnMetadata.lastProducerId && TransactionMetadata.isEpochExhausted(producerIdAndEpoch.epoch)) (producerIdAndEpoch.producerId == txnMetadata.previousProducerId && TransactionMetadata.isEpochExhausted(producerIdAndEpoch.epoch))
} }
if (txnMetadata.pendingTransitionInProgress) { if (txnMetadata.pendingTransitionInProgress) {
@ -487,6 +490,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
producerId: Long, producerId: Long,
producerEpoch: Short, producerEpoch: Short,
txnMarkerResult: TransactionResult, txnMarkerResult: TransactionResult,
clientTransactionVersion: TransactionVersion,
responseCallback: EndTxnCallback, responseCallback: EndTxnCallback,
requestLocal: RequestLocal = RequestLocal.noCaching): Unit = { requestLocal: RequestLocal = RequestLocal.noCaching): Unit = {
endTransaction(transactionalId, endTransaction(transactionalId,
@ -494,6 +498,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
producerEpoch, producerEpoch,
txnMarkerResult, txnMarkerResult,
isFromClient = true, isFromClient = true,
clientTransactionVersion,
responseCallback, responseCallback,
requestLocal) requestLocal)
} }
@ -503,12 +508,15 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
producerEpoch: Short, producerEpoch: Short,
txnMarkerResult: TransactionResult, txnMarkerResult: TransactionResult,
isFromClient: Boolean, isFromClient: Boolean,
clientTransactionVersion: TransactionVersion,
responseCallback: EndTxnCallback, responseCallback: EndTxnCallback,
requestLocal: RequestLocal): Unit = { requestLocal: RequestLocal): Unit = {
var isEpochFence = false var isEpochFence = false
if (transactionalId == null || transactionalId.isEmpty) if (transactionalId == null || transactionalId.isEmpty)
responseCallback(Errors.INVALID_REQUEST) responseCallback(Errors.INVALID_REQUEST, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH)
else { else {
var producerIdCopy = RecordBatch.NO_PRODUCER_ID
var producerEpochCopy = RecordBatch.NO_PRODUCER_EPOCH
val preAppendResult: ApiResult[(Int, TxnTransitMetadata)] = txnManager.getTransactionState(transactionalId).flatMap { val preAppendResult: ApiResult[(Int, TxnTransitMetadata)] = txnManager.getTransactionState(transactionalId).flatMap {
case None => case None =>
Left(Errors.INVALID_PRODUCER_ID_MAPPING) Left(Errors.INVALID_PRODUCER_ID_MAPPING)
@ -518,10 +526,39 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
val coordinatorEpoch = epochAndTxnMetadata.coordinatorEpoch val coordinatorEpoch = epochAndTxnMetadata.coordinatorEpoch
txnMetadata.inLock { txnMetadata.inLock {
if (txnMetadata.producerId != producerId) producerIdCopy = txnMetadata.producerId
producerEpochCopy = txnMetadata.producerEpoch
// PrepareEpochFence has slightly different epoch bumping logic so don't include it here.
val currentTxnMetadataIsAtLeastTransactionsV2 = !txnMetadata.pendingState.contains(PrepareEpochFence) && txnMetadata.clientTransactionVersion.supportsEpochBump()
// True if the client used TV_2 and retried a request that had overflowed the epoch, and a new producer ID is stored in the txnMetadata
val retryOnOverflow = currentTxnMetadataIsAtLeastTransactionsV2 &&
txnMetadata.previousProducerId == producerId && producerEpoch == Short.MaxValue - 1 && txnMetadata.producerEpoch == 0
// True if the client used TV_2 and retried an endTxn request, and the bumped producer epoch is stored in the txnMetadata.
val retryOnEpochBump = endTxnEpochBumped(txnMetadata, producerEpoch)
val isValidEpoch = {
if (currentTxnMetadataIsAtLeastTransactionsV2) {
// With transactions V2, state + same epoch is not sufficient to determine if a retry transition is valid. If the epoch is the
// same it actually indicates the next endTransaction call. Instead, we want to check the epoch matches with the epoch in the retry conditions.
// Return producer fenced even in the cases where the epoch is higher and could indicate an invalid state transition.
// Use the following criteria to determine if a v2 retry is valid:
txnMetadata.state match {
case Ongoing | Empty | Dead | PrepareEpochFence =>
producerEpoch == txnMetadata.producerEpoch
case PrepareCommit | PrepareAbort =>
retryOnEpochBump
case CompleteCommit | CompleteAbort =>
retryOnEpochBump || retryOnOverflow
}
} else {
// For transactions V1 strict equality is enforced on the client side requests, as they shouldn't bump the producer epoch without server knowledge.
(!isFromClient || producerEpoch == txnMetadata.producerEpoch) && producerEpoch >= txnMetadata.producerEpoch
}
}
if (txnMetadata.producerId != producerId && !retryOnOverflow)
Left(Errors.INVALID_PRODUCER_ID_MAPPING) Left(Errors.INVALID_PRODUCER_ID_MAPPING)
// Strict equality is enforced on the client side requests, as they shouldn't bump the producer epoch. else if (!isValidEpoch)
else if ((isFromClient && producerEpoch != txnMetadata.producerEpoch) || producerEpoch < txnMetadata.producerEpoch)
Left(Errors.PRODUCER_FENCED) Left(Errors.PRODUCER_FENCED)
else if (txnMetadata.pendingTransitionInProgress && txnMetadata.pendingState.get != PrepareEpochFence) else if (txnMetadata.pendingTransitionInProgress && txnMetadata.pendingState.get != PrepareEpochFence)
Left(Errors.CONCURRENT_TRANSACTIONS) Left(Errors.CONCURRENT_TRANSACTIONS)
@ -532,6 +569,19 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
else else
PrepareAbort PrepareAbort
// Maybe allocate new producer ID if we are bumping epoch and epoch is exhausted
val nextProducerIdOrErrors =
if (clientTransactionVersion.supportsEpochBump() && !txnMetadata.pendingState.contains(PrepareEpochFence) && txnMetadata.isProducerEpochExhausted) {
producerIdManager.generateProducerId() match {
case Success(newProducerId) =>
Right(newProducerId)
case Failure(exception) =>
Left(Errors.forException(exception))
}
} else {
Right(RecordBatch.NO_PRODUCER_ID)
}
if (nextState == PrepareAbort && txnMetadata.pendingState.contains(PrepareEpochFence)) { if (nextState == PrepareAbort && txnMetadata.pendingState.contains(PrepareEpochFence)) {
// We should clear the pending state to make way for the transition to PrepareAbort and also bump // We should clear the pending state to make way for the transition to PrepareAbort and also bump
// the epoch in the transaction metadata we are about to append. // the epoch in the transaction metadata we are about to append.
@ -541,7 +591,10 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
txnMetadata.lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH txnMetadata.lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH
} }
Right(coordinatorEpoch, txnMetadata.prepareAbortOrCommit(nextState, time.milliseconds())) nextProducerIdOrErrors.flatMap {
nextProducerId =>
Right(coordinatorEpoch, txnMetadata.prepareAbortOrCommit(nextState, clientTransactionVersion, nextProducerId, time.milliseconds()))
}
case CompleteCommit => case CompleteCommit =>
if (txnMarkerResult == TransactionResult.COMMIT) if (txnMarkerResult == TransactionResult.COMMIT)
Left(Errors.NONE) Left(Errors.NONE)
@ -576,8 +629,12 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
preAppendResult match { preAppendResult match {
case Left(err) => case Left(err) =>
if (err == Errors.NONE) {
responseCallback(err, producerIdCopy, producerEpochCopy)
} else {
debug(s"Aborting append of $txnMarkerResult to transaction log with coordinator and returning $err error to client for $transactionalId's EndTransaction request") debug(s"Aborting append of $txnMarkerResult to transaction log with coordinator and returning $err error to client for $transactionalId's EndTransaction request")
responseCallback(err) responseCallback(err, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH)
}
case Right((coordinatorEpoch, newMetadata)) => case Right((coordinatorEpoch, newMetadata)) =>
def sendTxnMarkersCallback(error: Errors): Unit = { def sendTxnMarkersCallback(error: Errors): Unit = {
@ -595,7 +652,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
txnMetadata.inLock { txnMetadata.inLock {
if (txnMetadata.producerId != producerId) if (txnMetadata.producerId != producerId)
Left(Errors.INVALID_PRODUCER_ID_MAPPING) Left(Errors.INVALID_PRODUCER_ID_MAPPING)
else if (txnMetadata.producerEpoch != producerEpoch) else if (txnMetadata.producerEpoch != producerEpoch && !endTxnEpochBumped(txnMetadata, producerEpoch))
Left(Errors.PRODUCER_FENCED) Left(Errors.PRODUCER_FENCED)
else if (txnMetadata.pendingTransitionInProgress) else if (txnMetadata.pendingTransitionInProgress)
Left(Errors.CONCURRENT_TRANSACTIONS) Left(Errors.CONCURRENT_TRANSACTIONS)
@ -630,12 +687,12 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
preSendResult match { preSendResult match {
case Left(err) => case Left(err) =>
info(s"Aborting sending of transaction markers after appended $txnMarkerResult to transaction log and returning $err error to client for $transactionalId's EndTransaction request") info(s"Aborting sending of transaction markers after appended $txnMarkerResult to transaction log and returning $err error to client for $transactionalId's EndTransaction request")
responseCallback(err) responseCallback(err, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH)
case Right((txnMetadata, newPreSendMetadata)) => case Right((txnMetadata, newPreSendMetadata)) =>
// we can respond to the client immediately and continue to write the txn markers if // we can respond to the client immediately and continue to write the txn markers if
// the log append was successful // the log append was successful
responseCallback(Errors.NONE) responseCallback(Errors.NONE, txnMetadata.producerId, txnMetadata.producerEpoch)
txnMarkerChannelManager.addTxnMarkersToSend(coordinatorEpoch, txnMarkerResult, txnMetadata, newPreSendMetadata) txnMarkerChannelManager.addTxnMarkersToSend(coordinatorEpoch, txnMarkerResult, txnMetadata, newPreSendMetadata)
} }
@ -659,7 +716,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
} }
} }
responseCallback(error) responseCallback(error, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH)
} }
} }
@ -669,11 +726,19 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
} }
} }
// When a client and server support V2, every endTransaction call bumps the producer epoch. When checking epoch, we want to
// check epoch + 1. Epoch bumps from PrepareEpochFence state are handled separately, so this method should not be used to check that case.
// Returns true if the transaction state epoch is the specified producer epoch + 1 and epoch bump on every transaction is expected.
private def endTxnEpochBumped(txnMetadata: TransactionMetadata, producerEpoch: Short): Boolean = {
!txnMetadata.pendingState.contains(PrepareEpochFence) && txnMetadata.clientTransactionVersion.supportsEpochBump() &&
txnMetadata.producerEpoch == producerEpoch + 1
}
def transactionTopicConfigs: Properties = txnManager.transactionTopicConfigs def transactionTopicConfigs: Properties = txnManager.transactionTopicConfigs
def partitionFor(transactionalId: String): Int = txnManager.partitionFor(transactionalId) def partitionFor(transactionalId: String): Int = txnManager.partitionFor(transactionalId)
private def onEndTransactionComplete(txnIdAndPidEpoch: TransactionalIdAndProducerIdEpoch)(error: Errors): Unit = { private def onEndTransactionComplete(txnIdAndPidEpoch: TransactionalIdAndProducerIdEpoch)(error: Errors, newProducerId: Long, newProducerEpoch: Short): Unit = {
error match { error match {
case Errors.NONE => case Errors.NONE =>
info("Completed rollback of ongoing transaction for transactionalId " + info("Completed rollback of ongoing transaction for transactionalId " +
@ -721,6 +786,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
txnTransitMetadata.producerEpoch, txnTransitMetadata.producerEpoch,
TransactionResult.ABORT, TransactionResult.ABORT,
isFromClient = false, isFromClient = false,
clientTransactionVersion = txnManager.transactionVersionLevel(), // Since this is not from client, use server TV
onComplete(txnIdAndPidEpoch), onComplete(txnIdAndPidEpoch),
RequestLocal.noCaching) RequestLocal.noCaching)
} }

View File

@ -25,6 +25,7 @@ import org.apache.kafka.common.protocol.{ByteBufferAccessor, MessageUtil}
import org.apache.kafka.common.record.{Record, RecordBatch} import org.apache.kafka.common.record.{Record, RecordBatch}
import org.apache.kafka.common.{MessageFormatter, TopicPartition} import org.apache.kafka.common.{MessageFormatter, TopicPartition}
import org.apache.kafka.coordinator.transaction.generated.{TransactionLogKey, TransactionLogValue} import org.apache.kafka.coordinator.transaction.generated.{TransactionLogKey, TransactionLogValue}
import org.apache.kafka.server.common.TransactionVersion
import scala.collection.mutable import scala.collection.mutable
import scala.jdk.CollectionConverters._ import scala.jdk.CollectionConverters._
@ -63,7 +64,7 @@ object TransactionLog {
* @return value payload bytes * @return value payload bytes
*/ */
private[transaction] def valueToBytes(txnMetadata: TxnTransitMetadata, private[transaction] def valueToBytes(txnMetadata: TxnTransitMetadata,
usesFlexibleRecords: Boolean): Array[Byte] = { transactionVersionLevel: TransactionVersion): Array[Byte] = {
if (txnMetadata.txnState == Empty && txnMetadata.topicPartitions.nonEmpty) if (txnMetadata.txnState == Empty && txnMetadata.topicPartitions.nonEmpty)
throw new IllegalStateException(s"Transaction is not expected to have any partitions since its state is ${txnMetadata.txnState}: $txnMetadata") throw new IllegalStateException(s"Transaction is not expected to have any partitions since its state is ${txnMetadata.txnState}: $txnMetadata")
@ -78,9 +79,7 @@ object TransactionLog {
// Serialize with version 0 (highest non-flexible version) until transaction.version 1 is enabled // Serialize with version 0 (highest non-flexible version) until transaction.version 1 is enabled
// which enables flexible fields in records. // which enables flexible fields in records.
val version: Short = MessageUtil.toVersionPrefixedBytes(transactionVersionLevel.transactionLogValueVersion(),
if (usesFlexibleRecords) 1 else 0
MessageUtil.toVersionPrefixedBytes(version,
new TransactionLogValue() new TransactionLogValue()
.setProducerId(txnMetadata.producerId) .setProducerId(txnMetadata.producerId)
.setProducerEpoch(txnMetadata.producerEpoch) .setProducerEpoch(txnMetadata.producerEpoch)
@ -88,7 +87,8 @@ object TransactionLog {
.setTransactionStatus(txnMetadata.txnState.id) .setTransactionStatus(txnMetadata.txnState.id)
.setTransactionLastUpdateTimestampMs(txnMetadata.txnLastUpdateTimestamp) .setTransactionLastUpdateTimestampMs(txnMetadata.txnLastUpdateTimestamp)
.setTransactionStartTimestampMs(txnMetadata.txnStartTimestamp) .setTransactionStartTimestampMs(txnMetadata.txnStartTimestamp)
.setTransactionPartitions(transactionPartitions)) .setTransactionPartitions(transactionPartitions)
.setClientTransactionVersion(txnMetadata.clientTransactionVersion.featureLevel()))
} }
/** /**
@ -124,14 +124,16 @@ object TransactionLog {
val transactionMetadata = new TransactionMetadata( val transactionMetadata = new TransactionMetadata(
transactionalId = transactionalId, transactionalId = transactionalId,
producerId = value.producerId, producerId = value.producerId,
lastProducerId = RecordBatch.NO_PRODUCER_ID, previousProducerId = value.previousProducerId,
nextProducerId = value.nextProducerId,
producerEpoch = value.producerEpoch, producerEpoch = value.producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = value.transactionTimeoutMs, txnTimeoutMs = value.transactionTimeoutMs,
state = TransactionState.fromId(value.transactionStatus), state = TransactionState.fromId(value.transactionStatus),
topicPartitions = mutable.Set.empty[TopicPartition], topicPartitions = mutable.Set.empty[TopicPartition],
txnStartTimestamp = value.transactionStartTimestampMs, txnStartTimestamp = value.transactionStartTimestampMs,
txnLastUpdateTimestamp = value.transactionLastUpdateTimestampMs) txnLastUpdateTimestamp = value.transactionLastUpdateTimestampMs,
clientTransactionVersion = TransactionVersion.fromFeatureLevel(value.clientTransactionVersion))
if (!transactionMetadata.state.equals(Empty)) if (!transactionMetadata.state.equals(Empty))
value.transactionPartitions.forEach(partitionsSchema => value.transactionPartitions.forEach(partitionsSchema =>

View File

@ -17,11 +17,11 @@
package kafka.coordinator.transaction package kafka.coordinator.transaction
import java.util.concurrent.locks.ReentrantLock import java.util.concurrent.locks.ReentrantLock
import kafka.utils.{CoreUtils, Logging, nonthreadsafe} import kafka.utils.{CoreUtils, Logging, nonthreadsafe}
import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.TopicPartition
import org.apache.kafka.common.protocol.Errors import org.apache.kafka.common.protocol.Errors
import org.apache.kafka.common.record.RecordBatch import org.apache.kafka.common.record.RecordBatch
import org.apache.kafka.server.common.TransactionVersion
import scala.collection.{immutable, mutable} import scala.collection.{immutable, mutable}
@ -163,70 +163,64 @@ private[transaction] case object PrepareEpochFence extends TransactionState {
} }
private[transaction] object TransactionMetadata { private[transaction] object TransactionMetadata {
def apply(transactionalId: String, producerId: Long, producerEpoch: Short, txnTimeoutMs: Int, timestamp: Long) =
new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Empty, collection.mutable.Set.empty[TopicPartition], timestamp, timestamp)
def apply(transactionalId: String, producerId: Long, producerEpoch: Short, txnTimeoutMs: Int,
state: TransactionState, timestamp: Long) =
new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, state, collection.mutable.Set.empty[TopicPartition], timestamp, timestamp)
def apply(transactionalId: String, producerId: Long, lastProducerId: Long, producerEpoch: Short,
lastProducerEpoch: Short, txnTimeoutMs: Int, state: TransactionState, timestamp: Long) =
new TransactionMetadata(transactionalId, producerId, lastProducerId, producerEpoch, lastProducerEpoch,
txnTimeoutMs, state, collection.mutable.Set.empty[TopicPartition], timestamp, timestamp)
def isEpochExhausted(producerEpoch: Short): Boolean = producerEpoch >= Short.MaxValue - 1 def isEpochExhausted(producerEpoch: Short): Boolean = producerEpoch >= Short.MaxValue - 1
} }
// this is a immutable object representing the target transition of the transaction metadata // this is a immutable object representing the target transition of the transaction metadata
private[transaction] case class TxnTransitMetadata(producerId: Long, private[transaction] case class TxnTransitMetadata(producerId: Long,
lastProducerId: Long, prevProducerId: Long,
nextProducerId: Long,
producerEpoch: Short, producerEpoch: Short,
lastProducerEpoch: Short, lastProducerEpoch: Short,
txnTimeoutMs: Int, txnTimeoutMs: Int,
txnState: TransactionState, txnState: TransactionState,
topicPartitions: immutable.Set[TopicPartition], topicPartitions: immutable.Set[TopicPartition],
txnStartTimestamp: Long, txnStartTimestamp: Long,
txnLastUpdateTimestamp: Long) { txnLastUpdateTimestamp: Long,
clientTransactionVersion: TransactionVersion) {
override def toString: String = { override def toString: String = {
"TxnTransitMetadata(" + "TxnTransitMetadata(" +
s"producerId=$producerId, " + s"producerId=$producerId, " +
s"lastProducerId=$lastProducerId, " + s"previousProducerId=$prevProducerId, " +
s"nextProducerId=$nextProducerId, " +
s"producerEpoch=$producerEpoch, " + s"producerEpoch=$producerEpoch, " +
s"lastProducerEpoch=$lastProducerEpoch, " + s"lastProducerEpoch=$lastProducerEpoch, " +
s"txnTimeoutMs=$txnTimeoutMs, " + s"txnTimeoutMs=$txnTimeoutMs, " +
s"txnState=$txnState, " + s"txnState=$txnState, " +
s"topicPartitions=$topicPartitions, " + s"topicPartitions=$topicPartitions, " +
s"txnStartTimestamp=$txnStartTimestamp, " + s"txnStartTimestamp=$txnStartTimestamp, " +
s"txnLastUpdateTimestamp=$txnLastUpdateTimestamp)" s"txnLastUpdateTimestamp=$txnLastUpdateTimestamp, " +
s"clientTransactionVersion=$clientTransactionVersion)"
} }
} }
/** /**
* *
* @param producerId producer id * @param producerId producer id
* @param lastProducerId last producer id assigned to the producer * @param previousProducerId producer id for the last committed transaction with this transactional ID
* @param producerEpoch current epoch of the producer * @param nextProducerId Latest producer ID sent to the producer for the given transactional ID
* @param lastProducerEpoch last epoch of the producer * @param producerEpoch current epoch of the producer
* @param txnTimeoutMs timeout to be used to abort long running transactions * @param lastProducerEpoch last epoch of the producer
* @param state current state of the transaction * @param txnTimeoutMs timeout to be used to abort long running transactions
* @param topicPartitions current set of partitions that are part of this transaction * @param state current state of the transaction
* @param txnStartTimestamp time the transaction was started, i.e., when first partition is added * @param topicPartitions current set of partitions that are part of this transaction
* @param txnLastUpdateTimestamp updated when any operation updates the TransactionMetadata. To be used for expiration * @param txnStartTimestamp time the transaction was started, i.e., when first partition is added
* @param txnLastUpdateTimestamp updated when any operation updates the TransactionMetadata. To be used for expiration
* @param clientTransactionVersion TransactionVersion used by the client when the state was transitioned
*/ */
@nonthreadsafe @nonthreadsafe
private[transaction] class TransactionMetadata(val transactionalId: String, private[transaction] class TransactionMetadata(val transactionalId: String,
var producerId: Long, var producerId: Long,
var lastProducerId: Long, var previousProducerId: Long,
var nextProducerId: Long,
var producerEpoch: Short, var producerEpoch: Short,
var lastProducerEpoch: Short, var lastProducerEpoch: Short,
var txnTimeoutMs: Int, var txnTimeoutMs: Int,
var state: TransactionState, var state: TransactionState,
val topicPartitions: mutable.Set[TopicPartition], val topicPartitions: mutable.Set[TopicPartition],
@volatile var txnStartTimestamp: Long = -1, @volatile var txnStartTimestamp: Long = -1,
@volatile var txnLastUpdateTimestamp: Long) extends Logging { @volatile var txnLastUpdateTimestamp: Long,
var clientTransactionVersion: TransactionVersion) extends Logging {
// pending state is used to indicate the state that this transaction is going to // pending state is used to indicate the state that this transaction is going to
// transit to, and for blocking future attempts to transit it again if it is not legal; // transit to, and for blocking future attempts to transit it again if it is not legal;
@ -256,8 +250,8 @@ private[transaction] class TransactionMetadata(val transactionalId: String,
// this is visible for test only // this is visible for test only
def prepareNoTransit(): TxnTransitMetadata = { def prepareNoTransit(): TxnTransitMetadata = {
// do not call transitTo as it will set the pending state, a follow-up call to abort the transaction will set its pending state // do not call transitTo as it will set the pending state, a follow-up call to abort the transaction will set its pending state
TxnTransitMetadata(producerId, lastProducerId, producerEpoch, lastProducerEpoch, txnTimeoutMs, state, topicPartitions.toSet, TxnTransitMetadata(producerId, previousProducerId, nextProducerId, producerEpoch, lastProducerEpoch, txnTimeoutMs, state, topicPartitions.toSet,
txnStartTimestamp, txnLastUpdateTimestamp) txnStartTimestamp, txnLastUpdateTimestamp, TransactionVersion.TV_0)
} }
def prepareFenceProducerEpoch(): TxnTransitMetadata = { def prepareFenceProducerEpoch(): TxnTransitMetadata = {
@ -335,9 +329,16 @@ private[transaction] class TransactionMetadata(val transactionalId: String,
(topicPartitions ++ addedTopicPartitions).toSet, newTxnStartTimestamp, updateTimestamp) (topicPartitions ++ addedTopicPartitions).toSet, newTxnStartTimestamp, updateTimestamp)
} }
def prepareAbortOrCommit(newState: TransactionState, updateTimestamp: Long): TxnTransitMetadata = { def prepareAbortOrCommit(newState: TransactionState, clientTransactionVersion: TransactionVersion, nextProducerId: Long, updateTimestamp: Long): TxnTransitMetadata = {
prepareTransitionTo(newState, producerId, producerEpoch, lastProducerEpoch, txnTimeoutMs, topicPartitions.toSet, val (updatedProducerEpoch, updatedLastProducerEpoch) = if (clientTransactionVersion.supportsEpochBump()) {
txnStartTimestamp, updateTimestamp) // We already ensured that we do not overflow here. MAX_SHORT is the highest possible value.
((producerEpoch + 1).toShort, producerEpoch)
} else {
(producerEpoch, lastProducerEpoch)
}
prepareTransitionTo(newState, producerId, nextProducerId, updatedProducerEpoch, updatedLastProducerEpoch, txnTimeoutMs, topicPartitions.toSet,
txnStartTimestamp, updateTimestamp, clientTransactionVersion)
} }
def prepareComplete(updateTimestamp: Long): TxnTransitMetadata = { def prepareComplete(updateTimestamp: Long): TxnTransitMetadata = {
@ -345,8 +346,15 @@ private[transaction] class TransactionMetadata(val transactionalId: String,
// Since the state change was successfully written to the log, unset the flag for a failed epoch fence // Since the state change was successfully written to the log, unset the flag for a failed epoch fence
hasFailedEpochFence = false hasFailedEpochFence = false
prepareTransitionTo(newState, producerId, producerEpoch, lastProducerEpoch, txnTimeoutMs, Set.empty[TopicPartition], val (updatedProducerId, updatedProducerEpoch) =
txnStartTimestamp, updateTimestamp) // If we overflowed on epoch bump, we have to set it as the producer ID now the marker has been written.
if (clientTransactionVersion.supportsEpochBump() && nextProducerId != RecordBatch.NO_PRODUCER_ID) {
(nextProducerId, 0.toShort)
} else {
(producerId, producerEpoch)
}
prepareTransitionTo(newState, updatedProducerId, RecordBatch.NO_PRODUCER_ID, updatedProducerEpoch, producerEpoch, txnTimeoutMs, Set.empty[TopicPartition],
txnStartTimestamp, updateTimestamp, clientTransactionVersion)
} }
def prepareDead(): TxnTransitMetadata = { def prepareDead(): TxnTransitMetadata = {
@ -367,37 +375,50 @@ private[transaction] class TransactionMetadata(val transactionalId: String,
} }
} }
private def prepareTransitionTo(newState: TransactionState, private def prepareTransitionTo(updatedState: TransactionState,
newProducerId: Long, updatedProducerId: Long,
newEpoch: Short, updatedEpoch: Short,
newLastEpoch: Short, updatedLastEpoch: Short,
newTxnTimeoutMs: Int, updatedTxnTimeoutMs: Int,
newTopicPartitions: immutable.Set[TopicPartition], updatedTopicPartitions: immutable.Set[TopicPartition],
newTxnStartTimestamp: Long, updatedTxnStartTimestamp: Long,
updateTimestamp: Long): TxnTransitMetadata = { updateTimestamp: Long): TxnTransitMetadata = {
prepareTransitionTo(updatedState, updatedProducerId, RecordBatch.NO_PRODUCER_ID, updatedEpoch, updatedLastEpoch, updatedTxnTimeoutMs, updatedTopicPartitions, updatedTxnStartTimestamp, updateTimestamp, TransactionVersion.TV_0)
}
private def prepareTransitionTo(updatedState: TransactionState,
updatedProducerId: Long,
nextProducerId: Long,
updatedEpoch: Short,
updatedLastEpoch: Short,
updatedTxnTimeoutMs: Int,
updatedTopicPartitions: immutable.Set[TopicPartition],
updatedTxnStartTimestamp: Long,
updateTimestamp: Long,
clientTransactionVersion: TransactionVersion): TxnTransitMetadata = {
if (pendingState.isDefined) if (pendingState.isDefined)
throw new IllegalStateException(s"Preparing transaction state transition to $newState " + throw new IllegalStateException(s"Preparing transaction state transition to $updatedState " +
s"while it already a pending state ${pendingState.get}") s"while it already a pending state ${pendingState.get}")
if (newProducerId < 0) if (updatedProducerId < 0)
throw new IllegalArgumentException(s"Illegal new producer id $newProducerId") throw new IllegalArgumentException(s"Illegal new producer id $updatedProducerId")
// The epoch is initialized to NO_PRODUCER_EPOCH when the TransactionMetadata // The epoch is initialized to NO_PRODUCER_EPOCH when the TransactionMetadata
// is created for the first time and it could stay like this until transitioning // is created for the first time and it could stay like this until transitioning
// to Dead. // to Dead.
if (newState != Dead && newEpoch < 0) if (updatedState != Dead && updatedEpoch < 0)
throw new IllegalArgumentException(s"Illegal new producer epoch $newEpoch") throw new IllegalArgumentException(s"Illegal new producer epoch $updatedEpoch")
// check that the new state transition is valid and update the pending state if necessary // check that the new state transition is valid and update the pending state if necessary
if (newState.validPreviousStates.contains(state)) { if (updatedState.validPreviousStates.contains(state)) {
val transitMetadata = TxnTransitMetadata(newProducerId, producerId, newEpoch, newLastEpoch, newTxnTimeoutMs, newState, val transitMetadata = TxnTransitMetadata(updatedProducerId, producerId, nextProducerId, updatedEpoch, updatedLastEpoch, updatedTxnTimeoutMs, updatedState,
newTopicPartitions, newTxnStartTimestamp, updateTimestamp) updatedTopicPartitions, updatedTxnStartTimestamp, updateTimestamp, clientTransactionVersion)
debug(s"TransactionalId $transactionalId prepare transition from $state to $transitMetadata") debug(s"TransactionalId $transactionalId prepare transition from $state to $transitMetadata")
pendingState = Some(newState) pendingState = Some(updatedState)
transitMetadata transitMetadata
} else { } else {
throw new IllegalStateException(s"Preparing transaction state transition to $newState failed since the target state" + throw new IllegalStateException(s"Preparing transaction state transition to $updatedState failed since the target state" +
s" $newState is not a valid previous state of the current state $state") s" $updatedState is not a valid previous state of the current state $state")
} }
} }
@ -436,7 +457,7 @@ private[transaction] class TransactionMetadata(val transactionalId: String,
producerEpoch = transitMetadata.producerEpoch producerEpoch = transitMetadata.producerEpoch
lastProducerEpoch = transitMetadata.lastProducerEpoch lastProducerEpoch = transitMetadata.lastProducerEpoch
producerId = transitMetadata.producerId producerId = transitMetadata.producerId
lastProducerId = transitMetadata.lastProducerId previousProducerId = transitMetadata.prevProducerId
} }
case Ongoing => // from addPartitions case Ongoing => // from addPartitions
@ -457,6 +478,10 @@ private[transaction] class TransactionMetadata(val transactionalId: String,
txnStartTimestamp != transitMetadata.txnStartTimestamp) { txnStartTimestamp != transitMetadata.txnStartTimestamp) {
throwStateTransitionFailure(transitMetadata) throwStateTransitionFailure(transitMetadata)
} else if (transitMetadata.clientTransactionVersion.supportsEpochBump()) {
producerEpoch = transitMetadata.producerEpoch
lastProducerEpoch = transitMetadata.lastProducerEpoch
nextProducerId = transitMetadata.nextProducerId
} }
case CompleteAbort | CompleteCommit => // from write markers case CompleteAbort | CompleteCommit => // from write markers
@ -468,6 +493,13 @@ private[transaction] class TransactionMetadata(val transactionalId: String,
} else { } else {
txnStartTimestamp = transitMetadata.txnStartTimestamp txnStartTimestamp = transitMetadata.txnStartTimestamp
topicPartitions.clear() topicPartitions.clear()
if (transitMetadata.clientTransactionVersion.supportsEpochBump()) {
producerEpoch = transitMetadata.producerEpoch
lastProducerEpoch = transitMetadata.lastProducerEpoch
previousProducerId = transitMetadata.prevProducerId
producerId = transitMetadata.producerId
nextProducerId = transitMetadata.nextProducerId
}
} }
case PrepareEpochFence => case PrepareEpochFence =>
@ -487,6 +519,7 @@ private[transaction] class TransactionMetadata(val transactionalId: String,
} }
debug(s"TransactionalId $transactionalId complete transition from $state to $transitMetadata") debug(s"TransactionalId $transactionalId complete transition from $state to $transitMetadata")
clientTransactionVersion = transitMetadata.clientTransactionVersion
txnLastUpdateTimestamp = transitMetadata.txnLastUpdateTimestamp txnLastUpdateTimestamp = transitMetadata.txnLastUpdateTimestamp
pendingState = None pendingState = None
state = toState state = toState
@ -494,8 +527,14 @@ private[transaction] class TransactionMetadata(val transactionalId: String,
} }
private def validProducerEpoch(transitMetadata: TxnTransitMetadata): Boolean = { private def validProducerEpoch(transitMetadata: TxnTransitMetadata): Boolean = {
val transitEpoch = transitMetadata.producerEpoch val isAtLeastTransactionsV2 = transitMetadata.clientTransactionVersion.supportsEpochBump()
val transitProducerId = transitMetadata.producerId val isOverflowComplete = isAtLeastTransactionsV2 && (transitMetadata.txnState == CompleteCommit || transitMetadata.txnState == CompleteAbort) && transitMetadata.producerEpoch == 0
val transitEpoch =
if (isOverflowComplete || (isAtLeastTransactionsV2 && (transitMetadata.txnState == PrepareCommit || transitMetadata.txnState == PrepareAbort)))
transitMetadata.lastProducerEpoch
else
transitMetadata.producerEpoch
val transitProducerId = if (isOverflowComplete) transitMetadata.prevProducerId else transitMetadata.producerId
transitEpoch == producerEpoch && transitProducerId == producerId transitEpoch == producerEpoch && transitProducerId == producerId
} }
@ -518,6 +557,8 @@ private[transaction] class TransactionMetadata(val transactionalId: String,
"TransactionMetadata(" + "TransactionMetadata(" +
s"transactionalId=$transactionalId, " + s"transactionalId=$transactionalId, " +
s"producerId=$producerId, " + s"producerId=$producerId, " +
s"previousProducerId=$previousProducerId, "
s"nextProducerId=$nextProducerId, "
s"producerEpoch=$producerEpoch, " + s"producerEpoch=$producerEpoch, " +
s"txnTimeoutMs=$txnTimeoutMs, " + s"txnTimeoutMs=$txnTimeoutMs, " +
s"state=$state, " + s"state=$state, " +

View File

@ -101,8 +101,10 @@ class TransactionStateManager(brokerId: Int,
TransactionStateManagerConfig.METRICS_GROUP, TransactionStateManagerConfig.METRICS_GROUP,
"The avg time it took to load the partitions in the last 30sec"), new Avg()) "The avg time it took to load the partitions in the last 30sec"), new Avg())
private[transaction] def usesFlexibleRecords(): Boolean = { private[transaction] def transactionVersionLevel(): TransactionVersion = {
metadataCache.features().finalizedFeatures().getOrDefault(TransactionVersion.FEATURE_NAME, 0.toShort) > 0 val version = TransactionVersion.fromFeatureLevel(metadataCache.features().finalizedFeatures().getOrDefault(
TransactionVersion.FEATURE_NAME, 0.toShort))
version
} }
// visible for testing only // visible for testing only
@ -624,7 +626,7 @@ class TransactionStateManager(brokerId: Int,
// generate the message for this transaction metadata // generate the message for this transaction metadata
val keyBytes = TransactionLog.keyToBytes(transactionalId) val keyBytes = TransactionLog.keyToBytes(transactionalId)
val valueBytes = TransactionLog.valueToBytes(newMetadata, usesFlexibleRecords()) val valueBytes = TransactionLog.valueToBytes(newMetadata, transactionVersionLevel())
val timestamp = time.milliseconds() val timestamp = time.milliseconds()
val records = MemoryRecords.withRecords(TransactionLog.EnforcedCompression, new SimpleRecord(timestamp, keyBytes, valueBytes)) val records = MemoryRecords.withRecords(TransactionLog.EnforcedCompression, new SimpleRecord(timestamp, keyBytes, valueBytes))

View File

@ -73,7 +73,7 @@ import org.apache.kafka.coordinator.group.{Group, GroupCoordinator}
import org.apache.kafka.coordinator.share.ShareCoordinator import org.apache.kafka.coordinator.share.ShareCoordinator
import org.apache.kafka.server.ClientMetricsManager import org.apache.kafka.server.ClientMetricsManager
import org.apache.kafka.server.authorizer._ import org.apache.kafka.server.authorizer._
import org.apache.kafka.server.common.{GroupVersion, MetadataVersion, RequestLocal} import org.apache.kafka.server.common.{GroupVersion, MetadataVersion, RequestLocal, TransactionVersion}
import org.apache.kafka.server.common.MetadataVersion.{IBP_0_11_0_IV0, IBP_2_3_IV0} import org.apache.kafka.server.common.MetadataVersion.{IBP_0_11_0_IV0, IBP_2_3_IV0}
import org.apache.kafka.server.record.BrokerCompressionType import org.apache.kafka.server.record.BrokerCompressionType
import org.apache.kafka.server.share.context.ShareFetchContext import org.apache.kafka.server.share.context.ShareFetchContext
@ -2299,7 +2299,7 @@ class KafkaApis(val requestChannel: RequestChannel,
val transactionalId = endTxnRequest.data.transactionalId val transactionalId = endTxnRequest.data.transactionalId
if (authHelper.authorize(request.context, WRITE, TRANSACTIONAL_ID, transactionalId)) { if (authHelper.authorize(request.context, WRITE, TRANSACTIONAL_ID, transactionalId)) {
def sendResponseCallback(error: Errors): Unit = { def sendResponseCallback(error: Errors, newProducerId: Long, newProducerEpoch: Short): Unit = {
def createResponse(requestThrottleMs: Int): AbstractResponse = { def createResponse(requestThrottleMs: Int): AbstractResponse = {
val finalError = val finalError =
if (endTxnRequest.version < 2 && error == Errors.PRODUCER_FENCED) { if (endTxnRequest.version < 2 && error == Errors.PRODUCER_FENCED) {
@ -2311,6 +2311,8 @@ class KafkaApis(val requestChannel: RequestChannel,
} }
val responseBody = new EndTxnResponse(new EndTxnResponseData() val responseBody = new EndTxnResponse(new EndTxnResponseData()
.setErrorCode(finalError.code) .setErrorCode(finalError.code)
.setProducerId(newProducerId)
.setProducerEpoch(newProducerEpoch)
.setThrottleTimeMs(requestThrottleMs)) .setThrottleTimeMs(requestThrottleMs))
trace(s"Completed ${endTxnRequest.data.transactionalId}'s EndTxnRequest " + trace(s"Completed ${endTxnRequest.data.transactionalId}'s EndTxnRequest " +
s"with committed: ${endTxnRequest.data.committed}, " + s"with committed: ${endTxnRequest.data.committed}, " +
@ -2320,10 +2322,14 @@ class KafkaApis(val requestChannel: RequestChannel,
requestHelper.sendResponseMaybeThrottle(request, createResponse) requestHelper.sendResponseMaybeThrottle(request, createResponse)
} }
// If the request is version 4, we know the client supports transaction version 2.
val clientTransactionVersion = if (endTxnRequest.version() > 4) TransactionVersion.TV_2 else TransactionVersion.TV_0
txnCoordinator.handleEndTransaction(endTxnRequest.data.transactionalId, txnCoordinator.handleEndTransaction(endTxnRequest.data.transactionalId,
endTxnRequest.data.producerId, endTxnRequest.data.producerId,
endTxnRequest.data.producerEpoch, endTxnRequest.data.producerEpoch,
endTxnRequest.result(), endTxnRequest.result(),
clientTransactionVersion,
sendResponseCallback, sendResponseCallback,
requestLocal) requestLocal)
} else } else

View File

@ -464,10 +464,10 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren
addPartitionsOp.awaitAndVerify(txn) addPartitionsOp.awaitAndVerify(txn)
val txnMetadata = transactionMetadata(txn).getOrElse(throw new IllegalStateException(s"Transaction not found $txn")) val txnMetadata = transactionMetadata(txn).getOrElse(throw new IllegalStateException(s"Transaction not found $txn"))
txnRecords += new SimpleRecord(txn.txnMessageKeyBytes, TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), true)) txnRecords += new SimpleRecord(txn.txnMessageKeyBytes, TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), TransactionVersion.TV_2))
txnMetadata.state = PrepareCommit txnMetadata.state = PrepareCommit
txnRecords += new SimpleRecord(txn.txnMessageKeyBytes, TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), true)) txnRecords += new SimpleRecord(txn.txnMessageKeyBytes, TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), TransactionVersion.TV_2))
prepareTxnLog(partitionId) prepareTxnLog(partitionId)
} }
@ -506,13 +506,15 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren
private def prepareExhaustedEpochTxnMetadata(txn: Transaction): TransactionMetadata = { private def prepareExhaustedEpochTxnMetadata(txn: Transaction): TransactionMetadata = {
new TransactionMetadata(transactionalId = txn.transactionalId, new TransactionMetadata(transactionalId = txn.transactionalId,
producerId = producerId, producerId = producerId,
lastProducerId = RecordBatch.NO_PRODUCER_ID, previousProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = (Short.MaxValue - 1).toShort, producerEpoch = (Short.MaxValue - 1).toShort,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 60000, txnTimeoutMs = 60000,
state = Empty, state = Empty,
topicPartitions = collection.mutable.Set.empty[TopicPartition], topicPartitions = collection.mutable.Set.empty[TopicPartition],
txnLastUpdateTimestamp = time.milliseconds()) txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TransactionVersion.TV_0)
} }
abstract class TxnOperation[R] extends Operation { abstract class TxnOperation[R] extends Operation {
@ -562,7 +564,8 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren
txnMetadata.producerId, txnMetadata.producerId,
txnMetadata.producerEpoch, txnMetadata.producerEpoch,
transactionResult(txn), transactionResult(txn),
resultCallback, TransactionVersion.TV_2,
(r, _, _) => resultCallback(r),
RequestLocal.withThreadConfinedCaching) RequestLocal.withThreadConfinedCaching)
} }
} }

View File

@ -23,9 +23,13 @@ import org.apache.kafka.common.record.RecordBatch
import org.apache.kafka.common.requests.{AddPartitionsToTxnResponse, TransactionResult} import org.apache.kafka.common.requests.{AddPartitionsToTxnResponse, TransactionResult}
import org.apache.kafka.common.utils.{LogContext, MockTime, ProducerIdAndEpoch} import org.apache.kafka.common.utils.{LogContext, MockTime, ProducerIdAndEpoch}
import org.apache.kafka.coordinator.transaction.TransactionStateManagerConfig import org.apache.kafka.coordinator.transaction.TransactionStateManagerConfig
import org.apache.kafka.server.common.TransactionVersion
import org.apache.kafka.server.common.TransactionVersion.{TV_0, TV_2}
import org.apache.kafka.server.util.MockScheduler import org.apache.kafka.server.util.MockScheduler
import org.junit.jupiter.api.Assertions._ import org.junit.jupiter.api.Assertions._
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.ValueSource
import org.mockito.{ArgumentCaptor, ArgumentMatchers} import org.mockito.{ArgumentCaptor, ArgumentMatchers}
import org.mockito.ArgumentMatchers.{any, anyInt} import org.mockito.ArgumentMatchers.{any, anyInt}
import org.mockito.Mockito.{mock, times, verify, when} import org.mockito.Mockito.{mock, times, verify, when}
@ -51,6 +55,7 @@ class TransactionCoordinatorTest {
private val producerId = 10L private val producerId = 10L
private val producerEpoch: Short = 1 private val producerEpoch: Short = 1
private val txnTimeoutMs = 1 private val txnTimeoutMs = 1
private val producerId2 = 11L
private val partitions = mutable.Set[TopicPartition](new TopicPartition("topic1", 0)) private val partitions = mutable.Set[TopicPartition](new TopicPartition("topic1", 0))
private val scheduler = new MockScheduler(time) private val scheduler = new MockScheduler(time)
@ -66,6 +71,8 @@ class TransactionCoordinatorTest {
val transactionStatePartitionCount = 1 val transactionStatePartitionCount = 1
var result: InitProducerIdResult = _ var result: InitProducerIdResult = _
var error: Errors = Errors.NONE var error: Errors = Errors.NONE
var newProducerId: Long = RecordBatch.NO_PRODUCER_ID
var newEpoch: Short = RecordBatch.NO_PRODUCER_EPOCH
private def mockPidGenerator(): Unit = { private def mockPidGenerator(): Unit = {
when(pidGenerator.generateProducerId()).thenAnswer(_ => { when(pidGenerator.generateProducerId()).thenAnswer(_ => {
@ -155,8 +162,8 @@ class TransactionCoordinatorTest {
def shouldGenerateNewProducerIdIfEpochsExhausted(): Unit = { def shouldGenerateNewProducerIdIfEpochsExhausted(): Unit = {
initPidGenericMocks(transactionalId) initPidGenericMocks(transactionalId)
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, (Short.MaxValue - 1).toShort, val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, (Short.MaxValue - 1).toShort,
(Short.MaxValue - 2).toShort, txnTimeoutMs, Empty, mutable.Set.empty, time.milliseconds(), time.milliseconds()) (Short.MaxValue - 2).toShort, txnTimeoutMs, Empty, mutable.Set.empty, time.milliseconds(), time.milliseconds(), TV_0)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
@ -245,7 +252,8 @@ class TransactionCoordinatorTest {
errors = AddPartitionsToTxnResponse.errorsForTransaction(result.topicResults()).asScala.toMap errors = AddPartitionsToTxnResponse.errorsForTransaction(result.topicResults()).asScala.toMap
} }
// If producer ID is not the same, return INVALID_PRODUCER_ID_MAPPING // If producer ID is not the same, return INVALID_PRODUCER_ID_MAPPING
val wrongPidTxnMetadata = new TransactionMetadata(transactionalId, 1, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, PrepareCommit, partitions, 0, 0) val wrongPidTxnMetadata = new TransactionMetadata(transactionalId, 1, 0, RecordBatch.NO_PRODUCER_ID,
0, RecordBatch.NO_PRODUCER_EPOCH, 0, PrepareCommit, partitions, 0, 0, TV_0)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(new CoordinatorEpochAndTxnMetadata(coordinatorEpoch, wrongPidTxnMetadata)))) .thenReturn(Right(Some(new CoordinatorEpochAndTxnMetadata(coordinatorEpoch, wrongPidTxnMetadata))))
@ -253,10 +261,10 @@ class TransactionCoordinatorTest {
errors.foreach { case (_, error) => errors.foreach { case (_, error) =>
assertEquals(Errors.INVALID_PRODUCER_ID_MAPPING, error) assertEquals(Errors.INVALID_PRODUCER_ID_MAPPING, error)
} }
// If producer epoch is not equal, return PRODUCER_FENCED // If producer epoch is not equal, return PRODUCER_FENCED
val oldEpochTxnMetadata = new TransactionMetadata(transactionalId, 0, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, PrepareCommit, partitions, 0, 0) val oldEpochTxnMetadata = new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID,
0, RecordBatch.NO_PRODUCER_EPOCH, 0, PrepareCommit, partitions, 0, 0, TV_0)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(new CoordinatorEpochAndTxnMetadata(coordinatorEpoch, oldEpochTxnMetadata)))) .thenReturn(Right(Some(new CoordinatorEpochAndTxnMetadata(coordinatorEpoch, oldEpochTxnMetadata))))
@ -266,7 +274,8 @@ class TransactionCoordinatorTest {
} }
// If the txn state is Prepare or AbortCommit, we return CONCURRENT_TRANSACTIONS // If the txn state is Prepare or AbortCommit, we return CONCURRENT_TRANSACTIONS
val emptyTxnMetadata = new TransactionMetadata(transactionalId, 0, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, PrepareCommit, partitions, 0, 0) val emptyTxnMetadata = new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID,
0, RecordBatch.NO_PRODUCER_EPOCH, 0, PrepareCommit, partitions, 0, 0, TV_0)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(new CoordinatorEpochAndTxnMetadata(coordinatorEpoch, emptyTxnMetadata)))) .thenReturn(Right(Some(new CoordinatorEpochAndTxnMetadata(coordinatorEpoch, emptyTxnMetadata))))
@ -276,7 +285,8 @@ class TransactionCoordinatorTest {
} }
// Pending state does not matter, we will just check if the partitions are in the txnMetadata. // Pending state does not matter, we will just check if the partitions are in the txnMetadata.
val ongoingTxnMetadata = new TransactionMetadata(transactionalId, 0, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, Ongoing, mutable.Set.empty, 0, 0) val ongoingTxnMetadata = new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID,
0, RecordBatch.NO_PRODUCER_EPOCH, 0, Ongoing, mutable.Set.empty, 0, 0, TV_0)
ongoingTxnMetadata.pendingState = Some(CompleteCommit) ongoingTxnMetadata.pendingState = Some(CompleteCommit)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(new CoordinatorEpochAndTxnMetadata(coordinatorEpoch, ongoingTxnMetadata)))) .thenReturn(Right(Some(new CoordinatorEpochAndTxnMetadata(coordinatorEpoch, ongoingTxnMetadata))))
@ -298,9 +308,11 @@ class TransactionCoordinatorTest {
} }
def validateConcurrentTransactions(state: TransactionState): Unit = { def validateConcurrentTransactions(state: TransactionState): Unit = {
// Since the clientTransactionVersion doesn't matter, use 2 since the states are PrepareCommit and PrepareAbort.
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch,
new TransactionMetadata(transactionalId, 0, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, state, mutable.Set.empty, 0, 0))))) new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID,
0, RecordBatch.NO_PRODUCER_EPOCH, 0, state, mutable.Set.empty, 0, 0, TV_2)))))
coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 0, partitions, errorsCallback) coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 0, partitions, errorsCallback)
assertEquals(Errors.CONCURRENT_TRANSACTIONS, error) assertEquals(Errors.CONCURRENT_TRANSACTIONS, error)
@ -308,9 +320,11 @@ class TransactionCoordinatorTest {
@Test @Test
def shouldRespondWithProducerFencedOnAddPartitionsWhenEpochsAreDifferent(): Unit = { def shouldRespondWithProducerFencedOnAddPartitionsWhenEpochsAreDifferent(): Unit = {
// Since the clientTransactionVersion doesn't matter, use 2 since the state is PrepareCommit.
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch,
new TransactionMetadata(transactionalId, 0, 0, 10, 9, 0, PrepareCommit, mutable.Set.empty, 0, 0))))) new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID,
10, 9, 0, PrepareCommit, mutable.Set.empty, 0, 0, TV_2)))))
coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 0, partitions, errorsCallback) coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 0, partitions, errorsCallback)
assertEquals(Errors.PRODUCER_FENCED, error) assertEquals(Errors.PRODUCER_FENCED, error)
@ -318,27 +332,30 @@ class TransactionCoordinatorTest {
@Test @Test
def shouldAppendNewMetadataToLogOnAddPartitionsWhenPartitionsAdded(): Unit = { def shouldAppendNewMetadataToLogOnAddPartitionsWhenPartitionsAdded(): Unit = {
validateSuccessfulAddPartitions(Empty) validateSuccessfulAddPartitions(Empty, 0)
} }
@Test @Test
def shouldRespondWithSuccessOnAddPartitionsWhenStateIsOngoing(): Unit = { def shouldRespondWithSuccessOnAddPartitionsWhenStateIsOngoing(): Unit = {
validateSuccessfulAddPartitions(Ongoing) validateSuccessfulAddPartitions(Ongoing, 0)
} }
@Test @ParameterizedTest
def shouldRespondWithSuccessOnAddPartitionsWhenStateIsCompleteCommit(): Unit = { @ValueSource(shorts = Array(0, 2))
validateSuccessfulAddPartitions(CompleteCommit) def shouldRespondWithSuccessOnAddPartitionsWhenStateIsCompleteCommit(clientTransactionVersion: Short): Unit = {
validateSuccessfulAddPartitions(CompleteCommit, clientTransactionVersion)
} }
@Test @ParameterizedTest
def shouldRespondWithSuccessOnAddPartitionsWhenStateIsCompleteAbort(): Unit = { @ValueSource(shorts = Array(0, 2))
validateSuccessfulAddPartitions(CompleteAbort) def shouldRespondWithSuccessOnAddPartitionsWhenStateIsCompleteAbort(clientTransactionVersion: Short): Unit = {
validateSuccessfulAddPartitions(CompleteAbort, clientTransactionVersion)
} }
def validateSuccessfulAddPartitions(previousState: TransactionState): Unit = { def validateSuccessfulAddPartitions(previousState: TransactionState, transactionVersion: Short): Unit = {
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, (producerEpoch - 1).toShort, val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
txnTimeoutMs, previousState, mutable.Set.empty, time.milliseconds(), time.milliseconds()) val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
producerEpoch, (producerEpoch - 1).toShort, txnTimeoutMs, previousState, mutable.Set.empty, time.milliseconds(), time.milliseconds(), clientTransactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
@ -360,7 +377,8 @@ class TransactionCoordinatorTest {
def shouldRespondWithErrorsNoneOnAddPartitionWhenNoErrorsAndPartitionsTheSame(): Unit = { def shouldRespondWithErrorsNoneOnAddPartitionWhenNoErrorsAndPartitionsTheSame(): Unit = {
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch,
new TransactionMetadata(transactionalId, 0, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, Empty, partitions, 0, 0))))) new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID,
0, RecordBatch.NO_PRODUCER_EPOCH, 0, Empty, partitions, 0, 0, TV_0)))))
coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 0, partitions, errorsCallback) coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 0, partitions, errorsCallback)
assertEquals(Errors.NONE, error) assertEquals(Errors.NONE, error)
@ -376,7 +394,8 @@ class TransactionCoordinatorTest {
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch,
new TransactionMetadata(transactionalId, 0, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, Ongoing, partitions, 0, 0))))) new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID,
0, RecordBatch.NO_PRODUCER_EPOCH, 0, Ongoing, partitions, 0, 0, TV_0)))))
coordinator.handleVerifyPartitionsInTransaction(transactionalId, 0L, 0, partitions, verifyPartitionsInTxnCallback) coordinator.handleVerifyPartitionsInTransaction(transactionalId, 0L, 0, partitions, verifyPartitionsInTxnCallback)
errors.foreach { case (_, error) => errors.foreach { case (_, error) =>
@ -394,7 +413,8 @@ class TransactionCoordinatorTest {
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch,
new TransactionMetadata(transactionalId, 0, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, Empty, partitions, 0, 0))))) new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID,
0, RecordBatch.NO_PRODUCER_EPOCH, 0, Empty, partitions, 0, 0, TV_0)))))
val extraPartitions = partitions ++ Set(new TopicPartition("topic2", 0)) val extraPartitions = partitions ++ Set(new TopicPartition("topic2", 0))
@ -404,107 +424,227 @@ class TransactionCoordinatorTest {
verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId))
} }
@Test @ParameterizedTest
def shouldReplyWithInvalidPidMappingOnEndTxnWhenTxnIdDoesntExist(): Unit = { @ValueSource(shorts = Array(0, 2))
def shouldReplyWithInvalidPidMappingOnEndTxnWhenTxnIdDoesntExist(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(None)) .thenReturn(Right(None))
coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, errorsCallback) coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback)
assertEquals(Errors.INVALID_PRODUCER_ID_MAPPING, error) assertEquals(Errors.INVALID_PRODUCER_ID_MAPPING, error)
verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId))
} }
@Test @ParameterizedTest
def shouldReplyWithInvalidPidMappingOnEndTxnWhenPidDosentMatchMapped(): Unit = { @ValueSource(shorts = Array(0, 2))
def shouldReplyWithInvalidPidMappingOnEndTxnWhenPidDosentMatchMapped(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch,
new TransactionMetadata(transactionalId, 10, 10, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, Ongoing, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))) new TransactionMetadata(transactionalId, 10, 10, RecordBatch.NO_PRODUCER_ID,
0, RecordBatch.NO_PRODUCER_EPOCH, 0, Ongoing, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_0)))))
coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, errorsCallback) coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback)
assertEquals(Errors.INVALID_PRODUCER_ID_MAPPING, error) assertEquals(Errors.INVALID_PRODUCER_ID_MAPPING, error)
verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId))
} }
@Test @ParameterizedTest
def shouldReplyWithProducerFencedOnEndTxnWhenEpochIsNotSameAsTransaction(): Unit = { @ValueSource(shorts = Array(0, 2))
def shouldReplyWithProducerFencedOnEndTxnWhenEpochIsNotSameAsTransaction(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch,
new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, (producerEpoch - 1).toShort, 1, Ongoing, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))) new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, producerEpoch,
(producerEpoch - 1).toShort, 1, Ongoing, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_0)))))
coordinator.handleEndTransaction(transactionalId, producerId, 0, TransactionResult.COMMIT, errorsCallback) coordinator.handleEndTransaction(transactionalId, producerId, 0, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback)
assertEquals(Errors.PRODUCER_FENCED, error) assertEquals(Errors.PRODUCER_FENCED, error)
verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId))
} }
@Test @ParameterizedTest
def shouldReturnOkOnEndTxnWhenStatusIsCompleteCommitAndResultIsCommit(): Unit ={ @ValueSource(shorts = Array(0, 2))
def shouldReturnOkOnEndTxnWhenStatusIsCompleteCommitAndResultIsCommit(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch,
new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, (producerEpoch - 1).toShort, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))) new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, producerEpoch,
(producerEpoch - 1).toShort, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)))))
coordinator.handleEndTransaction(transactionalId, producerId, 1, TransactionResult.COMMIT, errorsCallback) coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch(clientTransactionVersion), TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback)
assertEquals(Errors.NONE, error) assertEquals(Errors.NONE, error)
verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId))
} }
@Test @ParameterizedTest
def shouldReturnOkOnEndTxnWhenStatusIsCompleteAbortAndResultIsAbort(): Unit ={ @ValueSource(shorts = Array(0, 2))
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, (producerEpoch - 1).toShort, 1, CompleteAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()) def shouldReturnOkOnEndTxnWhenStatusIsCompleteAbortAndResultIsAbort(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
producerEpoch, (producerEpoch - 1).toShort, 1, CompleteAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
coordinator.handleEndTransaction(transactionalId, producerId, 1, TransactionResult.ABORT, errorsCallback) coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch(clientTransactionVersion), TransactionResult.ABORT, clientTransactionVersion, endTxnCallback)
assertEquals(Errors.NONE, error) assertEquals(Errors.NONE, error)
verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId))
} }
@Test @ParameterizedTest
def shouldReturnInvalidTxnRequestOnEndTxnRequestWhenStatusIsCompleteAbortAndResultIsNotAbort(): Unit = { @ValueSource(shorts = Array(0, 2))
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, (producerEpoch - 1).toShort, 1, CompleteAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()) def shouldReturnInvalidTxnRequestOnEndTxnRequestWhenStatusIsCompleteAbortAndResultIsNotAbort(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
producerEpoch, (producerEpoch - 1).toShort, 1, CompleteAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
coordinator.handleEndTransaction(transactionalId, producerId, 1, TransactionResult.COMMIT, errorsCallback) coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch(clientTransactionVersion), TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback)
assertEquals(Errors.INVALID_TXN_STATE, error) assertEquals(Errors.INVALID_TXN_STATE, error)
verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId))
} }
@Test @ParameterizedTest
def shouldReturnInvalidTxnRequestOnEndTxnRequestWhenStatusIsCompleteCommitAndResultIsNotCommit(): Unit = { @ValueSource(shorts = Array(0, 2))
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, (producerEpoch - 1).toShort,1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()) def shouldReturnInvalidTxnRequestOnEndTxnRequestWhenStatusIsCompleteCommitAndResultIsNotCommit(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
producerEpoch, (producerEpoch - 1).toShort,1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
coordinator.handleEndTransaction(transactionalId, producerId, 1, TransactionResult.ABORT, errorsCallback) coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch(clientTransactionVersion), TransactionResult.ABORT, clientTransactionVersion, endTxnCallback)
assertEquals(Errors.INVALID_TXN_STATE, error) assertEquals(Errors.INVALID_TXN_STATE, error)
verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId))
} }
@Test @ParameterizedTest
def shouldReturnConcurrentTxnRequestOnEndTxnRequestWhenStatusIsPrepareCommit(): Unit = { @ValueSource(shorts = Array(0, 2))
def shouldReturnConcurrentTransactionsOnEndTxnRequestWhenStatusIsPrepareCommit(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, (producerEpoch - 1).toShort, 1, PrepareCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_ID, producerEpoch, (producerEpoch - 1).toShort, 1, PrepareCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)))))
coordinator.handleEndTransaction(transactionalId, producerId, 1, TransactionResult.COMMIT, errorsCallback) coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch(clientTransactionVersion), TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback)
assertEquals(Errors.CONCURRENT_TRANSACTIONS, error) assertEquals(Errors.CONCURRENT_TRANSACTIONS, error)
verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId))
} }
@Test @ParameterizedTest
def shouldReturnInvalidTxnRequestOnEndTxnRequestWhenStatusIsPrepareAbort(): Unit = { @ValueSource(shorts = Array(0, 2))
def shouldReturnInvalidTxnRequestOnEndTxnRequestWhenStatusIsPrepareAbort(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, 1, RecordBatch.NO_PRODUCER_EPOCH, 1, PrepareAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, PrepareAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)))))
coordinator.handleEndTransaction(transactionalId, producerId, 1, TransactionResult.COMMIT, errorsCallback) coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch(clientTransactionVersion), TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback)
assertEquals(Errors.INVALID_TXN_STATE, error) assertEquals(Errors.INVALID_TXN_STATE, error)
verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId))
} }
@Test @Test
def shouldAppendPrepareCommitToLogOnEndTxnWhenStatusIsOngoingAndResultIsCommit(): Unit = { def shouldReturnWhenTransactionVersionDowngraded(): Unit = {
mockPrepare(PrepareCommit) // State was written when transactions V2
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_ID, (producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH, 1, PrepareCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2)))))
coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.COMMIT, errorsCallback) // Return CONCURRENT_TRANSACTIONS as the transaction is still completing
coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.COMMIT, TV_0, endTxnCallback)
assertEquals(Errors.CONCURRENT_TRANSACTIONS, error)
assertEquals(RecordBatch.NO_PRODUCER_ID, newProducerId)
assertEquals(RecordBatch.NO_PRODUCER_EPOCH, newEpoch)
verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId))
// Recognize the retry and return NONE
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_ID, (producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2)))))
coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.COMMIT, TV_0, endTxnCallback)
assertEquals(Errors.NONE, error)
assertEquals(producerId, newProducerId)
assertEquals((producerEpoch + 1).toShort, newEpoch) // epoch is bumped since we started as V2
verify(transactionManager, times(2)).getTransactionState(ArgumentMatchers.eq(transactionalId))
}
@Test
def shouldReturnCorrectlyWhenTransactionVersionUpgraded(): Unit = {
// State was written when transactions V0
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, PrepareCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_0)))))
// Transactions V0 throws the concurrent transactions error here.
coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.COMMIT, TV_2, endTxnCallback)
assertEquals(Errors.CONCURRENT_TRANSACTIONS, error)
verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId))
// When the transaction is completed, return and do not throw an error.
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_0)))))
coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.COMMIT, TV_2, endTxnCallback)
assertEquals(Errors.NONE, error)
assertEquals(producerId, newProducerId)
assertEquals(producerEpoch, newEpoch) // epoch is not bumped since this started as V1
verify(transactionManager, times(2)).getTransactionState(ArgumentMatchers.eq(transactionalId))
}
@Test
def shouldReturnInvalidTxnRequestOnEndTxnV2IfNotEndTxnV2Retry(): Unit = {
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, PrepareCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2)))))
// If producerEpoch is the same, this is not a retry of the EndTxnRequest, but the next EndTxnRequest. Return PRODUCER_FENCED.
coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.COMMIT, TV_2, endTxnCallback)
assertEquals(Errors.PRODUCER_FENCED, error)
verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId))
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2)))))
// If producerEpoch is the same, this is not a retry of the EndTxnRequest, but the next EndTxnRequest. Return PRODUCER_FENCED.
coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.COMMIT, TV_2, endTxnCallback)
assertEquals(Errors.PRODUCER_FENCED, error)
verify(transactionManager, times(2)).getTransactionState(ArgumentMatchers.eq(transactionalId))
}
@Test
def shouldReturnOkOnEndTxnV2IfEndTxnV2RetryEpochOverflow(): Unit = {
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId,
producerId2, Short.MaxValue, (Short.MaxValue - 1).toShort, 1, PrepareCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2)))))
// Return CONCURRENT_TRANSACTIONS while transaction is still completing
coordinator.handleEndTransaction(transactionalId, producerId, (Short.MaxValue - 1).toShort, TransactionResult.COMMIT, TV_2, endTxnCallback)
assertEquals(Errors.CONCURRENT_TRANSACTIONS, error)
verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId))
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId2, producerId,
RecordBatch.NO_PRODUCER_ID, 0, RecordBatch.NO_PRODUCER_EPOCH, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2)))))
coordinator.handleEndTransaction(transactionalId, producerId, (Short.MaxValue - 1).toShort, TransactionResult.COMMIT, TV_2, endTxnCallback)
assertEquals(Errors.NONE, error)
assertNotEquals(RecordBatch.NO_PRODUCER_ID, newProducerId)
assertNotEquals(producerId, newProducerId)
assertEquals(0, newEpoch)
verify(transactionManager, times(2)).getTransactionState(ArgumentMatchers.eq(transactionalId))
}
@ParameterizedTest
@ValueSource(shorts = Array(0, 2))
def shouldAppendPrepareCommitToLogOnEndTxnWhenStatusIsOngoingAndResultIsCommit(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
mockPrepare(PrepareCommit, clientTransactionVersion)
coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback)
verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId))
verify(transactionManager).appendTransactionToLog( verify(transactionManager).appendTransactionToLog(
ArgumentMatchers.eq(transactionalId), ArgumentMatchers.eq(transactionalId),
@ -515,11 +655,13 @@ class TransactionCoordinatorTest {
any()) any())
} }
@Test @ParameterizedTest
def shouldAppendPrepareAbortToLogOnEndTxnWhenStatusIsOngoingAndResultIsAbort(): Unit = { @ValueSource(shorts = Array(0, 2))
mockPrepare(PrepareAbort) def shouldAppendPrepareAbortToLogOnEndTxnWhenStatusIsOngoingAndResultIsAbort(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
mockPrepare(PrepareAbort, clientTransactionVersion)
coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.ABORT, errorsCallback) coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.ABORT, clientTransactionVersion, endTxnCallback)
verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId))
verify(transactionManager).appendTransactionToLog( verify(transactionManager).appendTransactionToLog(
ArgumentMatchers.eq(transactionalId), ArgumentMatchers.eq(transactionalId),
@ -530,90 +672,106 @@ class TransactionCoordinatorTest {
any()) any())
} }
@Test @ParameterizedTest
def shouldRespondWithInvalidRequestOnEndTxnWhenTransactionalIdIsNull(): Unit = { @ValueSource(shorts = Array(0, 2))
coordinator.handleEndTransaction(null, 0, 0, TransactionResult.COMMIT, errorsCallback) def shouldRespondWithInvalidRequestOnEndTxnWhenTransactionalIdIsNull(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
coordinator.handleEndTransaction(null, 0, 0, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback)
assertEquals(Errors.INVALID_REQUEST, error) assertEquals(Errors.INVALID_REQUEST, error)
} }
@Test @ParameterizedTest
def shouldRespondWithInvalidRequestOnEndTxnWhenTransactionalIdIsEmpty(): Unit = { @ValueSource(shorts = Array(0, 2))
def shouldRespondWithInvalidRequestOnEndTxnWhenTransactionalIdIsEmpty(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Left(Errors.NOT_COORDINATOR)) .thenReturn(Left(Errors.NOT_COORDINATOR))
coordinator.handleEndTransaction("", 0, 0, TransactionResult.COMMIT, errorsCallback) coordinator.handleEndTransaction("", 0, 0, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback)
assertEquals(Errors.INVALID_REQUEST, error) assertEquals(Errors.INVALID_REQUEST, error)
} }
@Test @ParameterizedTest
def shouldRespondWithNotCoordinatorOnEndTxnWhenIsNotCoordinatorForId(): Unit = { @ValueSource(shorts = Array(0, 2))
def shouldRespondWithNotCoordinatorOnEndTxnWhenIsNotCoordinatorForId(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Left(Errors.NOT_COORDINATOR)) .thenReturn(Left(Errors.NOT_COORDINATOR))
coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, errorsCallback) coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback)
assertEquals(Errors.NOT_COORDINATOR, error) assertEquals(Errors.NOT_COORDINATOR, error)
} }
@Test @ParameterizedTest
def shouldRespondWithCoordinatorLoadInProgressOnEndTxnWhenCoordinatorIsLoading(): Unit = { @ValueSource(shorts = Array(0, 2))
def shouldRespondWithCoordinatorLoadInProgressOnEndTxnWhenCoordinatorIsLoading(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Left(Errors.COORDINATOR_LOAD_IN_PROGRESS)) .thenReturn(Left(Errors.COORDINATOR_LOAD_IN_PROGRESS))
coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, errorsCallback) coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback)
assertEquals(Errors.COORDINATOR_LOAD_IN_PROGRESS, error) assertEquals(Errors.COORDINATOR_LOAD_IN_PROGRESS, error)
} }
@Test @ParameterizedTest
def shouldReturnInvalidEpochOnEndTxnWhenEpochIsLarger(): Unit = { @ValueSource(shorts = Array(0, 2))
def shouldReturnInvalidEpochOnEndTxnWhenEpochIsLarger(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
val serverProducerEpoch = 1.toShort val serverProducerEpoch = 1.toShort
verifyEndTxnEpoch(serverProducerEpoch, (serverProducerEpoch + 1).toShort) verifyEndTxnEpoch(serverProducerEpoch, (serverProducerEpoch + 1).toShort, clientTransactionVersion)
} }
@Test @ParameterizedTest
def shouldReturnInvalidEpochOnEndTxnWhenEpochIsSmaller(): Unit = { @ValueSource(shorts = Array(0, 2))
val serverProducerEpoch = 1.toShort def shouldReturnInvalidEpochOnEndTxnWhenEpochIsSmaller(transactionVersion: Short): Unit = {
verifyEndTxnEpoch(serverProducerEpoch, (serverProducerEpoch - 1).toShort) val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
val serverProducerEpoch = 2.toShort
// Since we bump epoch in transactionV2 the request should be one producer ID older
verifyEndTxnEpoch(serverProducerEpoch, requestEpoch(clientTransactionVersion), clientTransactionVersion)
} }
private def verifyEndTxnEpoch(metadataEpoch: Short, requestEpoch: Short): Unit = { private def verifyEndTxnEpoch(metadataEpoch: Short, requestEpoch: Short, clientTransactionVersion: TransactionVersion): Unit = {
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch,
new TransactionMetadata(transactionalId, producerId, producerId, metadataEpoch, 0, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))) new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, metadataEpoch, 1,
1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)))))
coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch, TransactionResult.COMMIT, errorsCallback) coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback)
assertEquals(Errors.PRODUCER_FENCED, error) assertEquals(Errors.PRODUCER_FENCED, error)
verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId))
} }
@Test @Test
def shouldIncrementEpochAndUpdateMetadataOnHandleInitPidWhenExistingEmptyTransaction(): Unit = { def shouldIncrementEpochAndUpdateMetadataOnHandleInitPidWhenExistingEmptyTransaction(): Unit = {
validateIncrementEpochAndUpdateMetadata(Empty) validateIncrementEpochAndUpdateMetadata(Empty, 0)
}
@ParameterizedTest
@ValueSource(shorts = Array(0, 2))
def shouldIncrementEpochAndUpdateMetadataOnHandleInitPidWhenExistingCompleteTransaction(clientTransactionVersion: Short): Unit = {
validateIncrementEpochAndUpdateMetadata(CompleteAbort, clientTransactionVersion)
}
@ParameterizedTest
@ValueSource(shorts = Array(0, 2))
def shouldIncrementEpochAndUpdateMetadataOnHandleInitPidWhenExistingCompleteCommitTransaction(clientTransactionVersion: Short): Unit = {
validateIncrementEpochAndUpdateMetadata(CompleteCommit, clientTransactionVersion)
} }
@Test @Test
def shouldIncrementEpochAndUpdateMetadataOnHandleInitPidWhenExistingCompleteTransaction(): Unit = { def shouldWaitForCommitToCompleteOnHandleInitPidAndExistingTransactionInPrepareCommitState(): Unit = {
validateIncrementEpochAndUpdateMetadata(CompleteAbort)
}
@Test
def shouldIncrementEpochAndUpdateMetadataOnHandleInitPidWhenExistingCompleteCommitTransaction(): Unit = {
validateIncrementEpochAndUpdateMetadata(CompleteCommit)
}
@Test
def shouldWaitForCommitToCompleteOnHandleInitPidAndExistingTransactionInPrepareCommitState(): Unit ={
validateRespondsWithConcurrentTransactionsOnInitPidWhenInPrepareState(PrepareCommit) validateRespondsWithConcurrentTransactionsOnInitPidWhenInPrepareState(PrepareCommit)
} }
@Test @Test
def shouldWaitForCommitToCompleteOnHandleInitPidAndExistingTransactionInPrepareAbortState(): Unit ={ def shouldWaitForCommitToCompleteOnHandleInitPidAndExistingTransactionInPrepareAbortState(): Unit = {
validateRespondsWithConcurrentTransactionsOnInitPidWhenInPrepareState(PrepareAbort) validateRespondsWithConcurrentTransactionsOnInitPidWhenInPrepareState(PrepareAbort)
} }
@Test @Test
def shouldAbortTransactionOnHandleInitPidWhenExistingTransactionInOngoingState(): Unit = { def shouldAbortTransactionOnHandleInitPidWhenExistingTransactionInOngoingState(): Unit = {
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
(producerEpoch - 1).toShort, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds()) producerEpoch, (producerEpoch - 1).toShort, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0)
when(transactionManager.validateTransactionTimeoutMs(anyInt())) when(transactionManager.validateTransactionTimeoutMs(anyInt()))
.thenReturn(true) .thenReturn(true)
@ -621,8 +779,10 @@ class TransactionCoordinatorTest {
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
val originalMetadata = new TransactionMetadata(transactionalId, producerId, producerId, (producerEpoch + 1).toShort, when(transactionManager.transactionVersionLevel()).thenReturn(TV_0)
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds())
val originalMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
(producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0)
when(transactionManager.appendTransactionToLog( when(transactionManager.appendTransactionToLog(
ArgumentMatchers.eq(transactionalId), ArgumentMatchers.eq(transactionalId),
ArgumentMatchers.eq(coordinatorEpoch), ArgumentMatchers.eq(coordinatorEpoch),
@ -640,7 +800,7 @@ class TransactionCoordinatorTest {
verify(transactionManager).appendTransactionToLog( verify(transactionManager).appendTransactionToLog(
ArgumentMatchers.eq(transactionalId), ArgumentMatchers.eq(transactionalId),
ArgumentMatchers.eq(coordinatorEpoch), ArgumentMatchers.eq(coordinatorEpoch),
ArgumentMatchers.eq(originalMetadata.prepareAbortOrCommit(PrepareAbort, time.milliseconds())), ArgumentMatchers.eq(originalMetadata.prepareAbortOrCommit(PrepareAbort, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds())),
any(), any(),
any(), any(),
any()) any())
@ -648,14 +808,14 @@ class TransactionCoordinatorTest {
@Test @Test
def shouldFailToAbortTransactionOnHandleInitPidWhenProducerEpochIsSmaller(): Unit = { def shouldFailToAbortTransactionOnHandleInitPidWhenProducerEpochIsSmaller(): Unit = {
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
(producerEpoch - 1).toShort, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds()) producerEpoch, (producerEpoch - 1).toShort, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0)
when(transactionManager.validateTransactionTimeoutMs(anyInt())) when(transactionManager.validateTransactionTimeoutMs(anyInt()))
.thenReturn(true) .thenReturn(true)
val bumpedTxnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, (producerEpoch + 2).toShort, val bumpedTxnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
(producerEpoch - 1).toShort, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds()) (producerEpoch + 2).toShort, (producerEpoch - 1).toShort, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
@ -671,8 +831,8 @@ class TransactionCoordinatorTest {
@Test @Test
def shouldNotRepeatedlyBumpEpochDueToInitPidDuringOngoingTxnIfAppendToLogFails(): Unit = { def shouldNotRepeatedlyBumpEpochDueToInitPidDuringOngoingTxnIfAppendToLogFails(): Unit = {
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds()) producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0)
when(transactionManager.validateTransactionTimeoutMs(anyInt())) when(transactionManager.validateTransactionTimeoutMs(anyInt()))
.thenReturn(true) .thenReturn(true)
@ -683,9 +843,11 @@ class TransactionCoordinatorTest {
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenAnswer(_ => Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) .thenAnswer(_ => Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
val originalMetadata = new TransactionMetadata(transactionalId, producerId, producerId, (producerEpoch + 1).toShort, when(transactionManager.transactionVersionLevel()).thenReturn(TV_0)
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds())
val txnTransitMetadata = originalMetadata.prepareAbortOrCommit(PrepareAbort, time.milliseconds()) val originalMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
(producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0)
val txnTransitMetadata = originalMetadata.prepareAbortOrCommit(PrepareAbort, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds())
when(transactionManager.appendTransactionToLog( when(transactionManager.appendTransactionToLog(
ArgumentMatchers.eq(transactionalId), ArgumentMatchers.eq(transactionalId),
ArgumentMatchers.eq(coordinatorEpoch), ArgumentMatchers.eq(coordinatorEpoch),
@ -740,33 +902,38 @@ class TransactionCoordinatorTest {
@Test @Test
def shouldUseLastEpochToFenceWhenEpochsAreExhausted(): Unit = { def shouldUseLastEpochToFenceWhenEpochsAreExhausted(): Unit = {
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, (Short.MaxValue - 1).toShort, val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
(Short.MaxValue - 2).toShort, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds()) (Short.MaxValue - 1).toShort, (Short.MaxValue - 2).toShort, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0)
assertTrue(txnMetadata.isProducerEpochExhausted) assertTrue(txnMetadata.isProducerEpochExhausted)
when(transactionManager.validateTransactionTimeoutMs(anyInt())) when(transactionManager.validateTransactionTimeoutMs(anyInt()))
.thenReturn(true) .thenReturn(true)
val postFenceTxnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, Short.MaxValue, val postFenceTxnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, PrepareAbort, partitions, time.milliseconds(), time.milliseconds()) Short.MaxValue, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, PrepareAbort, partitions, time.milliseconds(), time.milliseconds(), TV_0)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, postFenceTxnMetadata)))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, postFenceTxnMetadata))))
when(transactionManager.transactionVersionLevel()).thenReturn(TV_0)
// InitProducerId uses FenceProducerEpoch so clientTransactionVersion is 0.
when(transactionManager.appendTransactionToLog( when(transactionManager.appendTransactionToLog(
ArgumentMatchers.eq(transactionalId), ArgumentMatchers.eq(transactionalId),
ArgumentMatchers.eq(coordinatorEpoch), ArgumentMatchers.eq(coordinatorEpoch),
ArgumentMatchers.eq(TxnTransitMetadata( ArgumentMatchers.eq(TxnTransitMetadata(
producerId = producerId, producerId = producerId,
lastProducerId = producerId, prevProducerId = producerId,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = Short.MaxValue, producerEpoch = Short.MaxValue,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = txnTimeoutMs, txnTimeoutMs = txnTimeoutMs,
txnState = PrepareAbort, txnState = PrepareAbort,
topicPartitions = partitions.toSet, topicPartitions = partitions.toSet,
txnStartTimestamp = time.milliseconds(), txnStartTimestamp = time.milliseconds(),
txnLastUpdateTimestamp = time.milliseconds())), txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)),
capturedErrorsCallback.capture(), capturedErrorsCallback.capture(),
any(), any(),
any()) any())
@ -783,14 +950,16 @@ class TransactionCoordinatorTest {
ArgumentMatchers.eq(coordinatorEpoch), ArgumentMatchers.eq(coordinatorEpoch),
ArgumentMatchers.eq(TxnTransitMetadata( ArgumentMatchers.eq(TxnTransitMetadata(
producerId = producerId, producerId = producerId,
lastProducerId = producerId, prevProducerId = producerId,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = Short.MaxValue, producerEpoch = Short.MaxValue,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = txnTimeoutMs, txnTimeoutMs = txnTimeoutMs,
txnState = PrepareAbort, txnState = PrepareAbort,
topicPartitions = partitions.toSet, topicPartitions = partitions.toSet,
txnStartTimestamp = time.milliseconds(), txnStartTimestamp = time.milliseconds(),
txnLastUpdateTimestamp = time.milliseconds())), txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)),
any(), any(),
any(), any(),
any()) any())
@ -800,8 +969,8 @@ class TransactionCoordinatorTest {
def testInitProducerIdWithNoLastProducerData(): Unit = { def testInitProducerIdWithNoLastProducerData(): Unit = {
// If the metadata doesn't include the previous producer data (for example, if it was written to the log by a broker // If the metadata doesn't include the previous producer data (for example, if it was written to the log by a broker
// on an old version), the retry case should fail // on an old version), the retry case should fail
val txnMetadata = new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, (producerEpoch + 1).toShort, val txnMetadata = new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds) RecordBatch.NO_PRODUCER_EPOCH, (producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds, TV_0)
when(transactionManager.validateTransactionTimeoutMs(anyInt())) when(transactionManager.validateTransactionTimeoutMs(anyInt()))
.thenReturn(true) .thenReturn(true)
@ -817,8 +986,8 @@ class TransactionCoordinatorTest {
@Test @Test
def testFenceProducerWhenMappingExistsWithDifferentProducerId(): Unit = { def testFenceProducerWhenMappingExistsWithDifferentProducerId(): Unit = {
// Existing transaction ID maps to new producer ID // Existing transaction ID maps to new producer ID
val txnMetadata = new TransactionMetadata(transactionalId, producerId + 1, producerId, producerEpoch, val txnMetadata = new TransactionMetadata(transactionalId, producerId + 1, producerId,
(producerEpoch - 1).toShort, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds) RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, (producerEpoch - 1).toShort, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds, TV_0)
when(transactionManager.validateTransactionTimeoutMs(anyInt())) when(transactionManager.validateTransactionTimeoutMs(anyInt()))
.thenReturn(true) .thenReturn(true)
@ -835,8 +1004,8 @@ class TransactionCoordinatorTest {
def testInitProducerIdWithCurrentEpochProvided(): Unit = { def testInitProducerIdWithCurrentEpochProvided(): Unit = {
mockPidGenerator() mockPidGenerator()
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, 10, val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId,
9, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds) RecordBatch.NO_PRODUCER_EPOCH, 10, 9, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds, TV_0)
when(transactionManager.validateTransactionTimeoutMs(anyInt())) when(transactionManager.validateTransactionTimeoutMs(anyInt()))
.thenReturn(true) .thenReturn(true)
@ -870,8 +1039,8 @@ class TransactionCoordinatorTest {
def testInitProducerIdStaleCurrentEpochProvided(): Unit = { def testInitProducerIdStaleCurrentEpochProvided(): Unit = {
mockPidGenerator() mockPidGenerator()
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, 10, val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId,
9, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds) RecordBatch.NO_PRODUCER_EPOCH, 10, 9, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds, TV_0)
when(transactionManager.validateTransactionTimeoutMs(anyInt())) when(transactionManager.validateTransactionTimeoutMs(anyInt()))
.thenReturn(true) .thenReturn(true)
@ -906,8 +1075,8 @@ class TransactionCoordinatorTest {
@Test @Test
def testRetryInitProducerIdAfterProducerIdRotation(): Unit = { def testRetryInitProducerIdAfterProducerIdRotation(): Unit = {
// Existing transaction ID maps to new producer ID // Existing transaction ID maps to new producer ID
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, (Short.MaxValue - 1).toShort, val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId,
(Short.MaxValue - 2).toShort, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds) RecordBatch.NO_PRODUCER_EPOCH, (Short.MaxValue - 1).toShort, (Short.MaxValue - 2).toShort, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds, TV_0)
when(pidGenerator.generateProducerId()) when(pidGenerator.generateProducerId())
.thenReturn(Success(producerId + 1)) .thenReturn(Success(producerId + 1))
@ -928,7 +1097,7 @@ class TransactionCoordinatorTest {
capturedErrorsCallback.getValue.apply(Errors.NONE) capturedErrorsCallback.getValue.apply(Errors.NONE)
txnMetadata.pendingState = None txnMetadata.pendingState = None
txnMetadata.producerId = capturedTxnTransitMetadata.getValue.producerId txnMetadata.producerId = capturedTxnTransitMetadata.getValue.producerId
txnMetadata.lastProducerId = capturedTxnTransitMetadata.getValue.lastProducerId txnMetadata.previousProducerId = capturedTxnTransitMetadata.getValue.prevProducerId
txnMetadata.producerEpoch = capturedTxnTransitMetadata.getValue.producerEpoch txnMetadata.producerEpoch = capturedTxnTransitMetadata.getValue.producerEpoch
txnMetadata.lastProducerEpoch = capturedTxnTransitMetadata.getValue.lastProducerEpoch txnMetadata.lastProducerEpoch = capturedTxnTransitMetadata.getValue.lastProducerEpoch
}) })
@ -947,8 +1116,8 @@ class TransactionCoordinatorTest {
@Test @Test
def testInitProducerIdWithInvalidEpochAfterProducerIdRotation(): Unit = { def testInitProducerIdWithInvalidEpochAfterProducerIdRotation(): Unit = {
// Existing transaction ID maps to new producer ID // Existing transaction ID maps to new producer ID
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, (Short.MaxValue - 1).toShort, val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId,
(Short.MaxValue - 2).toShort, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds) RecordBatch.NO_PRODUCER_EPOCH, (Short.MaxValue - 1).toShort, (Short.MaxValue - 2).toShort, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds, TV_0)
when(pidGenerator.generateProducerId()) when(pidGenerator.generateProducerId())
.thenReturn(Success(producerId + 1)) .thenReturn(Success(producerId + 1))
@ -969,7 +1138,7 @@ class TransactionCoordinatorTest {
capturedErrorsCallback.getValue.apply(Errors.NONE) capturedErrorsCallback.getValue.apply(Errors.NONE)
txnMetadata.pendingState = None txnMetadata.pendingState = None
txnMetadata.producerId = capturedTxnTransitMetadata.getValue.producerId txnMetadata.producerId = capturedTxnTransitMetadata.getValue.producerId
txnMetadata.lastProducerId = capturedTxnTransitMetadata.getValue.lastProducerId txnMetadata.previousProducerId = capturedTxnTransitMetadata.getValue.prevProducerId
txnMetadata.producerEpoch = capturedTxnTransitMetadata.getValue.producerEpoch txnMetadata.producerEpoch = capturedTxnTransitMetadata.getValue.producerEpoch
txnMetadata.lastProducerEpoch = capturedTxnTransitMetadata.getValue.lastProducerEpoch txnMetadata.lastProducerEpoch = capturedTxnTransitMetadata.getValue.lastProducerEpoch
}) })
@ -995,16 +1164,20 @@ class TransactionCoordinatorTest {
@Test @Test
def shouldAbortExpiredTransactionsInOngoingStateAndBumpEpoch(): Unit = { def shouldAbortExpiredTransactionsInOngoingStateAndBumpEpoch(): Unit = {
val now = time.milliseconds() val now = time.milliseconds()
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now) RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now, TV_0)
when(transactionManager.timedOutTransactions()) when(transactionManager.timedOutTransactions())
.thenReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, producerId, producerEpoch))) .thenReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, producerId, producerEpoch)))
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
val expectedTransition = TxnTransitMetadata(producerId, producerId, (producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH, // Transaction timeouts use FenceProducerEpoch so clientTransactionVersion is 0.
txnTimeoutMs, PrepareAbort, partitions.toSet, now, now + TransactionStateManagerConfig.TRANSACTIONS_ABORT_TIMED_OUT_TRANSACTION_CLEANUP_INTERVAL_MS_DEFAULT) val expectedTransition = TxnTransitMetadata(producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, (producerEpoch + 1).toShort,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, PrepareAbort, partitions.toSet, now,
now + TransactionStateManagerConfig.TRANSACTIONS_ABORT_TIMED_OUT_TRANSACTION_CLEANUP_INTERVAL_MS_DEFAULT, TV_0)
when(transactionManager.transactionVersionLevel()).thenReturn(TV_0)
when(transactionManager.appendTransactionToLog(ArgumentMatchers.eq(transactionalId), when(transactionManager.appendTransactionToLog(ArgumentMatchers.eq(transactionalId),
ArgumentMatchers.eq(coordinatorEpoch), ArgumentMatchers.eq(coordinatorEpoch),
@ -1030,20 +1203,22 @@ class TransactionCoordinatorTest {
@Test @Test
def shouldNotAcceptSmallerEpochDuringTransactionExpiration(): Unit = { def shouldNotAcceptSmallerEpochDuringTransactionExpiration(): Unit = {
val now = time.milliseconds() val now = time.milliseconds()
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now) RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now, TV_0)
when(transactionManager.timedOutTransactions()) when(transactionManager.timedOutTransactions())
.thenReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, producerId, producerEpoch))) .thenReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, producerId, producerEpoch)))
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
val bumpedTxnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, (producerEpoch + 2).toShort, when(transactionManager.transactionVersionLevel()).thenReturn(TV_0)
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now)
val bumpedTxnMetadata = new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_EPOCH, (producerEpoch + 2).toShort, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now, TV_0)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, bumpedTxnMetadata)))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, bumpedTxnMetadata))))
def checkOnEndTransactionComplete(txnIdAndPidEpoch: TransactionalIdAndProducerIdEpoch)(error: Errors): Unit = { def checkOnEndTransactionComplete(txnIdAndPidEpoch: TransactionalIdAndProducerIdEpoch)(error: Errors, producerId: Long, producerEpoch: Short): Unit = {
assertEquals(Errors.PRODUCER_FENCED, error) assertEquals(Errors.PRODUCER_FENCED, error)
} }
coordinator.abortTimedOutTransactions(checkOnEndTransactionComplete) coordinator.abortTimedOutTransactions(checkOnEndTransactionComplete)
@ -1054,9 +1229,9 @@ class TransactionCoordinatorTest {
@Test @Test
def shouldNotAbortExpiredTransactionsThatHaveAPendingStateTransition(): Unit = { def shouldNotAbortExpiredTransactionsThatHaveAPendingStateTransition(): Unit = {
val metadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, val metadata = new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds()) RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0)
metadata.prepareAbortOrCommit(PrepareCommit, time.milliseconds()) metadata.prepareAbortOrCommit(PrepareCommit, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds())
when(transactionManager.timedOutTransactions()) when(transactionManager.timedOutTransactions())
.thenReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, producerId, producerEpoch))) .thenReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, producerId, producerEpoch)))
@ -1073,22 +1248,26 @@ class TransactionCoordinatorTest {
@Test @Test
def shouldNotBumpEpochWhenAbortingExpiredTransactionIfAppendToLogFails(): Unit = { def shouldNotBumpEpochWhenAbortingExpiredTransactionIfAppendToLogFails(): Unit = {
val now = time.milliseconds() val now = time.milliseconds()
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now) RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now, TV_0)
when(transactionManager.timedOutTransactions()) when(transactionManager.timedOutTransactions())
.thenReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, producerId, producerEpoch))) .thenReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, producerId, producerEpoch)))
val txnMetadataAfterAppendFailure = new TransactionMetadata(transactionalId, producerId, producerId, (producerEpoch + 1).toShort, val txnMetadataAfterAppendFailure = new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now) RecordBatch.NO_PRODUCER_EPOCH, (producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now, TV_0)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadataAfterAppendFailure)))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadataAfterAppendFailure))))
// Transaction timeouts use FenceProducerEpoch so clientTransactionVersion is 0.
val bumpedEpoch = (producerEpoch + 1).toShort val bumpedEpoch = (producerEpoch + 1).toShort
val expectedTransition = TxnTransitMetadata(producerId, producerId, bumpedEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, val expectedTransition = TxnTransitMetadata(producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, bumpedEpoch,
PrepareAbort, partitions.toSet, now, now + TransactionStateManagerConfig.TRANSACTIONS_ABORT_TIMED_OUT_TRANSACTION_CLEANUP_INTERVAL_MS_DEFAULT) RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, PrepareAbort, partitions.toSet, now,
now + TransactionStateManagerConfig.TRANSACTIONS_ABORT_TIMED_OUT_TRANSACTION_CLEANUP_INTERVAL_MS_DEFAULT, TV_0)
when(transactionManager.transactionVersionLevel()).thenReturn(TV_0)
when(transactionManager.appendTransactionToLog(ArgumentMatchers.eq(transactionalId), when(transactionManager.appendTransactionToLog(ArgumentMatchers.eq(transactionalId),
ArgumentMatchers.eq(coordinatorEpoch), ArgumentMatchers.eq(coordinatorEpoch),
@ -1117,9 +1296,9 @@ class TransactionCoordinatorTest {
@Test @Test
def shouldNotBumpEpochWithPendingTransaction(): Unit = { def shouldNotBumpEpochWithPendingTransaction(): Unit = {
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds()) RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0)
txnMetadata.prepareAbortOrCommit(PrepareCommit, time.milliseconds()) txnMetadata.prepareAbortOrCommit(PrepareCommit, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds())
when(transactionManager.validateTransactionTimeoutMs(anyInt())) when(transactionManager.validateTransactionTimeoutMs(anyInt()))
.thenReturn(true) .thenReturn(true)
@ -1146,9 +1325,9 @@ class TransactionCoordinatorTest {
def testDescribeTransactionsWithExpiringTransactionalId(): Unit = { def testDescribeTransactionsWithExpiringTransactionalId(): Unit = {
coordinator.startup(() => transactionStatePartitionCount, enableTransactionalIdExpiration = false) coordinator.startup(() => transactionStatePartitionCount, enableTransactionalIdExpiration = false)
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Dead, mutable.Set.empty, time.milliseconds(), RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Dead, mutable.Set.empty, time.milliseconds(),
time.milliseconds()) time.milliseconds(), TV_0)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
@ -1172,8 +1351,8 @@ class TransactionCoordinatorTest {
@Test @Test
def testDescribeTransactions(): Unit = { def testDescribeTransactions(): Unit = {
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds()) RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
@ -1199,7 +1378,9 @@ class TransactionCoordinatorTest {
when(transactionManager.validateTransactionTimeoutMs(anyInt())) when(transactionManager.validateTransactionTimeoutMs(anyInt()))
.thenReturn(true) .thenReturn(true)
val metadata = new TransactionMetadata(transactionalId, 0, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, state, mutable.Set[TopicPartition](new TopicPartition("topic", 1)), 0, 0) // Since the clientTransactionVersion doesn't matter, use 2 since the states are PrepareCommit and PrepareAbort.
val metadata = new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_EPOCH,
0, RecordBatch.NO_PRODUCER_EPOCH, 0, state, mutable.Set[TopicPartition](new TopicPartition("topic", 1)), 0, 0, TV_2)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, metadata)))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, metadata))))
@ -1208,14 +1389,16 @@ class TransactionCoordinatorTest {
assertEquals(InitProducerIdResult(-1, -1, Errors.CONCURRENT_TRANSACTIONS), result) assertEquals(InitProducerIdResult(-1, -1, Errors.CONCURRENT_TRANSACTIONS), result)
} }
private def validateIncrementEpochAndUpdateMetadata(state: TransactionState): Unit = { private def validateIncrementEpochAndUpdateMetadata(state: TransactionState, transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
when(pidGenerator.generateProducerId()) when(pidGenerator.generateProducerId())
.thenReturn(Success(producerId)) .thenReturn(Success(producerId))
when(transactionManager.validateTransactionTimeoutMs(anyInt())) when(transactionManager.validateTransactionTimeoutMs(anyInt()))
.thenReturn(true) .thenReturn(true)
val metadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, state, mutable.Set.empty[TopicPartition], time.milliseconds(), time.milliseconds()) val metadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH,
producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, state, mutable.Set.empty[TopicPartition], time.milliseconds(), time.milliseconds(), clientTransactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, metadata)))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, metadata))))
@ -1242,13 +1425,13 @@ class TransactionCoordinatorTest {
assertEquals(producerId, metadata.producerId) assertEquals(producerId, metadata.producerId)
} }
private def mockPrepare(transactionState: TransactionState, runCallback: Boolean = false): TransactionMetadata = { private def mockPrepare(transactionState: TransactionState, clientTransactionVersion: TransactionVersion, runCallback: Boolean = false): TransactionMetadata = {
val now = time.milliseconds() val now = time.milliseconds()
val originalMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, val originalMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs, Ongoing, partitions, now, now) producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now, TV_0)
val transition = TxnTransitMetadata(producerId, producerId, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, val transition = TxnTransitMetadata(producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, producerEpoch,
transactionState, partitions.toSet, now, now) RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, transactionState, partitions.toSet, now, now, clientTransactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, originalMetadata)))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, originalMetadata))))
@ -1264,8 +1447,8 @@ class TransactionCoordinatorTest {
capturedErrorsCallback.getValue.apply(Errors.NONE) capturedErrorsCallback.getValue.apply(Errors.NONE)
}) })
new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, producerEpoch,
txnTimeoutMs, transactionState, partitions, time.milliseconds(), time.milliseconds()) RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, transactionState, partitions, time.milliseconds(), time.milliseconds(), clientTransactionVersion)
} }
def initProducerIdMockCallback(ret: InitProducerIdResult): Unit = { def initProducerIdMockCallback(ret: InitProducerIdResult): Unit = {
@ -1275,4 +1458,17 @@ class TransactionCoordinatorTest {
def errorsCallback(ret: Errors): Unit = { def errorsCallback(ret: Errors): Unit = {
error = ret error = ret
} }
def endTxnCallback(ret: Errors, producerId: Long, epoch: Short): Unit = {
error = ret
newProducerId = producerId
newEpoch = epoch
}
def requestEpoch(clientTransactionVersion: TransactionVersion): Short = {
if (clientTransactionVersion.supportsEpochBump())
(producerEpoch - 1).toShort
else
producerEpoch
}
} }

View File

@ -23,8 +23,9 @@ import org.apache.kafka.common.compress.Compression
import org.apache.kafka.common.protocol.{ByteBufferAccessor, MessageUtil} import org.apache.kafka.common.protocol.{ByteBufferAccessor, MessageUtil}
import org.apache.kafka.common.protocol.types.Field.TaggedFieldsSection import org.apache.kafka.common.protocol.types.Field.TaggedFieldsSection
import org.apache.kafka.common.protocol.types.{CompactArrayOf, Field, Schema, Struct, Type} import org.apache.kafka.common.protocol.types.{CompactArrayOf, Field, Schema, Struct, Type}
import org.apache.kafka.common.record.{MemoryRecords, SimpleRecord} import org.apache.kafka.common.record.{MemoryRecords, RecordBatch, SimpleRecord}
import org.apache.kafka.coordinator.transaction.generated.{TransactionLogKey, TransactionLogValue} import org.apache.kafka.coordinator.transaction.generated.{TransactionLogKey, TransactionLogValue}
import org.apache.kafka.server.common.TransactionVersion.{TV_0, TV_2}
import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows, assertTrue} import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows, assertTrue}
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
@ -48,10 +49,11 @@ class TransactionLogTest {
val transactionalId = "transactionalId" val transactionalId = "transactionalId"
val producerId = 23423L val producerId = 23423L
val txnMetadata = TransactionMetadata(transactionalId, producerId, producerEpoch, transactionTimeoutMs, 0) val txnMetadata = new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_ID, producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH, transactionTimeoutMs, Empty, collection.mutable.Set.empty[TopicPartition], 0, 0, TV_0)
txnMetadata.addPartitions(topicPartitions) txnMetadata.addPartitions(topicPartitions)
assertThrows(classOf[IllegalStateException], () => TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), true)) assertThrows(classOf[IllegalStateException], () => TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), TV_2))
} }
@Test @Test
@ -72,14 +74,14 @@ class TransactionLogTest {
// generate transaction log messages // generate transaction log messages
val txnRecords = pidMappings.map { case (transactionalId, producerId) => val txnRecords = pidMappings.map { case (transactionalId, producerId) =>
val txnMetadata = TransactionMetadata(transactionalId, producerId, producerEpoch, transactionTimeoutMs, val txnMetadata = new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_ID, producerEpoch,
transactionStates(producerId), 0) RecordBatch.NO_PRODUCER_EPOCH, transactionTimeoutMs, transactionStates(producerId), collection.mutable.Set.empty[TopicPartition], 0, 0, TV_0)
if (!txnMetadata.state.equals(Empty)) if (!txnMetadata.state.equals(Empty))
txnMetadata.addPartitions(topicPartitions) txnMetadata.addPartitions(topicPartitions)
val keyBytes = TransactionLog.keyToBytes(transactionalId) val keyBytes = TransactionLog.keyToBytes(transactionalId)
val valueBytes = TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), true) val valueBytes = TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), TV_2)
new SimpleRecord(keyBytes, valueBytes) new SimpleRecord(keyBytes, valueBytes)
}.toSeq }.toSeq
@ -114,12 +116,12 @@ class TransactionLogTest {
val producerId = 1334L val producerId = 1334L
val topicPartition = new TopicPartition("topic", 0) val topicPartition = new TopicPartition("topic", 0)
val txnMetadata = TransactionMetadata(transactionalId, producerId, producerEpoch, val txnMetadata = new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_ID, producerEpoch,
transactionTimeoutMs, Ongoing, 0) RecordBatch.NO_PRODUCER_EPOCH, transactionTimeoutMs, Ongoing, collection.mutable.Set.empty[TopicPartition], 0, 0, TV_0)
txnMetadata.addPartitions(Set(topicPartition)) txnMetadata.addPartitions(Set(topicPartition))
val keyBytes = TransactionLog.keyToBytes(transactionalId) val keyBytes = TransactionLog.keyToBytes(transactionalId)
val valueBytes = TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), true) val valueBytes = TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), TV_2)
val transactionMetadataRecord = TestUtils.records(Seq( val transactionMetadataRecord = TestUtils.records(Seq(
new SimpleRecord(keyBytes, valueBytes) new SimpleRecord(keyBytes, valueBytes)
)).records.asScala.head )).records.asScala.head
@ -144,15 +146,15 @@ class TransactionLogTest {
@Test @Test
def testSerializeTransactionLogValueToHighestNonFlexibleVersion(): Unit = { def testSerializeTransactionLogValueToHighestNonFlexibleVersion(): Unit = {
val txnTransitMetadata = TxnTransitMetadata(1, 1, 1, 1, 1000, CompleteCommit, Set.empty, 500, 500) val txnTransitMetadata = TxnTransitMetadata(1, 1, 1, 1, 1, 1000, CompleteCommit, Set.empty, 500, 500, TV_0)
val txnLogValueBuffer = ByteBuffer.wrap(TransactionLog.valueToBytes(txnTransitMetadata, false)) val txnLogValueBuffer = ByteBuffer.wrap(TransactionLog.valueToBytes(txnTransitMetadata, TV_0))
assertEquals(0, txnLogValueBuffer.getShort) assertEquals(0, txnLogValueBuffer.getShort)
} }
@Test @Test
def testSerializeTransactionLogValueToFlexibleVersion(): Unit = { def testSerializeTransactionLogValueToFlexibleVersion(): Unit = {
val txnTransitMetadata = TxnTransitMetadata(1, 1, 1, 1, 1000, CompleteCommit, Set.empty, 500, 500) val txnTransitMetadata = TxnTransitMetadata(1, 1, 1, 1, 1, 1000, CompleteCommit, Set.empty, 500, 500, TV_2)
val txnLogValueBuffer = ByteBuffer.wrap(TransactionLog.valueToBytes(txnTransitMetadata, true)) val txnLogValueBuffer = ByteBuffer.wrap(TransactionLog.valueToBytes(txnTransitMetadata, TV_2))
assertEquals(TransactionLogValue.HIGHEST_SUPPORTED_VERSION, txnLogValueBuffer.getShort) assertEquals(TransactionLogValue.HIGHEST_SUPPORTED_VERSION, txnLogValueBuffer.getShort)
} }
@ -194,8 +196,8 @@ class TransactionLogTest {
new Field("topic", Type.COMPACT_STRING, ""), new Field("topic", Type.COMPACT_STRING, ""),
new Field("partition_ids", new CompactArrayOf(Type.INT32), ""), new Field("partition_ids", new CompactArrayOf(Type.INT32), ""),
TaggedFieldsSection.of( TaggedFieldsSection.of(
Int.box(0), new Field("partition_foo", Type.STRING, ""), Int.box(100), new Field("partition_foo", Type.STRING, ""),
Int.box(1), new Field("partition_foo", Type.INT32, "") Int.box(101), new Field("partition_foo", Type.INT32, "")
) )
) )
@ -204,8 +206,8 @@ class TransactionLogTest {
txnPartitions.set("topic", "topic") txnPartitions.set("topic", "topic")
txnPartitions.set("partition_ids", Array(Integer.valueOf(1))) txnPartitions.set("partition_ids", Array(Integer.valueOf(1)))
val txnPartitionsTaggedFields = new java.util.TreeMap[Integer, Any]() val txnPartitionsTaggedFields = new java.util.TreeMap[Integer, Any]()
txnPartitionsTaggedFields.put(0, "foo") txnPartitionsTaggedFields.put(100, "foo")
txnPartitionsTaggedFields.put(1, 4000) txnPartitionsTaggedFields.put(101, 4000)
txnPartitions.set("_tagged_fields", txnPartitionsTaggedFields) txnPartitions.set("_tagged_fields", txnPartitionsTaggedFields)
// Copy of TransactionLogValue.SCHEMA_1 with a few // Copy of TransactionLogValue.SCHEMA_1 with a few
@ -219,8 +221,8 @@ class TransactionLogTest {
new Field("transaction_last_update_timestamp_ms", Type.INT64, ""), new Field("transaction_last_update_timestamp_ms", Type.INT64, ""),
new Field("transaction_start_timestamp_ms", Type.INT64, ""), new Field("transaction_start_timestamp_ms", Type.INT64, ""),
TaggedFieldsSection.of( TaggedFieldsSection.of(
Int.box(0), new Field("txn_foo", Type.STRING, ""), Int.box(100), new Field("txn_foo", Type.STRING, ""),
Int.box(1), new Field("txn_bar", Type.INT32, "") Int.box(101), new Field("txn_bar", Type.INT32, "")
) )
) )
@ -234,8 +236,8 @@ class TransactionLogTest {
transactionLogValue.set("transaction_last_update_timestamp_ms", 2000L) transactionLogValue.set("transaction_last_update_timestamp_ms", 2000L)
transactionLogValue.set("transaction_start_timestamp_ms", 3000L) transactionLogValue.set("transaction_start_timestamp_ms", 3000L)
val txnLogValueTaggedFields = new java.util.TreeMap[Integer, Any]() val txnLogValueTaggedFields = new java.util.TreeMap[Integer, Any]()
txnLogValueTaggedFields.put(0, "foo") txnLogValueTaggedFields.put(100, "foo")
txnLogValueTaggedFields.put(1, 4000) txnLogValueTaggedFields.put(101, 4000)
transactionLogValue.set("_tagged_fields", txnLogValueTaggedFields) transactionLogValue.set("_tagged_fields", txnLogValueTaggedFields)
// Prepare the buffer. // Prepare the buffer.
@ -249,8 +251,8 @@ class TransactionLogTest {
// fields were read but ignored. // fields were read but ignored.
buffer.getShort() // Skip version. buffer.getShort() // Skip version.
val value = new TransactionLogValue(new ByteBufferAccessor(buffer), 1.toShort) val value = new TransactionLogValue(new ByteBufferAccessor(buffer), 1.toShort)
assertEquals(Seq(0, 1), value.unknownTaggedFields().asScala.map(_.tag)) assertEquals(Seq(100, 101), value.unknownTaggedFields().asScala.map(_.tag))
assertEquals(Seq(0, 1), value.transactionPartitions().get(0).unknownTaggedFields().asScala.map(_.tag)) assertEquals(Seq(100, 101), value.transactionPartitions().get(0).unknownTaggedFields().asScala.map(_.tag))
// Read the buffer with readTxnRecordValue. // Read the buffer with readTxnRecordValue.
buffer.rewind() buffer.rewind()

View File

@ -28,7 +28,7 @@ import org.apache.kafka.common.record.RecordBatch
import org.apache.kafka.common.requests.{RequestHeader, TransactionResult, WriteTxnMarkersRequest, WriteTxnMarkersResponse} import org.apache.kafka.common.requests.{RequestHeader, TransactionResult, WriteTxnMarkersRequest, WriteTxnMarkersResponse}
import org.apache.kafka.common.utils.MockTime import org.apache.kafka.common.utils.MockTime
import org.apache.kafka.common.{Node, TopicPartition} import org.apache.kafka.common.{Node, TopicPartition}
import org.apache.kafka.server.common.MetadataVersion import org.apache.kafka.server.common.{MetadataVersion, TransactionVersion}
import org.apache.kafka.server.metrics.{KafkaMetricsGroup, KafkaYammerMetrics} import org.apache.kafka.server.metrics.{KafkaMetricsGroup, KafkaYammerMetrics}
import org.apache.kafka.server.util.RequestAndCompletionHandler import org.apache.kafka.server.util.RequestAndCompletionHandler
import org.junit.jupiter.api.Assertions._ import org.junit.jupiter.api.Assertions._
@ -63,10 +63,10 @@ class TransactionMarkerChannelManagerTest {
private val coordinatorEpoch2 = 1 private val coordinatorEpoch2 = 1
private val txnTimeoutMs = 0 private val txnTimeoutMs = 0
private val txnResult = TransactionResult.COMMIT private val txnResult = TransactionResult.COMMIT
private val txnMetadata1 = new TransactionMetadata(transactionalId1, producerId1, producerId1, producerEpoch, lastProducerEpoch, private val txnMetadata1 = new TransactionMetadata(transactionalId1, producerId1, producerId1, RecordBatch.NO_PRODUCER_ID,
txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](partition1, partition2), 0L, 0L) producerEpoch, lastProducerEpoch, txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](partition1, partition2), 0L, 0L, TransactionVersion.TV_2)
private val txnMetadata2 = new TransactionMetadata(transactionalId2, producerId2, producerId2, producerEpoch, lastProducerEpoch, private val txnMetadata2 = new TransactionMetadata(transactionalId2, producerId2, producerId2, RecordBatch.NO_PRODUCER_ID,
txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](partition1), 0L, 0L) producerEpoch, lastProducerEpoch, txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](partition1), 0L, 0L, TransactionVersion.TV_2)
private val capturedErrorsCallback: ArgumentCaptor[Errors => Unit] = ArgumentCaptor.forClass(classOf[Errors => Unit]) private val capturedErrorsCallback: ArgumentCaptor[Errors => Unit] = ArgumentCaptor.forClass(classOf[Errors => Unit])
private val time = new MockTime private val time = new MockTime

View File

@ -23,6 +23,7 @@ import org.apache.kafka.common.TopicPartition
import org.apache.kafka.common.protocol.{ApiKeys, Errors} import org.apache.kafka.common.protocol.{ApiKeys, Errors}
import org.apache.kafka.common.record.RecordBatch import org.apache.kafka.common.record.RecordBatch
import org.apache.kafka.common.requests.{RequestHeader, TransactionResult, WriteTxnMarkersRequest, WriteTxnMarkersResponse} import org.apache.kafka.common.requests.{RequestHeader, TransactionResult, WriteTxnMarkersRequest, WriteTxnMarkersResponse}
import org.apache.kafka.server.common.TransactionVersion
import org.junit.jupiter.api.Assertions._ import org.junit.jupiter.api.Assertions._
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
import org.mockito.ArgumentMatchers import org.mockito.ArgumentMatchers
@ -42,8 +43,8 @@ class TransactionMarkerRequestCompletionHandlerTest {
private val coordinatorEpoch = 0 private val coordinatorEpoch = 0
private val txnResult = TransactionResult.COMMIT private val txnResult = TransactionResult.COMMIT
private val topicPartition = new TopicPartition("topic1", 0) private val topicPartition = new TopicPartition("topic1", 0)
private val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, lastProducerEpoch, private val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](topicPartition), 0L, 0L) producerEpoch, lastProducerEpoch, txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](topicPartition), 0L, 0L, TransactionVersion.TV_2)
private val pendingCompleteTxnAndMarkers = asList( private val pendingCompleteTxnAndMarkers = asList(
PendingCompleteTxnAndMarkerEntry( PendingCompleteTxnAndMarkerEntry(
PendingCompleteTxn(transactionalId, coordinatorEpoch, txnMetadata, txnMetadata.prepareComplete(42)), PendingCompleteTxn(transactionalId, coordinatorEpoch, txnMetadata, txnMetadata.prepareComplete(42)),

View File

@ -19,9 +19,13 @@ package kafka.coordinator.transaction
import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.TopicPartition
import org.apache.kafka.common.protocol.Errors import org.apache.kafka.common.protocol.Errors
import org.apache.kafka.common.record.RecordBatch import org.apache.kafka.common.record.RecordBatch
import org.apache.kafka.server.common.TransactionVersion
import org.apache.kafka.server.common.TransactionVersion.{TV_0, TV_2}
import org.apache.kafka.server.util.MockTime import org.apache.kafka.server.util.MockTime
import org.junit.jupiter.api.Assertions._ import org.junit.jupiter.api.Assertions._
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.ValueSource
import scala.collection.mutable import scala.collection.mutable
@ -38,13 +42,15 @@ class TransactionMetadataTest {
val txnMetadata = new TransactionMetadata( val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId, transactionalId = transactionalId,
producerId = producerId, producerId = producerId,
lastProducerId = RecordBatch.NO_PRODUCER_ID, previousProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch, producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000, txnTimeoutMs = 30000,
state = Empty, state = Empty,
topicPartitions = mutable.Set.empty, topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds()) txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, None) val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, None)
txnMetadata.completeTransitionTo(transitMetadata) txnMetadata.completeTransitionTo(transitMetadata)
@ -60,13 +66,15 @@ class TransactionMetadataTest {
val txnMetadata = new TransactionMetadata( val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId, transactionalId = transactionalId,
producerId = producerId, producerId = producerId,
lastProducerId = RecordBatch.NO_PRODUCER_ID, previousProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch, producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000, txnTimeoutMs = 30000,
state = Empty, state = Empty,
topicPartitions = mutable.Set.empty, topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds()) txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, None) val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, None)
txnMetadata.completeTransitionTo(transitMetadata) txnMetadata.completeTransitionTo(transitMetadata)
@ -82,13 +90,15 @@ class TransactionMetadataTest {
val txnMetadata = new TransactionMetadata( val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId, transactionalId = transactionalId,
producerId = producerId, producerId = producerId,
lastProducerId = RecordBatch.NO_PRODUCER_ID, previousProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch, producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000, txnTimeoutMs = 30000,
state = Empty, state = Empty,
topicPartitions = mutable.Set.empty, topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds()) txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
assertTrue(txnMetadata.isProducerEpochExhausted) assertTrue(txnMetadata.isProducerEpochExhausted)
assertThrows(classOf[IllegalStateException], () => txnMetadata.prepareIncrementProducerEpoch(30000, assertThrows(classOf[IllegalStateException], () => txnMetadata.prepareIncrementProducerEpoch(30000,
@ -101,14 +111,16 @@ class TransactionMetadataTest {
val txnMetadata = new TransactionMetadata( val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId, transactionalId = transactionalId,
producerId = producerId, producerId = producerId,
lastProducerId = RecordBatch.NO_PRODUCER_ID, previousProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch, producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000, txnTimeoutMs = 30000,
state = Empty, state = Empty,
topicPartitions = mutable.Set.empty, topicPartitions = mutable.Set.empty,
txnStartTimestamp = 1L, txnStartTimestamp = 1L,
txnLastUpdateTimestamp = time.milliseconds()) txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
// let new time be smaller // let new time be smaller
val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, Option(producerEpoch), val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, Option(producerEpoch),
@ -127,14 +139,16 @@ class TransactionMetadataTest {
val txnMetadata = new TransactionMetadata( val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId, transactionalId = transactionalId,
producerId = producerId, producerId = producerId,
lastProducerId = RecordBatch.NO_PRODUCER_ID, previousProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch, producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000, txnTimeoutMs = 30000,
state = Empty, state = Empty,
topicPartitions = mutable.Set.empty, topicPartitions = mutable.Set.empty,
txnStartTimestamp = 1L, txnStartTimestamp = 1L,
txnLastUpdateTimestamp = time.milliseconds()) txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
// let new time be smaller // let new time be smaller
val transitMetadata = txnMetadata.prepareProducerIdRotation(producerId + 1, 30000, time.milliseconds() - 1, recordLastEpoch = true) val transitMetadata = txnMetadata.prepareProducerIdRotation(producerId + 1, 30000, time.milliseconds() - 1, recordLastEpoch = true)
@ -152,14 +166,16 @@ class TransactionMetadataTest {
val txnMetadata = new TransactionMetadata( val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId, transactionalId = transactionalId,
producerId = producerId, producerId = producerId,
lastProducerId = RecordBatch.NO_PRODUCER_ID, previousProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch, producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000, txnTimeoutMs = 30000,
state = Empty, state = Empty,
topicPartitions = mutable.Set.empty, topicPartitions = mutable.Set.empty,
txnStartTimestamp = time.milliseconds(), txnStartTimestamp = time.milliseconds(),
txnLastUpdateTimestamp = time.milliseconds()) txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
// let new time be smaller; when transiting from Empty the start time would be updated to the update-time // let new time be smaller; when transiting from Empty the start time would be updated to the update-time
var transitMetadata = txnMetadata.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic1", 0)), time.milliseconds() - 1) var transitMetadata = txnMetadata.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic1", 0)), time.milliseconds() - 1)
@ -188,17 +204,19 @@ class TransactionMetadataTest {
val txnMetadata = new TransactionMetadata( val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId, transactionalId = transactionalId,
producerId = producerId, producerId = producerId,
lastProducerId = RecordBatch.NO_PRODUCER_ID, previousProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch, producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000, txnTimeoutMs = 30000,
state = Ongoing, state = Ongoing,
topicPartitions = mutable.Set.empty, topicPartitions = mutable.Set.empty,
txnStartTimestamp = 1L, txnStartTimestamp = 1L,
txnLastUpdateTimestamp = time.milliseconds()) txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
// let new time be smaller // let new time be smaller
val transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareCommit, time.milliseconds() - 1) val transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareCommit, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds() - 1)
txnMetadata.completeTransitionTo(transitMetadata) txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(PrepareCommit, txnMetadata.state) assertEquals(PrepareCommit, txnMetadata.state)
assertEquals(producerId, txnMetadata.producerId) assertEquals(producerId, txnMetadata.producerId)
@ -214,17 +232,19 @@ class TransactionMetadataTest {
val txnMetadata = new TransactionMetadata( val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId, transactionalId = transactionalId,
producerId = producerId, producerId = producerId,
lastProducerId = RecordBatch.NO_PRODUCER_ID, previousProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch, producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000, txnTimeoutMs = 30000,
state = Ongoing, state = Ongoing,
topicPartitions = mutable.Set.empty, topicPartitions = mutable.Set.empty,
txnStartTimestamp = 1L, txnStartTimestamp = 1L,
txnLastUpdateTimestamp = time.milliseconds()) txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
// let new time be smaller // let new time be smaller
val transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareAbort, time.milliseconds() - 1) val transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareAbort, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds() - 1)
txnMetadata.completeTransitionTo(transitMetadata) txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(PrepareAbort, txnMetadata.state) assertEquals(PrepareAbort, txnMetadata.state)
assertEquals(producerId, txnMetadata.producerId) assertEquals(producerId, txnMetadata.producerId)
@ -234,53 +254,65 @@ class TransactionMetadataTest {
assertEquals(time.milliseconds() - 1, txnMetadata.txnLastUpdateTimestamp) assertEquals(time.milliseconds() - 1, txnMetadata.txnLastUpdateTimestamp)
} }
@Test @ParameterizedTest
def testTolerateTimeShiftDuringCompleteCommit(): Unit = { @ValueSource(shorts = Array(0, 2))
def testTolerateTimeShiftDuringCompleteCommit(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
val producerEpoch: Short = 1 val producerEpoch: Short = 1
val txnMetadata = new TransactionMetadata( val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId, transactionalId = transactionalId,
producerId = producerId, producerId = producerId,
lastProducerId = RecordBatch.NO_PRODUCER_ID, previousProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch, producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000, txnTimeoutMs = 30000,
state = PrepareCommit, state = PrepareCommit,
topicPartitions = mutable.Set.empty, topicPartitions = mutable.Set.empty,
txnStartTimestamp = 1L, txnStartTimestamp = 1L,
txnLastUpdateTimestamp = time.milliseconds()) txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = clientTransactionVersion)
// let new time be smaller // let new time be smaller
val transitMetadata = txnMetadata.prepareComplete(time.milliseconds() - 1) val transitMetadata = txnMetadata.prepareComplete(time.milliseconds() - 1)
txnMetadata.completeTransitionTo(transitMetadata) txnMetadata.completeTransitionTo(transitMetadata)
val lastEpoch = if (clientTransactionVersion.supportsEpochBump()) producerEpoch else RecordBatch.NO_PRODUCER_EPOCH
assertEquals(CompleteCommit, txnMetadata.state) assertEquals(CompleteCommit, txnMetadata.state)
assertEquals(producerId, txnMetadata.producerId) assertEquals(producerId, txnMetadata.producerId)
assertEquals(RecordBatch.NO_PRODUCER_EPOCH, txnMetadata.lastProducerEpoch) assertEquals(lastEpoch, txnMetadata.lastProducerEpoch)
assertEquals(producerEpoch, txnMetadata.producerEpoch) assertEquals(producerEpoch, txnMetadata.producerEpoch)
assertEquals(1L, txnMetadata.txnStartTimestamp) assertEquals(1L, txnMetadata.txnStartTimestamp)
assertEquals(time.milliseconds() - 1, txnMetadata.txnLastUpdateTimestamp) assertEquals(time.milliseconds() - 1, txnMetadata.txnLastUpdateTimestamp)
} }
@Test @ParameterizedTest
def testTolerateTimeShiftDuringCompleteAbort(): Unit = { @ValueSource(shorts = Array(0, 2))
def testTolerateTimeShiftDuringCompleteAbort(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
val producerEpoch: Short = 1 val producerEpoch: Short = 1
val txnMetadata = new TransactionMetadata( val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId, transactionalId = transactionalId,
producerId = producerId, producerId = producerId,
lastProducerId = RecordBatch.NO_PRODUCER_ID, previousProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch, producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000, txnTimeoutMs = 30000,
state = PrepareAbort, state = PrepareAbort,
topicPartitions = mutable.Set.empty, topicPartitions = mutable.Set.empty,
txnStartTimestamp = 1L, txnStartTimestamp = 1L,
txnLastUpdateTimestamp = time.milliseconds()) txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = clientTransactionVersion)
// let new time be smaller // let new time be smaller
val transitMetadata = txnMetadata.prepareComplete(time.milliseconds() - 1) val transitMetadata = txnMetadata.prepareComplete(time.milliseconds() - 1)
txnMetadata.completeTransitionTo(transitMetadata) txnMetadata.completeTransitionTo(transitMetadata)
val lastEpoch = if (clientTransactionVersion.supportsEpochBump()) producerEpoch else RecordBatch.NO_PRODUCER_EPOCH
assertEquals(CompleteAbort, txnMetadata.state) assertEquals(CompleteAbort, txnMetadata.state)
assertEquals(producerId, txnMetadata.producerId) assertEquals(producerId, txnMetadata.producerId)
assertEquals(RecordBatch.NO_PRODUCER_EPOCH, txnMetadata.lastProducerEpoch) assertEquals(lastEpoch, txnMetadata.lastProducerEpoch)
assertEquals(producerEpoch, txnMetadata.producerEpoch) assertEquals(producerEpoch, txnMetadata.producerEpoch)
assertEquals(1L, txnMetadata.txnStartTimestamp) assertEquals(1L, txnMetadata.txnStartTimestamp)
assertEquals(time.milliseconds() - 1, txnMetadata.txnLastUpdateTimestamp) assertEquals(time.milliseconds() - 1, txnMetadata.txnLastUpdateTimestamp)
@ -293,13 +325,15 @@ class TransactionMetadataTest {
val txnMetadata = new TransactionMetadata( val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId, transactionalId = transactionalId,
producerId = producerId, producerId = producerId,
lastProducerId = RecordBatch.NO_PRODUCER_ID, previousProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch, producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000, txnTimeoutMs = 30000,
state = Ongoing, state = Ongoing,
topicPartitions = mutable.Set.empty, topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds()) txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
assertTrue(txnMetadata.isProducerEpochExhausted) assertTrue(txnMetadata.isProducerEpochExhausted)
val fencingTransitMetadata = txnMetadata.prepareFenceProducerEpoch() val fencingTransitMetadata = txnMetadata.prepareFenceProducerEpoch()
@ -310,7 +344,7 @@ class TransactionMetadataTest {
// We should reset the pending state to make way for the abort transition. // We should reset the pending state to make way for the abort transition.
txnMetadata.pendingState = None txnMetadata.pendingState = None
val transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareAbort, time.milliseconds()) val transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareAbort, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds())
txnMetadata.completeTransitionTo(transitMetadata) txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(producerId, transitMetadata.producerId) assertEquals(producerId, transitMetadata.producerId)
} }
@ -322,13 +356,15 @@ class TransactionMetadataTest {
val txnMetadata = new TransactionMetadata( val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId, transactionalId = transactionalId,
producerId = producerId, producerId = producerId,
lastProducerId = RecordBatch.NO_PRODUCER_ID, previousProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch, producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000, txnTimeoutMs = 30000,
state = Ongoing, state = Ongoing,
topicPartitions = mutable.Set.empty, topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds()) txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
assertTrue(txnMetadata.isProducerEpochExhausted) assertTrue(txnMetadata.isProducerEpochExhausted)
assertThrows(classOf[IllegalStateException], () => txnMetadata.prepareFenceProducerEpoch()) assertThrows(classOf[IllegalStateException], () => txnMetadata.prepareFenceProducerEpoch())
} }
@ -340,36 +376,108 @@ class TransactionMetadataTest {
val txnMetadata = new TransactionMetadata( val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId, transactionalId = transactionalId,
producerId = producerId, producerId = producerId,
lastProducerId = RecordBatch.NO_PRODUCER_ID, previousProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch, producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000, txnTimeoutMs = 30000,
state = Empty, state = Empty,
topicPartitions = mutable.Set.empty, topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds()) txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
val newProducerId = 9893L val newProducerId = 9893L
val transitMetadata = txnMetadata.prepareProducerIdRotation(newProducerId, 30000, time.milliseconds(), recordLastEpoch = true) val transitMetadata = txnMetadata.prepareProducerIdRotation(newProducerId, 30000, time.milliseconds(), recordLastEpoch = true)
txnMetadata.completeTransitionTo(transitMetadata) txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(newProducerId, txnMetadata.producerId) assertEquals(newProducerId, txnMetadata.producerId)
assertEquals(producerId, txnMetadata.lastProducerId) assertEquals(producerId, txnMetadata.previousProducerId)
assertEquals(0, txnMetadata.producerEpoch) assertEquals(0, txnMetadata.producerEpoch)
assertEquals(producerEpoch, txnMetadata.lastProducerEpoch) assertEquals(producerEpoch, txnMetadata.lastProducerEpoch)
} }
@Test
def testEpochBumpOnEndTxn(): Unit = {
time.sleep(100)
val producerEpoch = 10.toShort
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
previousProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = Ongoing,
topicPartitions = mutable.Set.empty,
txnStartTimestamp = time.milliseconds(),
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
var transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareCommit, TV_2, RecordBatch.NO_PRODUCER_ID, time.milliseconds() - 1)
txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(producerId, txnMetadata.producerId)
assertEquals((producerEpoch + 1).toShort, txnMetadata.producerEpoch)
assertEquals(TV_2, txnMetadata.clientTransactionVersion)
transitMetadata = txnMetadata.prepareComplete(time.milliseconds())
txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(producerId, txnMetadata.producerId)
assertEquals((producerEpoch + 1).toShort, txnMetadata.producerEpoch)
assertEquals(TV_2, txnMetadata.clientTransactionVersion)
}
@Test
def testEpochBumpOnEndTxnOverflow(): Unit = {
time.sleep(100)
val producerEpoch = (Short.MaxValue - 1).toShort
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
previousProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = Ongoing,
topicPartitions = mutable.Set.empty,
txnStartTimestamp = time.milliseconds(),
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
assertTrue(txnMetadata.isProducerEpochExhausted)
val newProducerId = 9893L
var transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareCommit, TV_2, newProducerId, time.milliseconds() - 1)
txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(producerId, txnMetadata.producerId)
assertEquals(Short.MaxValue, txnMetadata.producerEpoch)
assertEquals(TV_2, txnMetadata.clientTransactionVersion)
transitMetadata = txnMetadata.prepareComplete(time.milliseconds())
txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(newProducerId, txnMetadata.producerId)
assertEquals(0, txnMetadata.producerEpoch)
assertEquals(TV_2, txnMetadata.clientTransactionVersion)
}
@Test @Test
def testRotateProducerIdInOngoingState(): Unit = { def testRotateProducerIdInOngoingState(): Unit = {
assertThrows(classOf[IllegalStateException], () => testRotateProducerIdInOngoingState(Ongoing)) assertThrows(classOf[IllegalStateException], () => testRotateProducerIdInOngoingState(Ongoing, TV_0))
} }
@Test @ParameterizedTest
def testRotateProducerIdInPrepareAbortState(): Unit = { @ValueSource(shorts = Array(0, 2))
assertThrows(classOf[IllegalStateException], () => testRotateProducerIdInOngoingState(PrepareAbort)) def testRotateProducerIdInPrepareAbortState(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
assertThrows(classOf[IllegalStateException], () => testRotateProducerIdInOngoingState(PrepareAbort, clientTransactionVersion))
} }
@Test @ParameterizedTest
def testRotateProducerIdInPrepareCommitState(): Unit = { @ValueSource(shorts = Array(0, 2))
assertThrows(classOf[IllegalStateException], () => testRotateProducerIdInOngoingState(PrepareCommit)) def testRotateProducerIdInPrepareCommitState(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
assertThrows(classOf[IllegalStateException], () => testRotateProducerIdInOngoingState(PrepareCommit, clientTransactionVersion))
} }
@Test @Test
@ -379,13 +487,15 @@ class TransactionMetadataTest {
val txnMetadata = new TransactionMetadata( val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId, transactionalId = transactionalId,
producerId = producerId, producerId = producerId,
lastProducerId = RecordBatch.NO_PRODUCER_ID, previousProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = RecordBatch.NO_PRODUCER_EPOCH, producerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000, txnTimeoutMs = 30000,
state = Empty, state = Empty,
topicPartitions = mutable.Set.empty, topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds()) txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, Some(producerEpoch)) val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, Some(producerEpoch))
txnMetadata.completeTransitionTo(transitMetadata) txnMetadata.completeTransitionTo(transitMetadata)
@ -401,13 +511,15 @@ class TransactionMetadataTest {
val txnMetadata = new TransactionMetadata( val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId, transactionalId = transactionalId,
producerId = producerId, producerId = producerId,
lastProducerId = RecordBatch.NO_PRODUCER_ID, previousProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch, producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000, txnTimeoutMs = 30000,
state = Empty, state = Empty,
topicPartitions = mutable.Set.empty, topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds()) txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, Some(producerEpoch)) val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, Some(producerEpoch))
txnMetadata.completeTransitionTo(transitMetadata) txnMetadata.completeTransitionTo(transitMetadata)
@ -424,13 +536,15 @@ class TransactionMetadataTest {
val txnMetadata = new TransactionMetadata( val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId, transactionalId = transactionalId,
producerId = producerId, producerId = producerId,
lastProducerId = RecordBatch.NO_PRODUCER_ID, previousProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch, producerEpoch = producerEpoch,
lastProducerEpoch = lastProducerEpoch, lastProducerEpoch = lastProducerEpoch,
txnTimeoutMs = 30000, txnTimeoutMs = 30000,
state = Empty, state = Empty,
topicPartitions = mutable.Set.empty, topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds()) txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, Some(lastProducerEpoch)) val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, Some(lastProducerEpoch))
txnMetadata.completeTransitionTo(transitMetadata) txnMetadata.completeTransitionTo(transitMetadata)
@ -447,13 +561,15 @@ class TransactionMetadataTest {
val txnMetadata = new TransactionMetadata( val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId, transactionalId = transactionalId,
producerId = producerId, producerId = producerId,
lastProducerId = producerId, previousProducerId = producerId,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch, producerEpoch = producerEpoch,
lastProducerEpoch = lastProducerEpoch, lastProducerEpoch = lastProducerEpoch,
txnTimeoutMs = 30000, txnTimeoutMs = 30000,
state = Empty, state = Empty,
topicPartitions = mutable.Set.empty, topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds()) txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
val result = txnMetadata.prepareIncrementProducerEpoch(30000, Some((lastProducerEpoch - 1).toShort), val result = txnMetadata.prepareIncrementProducerEpoch(30000, Some((lastProducerEpoch - 1).toShort),
time.milliseconds()) time.milliseconds())
@ -503,19 +619,21 @@ class TransactionMetadataTest {
assertEquals(Set.empty, unmatchedStates) assertEquals(Set.empty, unmatchedStates)
} }
private def testRotateProducerIdInOngoingState(state: TransactionState): Unit = { private def testRotateProducerIdInOngoingState(state: TransactionState, clientTransactionVersion: TransactionVersion): Unit = {
val producerEpoch = (Short.MaxValue - 1).toShort val producerEpoch = (Short.MaxValue - 1).toShort
val txnMetadata = new TransactionMetadata( val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId, transactionalId = transactionalId,
producerId = producerId, producerId = producerId,
lastProducerId = producerId, previousProducerId = producerId,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch, producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000, txnTimeoutMs = 30000,
state = state, state = state,
topicPartitions = mutable.Set.empty, topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds()) txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = clientTransactionVersion)
val newProducerId = 9893L val newProducerId = 9893L
txnMetadata.prepareProducerIdRotation(newProducerId, 30000, time.milliseconds(), recordLastEpoch = false) txnMetadata.prepareProducerIdRotation(newProducerId, 30000, time.milliseconds(), recordLastEpoch = false)
} }

View File

@ -35,6 +35,7 @@ import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse
import org.apache.kafka.common.requests.TransactionResult import org.apache.kafka.common.requests.TransactionResult
import org.apache.kafka.common.utils.MockTime import org.apache.kafka.common.utils.MockTime
import org.apache.kafka.server.common.{FinalizedFeatures, MetadataVersion, RequestLocal, TransactionVersion} import org.apache.kafka.server.common.{FinalizedFeatures, MetadataVersion, RequestLocal, TransactionVersion}
import org.apache.kafka.server.common.TransactionVersion.{TV_0, TV_2}
import org.apache.kafka.coordinator.transaction.generated.TransactionLogKey import org.apache.kafka.coordinator.transaction.generated.TransactionLogKey
import org.apache.kafka.server.util.MockScheduler import org.apache.kafka.server.util.MockScheduler
import org.apache.kafka.storage.internals.log.{AppendOrigin, FetchDataInfo, FetchIsolation, LogConfig, LogOffsetMetadata} import org.apache.kafka.storage.internals.log.{AppendOrigin, FetchDataInfo, FetchIsolation, LogConfig, LogOffsetMetadata}
@ -181,7 +182,7 @@ class TransactionStateManagerTest {
new TopicPartition("topic1", 0), new TopicPartition("topic1", 0),
new TopicPartition("topic1", 1))) new TopicPartition("topic1", 1)))
val records = MemoryRecords.withRecords(startOffset, Compression.NONE, val records = MemoryRecords.withRecords(startOffset, Compression.NONE,
new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), true))) new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2)))
// We create a latch which is awaited while the log is loading. This ensures that the deletion // We create a latch which is awaited while the log is loading. This ensures that the deletion
// is triggered before the loading returns // is triggered before the loading returns
@ -225,19 +226,19 @@ class TransactionStateManagerTest {
txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0), txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0),
new TopicPartition("topic1", 1))) new TopicPartition("topic1", 1)))
txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), true)) txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2))
// pid1's transaction adds three more partitions // pid1's transaction adds three more partitions
txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic2", 0), txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic2", 0),
new TopicPartition("topic2", 1), new TopicPartition("topic2", 1),
new TopicPartition("topic2", 2))) new TopicPartition("topic2", 2)))
txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), true)) txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2))
// pid1's transaction is preparing to commit // pid1's transaction is preparing to commit
txnMetadata1.state = PrepareCommit txnMetadata1.state = PrepareCommit
txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), true)) txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2))
// pid2's transaction started with three partitions // pid2's transaction started with three partitions
txnMetadata2.state = Ongoing txnMetadata2.state = Ongoing
@ -245,23 +246,23 @@ class TransactionStateManagerTest {
new TopicPartition("topic3", 1), new TopicPartition("topic3", 1),
new TopicPartition("topic3", 2))) new TopicPartition("topic3", 2)))
txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), true)) txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), TV_2))
// pid2's transaction is preparing to abort // pid2's transaction is preparing to abort
txnMetadata2.state = PrepareAbort txnMetadata2.state = PrepareAbort
txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), true)) txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), TV_2))
// pid2's transaction has aborted // pid2's transaction has aborted
txnMetadata2.state = CompleteAbort txnMetadata2.state = CompleteAbort
txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), true)) txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), TV_2))
// pid2's epoch has advanced, with no ongoing transaction yet // pid2's epoch has advanced, with no ongoing transaction yet
txnMetadata2.state = Empty txnMetadata2.state = Empty
txnMetadata2.topicPartitions.clear() txnMetadata2.topicPartitions.clear()
txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), true)) txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), TV_2))
val startOffset = 15L // it should work for any start offset val startOffset = 15L // it should work for any start offset
val records = MemoryRecords.withRecords(startOffset, Compression.NONE, txnRecords.toArray: _*) val records = MemoryRecords.withRecords(startOffset, Compression.NONE, txnRecords.toArray: _*)
@ -796,14 +797,9 @@ class TransactionStateManagerTest {
// write the change. If the write fails (e.g. under min isr), the TransactionMetadata // write the change. If the write fails (e.g. under min isr), the TransactionMetadata
// is left at it is. If the transactional id is never reused, the TransactionMetadata // is left at it is. If the transactional id is never reused, the TransactionMetadata
// will be expired and it should succeed. // will be expired and it should succeed.
val txnMetadata = TransactionMetadata( val timestamp = time.milliseconds()
transactionalId = transactionalId, val txnMetadata = new TransactionMetadata(transactionalId, 1, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH,
producerId = 1, RecordBatch.NO_PRODUCER_EPOCH, transactionTimeoutMs, Empty, collection.mutable.Set.empty[TopicPartition], timestamp, timestamp, TV_0)
producerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = transactionTimeoutMs,
state = Empty,
timestamp = time.milliseconds()
)
transactionManager.putTransactionStateIfNotExists(txnMetadata) transactionManager.putTransactionStateIfNotExists(txnMetadata)
time.sleep(txnConfig.transactionalIdExpirationMs + 1) time.sleep(txnConfig.transactionalIdExpirationMs + 1)
@ -890,7 +886,7 @@ class TransactionStateManagerTest {
txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0), txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0),
new TopicPartition("topic1", 1))) new TopicPartition("topic1", 1)))
txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), true)) txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2))
val startOffset = 0L val startOffset = 0L
val records = MemoryRecords.withRecords(startOffset, Compression.NONE, txnRecords.toArray: _*) val records = MemoryRecords.withRecords(startOffset, Compression.NONE, txnRecords.toArray: _*)
@ -1053,7 +1049,7 @@ class TransactionStateManagerTest {
txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0), txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0),
new TopicPartition("topic1", 1))) new TopicPartition("topic1", 1)))
txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), true)) txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2))
val startOffset = 0L val startOffset = 0L
val records = MemoryRecords.withRecords(startOffset, Compression.NONE, txnRecords.toArray: _*) val records = MemoryRecords.withRecords(startOffset, Compression.NONE, txnRecords.toArray: _*)
@ -1081,7 +1077,9 @@ class TransactionStateManagerTest {
producerId: Long, producerId: Long,
state: TransactionState = Empty, state: TransactionState = Empty,
txnTimeout: Int = transactionTimeoutMs): TransactionMetadata = { txnTimeout: Int = transactionTimeoutMs): TransactionMetadata = {
TransactionMetadata(transactionalId, producerId, 0.toShort, txnTimeout, state, time.milliseconds()) val timestamp = time.milliseconds()
new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_ID, 0.toShort,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeout, state, collection.mutable.Set.empty[TopicPartition], timestamp, timestamp, TV_0)
} }
private def prepareTxnLog(topicPartition: TopicPartition, private def prepareTxnLog(topicPartition: TopicPartition,
@ -1159,7 +1157,7 @@ class TransactionStateManagerTest {
txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 1), txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 1),
new TopicPartition("topic1", 1))) new TopicPartition("topic1", 1)))
txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), true)) txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2))
val startOffset = 15L val startOffset = 15L
val records = MemoryRecords.withRecords(startOffset, Compression.NONE, txnRecords.toArray: _*) val records = MemoryRecords.withRecords(startOffset, Compression.NONE, txnRecords.toArray: _*)
@ -1178,7 +1176,7 @@ class TransactionStateManagerTest {
txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0), txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0),
new TopicPartition("topic1", 1))) new TopicPartition("topic1", 1)))
txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), true)) txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2))
val startOffset = 0L val startOffset = 0L
val unknownKey = new TransactionLogKey() val unknownKey = new TransactionLogKey()
@ -1199,7 +1197,7 @@ class TransactionStateManagerTest {
val txnMetadata = txnMetadataPool.get(transactionalId1) val txnMetadata = txnMetadataPool.get(transactionalId1)
assertEquals(txnMetadata1.transactionalId, txnMetadata.transactionalId) assertEquals(txnMetadata1.transactionalId, txnMetadata.transactionalId)
assertEquals(txnMetadata1.producerId, txnMetadata.producerId) assertEquals(txnMetadata1.producerId, txnMetadata.producerId)
assertEquals(txnMetadata1.lastProducerId, txnMetadata.lastProducerId) assertEquals(txnMetadata1.previousProducerId, txnMetadata.previousProducerId)
assertEquals(txnMetadata1.producerEpoch, txnMetadata.producerEpoch) assertEquals(txnMetadata1.producerEpoch, txnMetadata.producerEpoch)
assertEquals(txnMetadata1.lastProducerEpoch, txnMetadata.lastProducerEpoch) assertEquals(txnMetadata1.lastProducerEpoch, txnMetadata.lastProducerEpoch)
assertEquals(txnMetadata1.txnTimeoutMs, txnMetadata.txnTimeoutMs) assertEquals(txnMetadata1.txnTimeoutMs, txnMetadata.txnTimeoutMs)
@ -1210,7 +1208,7 @@ class TransactionStateManagerTest {
@ParameterizedTest @ParameterizedTest
@EnumSource(classOf[TransactionVersion]) @EnumSource(classOf[TransactionVersion])
def testUsesFlexibleRecords(transactionVersion: TransactionVersion): Unit = { def testTransactionVersionInTransactionManager(transactionVersion: TransactionVersion): Unit = {
val metadataCache = mock(classOf[MetadataCache]) val metadataCache = mock(classOf[MetadataCache])
when(metadataCache.features()).thenReturn { when(metadataCache.features()).thenReturn {
new FinalizedFeatures( new FinalizedFeatures(
@ -1223,7 +1221,6 @@ class TransactionStateManagerTest {
val transactionManager = new TransactionStateManager(0, scheduler, val transactionManager = new TransactionStateManager(0, scheduler,
replicaManager, metadataCache, txnConfig, time, metrics) replicaManager, metadataCache, txnConfig, time, metrics)
val expectFlexibleRecords = transactionVersion.featureLevel > 0 assertEquals(transactionVersion, transactionManager.transactionVersionLevel())
assertEquals(expectFlexibleRecords, transactionManager.usesFlexibleRecords())
} }
} }

View File

@ -85,7 +85,7 @@ import org.apache.kafka.security.authorizer.AclEntry
import org.apache.kafka.server.ClientMetricsManager import org.apache.kafka.server.ClientMetricsManager
import org.apache.kafka.server.authorizer.{Action, AuthorizationResult, Authorizer} import org.apache.kafka.server.authorizer.{Action, AuthorizationResult, Authorizer}
import org.apache.kafka.server.common.MetadataVersion.{IBP_0_10_2_IV0, IBP_2_2_IV1} import org.apache.kafka.server.common.MetadataVersion.{IBP_0_10_2_IV0, IBP_2_2_IV1}
import org.apache.kafka.server.common.{FeatureVersion, FinalizedFeatures, GroupVersion, KRaftVersion, MetadataVersion, RequestLocal} import org.apache.kafka.server.common.{FeatureVersion, FinalizedFeatures, GroupVersion, KRaftVersion, MetadataVersion, RequestLocal, TransactionVersion}
import org.apache.kafka.server.config.{ConfigType, KRaftConfigs, ReplicationConfigs, ServerConfigs, ServerLogConfigs, ShareGroupConfig} import org.apache.kafka.server.config.{ConfigType, KRaftConfigs, ReplicationConfigs, ServerConfigs, ServerLogConfigs, ShareGroupConfig}
import org.apache.kafka.server.metrics.ClientMetricsTestUtils import org.apache.kafka.server.metrics.ClientMetricsTestUtils
import org.apache.kafka.server.share.{CachedSharePartition, ErroneousAndValidPartitionData} import org.apache.kafka.server.share.{CachedSharePartition, ErroneousAndValidPartitionData}
@ -2572,7 +2572,7 @@ class KafkaApisTest extends Logging {
reset(replicaManager, clientRequestQuotaManager, requestChannel, txnCoordinator) reset(replicaManager, clientRequestQuotaManager, requestChannel, txnCoordinator)
val capturedResponse: ArgumentCaptor[EndTxnResponse] = ArgumentCaptor.forClass(classOf[EndTxnResponse]) val capturedResponse: ArgumentCaptor[EndTxnResponse] = ArgumentCaptor.forClass(classOf[EndTxnResponse])
val responseCallback: ArgumentCaptor[Errors => Unit] = ArgumentCaptor.forClass(classOf[Errors => Unit]) val responseCallback: ArgumentCaptor[(Errors, Long, Short) => Unit] = ArgumentCaptor.forClass(classOf[(Errors, Long, Short) => Unit])
val transactionalId = "txnId" val transactionalId = "txnId"
val producerId = 15L val producerId = 15L
@ -2587,15 +2587,18 @@ class KafkaApisTest extends Logging {
).build(version.toShort) ).build(version.toShort)
val request = buildRequest(endTxnRequest) val request = buildRequest(endTxnRequest)
val clientTransactionVersion = if (version > 4) TransactionVersion.TV_2 else TransactionVersion.TV_0
val requestLocal = RequestLocal.withThreadConfinedCaching val requestLocal = RequestLocal.withThreadConfinedCaching
when(txnCoordinator.handleEndTransaction( when(txnCoordinator.handleEndTransaction(
ArgumentMatchers.eq(transactionalId), ArgumentMatchers.eq(transactionalId),
ArgumentMatchers.eq(producerId), ArgumentMatchers.eq(producerId),
ArgumentMatchers.eq(epoch), ArgumentMatchers.eq(epoch),
ArgumentMatchers.eq(TransactionResult.COMMIT), ArgumentMatchers.eq(TransactionResult.COMMIT),
ArgumentMatchers.eq(clientTransactionVersion),
responseCallback.capture(), responseCallback.capture(),
ArgumentMatchers.eq(requestLocal) ArgumentMatchers.eq(requestLocal)
)).thenAnswer(_ => responseCallback.getValue.apply(Errors.PRODUCER_FENCED)) )).thenAnswer(_ => responseCallback.getValue.apply(Errors.PRODUCER_FENCED, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH))
val kafkaApis = createKafkaApis() val kafkaApis = createKafkaApis()
try { try {
kafkaApis.handleEndTxnRequest(request, requestLocal) kafkaApis.handleEndTxnRequest(request, requestLocal)

View File

@ -49,6 +49,10 @@ public enum TransactionVersion implements FeatureVersion {
return featureLevel; return featureLevel;
} }
public static TransactionVersion fromFeatureLevel(short version) {
return (TransactionVersion) Features.TRANSACTION_VERSION.fromFeatureLevel(version, true);
}
@Override @Override
public String featureName() { public String featureName() {
return FEATURE_NAME; return FEATURE_NAME;
@ -63,4 +67,14 @@ public enum TransactionVersion implements FeatureVersion {
public Map<String, Short> dependencies() { public Map<String, Short> dependencies() {
return dependencies; return dependencies;
} }
// Transactions V1 enables log version 0 (flexible fields)
public short transactionLogValueVersion() {
return (short) (featureLevel >= 1 ? 1 : 0);
}
// Transactions V2 enables epoch bump on commit/abort.
public boolean supportsEpochBump() {
return featureLevel >= 2;
}
} }

View File

@ -24,6 +24,10 @@
"fields": [ "fields": [
{ "name": "ProducerId", "type": "int64", "versions": "0+", { "name": "ProducerId", "type": "int64", "versions": "0+",
"about": "Producer id in use by the transactional id"}, "about": "Producer id in use by the transactional id"},
{ "name": "PreviousProducerId", "type": "int64", "taggedVersions": "1+", "tag": 0, "default": -1,
"about": "Producer id used by the last committed transaction"},
{ "name": "NextProducerId", "type": "int64", "taggedVersions": "1+", "tag": 1, "default": -1,
"about": "Latest producer ID sent to the producer for the given transactional ID"},
{ "name": "ProducerEpoch", "type": "int16", "versions": "0+", { "name": "ProducerEpoch", "type": "int16", "versions": "0+",
"about": "Epoch associated with the producer id"}, "about": "Epoch associated with the producer id"},
{ "name": "TransactionTimeoutMs", "type": "int32", "versions": "0+", { "name": "TransactionTimeoutMs", "type": "int32", "versions": "0+",
@ -37,6 +41,8 @@
{ "name": "TransactionLastUpdateTimestampMs", "type": "int64", "versions": "0+", { "name": "TransactionLastUpdateTimestampMs", "type": "int64", "versions": "0+",
"about": "Time the transaction was last updated"}, "about": "Time the transaction was last updated"},
{ "name": "TransactionStartTimestampMs", "type": "int64", "versions": "0+", { "name": "TransactionStartTimestampMs", "type": "int64", "versions": "0+",
"about": "Time the transaction was started"} "about": "Time the transaction was started"},
{ "name": "ClientTransactionVersion", "type": "int16", "default": 0, "taggedVersions": "1+", "tag": 2,
"about": "The transaction version used by the client"}
] ]
} }