KAFKA-19087 Move TransactionState to transaction-coordinator module (#19568)

Move TransactionState to transaction-coordinator module and rewrite it
as Java.

Reviewers: Chia-Ping Tsai <chia7712@gmail.com>
This commit is contained in:
PoAn Yang 2025-05-08 10:51:51 -05:00 committed by GitHub
parent 98e535b524
commit 9e785cee8d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 480 additions and 447 deletions

View File

@ -27,7 +27,7 @@ import org.apache.kafka.common.protocol.Errors
import org.apache.kafka.common.record.RecordBatch
import org.apache.kafka.common.requests.{AddPartitionsToTxnResponse, TransactionResult}
import org.apache.kafka.common.utils.{LogContext, ProducerIdAndEpoch, Time}
import org.apache.kafka.coordinator.transaction.{ProducerIdManager, TransactionLogConfig, TransactionStateManagerConfig}
import org.apache.kafka.coordinator.transaction.{ProducerIdManager, TransactionLogConfig, TransactionState, TransactionStateManagerConfig}
import org.apache.kafka.metadata.MetadataCache
import org.apache.kafka.server.common.{RequestLocal, TransactionVersion}
import org.apache.kafka.server.util.Scheduler
@ -154,7 +154,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
producerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = resolvedTxnTimeoutMs,
state = Empty,
state = TransactionState.EMPTY,
topicPartitions = collection.mutable.Set.empty[TopicPartition],
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TransactionVersion.TV_0)
@ -182,7 +182,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
responseCallback(initTransactionError(error))
case Right((coordinatorEpoch, newMetadata)) =>
if (newMetadata.txnState == PrepareEpochFence) {
if (newMetadata.txnState == TransactionState.PREPARE_EPOCH_FENCE) {
// abort the ongoing transaction and then return CONCURRENT_TRANSACTIONS to let client wait and retry
def sendRetriableErrorCallback(error: Errors, newProducerId: Long, newProducerEpoch: Short): Unit = {
if (error != Errors.NONE) {
@ -249,11 +249,11 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
} else {
// caller should have synchronized on txnMetadata already
txnMetadata.state match {
case PrepareAbort | PrepareCommit =>
case TransactionState.PREPARE_ABORT | TransactionState.PREPARE_COMMIT =>
// reply to client and let it backoff and retry
Left(Errors.CONCURRENT_TRANSACTIONS)
case CompleteAbort | CompleteCommit | Empty =>
case TransactionState.COMPLETE_ABORT | TransactionState.COMPLETE_COMMIT | TransactionState.EMPTY =>
val transitMetadataResult =
// If the epoch is exhausted and the expected epoch (if provided) matches it, generate a new producer ID
if (txnMetadata.isProducerEpochExhausted &&
@ -274,7 +274,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
case Left(err) => Left(err)
}
case Ongoing =>
case TransactionState.ONGOING =>
// indicate to abort the current ongoing txn first. Note that this epoch is never returned to the
// user. We will abort the ongoing transaction and return CONCURRENT_TRANSACTIONS to the client.
// This forces the client to retry, which will ensure that the epoch is bumped a second time. In
@ -282,7 +282,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
// then when the client retries, we will generate a new producerId.
Right(coordinatorEpoch, txnMetadata.prepareFenceProducerEpoch())
case Dead | PrepareEpochFence =>
case TransactionState.DEAD | TransactionState.PREPARE_EPOCH_FENCE =>
val errorMsg = s"Found transactionalId $transactionalId with state ${txnMetadata.state}. " +
s"This is illegal as we should never have transitioned to this state."
fatal(errorMsg)
@ -327,7 +327,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
case Right(Some(coordinatorEpochAndMetadata)) =>
val txnMetadata = coordinatorEpochAndMetadata.transactionMetadata
txnMetadata.inLock {
if (txnMetadata.state == Dead) {
if (txnMetadata.state == TransactionState.DEAD) {
// The transaction state is being expired, so ignore it
transactionState.setErrorCode(Errors.TRANSACTIONAL_ID_NOT_FOUND.code)
} else {
@ -345,7 +345,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
.setErrorCode(Errors.NONE.code)
.setProducerId(txnMetadata.producerId)
.setProducerEpoch(txnMetadata.producerEpoch)
.setTransactionState(txnMetadata.state.name)
.setTransactionState(txnMetadata.state.stateName)
.setTransactionTimeoutMs(txnMetadata.txnTimeoutMs)
.setTransactionStartTimeMs(txnMetadata.txnStartTimestamp)
}
@ -378,7 +378,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
Left(Errors.INVALID_PRODUCER_ID_MAPPING)
} else if (txnMetadata.producerEpoch != producerEpoch) {
Left(Errors.PRODUCER_FENCED)
} else if (txnMetadata.state == PrepareCommit || txnMetadata.state == PrepareAbort) {
} else if (txnMetadata.state == TransactionState.PREPARE_COMMIT || txnMetadata.state == TransactionState.PREPARE_ABORT) {
Left(Errors.CONCURRENT_TRANSACTIONS)
} else {
Right(partitions.map { part =>
@ -435,9 +435,9 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
Left(Errors.INVALID_PRODUCER_ID_MAPPING)
} else if (txnMetadata.producerEpoch != producerEpoch) {
Left(Errors.PRODUCER_FENCED)
} else if (txnMetadata.state == PrepareCommit || txnMetadata.state == PrepareAbort) {
} else if (txnMetadata.state == TransactionState.PREPARE_COMMIT || txnMetadata.state == TransactionState.PREPARE_ABORT) {
Left(Errors.CONCURRENT_TRANSACTIONS)
} else if (txnMetadata.state == Ongoing && partitions.subsetOf(txnMetadata.topicPartitions)) {
} else if (txnMetadata.state == TransactionState.ONGOING && partitions.subsetOf(txnMetadata.topicPartitions)) {
// this is an optimization: if the partitions are already in the metadata reply OK immediately
Left(Errors.NONE)
} else {
@ -555,16 +555,16 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
// Strict equality is enforced on the client side requests, as they shouldn't bump the producer epoch.
else if ((isFromClient && producerEpoch != txnMetadata.producerEpoch) || producerEpoch < txnMetadata.producerEpoch)
Left(Errors.PRODUCER_FENCED)
else if (txnMetadata.pendingTransitionInProgress && txnMetadata.pendingState.get != PrepareEpochFence)
else if (txnMetadata.pendingTransitionInProgress && txnMetadata.pendingState.get != TransactionState.PREPARE_EPOCH_FENCE)
Left(Errors.CONCURRENT_TRANSACTIONS)
else txnMetadata.state match {
case Ongoing =>
case TransactionState.ONGOING =>
val nextState = if (txnMarkerResult == TransactionResult.COMMIT)
PrepareCommit
TransactionState.PREPARE_COMMIT
else
PrepareAbort
TransactionState.PREPARE_ABORT
if (nextState == PrepareAbort && txnMetadata.pendingState.contains(PrepareEpochFence)) {
if (nextState == TransactionState.PREPARE_ABORT && txnMetadata.pendingState.contains(TransactionState.PREPARE_EPOCH_FENCE)) {
// We should clear the pending state to make way for the transition to PrepareAbort and also bump
// the epoch in the transaction metadata we are about to append.
isEpochFence = true
@ -574,29 +574,29 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
}
Right(coordinatorEpoch, txnMetadata.prepareAbortOrCommit(nextState, TransactionVersion.fromFeatureLevel(0), RecordBatch.NO_PRODUCER_ID, time.milliseconds(), false))
case CompleteCommit =>
case TransactionState.COMPLETE_COMMIT =>
if (txnMarkerResult == TransactionResult.COMMIT)
Left(Errors.NONE)
else
logInvalidStateTransitionAndReturnError(transactionalId, txnMetadata.state, txnMarkerResult)
case CompleteAbort =>
case TransactionState.COMPLETE_ABORT =>
if (txnMarkerResult == TransactionResult.ABORT)
Left(Errors.NONE)
else
logInvalidStateTransitionAndReturnError(transactionalId, txnMetadata.state, txnMarkerResult)
case PrepareCommit =>
case TransactionState.PREPARE_COMMIT =>
if (txnMarkerResult == TransactionResult.COMMIT)
Left(Errors.CONCURRENT_TRANSACTIONS)
else
logInvalidStateTransitionAndReturnError(transactionalId, txnMetadata.state, txnMarkerResult)
case PrepareAbort =>
case TransactionState.PREPARE_ABORT =>
if (txnMarkerResult == TransactionResult.ABORT)
Left(Errors.CONCURRENT_TRANSACTIONS)
else
logInvalidStateTransitionAndReturnError(transactionalId, txnMetadata.state, txnMarkerResult)
case Empty =>
case TransactionState.EMPTY =>
logInvalidStateTransitionAndReturnError(transactionalId, txnMetadata.state, txnMarkerResult)
case Dead | PrepareEpochFence =>
case TransactionState.DEAD | TransactionState.PREPARE_EPOCH_FENCE =>
val errorMsg = s"Found transactionalId $transactionalId with state ${txnMetadata.state}. " +
s"This is illegal as we should never have transitioned to this state."
fatal(errorMsg)
@ -631,19 +631,19 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
else if (txnMetadata.pendingTransitionInProgress)
Left(Errors.CONCURRENT_TRANSACTIONS)
else txnMetadata.state match {
case Empty| Ongoing | CompleteCommit | CompleteAbort =>
case TransactionState.EMPTY| TransactionState.ONGOING | TransactionState.COMPLETE_COMMIT | TransactionState.COMPLETE_ABORT =>
logInvalidStateTransitionAndReturnError(transactionalId, txnMetadata.state, txnMarkerResult)
case PrepareCommit =>
case TransactionState.PREPARE_COMMIT =>
if (txnMarkerResult != TransactionResult.COMMIT)
logInvalidStateTransitionAndReturnError(transactionalId, txnMetadata.state, txnMarkerResult)
else
Right(txnMetadata, txnMetadata.prepareComplete(time.milliseconds()))
case PrepareAbort =>
case TransactionState.PREPARE_ABORT =>
if (txnMarkerResult != TransactionResult.ABORT)
logInvalidStateTransitionAndReturnError(transactionalId, txnMetadata.state, txnMarkerResult)
else
Right(txnMetadata, txnMetadata.prepareComplete(time.milliseconds()))
case Dead | PrepareEpochFence =>
case TransactionState.DEAD | TransactionState.PREPARE_EPOCH_FENCE =>
val errorMsg = s"Found transactionalId $transactionalId with state ${txnMetadata.state}. " +
s"This is illegal as we should never have transitioned to this state."
fatal(errorMsg)
@ -776,7 +776,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
producerEpochCopy = txnMetadata.producerEpoch
// PrepareEpochFence has slightly different epoch bumping logic so don't include it here.
// Note that, it can only happen when the current state is Ongoing.
isEpochFence = txnMetadata.pendingState.contains(PrepareEpochFence)
isEpochFence = txnMetadata.pendingState.contains(TransactionState.PREPARE_EPOCH_FENCE)
// True if the client retried a request that had overflowed the epoch, and a new producer ID is stored in the txnMetadata
val retryOnOverflow = !isEpochFence && txnMetadata.prevProducerId == producerId &&
producerEpoch == Short.MaxValue - 1 && txnMetadata.producerEpoch == 0
@ -790,11 +790,11 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
// Return producer fenced even in the cases where the epoch is higher and could indicate an invalid state transition.
// Use the following criteria to determine if a v2 retry is valid:
txnMetadata.state match {
case Ongoing | Empty | Dead | PrepareEpochFence =>
case TransactionState.ONGOING | TransactionState.EMPTY | TransactionState.DEAD | TransactionState.PREPARE_EPOCH_FENCE =>
producerEpoch == txnMetadata.producerEpoch
case PrepareCommit | PrepareAbort =>
case TransactionState.PREPARE_COMMIT | TransactionState.PREPARE_ABORT =>
retryOnEpochBump
case CompleteCommit | CompleteAbort =>
case TransactionState.COMPLETE_COMMIT | TransactionState.COMPLETE_ABORT =>
retryOnEpochBump || retryOnOverflow || producerEpoch == txnMetadata.producerEpoch
}
} else {
@ -818,7 +818,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
Right(RecordBatch.NO_PRODUCER_ID)
}
if (nextState == PrepareAbort && isEpochFence) {
if (nextState == TransactionState.PREPARE_ABORT && isEpochFence) {
// We should clear the pending state to make way for the transition to PrepareAbort and also bump
// the epoch in the transaction metadata we are about to append.
txnMetadata.pendingState = None
@ -832,7 +832,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
}
}
if (txnMetadata.pendingTransitionInProgress && txnMetadata.pendingState.get != PrepareEpochFence) {
if (txnMetadata.pendingTransitionInProgress && txnMetadata.pendingState.get != TransactionState.PREPARE_EPOCH_FENCE) {
// This check is performed first so that the pending transition can complete before the next checks.
// With TV2, we may be transitioning over a producer epoch overflow, and the producer may be using the
// new producer ID that is still only in pending state.
@ -842,14 +842,14 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
else if (!isValidEpoch)
Left(Errors.PRODUCER_FENCED)
else txnMetadata.state match {
case Ongoing =>
case TransactionState.ONGOING =>
val nextState = if (txnMarkerResult == TransactionResult.COMMIT)
PrepareCommit
TransactionState.PREPARE_COMMIT
else
PrepareAbort
TransactionState.PREPARE_ABORT
generateTxnTransitMetadataForTxnCompletion(nextState, false)
case CompleteCommit =>
case TransactionState.COMPLETE_COMMIT =>
if (txnMarkerResult == TransactionResult.COMMIT) {
if (isRetry)
Left(Errors.NONE)
@ -860,35 +860,35 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
if (isRetry)
logInvalidStateTransitionAndReturnError(transactionalId, txnMetadata.state, txnMarkerResult)
else
generateTxnTransitMetadataForTxnCompletion(PrepareAbort, true)
generateTxnTransitMetadataForTxnCompletion(TransactionState.PREPARE_ABORT, true)
}
case CompleteAbort =>
case TransactionState.COMPLETE_ABORT =>
if (txnMarkerResult == TransactionResult.ABORT) {
if (isRetry)
Left(Errors.NONE)
else
generateTxnTransitMetadataForTxnCompletion(PrepareAbort, true)
generateTxnTransitMetadataForTxnCompletion(TransactionState.PREPARE_ABORT, true)
} else {
// Commit.
logInvalidStateTransitionAndReturnError(transactionalId, txnMetadata.state, txnMarkerResult)
}
case PrepareCommit =>
case TransactionState.PREPARE_COMMIT =>
if (txnMarkerResult == TransactionResult.COMMIT)
Left(Errors.CONCURRENT_TRANSACTIONS)
else
logInvalidStateTransitionAndReturnError(transactionalId, txnMetadata.state, txnMarkerResult)
case PrepareAbort =>
case TransactionState.PREPARE_ABORT =>
if (txnMarkerResult == TransactionResult.ABORT)
Left(Errors.CONCURRENT_TRANSACTIONS)
else
logInvalidStateTransitionAndReturnError(transactionalId, txnMetadata.state, txnMarkerResult)
case Empty =>
case TransactionState.EMPTY =>
if (txnMarkerResult == TransactionResult.ABORT) {
generateTxnTransitMetadataForTxnCompletion(PrepareAbort, true)
generateTxnTransitMetadataForTxnCompletion(TransactionState.PREPARE_ABORT, true)
} else {
logInvalidStateTransitionAndReturnError(transactionalId, txnMetadata.state, txnMarkerResult)
}
case Dead | PrepareEpochFence =>
case TransactionState.DEAD | TransactionState.PREPARE_EPOCH_FENCE =>
val errorMsg = s"Found transactionalId $transactionalId with state ${txnMetadata.state}. " +
s"This is illegal as we should never have transitioned to this state."
fatal(errorMsg)
@ -928,19 +928,19 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
else if (txnMetadata.pendingTransitionInProgress)
Left(Errors.CONCURRENT_TRANSACTIONS)
else txnMetadata.state match {
case Empty| Ongoing | CompleteCommit | CompleteAbort =>
case TransactionState.EMPTY | TransactionState.ONGOING | TransactionState.COMPLETE_COMMIT | TransactionState.COMPLETE_ABORT =>
logInvalidStateTransitionAndReturnError(transactionalId, txnMetadata.state, txnMarkerResult)
case PrepareCommit =>
case TransactionState.PREPARE_COMMIT =>
if (txnMarkerResult != TransactionResult.COMMIT)
logInvalidStateTransitionAndReturnError(transactionalId, txnMetadata.state, txnMarkerResult)
else
Right(txnMetadata, txnMetadata.prepareComplete(time.milliseconds()))
case PrepareAbort =>
case TransactionState.PREPARE_ABORT =>
if (txnMarkerResult != TransactionResult.ABORT)
logInvalidStateTransitionAndReturnError(transactionalId, txnMetadata.state, txnMarkerResult)
else
Right(txnMetadata, txnMetadata.prepareComplete(time.milliseconds()))
case Dead | PrepareEpochFence =>
case TransactionState.DEAD | TransactionState.PREPARE_EPOCH_FENCE =>
val errorMsg = s"Found transactionalId $transactionalId with state ${txnMetadata.state}. " +
s"This is illegal as we should never have transitioned to this state."
fatal(errorMsg)

View File

@ -21,6 +21,7 @@ import org.apache.kafka.common.compress.Compression
import org.apache.kafka.common.protocol.{ByteBufferAccessor, MessageUtil}
import org.apache.kafka.common.record.RecordBatch
import org.apache.kafka.common.TopicPartition
import org.apache.kafka.coordinator.transaction.TransactionState
import org.apache.kafka.coordinator.transaction.generated.{CoordinatorRecordType, TransactionLogKey, TransactionLogValue}
import org.apache.kafka.server.common.TransactionVersion
@ -61,10 +62,10 @@ object TransactionLog {
*/
private[transaction] def valueToBytes(txnMetadata: TxnTransitMetadata,
transactionVersionLevel: TransactionVersion): Array[Byte] = {
if (txnMetadata.txnState == Empty && txnMetadata.topicPartitions.nonEmpty)
if (txnMetadata.txnState == TransactionState.EMPTY && txnMetadata.topicPartitions.nonEmpty)
throw new IllegalStateException(s"Transaction is not expected to have any partitions since its state is ${txnMetadata.txnState}: $txnMetadata")
val transactionPartitions = if (txnMetadata.txnState == Empty) null
val transactionPartitions = if (txnMetadata.txnState == TransactionState.EMPTY) null
else txnMetadata.topicPartitions
.groupBy(_.topic)
.map { case (topic, partitions) =>
@ -127,7 +128,7 @@ object TransactionLog {
txnLastUpdateTimestamp = value.transactionLastUpdateTimestampMs,
clientTransactionVersion = TransactionVersion.fromFeatureLevel(value.clientTransactionVersion))
if (!transactionMetadata.state.equals(Empty))
if (!transactionMetadata.state.equals(TransactionState.EMPTY))
value.transactionPartitions.forEach(partitionsSchema =>
transactionMetadata.addPartitions(partitionsSchema.partitionIds
.asScala

View File

@ -21,151 +21,11 @@ import kafka.utils.{CoreUtils, Logging, nonthreadsafe}
import org.apache.kafka.common.TopicPartition
import org.apache.kafka.common.protocol.Errors
import org.apache.kafka.common.record.RecordBatch
import org.apache.kafka.coordinator.transaction.TransactionState
import org.apache.kafka.server.common.TransactionVersion
import scala.collection.{immutable, mutable}
object TransactionState {
val AllStates: Set[TransactionState] = Set(
Empty,
Ongoing,
PrepareCommit,
PrepareAbort,
CompleteCommit,
CompleteAbort,
Dead,
PrepareEpochFence
)
def fromName(name: String): Option[TransactionState] = {
AllStates.find(_.name == name)
}
def fromId(id: Byte): TransactionState = {
id match {
case 0 => Empty
case 1 => Ongoing
case 2 => PrepareCommit
case 3 => PrepareAbort
case 4 => CompleteCommit
case 5 => CompleteAbort
case 6 => Dead
case 7 => PrepareEpochFence
case _ => throw new IllegalStateException(s"Unknown transaction state id $id from the transaction status message")
}
}
}
private[transaction] sealed trait TransactionState {
def id: Byte
/**
* Get the name of this state. This is exposed through the `DescribeTransactions` API.
*/
def name: String
def validPreviousStates: Set[TransactionState]
def isExpirationAllowed: Boolean = false
}
/**
* Transaction has not existed yet
*
* transition: received AddPartitionsToTxnRequest => Ongoing
* received AddOffsetsToTxnRequest => Ongoing
* received EndTxnRequest with abort and TransactionV2 enabled => PrepareAbort
*/
private[transaction] case object Empty extends TransactionState {
val id: Byte = 0
val name: String = "Empty"
val validPreviousStates: Set[TransactionState] = Set(Empty, CompleteCommit, CompleteAbort)
override def isExpirationAllowed: Boolean = true
}
/**
* Transaction has started and ongoing
*
* transition: received EndTxnRequest with commit => PrepareCommit
* received EndTxnRequest with abort => PrepareAbort
* received AddPartitionsToTxnRequest => Ongoing
* received AddOffsetsToTxnRequest => Ongoing
*/
private[transaction] case object Ongoing extends TransactionState {
val id: Byte = 1
val name: String = "Ongoing"
val validPreviousStates: Set[TransactionState] = Set(Ongoing, Empty, CompleteCommit, CompleteAbort)
}
/**
* Group is preparing to commit
*
* transition: received acks from all partitions => CompleteCommit
*/
private[transaction] case object PrepareCommit extends TransactionState {
val id: Byte = 2
val name: String = "PrepareCommit"
val validPreviousStates: Set[TransactionState] = Set(Ongoing)
}
/**
* Group is preparing to abort
*
* transition: received acks from all partitions => CompleteAbort
*
* Note, In transaction v2, we allow Empty, CompleteCommit, CompleteAbort to transition to PrepareAbort. because the
* client may not know the txn state on the server side, it needs to send endTxn request when uncertain.
*/
private[transaction] case object PrepareAbort extends TransactionState {
val id: Byte = 3
val name: String = "PrepareAbort"
val validPreviousStates: Set[TransactionState] = Set(Ongoing, PrepareEpochFence, Empty, CompleteCommit, CompleteAbort)
}
/**
* Group has completed commit
*
* Will soon be removed from the ongoing transaction cache
*/
private[transaction] case object CompleteCommit extends TransactionState {
val id: Byte = 4
val name: String = "CompleteCommit"
val validPreviousStates: Set[TransactionState] = Set(PrepareCommit)
override def isExpirationAllowed: Boolean = true
}
/**
* Group has completed abort
*
* Will soon be removed from the ongoing transaction cache
*/
private[transaction] case object CompleteAbort extends TransactionState {
val id: Byte = 5
val name: String = "CompleteAbort"
val validPreviousStates: Set[TransactionState] = Set(PrepareAbort)
override def isExpirationAllowed: Boolean = true
}
/**
* TransactionalId has expired and is about to be removed from the transaction cache
*/
private[transaction] case object Dead extends TransactionState {
val id: Byte = 6
val name: String = "Dead"
val validPreviousStates: Set[TransactionState] = Set(Empty, CompleteAbort, CompleteCommit)
}
/**
* We are in the middle of bumping the epoch and fencing out older producers.
*/
private[transaction] case object PrepareEpochFence extends TransactionState {
val id: Byte = 7
val name: String = "PrepareEpochFence"
val validPreviousStates: Set[TransactionState] = Set(Ongoing)
}
private[transaction] object TransactionMetadata {
def isEpochExhausted(producerEpoch: Short): Boolean = producerEpoch >= Short.MaxValue - 1
}
@ -244,7 +104,7 @@ private[transaction] class TransactionMetadata(val transactionalId: String,
}
def removePartition(topicPartition: TopicPartition): Unit = {
if (state != PrepareCommit && state != PrepareAbort)
if (state != TransactionState.PREPARE_COMMIT && state != TransactionState.PREPARE_ABORT)
throw new IllegalStateException(s"Transaction metadata's current state is $state, and its pending state is $pendingState " +
s"while trying to remove partitions whose txn marker has been sent, this is not expected")
@ -267,7 +127,7 @@ private[transaction] class TransactionMetadata(val transactionalId: String,
val bumpedEpoch = if (hasFailedEpochFence) producerEpoch else (producerEpoch + 1).toShort
prepareTransitionTo(
state = PrepareEpochFence,
state = TransactionState.PREPARE_EPOCH_FENCE,
producerEpoch = bumpedEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH
)
@ -309,7 +169,7 @@ private[transaction] class TransactionMetadata(val transactionalId: String,
epochBumpResult match {
case Right((nextEpoch, lastEpoch)) => Right(prepareTransitionTo(
state = Empty,
state = TransactionState.EMPTY,
producerEpoch = nextEpoch,
lastProducerEpoch = lastEpoch,
txnTimeoutMs = newTxnTimeoutMs,
@ -330,7 +190,7 @@ private[transaction] class TransactionMetadata(val transactionalId: String,
throw new IllegalStateException("Cannot rotate producer ids while a transaction is still pending")
prepareTransitionTo(
state = Empty,
state = TransactionState.EMPTY,
producerId = newProducerId,
producerEpoch = 0,
lastProducerEpoch = if (recordLastEpoch) producerEpoch else RecordBatch.NO_PRODUCER_EPOCH,
@ -343,12 +203,12 @@ private[transaction] class TransactionMetadata(val transactionalId: String,
def prepareAddPartitions(addedTopicPartitions: immutable.Set[TopicPartition], updateTimestamp: Long, clientTransactionVersion: TransactionVersion): TxnTransitMetadata = {
val newTxnStartTimestamp = state match {
case Empty | CompleteAbort | CompleteCommit => updateTimestamp
case TransactionState.EMPTY | TransactionState.COMPLETE_ABORT | TransactionState.COMPLETE_COMMIT => updateTimestamp
case _ => txnStartTimestamp
}
prepareTransitionTo(
state = Ongoing,
state = TransactionState.ONGOING,
topicPartitions = (topicPartitions ++ addedTopicPartitions),
txnStartTimestamp = newTxnStartTimestamp,
txnLastUpdateTimestamp = updateTimestamp,
@ -379,7 +239,7 @@ private[transaction] class TransactionMetadata(val transactionalId: String,
}
def prepareComplete(updateTimestamp: Long): TxnTransitMetadata = {
val newState = if (state == PrepareCommit) CompleteCommit else CompleteAbort
val newState = if (state == TransactionState.PREPARE_COMMIT) TransactionState.COMPLETE_COMMIT else TransactionState.COMPLETE_ABORT
// Since the state change was successfully written to the log, unset the flag for a failed epoch fence
hasFailedEpochFence = false
@ -408,7 +268,7 @@ private[transaction] class TransactionMetadata(val transactionalId: String,
def prepareDead(): TxnTransitMetadata = {
prepareTransitionTo(
state = Dead,
state = TransactionState.DEAD,
topicPartitions = mutable.Set.empty[TopicPartition]
)
}
@ -427,7 +287,7 @@ private[transaction] class TransactionMetadata(val transactionalId: String,
private def hasPendingTransaction: Boolean = {
state match {
case Ongoing | PrepareAbort | PrepareCommit => true
case TransactionState.ONGOING | TransactionState.PREPARE_ABORT | TransactionState.PREPARE_COMMIT => true
case _ => false
}
}
@ -452,7 +312,7 @@ private[transaction] class TransactionMetadata(val transactionalId: String,
// The epoch is initialized to NO_PRODUCER_EPOCH when the TransactionMetadata
// is created for the first time and it could stay like this until transitioning
// to Dead.
if (state != Dead && producerEpoch < 0)
if (state != TransactionState.DEAD && producerEpoch < 0)
throw new IllegalArgumentException(s"Illegal new producer epoch $producerEpoch")
// check that the new state transition is valid and update the pending state if necessary
@ -492,7 +352,7 @@ private[transaction] class TransactionMetadata(val transactionalId: String,
throwStateTransitionFailure(transitMetadata)
} else {
toState match {
case Empty => // from initPid
case TransactionState.EMPTY => // from initPid
if ((producerEpoch != transitMetadata.producerEpoch && !validProducerEpochBump(transitMetadata)) ||
transitMetadata.topicPartitions.nonEmpty ||
transitMetadata.txnStartTimestamp != -1) {
@ -500,7 +360,7 @@ private[transaction] class TransactionMetadata(val transactionalId: String,
throwStateTransitionFailure(transitMetadata)
}
case Ongoing => // from addPartitions
case TransactionState.ONGOING => // from addPartitions
if (!validProducerEpoch(transitMetadata) ||
!topicPartitions.subsetOf(transitMetadata.topicPartitions) ||
txnTimeoutMs != transitMetadata.txnTimeoutMs) {
@ -508,11 +368,11 @@ private[transaction] class TransactionMetadata(val transactionalId: String,
throwStateTransitionFailure(transitMetadata)
}
case PrepareAbort | PrepareCommit => // from endTxn
case TransactionState.PREPARE_ABORT | TransactionState.PREPARE_COMMIT => // from endTxn
// In V2, we allow state transits from Empty, CompleteCommit and CompleteAbort to PrepareAbort. It is possible
// their updated start time is not equal to the current start time.
val allowedEmptyAbort = toState == PrepareAbort && transitMetadata.clientTransactionVersion.supportsEpochBump() &&
(state == Empty || state == CompleteCommit || state == CompleteAbort)
val allowedEmptyAbort = toState == TransactionState.PREPARE_ABORT && transitMetadata.clientTransactionVersion.supportsEpochBump() &&
(state == TransactionState.EMPTY || state == TransactionState.COMPLETE_COMMIT || state == TransactionState.COMPLETE_ABORT)
val validTimestamp = txnStartTimestamp == transitMetadata.txnStartTimestamp || allowedEmptyAbort
if (!validProducerEpoch(transitMetadata) ||
!topicPartitions.equals(transitMetadata.topicPartitions) ||
@ -521,14 +381,14 @@ private[transaction] class TransactionMetadata(val transactionalId: String,
throwStateTransitionFailure(transitMetadata)
}
case CompleteAbort | CompleteCommit => // from write markers
case TransactionState.COMPLETE_ABORT | TransactionState.COMPLETE_COMMIT => // from write markers
if (!validProducerEpoch(transitMetadata) ||
txnTimeoutMs != transitMetadata.txnTimeoutMs ||
transitMetadata.txnStartTimestamp == -1) {
throwStateTransitionFailure(transitMetadata)
}
case PrepareEpochFence =>
case TransactionState.PREPARE_EPOCH_FENCE =>
// We should never get here, since once we prepare to fence the epoch, we immediately set the pending state
// to PrepareAbort, and then consequently to CompleteAbort after the markers are written.. So we should never
// ever try to complete a transition to PrepareEpochFence, as it is not a valid previous state for any other state, and hence
@ -536,7 +396,7 @@ private[transaction] class TransactionMetadata(val transactionalId: String,
throwStateTransitionFailure(transitMetadata)
case Dead =>
case TransactionState.DEAD =>
// The transactionalId was being expired. The completion of the operation should result in removal of the
// the metadata from the cache, so we should never realistically transition to the dead state.
throw new IllegalStateException(s"TransactionalId $transactionalId is trying to complete a transition to " +
@ -590,11 +450,11 @@ private[transaction] class TransactionMetadata(val transactionalId: String,
val transitLastProducerEpoch = transitMetadata.lastProducerEpoch
(isAtLeastTransactionsV2, txnState, transitProducerEpoch) match {
case (true, CompleteCommit | CompleteAbort, epoch) if epoch == 0.toShort =>
case (true, TransactionState.COMPLETE_COMMIT | TransactionState.COMPLETE_ABORT, epoch) if epoch == 0.toShort =>
transitLastProducerEpoch == lastProducerEpoch &&
transitMetadata.prevProducerId == producerId
case (true, PrepareCommit | PrepareAbort, _) =>
case (true, TransactionState.PREPARE_COMMIT | TransactionState.PREPARE_ABORT, _) =>
transitLastProducerEpoch == producerEpoch &&
transitProducerId == producerId

View File

@ -35,7 +35,7 @@ import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse
import org.apache.kafka.common.requests.TransactionResult
import org.apache.kafka.common.utils.{Time, Utils}
import org.apache.kafka.common.{KafkaException, TopicIdPartition, TopicPartition}
import org.apache.kafka.coordinator.transaction.{TransactionLogConfig, TransactionStateManagerConfig}
import org.apache.kafka.coordinator.transaction.{TransactionLogConfig, TransactionState, TransactionStateManagerConfig}
import org.apache.kafka.metadata.MetadataCache
import org.apache.kafka.server.common.{RequestLocal, TransactionVersion}
import org.apache.kafka.server.config.ServerConfigs
@ -134,7 +134,7 @@ class TransactionStateManager(brokerId: Int,
false
} else {
txnMetadata.state match {
case Ongoing =>
case TransactionState.ONGOING =>
// Do not apply timeout to distributed two phase commit transactions.
(!txnMetadata.isDistributedTwoPhaseCommitTxn) &&
(txnMetadata.txnStartTimestamp + txnMetadata.txnTimeoutMs < now)
@ -265,7 +265,7 @@ class TransactionStateManager(brokerId: Int,
val txnMetadata = txnMetadataCacheEntry.metadataPerTransactionalId.get(transactionalId)
txnMetadata.inLock {
if (txnMetadataCacheEntry.coordinatorEpoch == idCoordinatorEpochAndMetadata.coordinatorEpoch
&& txnMetadata.pendingState.contains(Dead)
&& txnMetadata.pendingState.contains(TransactionState.DEAD)
&& txnMetadata.producerEpoch == idCoordinatorEpochAndMetadata.transitMetadata.producerEpoch
&& response.error == Errors.NONE) {
txnMetadataCacheEntry.metadataPerTransactionalId.remove(transactionalId)
@ -328,15 +328,15 @@ class TransactionStateManager(brokerId: Int,
} else {
val filterStates = mutable.Set.empty[TransactionState]
filterStateNames.foreach { stateName =>
TransactionState.fromName(stateName) match {
case Some(state) => filterStates += state
case None => response.unknownStateFilters.add(stateName)
}
TransactionState.fromName(stateName).ifPresentOrElse(
state => filterStates += state,
() => response.unknownStateFilters.add(stateName)
)
}
val now : Long = time.milliseconds()
def shouldInclude(txnMetadata: TransactionMetadata, pattern: Pattern): Boolean = {
if (txnMetadata.state == Dead) {
if (txnMetadata.state == TransactionState.DEAD) {
// We filter the `Dead` state since it is a transient state which
// indicates that the transactionalId and its metadata are in the
// process of expiration and removal.
@ -371,7 +371,7 @@ class TransactionStateManager(brokerId: Int,
states.add(new ListTransactionsResponseData.TransactionState()
.setTransactionalId(txnMetadata.transactionalId)
.setProducerId(txnMetadata.producerId)
.setTransactionState(txnMetadata.state.name)
.setTransactionState(txnMetadata.state.stateName)
)
}
}
@ -568,10 +568,10 @@ class TransactionStateManager(brokerId: Int,
txnMetadata.inLock {
// if state is PrepareCommit or PrepareAbort we need to complete the transaction
txnMetadata.state match {
case PrepareAbort =>
case TransactionState.PREPARE_ABORT =>
transactionsPendingForCompletion +=
TransactionalIdCoordinatorEpochAndTransitMetadata(transactionalId, coordinatorEpoch, TransactionResult.ABORT, txnMetadata, txnMetadata.prepareComplete(time.milliseconds()))
case PrepareCommit =>
case TransactionState.PREPARE_COMMIT =>
transactionsPendingForCompletion +=
TransactionalIdCoordinatorEpochAndTransitMetadata(transactionalId, coordinatorEpoch, TransactionResult.COMMIT, txnMetadata, txnMetadata.prepareComplete(time.milliseconds()))
case _ =>

View File

@ -28,8 +28,8 @@ import kafka.server.KafkaConfig
import kafka.utils.TestUtils
import org.apache.kafka.clients.{ClientResponse, NetworkClient}
import org.apache.kafka.common.internals.Topic
import org.apache.kafka.common.compress.Compression
import org.apache.kafka.common.internals.Topic.TRANSACTION_STATE_TOPIC_NAME
import org.apache.kafka.common.compress.Compression
import org.apache.kafka.common.metrics.Metrics
import org.apache.kafka.common.network.ListenerName
import org.apache.kafka.common.protocol.{ApiKeys, Errors}
@ -37,7 +37,7 @@ import org.apache.kafka.common.record.{FileRecords, MemoryRecords, RecordBatch,
import org.apache.kafka.common.requests._
import org.apache.kafka.common.utils.{LogContext, MockTime, ProducerIdAndEpoch}
import org.apache.kafka.common.{Node, TopicPartition, Uuid}
import org.apache.kafka.coordinator.transaction.ProducerIdManager
import org.apache.kafka.coordinator.transaction.{ProducerIdManager, TransactionState}
import org.apache.kafka.metadata.MetadataCache
import org.apache.kafka.server.common.{FinalizedFeatures, MetadataVersion, RequestLocal, TransactionVersion}
import org.apache.kafka.server.storage.log.FetchIsolation
@ -468,7 +468,7 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren
val txnMetadata = transactionMetadata(txn).getOrElse(throw new IllegalStateException(s"Transaction not found $txn"))
txnRecords += new SimpleRecord(txn.txnMessageKeyBytes, TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), TransactionVersion.TV_2))
txnMetadata.state = PrepareCommit
txnMetadata.state = TransactionState.PREPARE_COMMIT
txnRecords += new SimpleRecord(txn.txnMessageKeyBytes, TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), TransactionVersion.TV_2))
prepareTxnLog(partitionId)
@ -513,7 +513,7 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren
producerEpoch = (Short.MaxValue - 1).toShort,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 60000,
state = Empty,
state = TransactionState.EMPTY,
topicPartitions = collection.mutable.Set.empty[TopicPartition],
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TransactionVersion.TV_0)
@ -544,7 +544,7 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren
override def awaitAndVerify(txn: Transaction): Unit = {
val initPidResult = result.getOrElse(throw new IllegalStateException("InitProducerId has not completed"))
assertEquals(Errors.NONE, initPidResult.error)
verifyTransaction(txn, Empty)
verifyTransaction(txn, TransactionState.EMPTY)
}
}
@ -564,7 +564,7 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren
override def awaitAndVerify(txn: Transaction): Unit = {
val error = result.getOrElse(throw new IllegalStateException("AddPartitionsToTransaction has not completed"))
assertEquals(Errors.NONE, error)
verifyTransaction(txn, Ongoing)
verifyTransaction(txn, TransactionState.ONGOING)
}
}
@ -585,7 +585,7 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren
if (!txn.ended) {
txn.ended = true
assertEquals(Errors.NONE, error)
val expectedState = if (transactionResult(txn) == TransactionResult.COMMIT) CompleteCommit else CompleteAbort
val expectedState = if (transactionResult(txn) == TransactionResult.COMMIT) TransactionState.COMPLETE_COMMIT else TransactionState.COMPLETE_ABORT
verifyTransaction(txn, expectedState)
} else
assertEquals(Errors.INVALID_TXN_STATE, error)
@ -606,7 +606,7 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren
override def await(): Unit = {
allTransactions.foreach { txn =>
if (txnStateManager.partitionFor(txn.transactionalId) == txnTopicPartitionId) {
verifyTransaction(txn, CompleteCommit)
verifyTransaction(txn, TransactionState.COMPLETE_COMMIT)
}
}
}

View File

@ -22,7 +22,7 @@ import org.apache.kafka.common.protocol.Errors
import org.apache.kafka.common.record.RecordBatch
import org.apache.kafka.common.requests.{AddPartitionsToTxnResponse, TransactionResult}
import org.apache.kafka.common.utils.{LogContext, MockTime, ProducerIdAndEpoch}
import org.apache.kafka.coordinator.transaction.{ProducerIdManager, TransactionStateManagerConfig}
import org.apache.kafka.coordinator.transaction.{ProducerIdManager, TransactionState, TransactionStateManagerConfig}
import org.apache.kafka.server.common.TransactionVersion
import org.apache.kafka.server.common.TransactionVersion.{TV_0, TV_2}
import org.apache.kafka.server.util.MockScheduler
@ -197,7 +197,7 @@ class TransactionCoordinatorTest {
initPidGenericMocks(transactionalId)
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, (Short.MaxValue - 1).toShort,
(Short.MaxValue - 2).toShort, txnTimeoutMs, Empty, mutable.Set.empty, time.milliseconds(), time.milliseconds(), TV_0)
(Short.MaxValue - 2).toShort, txnTimeoutMs, TransactionState.EMPTY, mutable.Set.empty, time.milliseconds(), time.milliseconds(), TV_0)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
@ -229,11 +229,11 @@ class TransactionCoordinatorTest {
initPidGenericMocks(transactionalId)
val txnMetadata1 = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, (Short.MaxValue - 1).toShort,
(Short.MaxValue - 2).toShort, txnTimeoutMs, Ongoing, mutable.Set.empty, time.milliseconds(), time.milliseconds(), TV_2)
// We start with txnMetadata1 so we can transform the metadata to PrepareCommit.
(Short.MaxValue - 2).toShort, txnTimeoutMs, TransactionState.ONGOING, mutable.Set.empty, time.milliseconds(), time.milliseconds(), TV_2)
// We start with txnMetadata1 so we can transform the metadata to TransactionState.PREPARE_COMMIT.
val txnMetadata2 = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, (Short.MaxValue - 1).toShort,
(Short.MaxValue - 2).toShort, txnTimeoutMs, Ongoing, mutable.Set.empty, time.milliseconds(), time.milliseconds(), TV_2)
val transitMetadata = txnMetadata2.prepareAbortOrCommit(PrepareCommit, TV_2, producerId2, time.milliseconds(), false)
(Short.MaxValue - 2).toShort, txnTimeoutMs, TransactionState.ONGOING, mutable.Set.empty, time.milliseconds(), time.milliseconds(), TV_2)
val transitMetadata = txnMetadata2.prepareAbortOrCommit(TransactionState.PREPARE_COMMIT, TV_2, producerId2, time.milliseconds(), false)
txnMetadata2.completeTransitionTo(transitMetadata)
assertEquals(producerId, txnMetadata2.producerId)
@ -342,7 +342,7 @@ class TransactionCoordinatorTest {
}
// If producer ID is not the same, return INVALID_PRODUCER_ID_MAPPING
val wrongPidTxnMetadata = new TransactionMetadata(transactionalId, 1, 0, RecordBatch.NO_PRODUCER_ID,
0, RecordBatch.NO_PRODUCER_EPOCH, 0, PrepareCommit, partitions, 0, 0, TV_0)
0, RecordBatch.NO_PRODUCER_EPOCH, 0, TransactionState.PREPARE_COMMIT, partitions, 0, 0, TV_0)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(new CoordinatorEpochAndTxnMetadata(coordinatorEpoch, wrongPidTxnMetadata))))
@ -353,7 +353,7 @@ class TransactionCoordinatorTest {
// If producer epoch is not equal, return PRODUCER_FENCED
val oldEpochTxnMetadata = new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID,
0, RecordBatch.NO_PRODUCER_EPOCH, 0, PrepareCommit, partitions, 0, 0, TV_0)
0, RecordBatch.NO_PRODUCER_EPOCH, 0, TransactionState.PREPARE_COMMIT, partitions, 0, 0, TV_0)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(new CoordinatorEpochAndTxnMetadata(coordinatorEpoch, oldEpochTxnMetadata))))
@ -364,7 +364,7 @@ class TransactionCoordinatorTest {
// If the txn state is Prepare or AbortCommit, we return CONCURRENT_TRANSACTIONS
val emptyTxnMetadata = new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID,
0, RecordBatch.NO_PRODUCER_EPOCH, 0, PrepareCommit, partitions, 0, 0, TV_0)
0, RecordBatch.NO_PRODUCER_EPOCH, 0, TransactionState.PREPARE_COMMIT, partitions, 0, 0, TV_0)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(new CoordinatorEpochAndTxnMetadata(coordinatorEpoch, emptyTxnMetadata))))
@ -375,8 +375,8 @@ class TransactionCoordinatorTest {
// Pending state does not matter, we will just check if the partitions are in the txnMetadata.
val ongoingTxnMetadata = new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID,
0, RecordBatch.NO_PRODUCER_EPOCH, 0, Ongoing, mutable.Set.empty, 0, 0, TV_0)
ongoingTxnMetadata.pendingState = Some(CompleteCommit)
0, RecordBatch.NO_PRODUCER_EPOCH, 0, TransactionState.ONGOING, mutable.Set.empty, 0, 0, TV_0)
ongoingTxnMetadata.pendingState = Some(TransactionState.COMPLETE_COMMIT)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(new CoordinatorEpochAndTxnMetadata(coordinatorEpoch, ongoingTxnMetadata))))
@ -388,16 +388,16 @@ class TransactionCoordinatorTest {
@Test
def shouldRespondWithConcurrentTransactionsOnAddPartitionsWhenStateIsPrepareCommit(): Unit = {
validateConcurrentTransactions(PrepareCommit)
validateConcurrentTransactions(TransactionState.PREPARE_COMMIT)
}
@Test
def shouldRespondWithConcurrentTransactionOnAddPartitionsWhenStateIsPrepareAbort(): Unit = {
validateConcurrentTransactions(PrepareAbort)
validateConcurrentTransactions(TransactionState.PREPARE_ABORT)
}
def validateConcurrentTransactions(state: TransactionState): Unit = {
// Since the clientTransactionVersion doesn't matter, use 2 since the states are PrepareCommit and PrepareAbort.
// Since the clientTransactionVersion doesn't matter, use 2 since the states are TransactionState.PREPARE_COMMIT and TransactionState.PREPARE_ABORT.
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch,
new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID,
@ -409,11 +409,11 @@ class TransactionCoordinatorTest {
@Test
def shouldRespondWithProducerFencedOnAddPartitionsWhenEpochsAreDifferent(): Unit = {
// Since the clientTransactionVersion doesn't matter, use 2 since the state is PrepareCommit.
// Since the clientTransactionVersion doesn't matter, use 2 since the state is TransactionState.PREPARE_COMMIT.
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch,
new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID,
10, 9, 0, PrepareCommit, mutable.Set.empty, 0, 0, TV_2)))))
10, 9, 0, TransactionState.PREPARE_COMMIT, mutable.Set.empty, 0, 0, TV_2)))))
coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 0, partitions, errorsCallback, TV_2)
assertEquals(Errors.PRODUCER_FENCED, error)
@ -421,24 +421,24 @@ class TransactionCoordinatorTest {
@Test
def shouldAppendNewMetadataToLogOnAddPartitionsWhenPartitionsAdded(): Unit = {
validateSuccessfulAddPartitions(Empty, 0)
validateSuccessfulAddPartitions(TransactionState.EMPTY, 0)
}
@Test
def shouldRespondWithSuccessOnAddPartitionsWhenStateIsOngoing(): Unit = {
validateSuccessfulAddPartitions(Ongoing, 0)
validateSuccessfulAddPartitions(TransactionState.ONGOING, 0)
}
@ParameterizedTest
@ValueSource(shorts = Array(0, 2))
def shouldRespondWithSuccessOnAddPartitionsWhenStateIsCompleteCommit(clientTransactionVersion: Short): Unit = {
validateSuccessfulAddPartitions(CompleteCommit, clientTransactionVersion)
validateSuccessfulAddPartitions(TransactionState.COMPLETE_COMMIT, clientTransactionVersion)
}
@ParameterizedTest
@ValueSource(shorts = Array(0, 2))
def shouldRespondWithSuccessOnAddPartitionsWhenStateIsCompleteAbort(clientTransactionVersion: Short): Unit = {
validateSuccessfulAddPartitions(CompleteAbort, clientTransactionVersion)
validateSuccessfulAddPartitions(TransactionState.COMPLETE_ABORT, clientTransactionVersion)
}
def validateSuccessfulAddPartitions(previousState: TransactionState, transactionVersion: Short): Unit = {
@ -467,7 +467,7 @@ class TransactionCoordinatorTest {
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch,
new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID,
0, RecordBatch.NO_PRODUCER_EPOCH, 0, Empty, partitions, 0, 0, TV_0)))))
0, RecordBatch.NO_PRODUCER_EPOCH, 0, TransactionState.EMPTY, partitions, 0, 0, TV_0)))))
coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 0, partitions, errorsCallback, TV_0)
assertEquals(Errors.NONE, error)
@ -484,7 +484,7 @@ class TransactionCoordinatorTest {
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch,
new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID,
0, RecordBatch.NO_PRODUCER_EPOCH, 0, Ongoing, partitions, 0, 0, TV_0)))))
0, RecordBatch.NO_PRODUCER_EPOCH, 0, TransactionState.ONGOING, partitions, 0, 0, TV_0)))))
coordinator.handleVerifyPartitionsInTransaction(transactionalId, 0L, 0, partitions, verifyPartitionsInTxnCallback)
errors.foreach { case (_, error) =>
@ -503,7 +503,7 @@ class TransactionCoordinatorTest {
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch,
new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID,
0, RecordBatch.NO_PRODUCER_EPOCH, 0, Empty, partitions, 0, 0, TV_0)))))
0, RecordBatch.NO_PRODUCER_EPOCH, 0, TransactionState.EMPTY, partitions, 0, 0, TV_0)))))
val extraPartitions = partitions ++ Set(new TopicPartition("topic2", 0))
@ -532,7 +532,7 @@ class TransactionCoordinatorTest {
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch,
new TransactionMetadata(transactionalId, 10, 10, RecordBatch.NO_PRODUCER_ID,
0, RecordBatch.NO_PRODUCER_EPOCH, 0, Ongoing, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_0)))))
0, RecordBatch.NO_PRODUCER_EPOCH, 0, TransactionState.ONGOING, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_0)))))
coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback)
assertEquals(Errors.INVALID_PRODUCER_ID_MAPPING, error)
@ -546,7 +546,7 @@ class TransactionCoordinatorTest {
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch,
new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, producerEpoch,
(producerEpoch - 1).toShort, 1, Ongoing, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_0)))))
(producerEpoch - 1).toShort, 1, TransactionState.ONGOING, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_0)))))
coordinator.handleEndTransaction(transactionalId, producerId, 0, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback)
assertEquals(Errors.PRODUCER_FENCED, error)
@ -560,7 +560,7 @@ class TransactionCoordinatorTest {
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch,
new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, producerEpoch,
(producerEpoch - 1).toShort, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)))))
(producerEpoch - 1).toShort, 1, TransactionState.COMPLETE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)))))
val epoch = if (isRetry) producerEpoch - 1 else producerEpoch
coordinator.handleEndTransaction(transactionalId, producerId, epoch.toShort, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback)
@ -587,7 +587,7 @@ class TransactionCoordinatorTest {
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch,
new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, producerEpoch,
(producerEpoch - 1).toShort, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)))))
(producerEpoch - 1).toShort, 1, TransactionState.COMPLETE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)))))
val epoch = if (isRetry) producerEpoch - 1 else producerEpoch
coordinator.handleEndTransaction(transactionalId, producerId, epoch.toShort, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback)
@ -604,7 +604,7 @@ class TransactionCoordinatorTest {
def testEndTxnWhenStatusIsCompleteAbortAndResultIsAbortInV1(isRetry: Boolean): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(0)
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
producerEpoch, (producerEpoch - 1).toShort, 1, CompleteAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)
producerEpoch, (producerEpoch - 1).toShort, 1, TransactionState.COMPLETE_ABORT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
@ -623,7 +623,7 @@ class TransactionCoordinatorTest {
def shouldReturnOkOnEndTxnWhenStatusIsCompleteAbortAndResultIsAbortInV2(isRetry: Boolean): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(2)
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
producerEpoch, (producerEpoch - 1).toShort, 1, CompleteAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)
producerEpoch, (producerEpoch - 1).toShort, 1, TransactionState.COMPLETE_ABORT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
@ -660,7 +660,7 @@ class TransactionCoordinatorTest {
def shouldReturnInvalidTxnRequestOnEndTxnRequestWhenStatusIsCompleteAbortAndResultIsNotAbort(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
producerEpoch, (producerEpoch - 1).toShort, 1, CompleteAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)
producerEpoch, (producerEpoch - 1).toShort, 1, TransactionState.COMPLETE_ABORT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
@ -673,7 +673,7 @@ class TransactionCoordinatorTest {
def shouldReturnInvalidTxnRequestOnEndTxnRequestWhenStatusIsCompleteCommitAndResultIsNotCommit(): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(0)
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
producerEpoch, (producerEpoch - 1).toShort,1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)
producerEpoch, (producerEpoch - 1).toShort,1, TransactionState.COMPLETE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
@ -687,7 +687,7 @@ class TransactionCoordinatorTest {
def testEndTxnRequestWhenStatusIsCompleteCommitAndResultIsAbortInV1(isRetry: Boolean): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(0)
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
producerEpoch, (producerEpoch - 1).toShort, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)
producerEpoch, (producerEpoch - 1).toShort, 1, TransactionState.COMPLETE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
@ -706,7 +706,7 @@ class TransactionCoordinatorTest {
def testEndTxnRequestWhenStatusIsCompleteCommitAndResultIsAbortInV2(isRetry: Boolean): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(2)
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
producerEpoch, (producerEpoch - 1).toShort, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)
producerEpoch, (producerEpoch - 1).toShort, 1, TransactionState.COMPLETE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
@ -737,7 +737,7 @@ class TransactionCoordinatorTest {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_ID, producerEpoch, (producerEpoch - 1).toShort, 1, PrepareCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)))))
RecordBatch.NO_PRODUCER_ID, producerEpoch, (producerEpoch - 1).toShort, 1, TransactionState.PREPARE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)))))
coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch(clientTransactionVersion), TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback)
assertEquals(Errors.CONCURRENT_TRANSACTIONS, error)
@ -750,7 +750,7 @@ class TransactionCoordinatorTest {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, PrepareAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)))))
RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, TransactionState.PREPARE_ABORT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)))))
coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch(clientTransactionVersion), TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback)
assertEquals(Errors.INVALID_TXN_STATE, error)
@ -762,7 +762,7 @@ class TransactionCoordinatorTest {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(0)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, Empty, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)))))
RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, TransactionState.EMPTY, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)))))
coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.ABORT, clientTransactionVersion, endTxnCallback)
assertEquals(Errors.INVALID_TXN_STATE, error)
@ -775,7 +775,7 @@ class TransactionCoordinatorTest {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(2)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, Empty, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)))))
RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, TransactionState.EMPTY, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)))))
val epoch = if (isRetry) producerEpoch - 1 else producerEpoch
coordinator.handleEndTransaction(transactionalId, producerId, epoch.toShort, TransactionResult.ABORT, clientTransactionVersion, endTxnCallback)
@ -804,7 +804,7 @@ class TransactionCoordinatorTest {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(2)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, Empty, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)))))
RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, TransactionState.EMPTY, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)))))
val epoch = if (isRetry) producerEpoch - 1 else producerEpoch
coordinator.handleEndTransaction(transactionalId, producerId, epoch.toShort, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback)
@ -820,7 +820,7 @@ class TransactionCoordinatorTest {
def shouldReturnInvalidTxnRequestOnEndTxnV2IfNotEndTxnV2Retry(): Unit = {
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, PrepareCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2)))))
RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, TransactionState.PREPARE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2)))))
// If producerEpoch is the same, this is not a retry of the EndTxnRequest, but the next EndTxnRequest. Return PRODUCER_FENCED.
coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.COMMIT, TV_2, endTxnCallback)
@ -829,7 +829,7 @@ class TransactionCoordinatorTest {
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2)))))
RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, TransactionState.COMPLETE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2)))))
// If producerEpoch is the same, this is not a retry of the EndTxnRequest, but the next EndTxnRequest. Return INVALID_TXN_STATE.
coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.COMMIT, TV_2, endTxnCallback)
@ -841,7 +841,7 @@ class TransactionCoordinatorTest {
def shouldReturnOkOnEndTxnV2IfEndTxnV2RetryEpochOverflow(): Unit = {
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId,
producerId2, Short.MaxValue, (Short.MaxValue - 1).toShort, 1, PrepareCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2)))))
producerId2, Short.MaxValue, (Short.MaxValue - 1).toShort, 1, TransactionState.PREPARE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2)))))
// Return CONCURRENT_TRANSACTIONS while transaction is still completing
coordinator.handleEndTransaction(transactionalId, producerId, (Short.MaxValue - 1).toShort, TransactionResult.COMMIT, TV_2, endTxnCallback)
@ -850,7 +850,7 @@ class TransactionCoordinatorTest {
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId2, producerId,
RecordBatch.NO_PRODUCER_ID, 0, RecordBatch.NO_PRODUCER_EPOCH, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2)))))
RecordBatch.NO_PRODUCER_ID, 0, RecordBatch.NO_PRODUCER_EPOCH, 1, TransactionState.COMPLETE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2)))))
coordinator.handleEndTransaction(transactionalId, producerId, (Short.MaxValue - 1).toShort, TransactionResult.COMMIT, TV_2, endTxnCallback)
assertEquals(Errors.NONE, error)
@ -863,7 +863,7 @@ class TransactionCoordinatorTest {
@Test
def shouldReturnConcurrentTxnOnAddPartitionsIfEndTxnV2EpochOverflowAndNotComplete(): Unit = {
val prepareWithPending = new TransactionMetadata(transactionalId, producerId, producerId,
producerId2, Short.MaxValue, (Short.MaxValue - 1).toShort, 1, PrepareCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2)
producerId2, Short.MaxValue, (Short.MaxValue - 1).toShort, 1, TransactionState.PREPARE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2)
val txnTransitMetadata = prepareWithPending.prepareComplete(time.milliseconds())
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
@ -875,7 +875,7 @@ class TransactionCoordinatorTest {
verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId))
prepareWithPending.completeTransitionTo(txnTransitMetadata)
assertEquals(CompleteCommit, prepareWithPending.state)
assertEquals(TransactionState.COMPLETE_COMMIT, prepareWithPending.state)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, prepareWithPending))))
when(transactionManager.appendTransactionToLog(
@ -897,7 +897,7 @@ class TransactionCoordinatorTest {
@ValueSource(shorts = Array(0, 2))
def shouldAppendPrepareCommitToLogOnEndTxnWhenStatusIsOngoingAndResultIsCommit(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
mockPrepare(PrepareCommit, clientTransactionVersion)
mockPrepare(TransactionState.PREPARE_COMMIT, clientTransactionVersion)
coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback)
verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId))
@ -914,7 +914,7 @@ class TransactionCoordinatorTest {
@ValueSource(shorts = Array(0, 2))
def shouldAppendPrepareAbortToLogOnEndTxnWhenStatusIsOngoingAndResultIsAbort(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
mockPrepare(PrepareAbort, clientTransactionVersion)
mockPrepare(TransactionState.PREPARE_ABORT, clientTransactionVersion)
coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.ABORT, clientTransactionVersion, endTxnCallback)
verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId))
@ -989,7 +989,7 @@ class TransactionCoordinatorTest {
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch,
new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, metadataEpoch, 1,
1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)))))
1, TransactionState.COMPLETE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)))))
coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback)
assertEquals(Errors.PRODUCER_FENCED, error)
@ -998,29 +998,29 @@ class TransactionCoordinatorTest {
@Test
def shouldIncrementEpochAndUpdateMetadataOnHandleInitPidWhenExistingEmptyTransaction(): Unit = {
validateIncrementEpochAndUpdateMetadata(Empty, 0)
validateIncrementEpochAndUpdateMetadata(TransactionState.EMPTY, 0)
}
@ParameterizedTest
@ValueSource(shorts = Array(0, 2))
def shouldIncrementEpochAndUpdateMetadataOnHandleInitPidWhenExistingCompleteTransaction(clientTransactionVersion: Short): Unit = {
validateIncrementEpochAndUpdateMetadata(CompleteAbort, clientTransactionVersion)
validateIncrementEpochAndUpdateMetadata(TransactionState.COMPLETE_ABORT, clientTransactionVersion)
}
@ParameterizedTest
@ValueSource(shorts = Array(0, 2))
def shouldIncrementEpochAndUpdateMetadataOnHandleInitPidWhenExistingCompleteCommitTransaction(clientTransactionVersion: Short): Unit = {
validateIncrementEpochAndUpdateMetadata(CompleteCommit, clientTransactionVersion)
validateIncrementEpochAndUpdateMetadata(TransactionState.COMPLETE_COMMIT, clientTransactionVersion)
}
@Test
def shouldWaitForCommitToCompleteOnHandleInitPidAndExistingTransactionInPrepareCommitState(): Unit = {
validateRespondsWithConcurrentTransactionsOnInitPidWhenInPrepareState(PrepareCommit)
validateRespondsWithConcurrentTransactionsOnInitPidWhenInPrepareState(TransactionState.PREPARE_COMMIT)
}
@Test
def shouldWaitForCommitToCompleteOnHandleInitPidAndExistingTransactionInPrepareAbortState(): Unit = {
validateRespondsWithConcurrentTransactionsOnInitPidWhenInPrepareState(PrepareAbort)
validateRespondsWithConcurrentTransactionsOnInitPidWhenInPrepareState(TransactionState.PREPARE_ABORT)
}
@ParameterizedTest(name = "enableTwoPCFlag={0}, keepPreparedTxn={1}")
@ -1030,7 +1030,7 @@ class TransactionCoordinatorTest {
keepPreparedTxn: Boolean
): Unit = {
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
producerEpoch, (producerEpoch - 1).toShort, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0)
producerEpoch, (producerEpoch - 1).toShort, txnTimeoutMs, TransactionState.ONGOING, partitions, time.milliseconds(), time.milliseconds(), TV_0)
when(transactionManager.validateTransactionTimeoutMs(anyBoolean(), anyInt()))
.thenReturn(true)
@ -1041,7 +1041,7 @@ class TransactionCoordinatorTest {
when(transactionManager.transactionVersionLevel()).thenReturn(TV_0)
val originalMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
(producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0)
(producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.ONGOING, partitions, time.milliseconds(), time.milliseconds(), TV_0)
when(transactionManager.appendTransactionToLog(
ArgumentMatchers.eq(transactionalId),
ArgumentMatchers.eq(coordinatorEpoch),
@ -1066,7 +1066,7 @@ class TransactionCoordinatorTest {
verify(transactionManager).appendTransactionToLog(
ArgumentMatchers.eq(transactionalId),
ArgumentMatchers.eq(coordinatorEpoch),
ArgumentMatchers.eq(originalMetadata.prepareAbortOrCommit(PrepareAbort, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds(), false)),
ArgumentMatchers.eq(originalMetadata.prepareAbortOrCommit(TransactionState.PREPARE_ABORT, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds(), false)),
any(),
any(),
any())
@ -1075,13 +1075,13 @@ class TransactionCoordinatorTest {
@Test
def shouldFailToAbortTransactionOnHandleInitPidWhenProducerEpochIsSmaller(): Unit = {
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
producerEpoch, (producerEpoch - 1).toShort, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0)
producerEpoch, (producerEpoch - 1).toShort, txnTimeoutMs, TransactionState.ONGOING, partitions, time.milliseconds(), time.milliseconds(), TV_0)
when(transactionManager.validateTransactionTimeoutMs(anyBoolean(), anyInt()))
.thenReturn(true)
val bumpedTxnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
(producerEpoch + 2).toShort, (producerEpoch - 1).toShort, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0)
(producerEpoch + 2).toShort, (producerEpoch - 1).toShort, txnTimeoutMs, TransactionState.ONGOING, partitions, time.milliseconds(), time.milliseconds(), TV_0)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
@ -1106,7 +1106,7 @@ class TransactionCoordinatorTest {
@Test
def shouldNotRepeatedlyBumpEpochDueToInitPidDuringOngoingTxnIfAppendToLogFails(): Unit = {
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0)
producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.ONGOING, partitions, time.milliseconds(), time.milliseconds(), TV_0)
when(transactionManager.validateTransactionTimeoutMs(anyBoolean(), anyInt()))
.thenReturn(true)
@ -1120,8 +1120,8 @@ class TransactionCoordinatorTest {
when(transactionManager.transactionVersionLevel()).thenReturn(TV_0)
val originalMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
(producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0)
val txnTransitMetadata = originalMetadata.prepareAbortOrCommit(PrepareAbort, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds(), false)
(producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.ONGOING, partitions, time.milliseconds(), time.milliseconds(), TV_0)
val txnTransitMetadata = originalMetadata.prepareAbortOrCommit(TransactionState.PREPARE_ABORT, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds(), false)
when(transactionManager.appendTransactionToLog(
ArgumentMatchers.eq(transactionalId),
ArgumentMatchers.eq(coordinatorEpoch),
@ -1198,14 +1198,14 @@ class TransactionCoordinatorTest {
@Test
def shouldUseLastEpochToFenceWhenEpochsAreExhausted(): Unit = {
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
(Short.MaxValue - 1).toShort, (Short.MaxValue - 2).toShort, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0)
(Short.MaxValue - 1).toShort, (Short.MaxValue - 2).toShort, txnTimeoutMs, TransactionState.ONGOING, partitions, time.milliseconds(), time.milliseconds(), TV_0)
assertTrue(txnMetadata.isProducerEpochExhausted)
when(transactionManager.validateTransactionTimeoutMs(anyBoolean(), anyInt()))
.thenReturn(true)
val postFenceTxnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
Short.MaxValue, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, PrepareAbort, partitions, time.milliseconds(), time.milliseconds(), TV_0)
Short.MaxValue, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.PREPARE_ABORT, partitions, time.milliseconds(), time.milliseconds(), TV_0)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
@ -1224,7 +1224,7 @@ class TransactionCoordinatorTest {
producerEpoch = Short.MaxValue,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = txnTimeoutMs,
txnState = PrepareAbort,
txnState = TransactionState.PREPARE_ABORT,
topicPartitions = partitions.clone,
txnStartTimestamp = time.milliseconds(),
txnLastUpdateTimestamp = time.milliseconds(),
@ -1257,7 +1257,7 @@ class TransactionCoordinatorTest {
producerEpoch = Short.MaxValue,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = txnTimeoutMs,
txnState = PrepareAbort,
txnState = TransactionState.PREPARE_ABORT,
topicPartitions = partitions.clone,
txnStartTimestamp = time.milliseconds(),
txnLastUpdateTimestamp = time.milliseconds(),
@ -1272,7 +1272,7 @@ class TransactionCoordinatorTest {
// If the metadata doesn't include the previous producer data (for example, if it was written to the log by a broker
// on an old version), the retry case should fail
val txnMetadata = new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID,
RecordBatch.NO_PRODUCER_EPOCH, (producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds, TV_0)
RecordBatch.NO_PRODUCER_EPOCH, (producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.EMPTY, partitions, time.milliseconds, time.milliseconds, TV_0)
when(transactionManager.validateTransactionTimeoutMs(anyBoolean(), anyInt()))
.thenReturn(true)
@ -1295,7 +1295,7 @@ class TransactionCoordinatorTest {
def testFenceProducerWhenMappingExistsWithDifferentProducerId(): Unit = {
// Existing transaction ID maps to new producer ID
val txnMetadata = new TransactionMetadata(transactionalId, producerId + 1, producerId,
RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, (producerEpoch - 1).toShort, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds, TV_0)
RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, (producerEpoch - 1).toShort, txnTimeoutMs, TransactionState.EMPTY, partitions, time.milliseconds, time.milliseconds, TV_0)
when(transactionManager.validateTransactionTimeoutMs(anyBoolean(), anyInt()))
.thenReturn(true)
@ -1319,7 +1319,7 @@ class TransactionCoordinatorTest {
mockPidGenerator()
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_EPOCH, 10, 9, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds, TV_0)
RecordBatch.NO_PRODUCER_EPOCH, 10, 9, txnTimeoutMs, TransactionState.EMPTY, partitions, time.milliseconds, time.milliseconds, TV_0)
when(transactionManager.validateTransactionTimeoutMs(anyBoolean(), anyInt()))
.thenReturn(true)
@ -1366,7 +1366,7 @@ class TransactionCoordinatorTest {
mockPidGenerator()
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_EPOCH, 10, 9, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds, TV_0)
RecordBatch.NO_PRODUCER_EPOCH, 10, 9, txnTimeoutMs, TransactionState.EMPTY, partitions, time.milliseconds, time.milliseconds, TV_0)
when(transactionManager.validateTransactionTimeoutMs(anyBoolean(), anyInt()))
.thenReturn(true)
@ -1415,7 +1415,7 @@ class TransactionCoordinatorTest {
def testRetryInitProducerIdAfterProducerIdRotation(): Unit = {
// Existing transaction ID maps to new producer ID
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_EPOCH, (Short.MaxValue - 1).toShort, (Short.MaxValue - 2).toShort, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds, TV_0)
RecordBatch.NO_PRODUCER_EPOCH, (Short.MaxValue - 1).toShort, (Short.MaxValue - 2).toShort, txnTimeoutMs, TransactionState.EMPTY, partitions, time.milliseconds, time.milliseconds, TV_0)
when(pidGenerator.generateProducerId())
.thenReturn(producerId + 1)
@ -1468,7 +1468,7 @@ class TransactionCoordinatorTest {
def testInitProducerIdWithInvalidEpochAfterProducerIdRotation(): Unit = {
// Existing transaction ID maps to new producer ID
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_EPOCH, (Short.MaxValue - 1).toShort, (Short.MaxValue - 2).toShort, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds, TV_0)
RecordBatch.NO_PRODUCER_EPOCH, (Short.MaxValue - 1).toShort, (Short.MaxValue - 2).toShort, txnTimeoutMs, TransactionState.EMPTY, partitions, time.milliseconds, time.milliseconds, TV_0)
when(pidGenerator.generateProducerId())
.thenReturn(producerId + 1)
@ -1528,7 +1528,7 @@ class TransactionCoordinatorTest {
def shouldAbortExpiredTransactionsInOngoingStateAndBumpEpoch(): Unit = {
val now = time.milliseconds()
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now, TV_0)
RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.ONGOING, partitions, now, now, TV_0)
when(transactionManager.timedOutTransactions())
.thenReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, producerId, producerEpoch)))
@ -1537,7 +1537,7 @@ class TransactionCoordinatorTest {
// Transaction timeouts use FenceProducerEpoch so clientTransactionVersion is 0.
val expectedTransition = TxnTransitMetadata(producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, (producerEpoch + 1).toShort,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, PrepareAbort, partitions.clone, now,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.PREPARE_ABORT, partitions.clone, now,
now + TransactionStateManagerConfig.TRANSACTIONS_ABORT_TIMED_OUT_TRANSACTION_CLEANUP_INTERVAL_MS_DEFAULT, TV_0)
when(transactionManager.transactionVersionLevel()).thenReturn(TV_0)
@ -1567,7 +1567,7 @@ class TransactionCoordinatorTest {
def shouldNotAcceptSmallerEpochDuringTransactionExpiration(): Unit = {
val now = time.milliseconds()
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now, TV_0)
RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.ONGOING, partitions, now, now, TV_0)
when(transactionManager.timedOutTransactions())
.thenReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, producerId, producerEpoch)))
@ -1577,7 +1577,7 @@ class TransactionCoordinatorTest {
when(transactionManager.transactionVersionLevel()).thenReturn(TV_0)
val bumpedTxnMetadata = new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_EPOCH, (producerEpoch + 2).toShort, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now, TV_0)
RecordBatch.NO_PRODUCER_EPOCH, (producerEpoch + 2).toShort, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.ONGOING, partitions, now, now, TV_0)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, bumpedTxnMetadata))))
@ -1593,8 +1593,8 @@ class TransactionCoordinatorTest {
@Test
def shouldNotAbortExpiredTransactionsThatHaveAPendingStateTransition(): Unit = {
val metadata = new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0)
metadata.prepareAbortOrCommit(PrepareCommit, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds(), false)
RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.ONGOING, partitions, time.milliseconds(), time.milliseconds(), TV_0)
metadata.prepareAbortOrCommit(TransactionState.PREPARE_COMMIT, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds(), false)
when(transactionManager.timedOutTransactions())
.thenReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, producerId, producerEpoch)))
@ -1612,13 +1612,13 @@ class TransactionCoordinatorTest {
def shouldNotBumpEpochWhenAbortingExpiredTransactionIfAppendToLogFails(): Unit = {
val now = time.milliseconds()
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now, TV_0)
RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.ONGOING, partitions, now, now, TV_0)
when(transactionManager.timedOutTransactions())
.thenReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, producerId, producerEpoch)))
val txnMetadataAfterAppendFailure = new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_EPOCH, (producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now, TV_0)
RecordBatch.NO_PRODUCER_EPOCH, (producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.ONGOING, partitions, now, now, TV_0)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
@ -1627,7 +1627,7 @@ class TransactionCoordinatorTest {
// Transaction timeouts use FenceProducerEpoch so clientTransactionVersion is 0.
val bumpedEpoch = (producerEpoch + 1).toShort
val expectedTransition = TxnTransitMetadata(producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, bumpedEpoch,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, PrepareAbort, partitions.clone, now,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.PREPARE_ABORT, partitions.clone, now,
now + TransactionStateManagerConfig.TRANSACTIONS_ABORT_TIMED_OUT_TRANSACTION_CLEANUP_INTERVAL_MS_DEFAULT, TV_0)
when(transactionManager.transactionVersionLevel()).thenReturn(TV_0)
@ -1660,8 +1660,8 @@ class TransactionCoordinatorTest {
@Test
def shouldNotBumpEpochWithPendingTransaction(): Unit = {
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0)
txnMetadata.prepareAbortOrCommit(PrepareCommit, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds(), false)
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.ONGOING, partitions, time.milliseconds(), time.milliseconds(), TV_0)
txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_COMMIT, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds(), false)
when(transactionManager.validateTransactionTimeoutMs(anyBoolean(), anyInt()))
.thenReturn(true)
@ -1695,7 +1695,7 @@ class TransactionCoordinatorTest {
coordinator.startup(() => transactionStatePartitionCount, enableTransactionalIdExpiration = false)
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Dead, mutable.Set.empty, time.milliseconds(),
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.DEAD, mutable.Set.empty, time.milliseconds(),
time.milliseconds(), TV_0)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
@ -1721,7 +1721,7 @@ class TransactionCoordinatorTest {
@Test
def testDescribeTransactions(): Unit = {
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0)
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.ONGOING, partitions, time.milliseconds(), time.milliseconds(), TV_0)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
@ -1747,7 +1747,7 @@ class TransactionCoordinatorTest {
when(transactionManager.validateTransactionTimeoutMs(anyBoolean(), anyInt()))
.thenReturn(true)
// Since the clientTransactionVersion doesn't matter, use 2 since the states are PrepareCommit and PrepareAbort.
// Since the clientTransactionVersion doesn't matter, use 2 since the states are TransactionState.PREPARE_COMMIT and TransactionState.PREPARE_ABORT.
val metadata = new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_EPOCH,
0, RecordBatch.NO_PRODUCER_EPOCH, 0, state, mutable.Set[TopicPartition](new TopicPartition("topic", 1)), 0, 0, TV_2)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
@ -1799,7 +1799,7 @@ class TransactionCoordinatorTest {
private def mockPrepare(transactionState: TransactionState, clientTransactionVersion: TransactionVersion, runCallback: Boolean = false): TransactionMetadata = {
val now = time.milliseconds()
val originalMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH,
producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now, TV_0)
producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.ONGOING, partitions, now, now, TV_0)
val transition = TxnTransitMetadata(producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, transactionState, partitions.clone, now, now, clientTransactionVersion)

View File

@ -23,6 +23,7 @@ import org.apache.kafka.common.protocol.{ByteBufferAccessor, MessageUtil}
import org.apache.kafka.common.protocol.types.Field.TaggedFieldsSection
import org.apache.kafka.common.protocol.types.{CompactArrayOf, Field, Schema, Struct, Type}
import org.apache.kafka.common.record.{MemoryRecords, RecordBatch, SimpleRecord}
import org.apache.kafka.coordinator.transaction.TransactionState
import org.apache.kafka.coordinator.transaction.generated.{TransactionLogKey, TransactionLogValue}
import org.apache.kafka.server.common.TransactionVersion.{TV_0, TV_2}
import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows, assertTrue, fail}
@ -49,7 +50,7 @@ class TransactionLogTest {
val producerId = 23423L
val txnMetadata = new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_ID, producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH, transactionTimeoutMs, Empty, collection.mutable.Set.empty[TopicPartition], 0, 0, TV_0)
RecordBatch.NO_PRODUCER_EPOCH, transactionTimeoutMs, TransactionState.EMPTY, collection.mutable.Set.empty[TopicPartition], 0, 0, TV_0)
txnMetadata.addPartitions(topicPartitions)
assertThrows(classOf[IllegalStateException], () => TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), TV_2))
@ -64,19 +65,19 @@ class TransactionLogTest {
"four" -> 4L,
"five" -> 5L)
val transactionStates = Map[Long, TransactionState](0L -> Empty,
1L -> Ongoing,
2L -> PrepareCommit,
3L -> CompleteCommit,
4L -> PrepareAbort,
5L -> CompleteAbort)
val transactionStates = Map[Long, TransactionState](0L -> TransactionState.EMPTY,
1L -> TransactionState.ONGOING,
2L -> TransactionState.PREPARE_COMMIT,
3L -> TransactionState.COMPLETE_COMMIT,
4L -> TransactionState.PREPARE_ABORT,
5L -> TransactionState.COMPLETE_ABORT)
// generate transaction log messages
val txnRecords = pidMappings.map { case (transactionalId, producerId) =>
val txnMetadata = new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_ID, producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH, transactionTimeoutMs, transactionStates(producerId), collection.mutable.Set.empty[TopicPartition], 0, 0, TV_0)
if (!txnMetadata.state.equals(Empty))
if (!txnMetadata.state.equals(TransactionState.EMPTY))
txnMetadata.addPartitions(topicPartitions)
val keyBytes = TransactionLog.keyToBytes(transactionalId)
@ -99,7 +100,7 @@ class TransactionLogTest {
assertEquals(transactionTimeoutMs, txnMetadata.txnTimeoutMs)
assertEquals(transactionStates(txnMetadata.producerId), txnMetadata.state)
if (txnMetadata.state.equals(Empty))
if (txnMetadata.state.equals(TransactionState.EMPTY))
assertEquals(Set.empty[TopicPartition], txnMetadata.topicPartitions)
else
assertEquals(topicPartitions, txnMetadata.topicPartitions)
@ -113,14 +114,14 @@ class TransactionLogTest {
@Test
def testSerializeTransactionLogValueToHighestNonFlexibleVersion(): Unit = {
val txnTransitMetadata = TxnTransitMetadata(1, 1, 1, 1, 1, 1000, CompleteCommit, mutable.Set.empty, 500, 500, TV_0)
val txnTransitMetadata = TxnTransitMetadata(1, 1, 1, 1, 1, 1000, TransactionState.COMPLETE_COMMIT, mutable.Set.empty, 500, 500, TV_0)
val txnLogValueBuffer = ByteBuffer.wrap(TransactionLog.valueToBytes(txnTransitMetadata, TV_0))
assertEquals(0, txnLogValueBuffer.getShort)
}
@Test
def testSerializeTransactionLogValueToFlexibleVersion(): Unit = {
val txnTransitMetadata = TxnTransitMetadata(1, 1, 1, 1, 1, 1000, CompleteCommit, mutable.Set.empty, 500, 500, TV_2)
val txnTransitMetadata = TxnTransitMetadata(1, 1, 1, 1, 1, 1000, TransactionState.COMPLETE_COMMIT, mutable.Set.empty, 500, 500, TV_2)
val txnLogValueBuffer = ByteBuffer.wrap(TransactionLog.valueToBytes(txnTransitMetadata, TV_2))
assertEquals(TransactionLogValue.HIGHEST_SUPPORTED_VERSION, txnLogValueBuffer.getShort)
}
@ -134,7 +135,7 @@ class TransactionLogTest {
val txnLogValue = new TransactionLogValue()
.setProducerId(100)
.setProducerEpoch(50.toShort)
.setTransactionStatus(CompleteCommit.id)
.setTransactionStatus(TransactionState.COMPLETE_COMMIT.id)
.setTransactionStartTimestampMs(750L)
.setTransactionLastUpdateTimestampMs(1000L)
.setTransactionTimeoutMs(500)
@ -145,7 +146,7 @@ class TransactionLogTest {
assertEquals(100, deserialized.producerId)
assertEquals(50, deserialized.producerEpoch)
assertEquals(CompleteCommit, deserialized.state)
assertEquals(TransactionState.COMPLETE_COMMIT, deserialized.state)
assertEquals(750L, deserialized.txnStartTimestamp)
assertEquals(1000L, deserialized.txnLastUpdateTimestamp)
assertEquals(500, deserialized.txnTimeoutMs)
@ -198,7 +199,7 @@ class TransactionLogTest {
transactionLogValue.set("producer_id", 1000L)
transactionLogValue.set("producer_epoch", 100.toShort)
transactionLogValue.set("transaction_timeout_ms", 1000)
transactionLogValue.set("transaction_status", CompleteCommit.id)
transactionLogValue.set("transaction_status", TransactionState.COMPLETE_COMMIT.id)
transactionLogValue.set("transaction_partitions", Array(txnPartitions))
transactionLogValue.set("transaction_last_update_timestamp_ms", 2000L)
transactionLogValue.set("transaction_start_timestamp_ms", 3000L)
@ -227,7 +228,7 @@ class TransactionLogTest {
assertEquals(1000L, txnMetadata.producerId)
assertEquals(100, txnMetadata.producerEpoch)
assertEquals(1000L, txnMetadata.txnTimeoutMs)
assertEquals(CompleteCommit, txnMetadata.state)
assertEquals(TransactionState.COMPLETE_COMMIT, txnMetadata.state)
assertEquals(Set(new TopicPartition("topic", 1)), txnMetadata.topicPartitions)
assertEquals(2000L, txnMetadata.txnLastUpdateTimestamp)
assertEquals(3000L, txnMetadata.txnStartTimestamp)

View File

@ -29,6 +29,7 @@ import org.apache.kafka.common.record.RecordBatch
import org.apache.kafka.common.requests.{RequestHeader, TransactionResult, WriteTxnMarkersRequest, WriteTxnMarkersResponse}
import org.apache.kafka.common.utils.MockTime
import org.apache.kafka.common.{Node, TopicPartition}
import org.apache.kafka.coordinator.transaction.TransactionState
import org.apache.kafka.metadata.MetadataCache
import org.apache.kafka.server.common.{MetadataVersion, TransactionVersion}
import org.apache.kafka.server.metrics.{KafkaMetricsGroup, KafkaYammerMetrics}
@ -68,9 +69,9 @@ class TransactionMarkerChannelManagerTest {
private val txnTimeoutMs = 0
private val txnResult = TransactionResult.COMMIT
private val txnMetadata1 = new TransactionMetadata(transactionalId1, producerId1, producerId1, RecordBatch.NO_PRODUCER_ID,
producerEpoch, lastProducerEpoch, txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](partition1, partition2), 0L, 0L, TransactionVersion.TV_2)
producerEpoch, lastProducerEpoch, txnTimeoutMs, TransactionState.PREPARE_COMMIT, mutable.Set[TopicPartition](partition1, partition2), 0L, 0L, TransactionVersion.TV_2)
private val txnMetadata2 = new TransactionMetadata(transactionalId2, producerId2, producerId2, RecordBatch.NO_PRODUCER_ID,
producerEpoch, lastProducerEpoch, txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](partition1), 0L, 0L, TransactionVersion.TV_2)
producerEpoch, lastProducerEpoch, txnTimeoutMs, TransactionState.PREPARE_COMMIT, mutable.Set[TopicPartition](partition1), 0L, 0L, TransactionVersion.TV_2)
private val capturedErrorsCallback: ArgumentCaptor[Errors => Unit] = ArgumentCaptor.forClass(classOf[Errors => Unit])
private val time = new MockTime
@ -480,7 +481,7 @@ class TransactionMarkerChannelManagerTest {
assertEquals(0, channelManager.numTxnsWithPendingMarkers)
assertEquals(0, channelManager.queueForBroker(broker1.id).get.totalNumMarkers)
assertEquals(None, txnMetadata2.pendingState)
assertEquals(CompleteCommit, txnMetadata2.state)
assertEquals(TransactionState.COMPLETE_COMMIT, txnMetadata2.state)
}
@Test
@ -533,7 +534,7 @@ class TransactionMarkerChannelManagerTest {
assertEquals(0, channelManager.numTxnsWithPendingMarkers)
assertEquals(0, channelManager.queueForBroker(broker1.id).get.totalNumMarkers)
assertEquals(None, txnMetadata2.pendingState)
assertEquals(PrepareCommit, txnMetadata2.state)
assertEquals(TransactionState.PREPARE_COMMIT, txnMetadata2.state)
}
@ParameterizedTest
@ -594,7 +595,7 @@ class TransactionMarkerChannelManagerTest {
assertEquals(0, channelManager.numTxnsWithPendingMarkers)
assertEquals(0, channelManager.queueForBroker(broker1.id).get.totalNumMarkers)
assertEquals(None, txnMetadata2.pendingState)
assertEquals(CompleteCommit, txnMetadata2.state)
assertEquals(TransactionState.COMPLETE_COMMIT, txnMetadata2.state)
}
private def createPidErrorMap(errors: Errors): util.HashMap[java.lang.Long, util.Map[TopicPartition, Errors]] = {

View File

@ -23,6 +23,7 @@ import org.apache.kafka.common.TopicPartition
import org.apache.kafka.common.protocol.{ApiKeys, Errors}
import org.apache.kafka.common.record.RecordBatch
import org.apache.kafka.common.requests.{RequestHeader, TransactionResult, WriteTxnMarkersRequest, WriteTxnMarkersResponse}
import org.apache.kafka.coordinator.transaction.TransactionState
import org.apache.kafka.server.common.TransactionVersion
import org.junit.jupiter.api.Assertions._
import org.junit.jupiter.api.Test
@ -44,7 +45,7 @@ class TransactionMarkerRequestCompletionHandlerTest {
private val txnResult = TransactionResult.COMMIT
private val topicPartition = new TopicPartition("topic1", 0)
private val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
producerEpoch, lastProducerEpoch, txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](topicPartition), 0L, 0L, TransactionVersion.TV_2)
producerEpoch, lastProducerEpoch, txnTimeoutMs, TransactionState.PREPARE_COMMIT, mutable.Set[TopicPartition](topicPartition), 0L, 0L, TransactionVersion.TV_2)
private val pendingCompleteTxnAndMarkers = asList(
PendingCompleteTxnAndMarkerEntry(
PendingCompleteTxn(transactionalId, coordinatorEpoch, txnMetadata, txnMetadata.prepareComplete(42)),

View File

@ -19,6 +19,7 @@ package kafka.coordinator.transaction
import org.apache.kafka.common.TopicPartition
import org.apache.kafka.common.protocol.Errors
import org.apache.kafka.common.record.RecordBatch
import org.apache.kafka.coordinator.transaction.TransactionState
import org.apache.kafka.server.common.TransactionVersion
import org.apache.kafka.server.common.TransactionVersion.{TV_0, TV_2}
import org.apache.kafka.server.util.MockTime
@ -27,7 +28,10 @@ import org.junit.jupiter.api.Test
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.ValueSource
import java.util.Optional
import scala.collection.mutable
import scala.jdk.CollectionConverters._
class TransactionMetadataTest {
@ -47,7 +51,7 @@ class TransactionMetadataTest {
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = Empty,
state = TransactionState.EMPTY,
topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
@ -71,7 +75,7 @@ class TransactionMetadataTest {
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = Empty,
state = TransactionState.EMPTY,
topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
@ -95,7 +99,7 @@ class TransactionMetadataTest {
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = Empty,
state = TransactionState.EMPTY,
topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
@ -117,13 +121,13 @@ class TransactionMetadataTest {
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = Empty,
state = TransactionState.EMPTY,
topicPartitions = mutable.Set.empty,
txnStartTimestamp = -1,
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_2)
val transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareAbort, TV_2, RecordBatch.NO_PRODUCER_ID, time.milliseconds() + 1, true)
val transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_ABORT, TV_2, RecordBatch.NO_PRODUCER_ID, time.milliseconds() + 1, true)
txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(producerId, txnMetadata.producerId)
assertEquals(producerEpoch + 1, txnMetadata.producerEpoch)
@ -142,13 +146,13 @@ class TransactionMetadataTest {
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = CompleteAbort,
state = TransactionState.COMPLETE_ABORT,
topicPartitions = mutable.Set.empty,
txnStartTimestamp = time.milliseconds() - 1,
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_2)
val transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareAbort, TV_2, RecordBatch.NO_PRODUCER_ID, time.milliseconds() + 1, true)
val transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_ABORT, TV_2, RecordBatch.NO_PRODUCER_ID, time.milliseconds() + 1, true)
txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(producerId, txnMetadata.producerId)
assertEquals(producerEpoch + 1, txnMetadata.producerEpoch)
@ -167,13 +171,13 @@ class TransactionMetadataTest {
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = CompleteCommit,
state = TransactionState.COMPLETE_COMMIT,
topicPartitions = mutable.Set.empty,
txnStartTimestamp = time.milliseconds() - 1,
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_2)
val transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareAbort, TV_2, RecordBatch.NO_PRODUCER_ID, time.milliseconds() + 1, true)
val transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_ABORT, TV_2, RecordBatch.NO_PRODUCER_ID, time.milliseconds() + 1, true)
txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(producerId, txnMetadata.producerId)
assertEquals(producerEpoch + 1, txnMetadata.producerEpoch)
@ -191,7 +195,7 @@ class TransactionMetadataTest {
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = Empty,
state = TransactionState.EMPTY,
topicPartitions = mutable.Set.empty,
txnStartTimestamp = 1L,
txnLastUpdateTimestamp = time.milliseconds(),
@ -219,7 +223,7 @@ class TransactionMetadataTest {
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = Empty,
state = TransactionState.EMPTY,
topicPartitions = mutable.Set.empty,
txnStartTimestamp = 1L,
txnLastUpdateTimestamp = time.milliseconds(),
@ -246,13 +250,13 @@ class TransactionMetadataTest {
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = Empty,
state = TransactionState.EMPTY,
topicPartitions = mutable.Set.empty,
txnStartTimestamp = time.milliseconds(),
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
// let new time be smaller; when transiting from Empty the start time would be updated to the update-time
// let new time be smaller; when transiting from TransactionState.EMPTY the start time would be updated to the update-time
var transitMetadata = txnMetadata.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic1", 0)), time.milliseconds() - 1, TV_0)
txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(Set[TopicPartition](new TopicPartition("topic1", 0)), txnMetadata.topicPartitions)
@ -262,7 +266,7 @@ class TransactionMetadataTest {
assertEquals(time.milliseconds() - 1, txnMetadata.txnStartTimestamp)
assertEquals(time.milliseconds() - 1, txnMetadata.txnLastUpdateTimestamp)
// add another partition, check that in Ongoing state the start timestamp would not change to update time
// add another partition, check that in TransactionState.ONGOING state the start timestamp would not change to update time
transitMetadata = txnMetadata.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds() - 2, TV_0)
txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(Set[TopicPartition](new TopicPartition("topic1", 0), new TopicPartition("topic2", 0)), txnMetadata.topicPartitions)
@ -284,16 +288,16 @@ class TransactionMetadataTest {
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = Ongoing,
state = TransactionState.ONGOING,
topicPartitions = mutable.Set.empty,
txnStartTimestamp = 1L,
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
// let new time be smaller
val transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareCommit, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds() - 1, false)
val transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_COMMIT, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds() - 1, false)
txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(PrepareCommit, txnMetadata.state)
assertEquals(TransactionState.PREPARE_COMMIT, txnMetadata.state)
assertEquals(producerId, txnMetadata.producerId)
assertEquals(RecordBatch.NO_PRODUCER_EPOCH, txnMetadata.lastProducerEpoch)
assertEquals(producerEpoch, txnMetadata.producerEpoch)
@ -312,16 +316,16 @@ class TransactionMetadataTest {
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = Ongoing,
state = TransactionState.ONGOING,
topicPartitions = mutable.Set.empty,
txnStartTimestamp = 1L,
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
// let new time be smaller
val transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareAbort, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds() - 1, false)
val transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_ABORT, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds() - 1, false)
txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(PrepareAbort, txnMetadata.state)
assertEquals(TransactionState.PREPARE_ABORT, txnMetadata.state)
assertEquals(producerId, txnMetadata.producerId)
assertEquals(RecordBatch.NO_PRODUCER_EPOCH, txnMetadata.lastProducerEpoch)
assertEquals(producerEpoch, txnMetadata.producerEpoch)
@ -343,7 +347,7 @@ class TransactionMetadataTest {
producerEpoch = producerEpoch,
lastProducerEpoch = lastProducerEpoch,
txnTimeoutMs = 30000,
state = PrepareCommit,
state = TransactionState.PREPARE_COMMIT,
topicPartitions = mutable.Set.empty,
txnStartTimestamp = 1L,
txnLastUpdateTimestamp = time.milliseconds(),
@ -354,7 +358,7 @@ class TransactionMetadataTest {
val transitMetadata = txnMetadata.prepareComplete(time.milliseconds() - 1)
txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(CompleteCommit, txnMetadata.state)
assertEquals(TransactionState.COMPLETE_COMMIT, txnMetadata.state)
assertEquals(producerId, txnMetadata.producerId)
assertEquals(lastProducerEpoch, txnMetadata.lastProducerEpoch)
assertEquals(producerEpoch, txnMetadata.producerEpoch)
@ -376,7 +380,7 @@ class TransactionMetadataTest {
producerEpoch = producerEpoch,
lastProducerEpoch = lastProducerEpoch,
txnTimeoutMs = 30000,
state = PrepareAbort,
state = TransactionState.PREPARE_ABORT,
topicPartitions = mutable.Set.empty,
txnStartTimestamp = 1L,
txnLastUpdateTimestamp = time.milliseconds(),
@ -387,7 +391,7 @@ class TransactionMetadataTest {
val transitMetadata = txnMetadata.prepareComplete(time.milliseconds() - 1)
txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(CompleteAbort, txnMetadata.state)
assertEquals(TransactionState.COMPLETE_ABORT, txnMetadata.state)
assertEquals(producerId, txnMetadata.producerId)
assertEquals(lastProducerEpoch, txnMetadata.lastProducerEpoch)
assertEquals(producerEpoch, txnMetadata.producerEpoch)
@ -407,7 +411,7 @@ class TransactionMetadataTest {
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = Ongoing,
state = TransactionState.ONGOING,
topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
@ -416,12 +420,12 @@ class TransactionMetadataTest {
val fencingTransitMetadata = txnMetadata.prepareFenceProducerEpoch()
assertEquals(Short.MaxValue, fencingTransitMetadata.producerEpoch)
assertEquals(RecordBatch.NO_PRODUCER_EPOCH, fencingTransitMetadata.lastProducerEpoch)
assertEquals(Some(PrepareEpochFence), txnMetadata.pendingState)
assertEquals(Some(TransactionState.PREPARE_EPOCH_FENCE), txnMetadata.pendingState)
// We should reset the pending state to make way for the abort transition.
txnMetadata.pendingState = None
val transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareAbort, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds(), false)
val transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_ABORT, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds(), false)
txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(producerId, transitMetadata.producerId)
}
@ -438,7 +442,7 @@ class TransactionMetadataTest {
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = CompleteCommit,
state = TransactionState.COMPLETE_COMMIT,
topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
@ -459,7 +463,7 @@ class TransactionMetadataTest {
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = CompleteAbort,
state = TransactionState.COMPLETE_ABORT,
topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
@ -480,7 +484,7 @@ class TransactionMetadataTest {
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = Ongoing,
state = TransactionState.ONGOING,
topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
@ -500,7 +504,7 @@ class TransactionMetadataTest {
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = Empty,
state = TransactionState.EMPTY,
topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
@ -527,13 +531,13 @@ class TransactionMetadataTest {
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = Ongoing,
state = TransactionState.ONGOING,
topicPartitions = mutable.Set.empty,
txnStartTimestamp = time.milliseconds(),
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_2)
var transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareCommit, TV_2, RecordBatch.NO_PRODUCER_ID, time.milliseconds() - 1, false)
var transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_COMMIT, TV_2, RecordBatch.NO_PRODUCER_ID, time.milliseconds() - 1, false)
txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(producerId, txnMetadata.producerId)
assertEquals((producerEpoch + 1).toShort, txnMetadata.producerEpoch)
@ -559,7 +563,7 @@ class TransactionMetadataTest {
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = Ongoing,
state = TransactionState.ONGOING,
topicPartitions = mutable.Set.empty,
txnStartTimestamp = time.milliseconds(),
txnLastUpdateTimestamp = time.milliseconds(),
@ -567,7 +571,7 @@ class TransactionMetadataTest {
assertTrue(txnMetadata.isProducerEpochExhausted)
val newProducerId = 9893L
var transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareCommit, TV_2, newProducerId, time.milliseconds() - 1, false)
var transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_COMMIT, TV_2, newProducerId, time.milliseconds() - 1, false)
txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(producerId, txnMetadata.producerId)
assertEquals(Short.MaxValue, txnMetadata.producerEpoch)
@ -584,21 +588,21 @@ class TransactionMetadataTest {
@Test
def testRotateProducerIdInOngoingState(): Unit = {
assertThrows(classOf[IllegalStateException], () => testRotateProducerIdInOngoingState(Ongoing, TV_0))
assertThrows(classOf[IllegalStateException], () => testRotateProducerIdInOngoingState(TransactionState.ONGOING, TV_0))
}
@ParameterizedTest
@ValueSource(shorts = Array(0, 2))
def testRotateProducerIdInPrepareAbortState(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
assertThrows(classOf[IllegalStateException], () => testRotateProducerIdInOngoingState(PrepareAbort, clientTransactionVersion))
assertThrows(classOf[IllegalStateException], () => testRotateProducerIdInOngoingState(TransactionState.PREPARE_ABORT, clientTransactionVersion))
}
@ParameterizedTest
@ValueSource(shorts = Array(0, 2))
def testRotateProducerIdInPrepareCommitState(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
assertThrows(classOf[IllegalStateException], () => testRotateProducerIdInOngoingState(PrepareCommit, clientTransactionVersion))
assertThrows(classOf[IllegalStateException], () => testRotateProducerIdInOngoingState(TransactionState.PREPARE_COMMIT, clientTransactionVersion))
}
@Test
@ -613,7 +617,7 @@ class TransactionMetadataTest {
producerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = Empty,
state = TransactionState.EMPTY,
topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
@ -637,7 +641,7 @@ class TransactionMetadataTest {
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = Empty,
state = TransactionState.EMPTY,
topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
@ -662,7 +666,7 @@ class TransactionMetadataTest {
producerEpoch = producerEpoch,
lastProducerEpoch = lastProducerEpoch,
txnTimeoutMs = 30000,
state = Empty,
state = TransactionState.EMPTY,
topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
@ -687,7 +691,7 @@ class TransactionMetadataTest {
producerEpoch = producerEpoch,
lastProducerEpoch = lastProducerEpoch,
txnTimeoutMs = 30000,
state = Empty,
state = TransactionState.EMPTY,
topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
@ -699,13 +703,13 @@ class TransactionMetadataTest {
@Test
def testTransactionStateIdAndNameMapping(): Unit = {
for (state <- TransactionState.AllStates) {
for (state <- TransactionState.ALL_STATES.asScala) {
assertEquals(state, TransactionState.fromId(state.id))
assertEquals(Some(state), TransactionState.fromName(state.name))
assertEquals(Optional.of(state), TransactionState.fromName(state.stateName))
if (state != Dead) {
val clientTransactionState = org.apache.kafka.clients.admin.TransactionState.parse(state.name)
assertEquals(state.name, clientTransactionState.toString)
if (state != TransactionState.DEAD) {
val clientTransactionState = org.apache.kafka.clients.admin.TransactionState.parse(state.stateName)
assertEquals(state.stateName, clientTransactionState.toString)
assertNotEquals(org.apache.kafka.clients.admin.TransactionState.UNKNOWN, clientTransactionState)
}
}
@ -714,27 +718,27 @@ class TransactionMetadataTest {
@Test
def testAllTransactionStatesAreMapped(): Unit = {
val unmatchedStates = mutable.Set(
Empty,
Ongoing,
PrepareCommit,
PrepareAbort,
CompleteCommit,
CompleteAbort,
PrepareEpochFence,
Dead
TransactionState.EMPTY,
TransactionState.ONGOING,
TransactionState.PREPARE_COMMIT,
TransactionState.PREPARE_ABORT,
TransactionState.COMPLETE_COMMIT,
TransactionState.COMPLETE_ABORT,
TransactionState.PREPARE_EPOCH_FENCE,
TransactionState.DEAD
)
// The exhaustive match is intentional here to ensure that we are
// forced to update the test case if a new state is added.
TransactionState.AllStates.foreach {
case Empty => assertTrue(unmatchedStates.remove(Empty))
case Ongoing => assertTrue(unmatchedStates.remove(Ongoing))
case PrepareCommit => assertTrue(unmatchedStates.remove(PrepareCommit))
case PrepareAbort => assertTrue(unmatchedStates.remove(PrepareAbort))
case CompleteCommit => assertTrue(unmatchedStates.remove(CompleteCommit))
case CompleteAbort => assertTrue(unmatchedStates.remove(CompleteAbort))
case PrepareEpochFence => assertTrue(unmatchedStates.remove(PrepareEpochFence))
case Dead => assertTrue(unmatchedStates.remove(Dead))
TransactionState.ALL_STATES.asScala.foreach {
case TransactionState.EMPTY => assertTrue(unmatchedStates.remove(TransactionState.EMPTY))
case TransactionState.ONGOING => assertTrue(unmatchedStates.remove(TransactionState.ONGOING))
case TransactionState.PREPARE_COMMIT => assertTrue(unmatchedStates.remove(TransactionState.PREPARE_COMMIT))
case TransactionState.PREPARE_ABORT => assertTrue(unmatchedStates.remove(TransactionState.PREPARE_ABORT))
case TransactionState.COMPLETE_COMMIT => assertTrue(unmatchedStates.remove(TransactionState.COMPLETE_COMMIT))
case TransactionState.COMPLETE_ABORT => assertTrue(unmatchedStates.remove(TransactionState.COMPLETE_ABORT))
case TransactionState.PREPARE_EPOCH_FENCE => assertTrue(unmatchedStates.remove(TransactionState.PREPARE_EPOCH_FENCE))
case TransactionState.DEAD => assertTrue(unmatchedStates.remove(TransactionState.DEAD))
}
assertEquals(Set.empty, unmatchedStates)

View File

@ -33,6 +33,7 @@ import org.apache.kafka.common.record._
import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse
import org.apache.kafka.common.requests.TransactionResult
import org.apache.kafka.common.utils.MockTime
import org.apache.kafka.coordinator.transaction.TransactionState
import org.apache.kafka.metadata.MetadataCache
import org.apache.kafka.server.common.{FinalizedFeatures, MetadataVersion, RequestLocal, TransactionVersion}
import org.apache.kafka.server.common.TransactionVersion.{TV_0, TV_2}
@ -181,7 +182,7 @@ class TransactionStateManagerTest {
).thenReturn(new FetchDataInfo(new LogOffsetMetadata(startOffset), fileRecordsMock))
when(replicaManager.getLogEndOffset(topicPartition)).thenReturn(Some(endOffset))
txnMetadata1.state = PrepareCommit
txnMetadata1.state = TransactionState.PREPARE_COMMIT
txnMetadata1.addPartitions(Set[TopicPartition](
new TopicPartition("topic1", 0),
new TopicPartition("topic1", 1)))
@ -240,7 +241,7 @@ class TransactionStateManagerTest {
).thenReturn(new FetchDataInfo(new LogOffsetMetadata(startOffset), fileRecordsMock))
when(replicaManager.getLogEndOffset(topicPartition)).thenReturn(Some(endOffset))
txnMetadata1.state = PrepareCommit
txnMetadata1.state = TransactionState.PREPARE_COMMIT
txnMetadata1.addPartitions(Set[TopicPartition](
new TopicPartition("topic1", 0),
new TopicPartition("topic1", 1)))
@ -285,7 +286,7 @@ class TransactionStateManagerTest {
// generate transaction log messages for two pids traces:
// pid1's transaction started with two partitions
txnMetadata1.state = Ongoing
txnMetadata1.state = TransactionState.ONGOING
txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0),
new TopicPartition("topic1", 1)))
@ -299,12 +300,12 @@ class TransactionStateManagerTest {
txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2))
// pid1's transaction is preparing to commit
txnMetadata1.state = PrepareCommit
txnMetadata1.state = TransactionState.PREPARE_COMMIT
txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2))
// pid2's transaction started with three partitions
txnMetadata2.state = Ongoing
txnMetadata2.state = TransactionState.ONGOING
txnMetadata2.addPartitions(Set[TopicPartition](new TopicPartition("topic3", 0),
new TopicPartition("topic3", 1),
new TopicPartition("topic3", 2)))
@ -312,17 +313,17 @@ class TransactionStateManagerTest {
txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), TV_2))
// pid2's transaction is preparing to abort
txnMetadata2.state = PrepareAbort
txnMetadata2.state = TransactionState.PREPARE_ABORT
txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), TV_2))
// pid2's transaction has aborted
txnMetadata2.state = CompleteAbort
txnMetadata2.state = TransactionState.COMPLETE_ABORT
txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), TV_2))
// pid2's epoch has advanced, with no ongoing transaction yet
txnMetadata2.state = Empty
txnMetadata2.state = TransactionState.EMPTY
txnMetadata2.topicPartitions.clear()
txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), TV_2))
@ -511,7 +512,7 @@ class TransactionStateManagerTest {
prepareForTxnMessageAppend(Errors.UNKNOWN_TOPIC_OR_PARTITION)
transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback, _ => true, RequestLocal.withThreadConfinedCaching)
assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))), transactionManager.getTransactionState(transactionalId1))
assertEquals(Some(Ongoing), txnMetadata1.pendingState)
assertEquals(Some(TransactionState.ONGOING), txnMetadata1.pendingState)
}
@Test
@ -591,25 +592,25 @@ class TransactionStateManagerTest {
}
}
putTransaction(transactionalId = "t0", producerId = 0, state = Ongoing)
putTransaction(transactionalId = "t1", producerId = 1, state = Ongoing)
putTransaction(transactionalId = "my-special-0", producerId = 0, state = Ongoing)
putTransaction(transactionalId = "t0", producerId = 0, state = TransactionState.ONGOING)
putTransaction(transactionalId = "t1", producerId = 1, state = TransactionState.ONGOING)
putTransaction(transactionalId = "my-special-0", producerId = 0, state = TransactionState.ONGOING)
// update time to create transactions with various durations
time.sleep(1000)
putTransaction(transactionalId = "t2", producerId = 2, state = PrepareCommit)
putTransaction(transactionalId = "t3", producerId = 3, state = PrepareAbort)
putTransaction(transactionalId = "your-special-1", producerId = 0, state = PrepareAbort)
putTransaction(transactionalId = "t2", producerId = 2, state = TransactionState.PREPARE_COMMIT)
putTransaction(transactionalId = "t3", producerId = 3, state = TransactionState.PREPARE_ABORT)
putTransaction(transactionalId = "your-special-1", producerId = 0, state = TransactionState.PREPARE_ABORT)
time.sleep(1000)
putTransaction(transactionalId = "t4", producerId = 4, state = CompleteCommit)
putTransaction(transactionalId = "t5", producerId = 5, state = CompleteAbort)
putTransaction(transactionalId = "t6", producerId = 6, state = CompleteAbort)
putTransaction(transactionalId = "t7", producerId = 7, state = PrepareEpochFence)
putTransaction(transactionalId = "their-special-2", producerId = 7, state = CompleteAbort)
putTransaction(transactionalId = "t4", producerId = 4, state = TransactionState.COMPLETE_COMMIT)
putTransaction(transactionalId = "t5", producerId = 5, state = TransactionState.COMPLETE_ABORT)
putTransaction(transactionalId = "t6", producerId = 6, state = TransactionState.COMPLETE_ABORT)
putTransaction(transactionalId = "t7", producerId = 7, state = TransactionState.PREPARE_EPOCH_FENCE)
putTransaction(transactionalId = "their-special-2", producerId = 7, state = TransactionState.COMPLETE_ABORT)
time.sleep(1000)
// Note that `Dead` transactions are never returned. This is a transient state
// Note that `TransactionState.DEAD` transactions are never returned. This is a transient state
// which is used when the transaction state is in the process of being deleted
// (whether though expiration or coordinator unloading).
putTransaction(transactionalId = "t8", producerId = 8, state = Dead)
putTransaction(transactionalId = "t8", producerId = 8, state = TransactionState.DEAD)
def assertListTransactions(
expectedTransactionalIds: Set[String],
@ -657,12 +658,12 @@ class TransactionStateManagerTest {
transactionManager.addLoadedTransactionsToCache(partitionId, 0, new ConcurrentHashMap[String, TransactionMetadata]())
}
transactionManager.putTransactionStateIfNotExists(transactionMetadata("ongoing", producerId = 0, state = Ongoing))
transactionManager.putTransactionStateIfNotExists(transactionMetadata("not-expiring", producerId = 1, state = Ongoing, txnTimeout = 10000))
transactionManager.putTransactionStateIfNotExists(transactionMetadata("prepare-commit", producerId = 2, state = PrepareCommit))
transactionManager.putTransactionStateIfNotExists(transactionMetadata("prepare-abort", producerId = 3, state = PrepareAbort))
transactionManager.putTransactionStateIfNotExists(transactionMetadata("complete-commit", producerId = 4, state = CompleteCommit))
transactionManager.putTransactionStateIfNotExists(transactionMetadata("complete-abort", producerId = 5, state = CompleteAbort))
transactionManager.putTransactionStateIfNotExists(transactionMetadata("ongoing", producerId = 0, state = TransactionState.ONGOING))
transactionManager.putTransactionStateIfNotExists(transactionMetadata("not-expiring", producerId = 1, state = TransactionState.ONGOING, txnTimeout = 10000))
transactionManager.putTransactionStateIfNotExists(transactionMetadata("prepare-commit", producerId = 2, state = TransactionState.PREPARE_COMMIT))
transactionManager.putTransactionStateIfNotExists(transactionMetadata("prepare-abort", producerId = 3, state = TransactionState.PREPARE_ABORT))
transactionManager.putTransactionStateIfNotExists(transactionMetadata("complete-commit", producerId = 4, state = TransactionState.COMPLETE_COMMIT))
transactionManager.putTransactionStateIfNotExists(transactionMetadata("complete-abort", producerId = 5, state = TransactionState.COMPLETE_ABORT))
time.sleep(2000)
val expiring = transactionManager.timedOutTransactions()
@ -671,59 +672,59 @@ class TransactionStateManagerTest {
@Test
def shouldWriteTxnMarkersForTransactionInPreparedCommitState(): Unit = {
verifyWritesTxnMarkersInPrepareState(PrepareCommit)
verifyWritesTxnMarkersInPrepareState(TransactionState.PREPARE_COMMIT)
}
@Test
def shouldWriteTxnMarkersForTransactionInPreparedAbortState(): Unit = {
verifyWritesTxnMarkersInPrepareState(PrepareAbort)
verifyWritesTxnMarkersInPrepareState(TransactionState.PREPARE_ABORT)
}
@Test
def shouldRemoveCompleteCommitExpiredTransactionalIds(): Unit = {
setupAndRunTransactionalIdExpiration(Errors.NONE, CompleteCommit)
setupAndRunTransactionalIdExpiration(Errors.NONE, TransactionState.COMPLETE_COMMIT)
verifyMetadataDoesntExist(transactionalId1)
verifyMetadataDoesExistAndIsUsable(transactionalId2)
}
@Test
def shouldRemoveCompleteAbortExpiredTransactionalIds(): Unit = {
setupAndRunTransactionalIdExpiration(Errors.NONE, CompleteAbort)
setupAndRunTransactionalIdExpiration(Errors.NONE, TransactionState.COMPLETE_ABORT)
verifyMetadataDoesntExist(transactionalId1)
verifyMetadataDoesExistAndIsUsable(transactionalId2)
}
@Test
def shouldRemoveEmptyExpiredTransactionalIds(): Unit = {
setupAndRunTransactionalIdExpiration(Errors.NONE, Empty)
setupAndRunTransactionalIdExpiration(Errors.NONE, TransactionState.EMPTY)
verifyMetadataDoesntExist(transactionalId1)
verifyMetadataDoesExistAndIsUsable(transactionalId2)
}
@Test
def shouldNotRemoveExpiredTransactionalIdsIfLogAppendFails(): Unit = {
setupAndRunTransactionalIdExpiration(Errors.NOT_ENOUGH_REPLICAS, CompleteAbort)
setupAndRunTransactionalIdExpiration(Errors.NOT_ENOUGH_REPLICAS, TransactionState.COMPLETE_ABORT)
verifyMetadataDoesExistAndIsUsable(transactionalId1)
verifyMetadataDoesExistAndIsUsable(transactionalId2)
}
@Test
def shouldNotRemoveOngoingTransactionalIds(): Unit = {
setupAndRunTransactionalIdExpiration(Errors.NONE, Ongoing)
setupAndRunTransactionalIdExpiration(Errors.NONE, TransactionState.ONGOING)
verifyMetadataDoesExistAndIsUsable(transactionalId1)
verifyMetadataDoesExistAndIsUsable(transactionalId2)
}
@Test
def shouldNotRemovePrepareAbortTransactionalIds(): Unit = {
setupAndRunTransactionalIdExpiration(Errors.NONE, PrepareAbort)
setupAndRunTransactionalIdExpiration(Errors.NONE, TransactionState.PREPARE_ABORT)
verifyMetadataDoesExistAndIsUsable(transactionalId1)
verifyMetadataDoesExistAndIsUsable(transactionalId2)
}
@Test
def shouldNotRemovePrepareCommitTransactionalIds(): Unit = {
setupAndRunTransactionalIdExpiration(Errors.NONE, PrepareCommit)
setupAndRunTransactionalIdExpiration(Errors.NONE, TransactionState.PREPARE_COMMIT)
verifyMetadataDoesExistAndIsUsable(transactionalId1)
verifyMetadataDoesExistAndIsUsable(transactionalId2)
}
@ -879,7 +880,7 @@ class TransactionStateManagerTest {
// will be expired and it should succeed.
val timestamp = time.milliseconds()
val txnMetadata = new TransactionMetadata(transactionalId, 1, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH,
RecordBatch.NO_PRODUCER_EPOCH, transactionTimeoutMs, Empty, collection.mutable.Set.empty[TopicPartition], timestamp, timestamp, TV_0)
RecordBatch.NO_PRODUCER_EPOCH, transactionTimeoutMs, TransactionState.EMPTY, collection.mutable.Set.empty[TopicPartition], timestamp, timestamp, TV_0)
transactionManager.putTransactionStateIfNotExists(txnMetadata)
time.sleep(txnConfig.transactionalIdExpirationMs + 1)
@ -966,7 +967,7 @@ class TransactionStateManagerTest {
@Test
def testSuccessfulReimmigration(): Unit = {
txnMetadata1.state = PrepareCommit
txnMetadata1.state = TransactionState.PREPARE_COMMIT
txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0),
new TopicPartition("topic1", 1)))
@ -1033,9 +1034,9 @@ class TransactionStateManagerTest {
@Test
def testLoadTransactionMetadataContainingSegmentEndingWithEmptyBatch(): Unit = {
// Simulate a case where a log contains two segments and the first segment ending with an empty batch.
txnMetadata1.state = PrepareCommit
txnMetadata1.state = TransactionState.PREPARE_COMMIT
txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0)))
txnMetadata2.state = Ongoing
txnMetadata2.state = TransactionState.ONGOING
txnMetadata2.addPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)))
// Create the first segment which contains two batches.
@ -1176,7 +1177,7 @@ class TransactionStateManagerTest {
transactionManager.removeExpiredTransactionalIds()
val stateAllowsExpiration = txnState match {
case Empty | CompleteCommit | CompleteAbort => true
case TransactionState.EMPTY | TransactionState.COMPLETE_COMMIT | TransactionState.COMPLETE_ABORT => true
case _ => false
}
@ -1223,7 +1224,7 @@ class TransactionStateManagerTest {
private def transactionMetadata(transactionalId: String,
producerId: Long,
state: TransactionState = Empty,
state: TransactionState = TransactionState.EMPTY,
txnTimeout: Int = transactionTimeoutMs): TransactionMetadata = {
val timestamp = time.milliseconds()
new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_ID, 0.toShort,
@ -1300,7 +1301,7 @@ class TransactionStateManagerTest {
assertEquals(Double.NaN, partitionLoadTime("partition-load-time-avg"), 0)
assertTrue(reporter.containsMbean(mBeanName))
txnMetadata1.state = Ongoing
txnMetadata1.state = TransactionState.ONGOING
txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 1),
new TopicPartition("topic1", 1)))
@ -1319,7 +1320,7 @@ class TransactionStateManagerTest {
@Test
def testIgnoreUnknownRecordType(): Unit = {
txnMetadata1.state = PrepareCommit
txnMetadata1.state = TransactionState.PREPARE_COMMIT
txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0),
new TopicPartition("topic1", 1)))

View File

@ -0,0 +1,164 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.coordinator.transaction;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
/**
* Represents the states of a transaction in the transaction coordinator.
* This enum corresponds to the Scala sealed trait TransactionState in kafka.coordinator.transaction.
*/
public enum TransactionState {
/**
* Transaction has not existed yet
* <p>
* transition: received AddPartitionsToTxnRequest => Ongoing
* received AddOffsetsToTxnRequest => Ongoing
* received EndTxnRequest with abort and TransactionV2 enabled => PrepareAbort
*/
EMPTY((byte) 0, org.apache.kafka.clients.admin.TransactionState.EMPTY.toString(), true),
/**
* Transaction has started and ongoing
* <p>
* transition: received EndTxnRequest with commit => PrepareCommit
* received EndTxnRequest with abort => PrepareAbort
* received AddPartitionsToTxnRequest => Ongoing
* received AddOffsetsToTxnRequest => Ongoing
*/
ONGOING((byte) 1, org.apache.kafka.clients.admin.TransactionState.ONGOING.toString(), false),
/**
* Group is preparing to commit
* transition: received acks from all partitions => CompleteCommit
*/
PREPARE_COMMIT((byte) 2, org.apache.kafka.clients.admin.TransactionState.PREPARE_COMMIT.toString(), false),
/**
* Group is preparing to abort
* <p>
* transition: received acks from all partitions => CompleteAbort
* <p>
* Note, In transaction v2, we allow Empty, CompleteCommit, CompleteAbort to transition to PrepareAbort. because the
* client may not know the txn state on the server side, it needs to send endTxn request when uncertain.
*/
PREPARE_ABORT((byte) 3, org.apache.kafka.clients.admin.TransactionState.PREPARE_ABORT.toString(), false),
/**
* Group has completed commit
* <p>
* Will soon be removed from the ongoing transaction cache
*/
COMPLETE_COMMIT((byte) 4, org.apache.kafka.clients.admin.TransactionState.COMPLETE_COMMIT.toString(), true),
/**
* Group has completed abort
* <p>
* Will soon be removed from the ongoing transaction cache
*/
COMPLETE_ABORT((byte) 5, org.apache.kafka.clients.admin.TransactionState.COMPLETE_ABORT.toString(), true),
/**
* TransactionalId has expired and is about to be removed from the transaction cache
*/
DEAD((byte) 6, "Dead", false),
/**
* We are in the middle of bumping the epoch and fencing out older producers.
*/
PREPARE_EPOCH_FENCE((byte) 7, org.apache.kafka.clients.admin.TransactionState.PREPARE_EPOCH_FENCE.toString(), false);
private static final Map<String, TransactionState> NAME_TO_ENUM = Arrays.stream(values())
.collect(Collectors.toUnmodifiableMap(TransactionState::stateName, Function.identity()));
private static final Map<Byte, TransactionState> ID_TO_ENUM = Arrays.stream(values())
.collect(Collectors.toUnmodifiableMap(TransactionState::id, Function.identity()));
public static final Set<TransactionState> ALL_STATES = Set.copyOf(EnumSet.allOf(TransactionState.class));
private final byte id;
private final String stateName;
public static final Map<TransactionState, Set<TransactionState>> VALID_PREVIOUS_STATES = Map.of(
EMPTY, Set.of(EMPTY, COMPLETE_COMMIT, COMPLETE_ABORT),
ONGOING, Set.of(ONGOING, EMPTY, COMPLETE_COMMIT, COMPLETE_ABORT),
PREPARE_COMMIT, Set.of(ONGOING),
PREPARE_ABORT, Set.of(ONGOING, PREPARE_EPOCH_FENCE, EMPTY, COMPLETE_COMMIT, COMPLETE_ABORT),
COMPLETE_COMMIT, Set.of(PREPARE_COMMIT),
COMPLETE_ABORT, Set.of(PREPARE_ABORT),
DEAD, Set.of(EMPTY, COMPLETE_ABORT, COMPLETE_COMMIT),
PREPARE_EPOCH_FENCE, Set.of(ONGOING)
);
private final boolean expirationAllowed;
TransactionState(byte id, String name, boolean expirationAllowed) {
this.id = id;
this.stateName = name;
this.expirationAllowed = expirationAllowed;
}
/**
* @return The state id byte.
*/
public byte id() {
return id;
}
/**
* Get the name of this state. This is exposed through the `DescribeTransactions` API.
* @return The state name string.
*/
public String stateName() {
return stateName;
}
/**
* @return The set of states from which it is valid to transition into this state.
*/
public Set<TransactionState> validPreviousStates() {
return VALID_PREVIOUS_STATES.getOrDefault(this, Set.of());
}
/**
* @return True if expiration is allowed in this state, false otherwise.
*/
public boolean isExpirationAllowed() {
return expirationAllowed;
}
/**
* Finds a TransactionState by its name.
* @param name The name of the state.
* @return An Optional containing the TransactionState if found, otherwise empty.
*/
public static Optional<TransactionState> fromName(String name) {
return Optional.ofNullable(NAME_TO_ENUM.get(name));
}
/**
* Finds a TransactionState by its ID.
* @param id The byte ID of the state.
* @return The TransactionState corresponding to the ID.
* @throws IllegalStateException if the ID is unknown.
*/
public static TransactionState fromId(byte id) {
TransactionState state = ID_TO_ENUM.get(id);
if (state == null) {
throw new IllegalStateException("Unknown transaction state id " + id + " from the transaction status message");
}
return state;
}
}