diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala index 14802861814..a200cd3d6ae 100644 --- a/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala +++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala @@ -27,7 +27,7 @@ import org.apache.kafka.common.protocol.Errors import org.apache.kafka.common.record.RecordBatch import org.apache.kafka.common.requests.{AddPartitionsToTxnResponse, TransactionResult} import org.apache.kafka.common.utils.{LogContext, ProducerIdAndEpoch, Time} -import org.apache.kafka.coordinator.transaction.{ProducerIdManager, TransactionLogConfig, TransactionStateManagerConfig} +import org.apache.kafka.coordinator.transaction.{ProducerIdManager, TransactionLogConfig, TransactionState, TransactionStateManagerConfig} import org.apache.kafka.metadata.MetadataCache import org.apache.kafka.server.common.{RequestLocal, TransactionVersion} import org.apache.kafka.server.util.Scheduler @@ -154,7 +154,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, producerEpoch = RecordBatch.NO_PRODUCER_EPOCH, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = resolvedTxnTimeoutMs, - state = Empty, + state = TransactionState.EMPTY, topicPartitions = collection.mutable.Set.empty[TopicPartition], txnLastUpdateTimestamp = time.milliseconds(), clientTransactionVersion = TransactionVersion.TV_0) @@ -182,7 +182,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, responseCallback(initTransactionError(error)) case Right((coordinatorEpoch, newMetadata)) => - if (newMetadata.txnState == PrepareEpochFence) { + if (newMetadata.txnState == TransactionState.PREPARE_EPOCH_FENCE) { // abort the ongoing transaction and then return CONCURRENT_TRANSACTIONS to let client wait and retry def sendRetriableErrorCallback(error: Errors, newProducerId: Long, newProducerEpoch: Short): Unit = { if (error != Errors.NONE) { @@ -249,11 +249,11 @@ class TransactionCoordinator(txnConfig: TransactionConfig, } else { // caller should have synchronized on txnMetadata already txnMetadata.state match { - case PrepareAbort | PrepareCommit => + case TransactionState.PREPARE_ABORT | TransactionState.PREPARE_COMMIT => // reply to client and let it backoff and retry Left(Errors.CONCURRENT_TRANSACTIONS) - case CompleteAbort | CompleteCommit | Empty => + case TransactionState.COMPLETE_ABORT | TransactionState.COMPLETE_COMMIT | TransactionState.EMPTY => val transitMetadataResult = // If the epoch is exhausted and the expected epoch (if provided) matches it, generate a new producer ID if (txnMetadata.isProducerEpochExhausted && @@ -274,7 +274,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, case Left(err) => Left(err) } - case Ongoing => + case TransactionState.ONGOING => // indicate to abort the current ongoing txn first. Note that this epoch is never returned to the // user. We will abort the ongoing transaction and return CONCURRENT_TRANSACTIONS to the client. // This forces the client to retry, which will ensure that the epoch is bumped a second time. In @@ -282,7 +282,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, // then when the client retries, we will generate a new producerId. Right(coordinatorEpoch, txnMetadata.prepareFenceProducerEpoch()) - case Dead | PrepareEpochFence => + case TransactionState.DEAD | TransactionState.PREPARE_EPOCH_FENCE => val errorMsg = s"Found transactionalId $transactionalId with state ${txnMetadata.state}. " + s"This is illegal as we should never have transitioned to this state." fatal(errorMsg) @@ -327,7 +327,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, case Right(Some(coordinatorEpochAndMetadata)) => val txnMetadata = coordinatorEpochAndMetadata.transactionMetadata txnMetadata.inLock { - if (txnMetadata.state == Dead) { + if (txnMetadata.state == TransactionState.DEAD) { // The transaction state is being expired, so ignore it transactionState.setErrorCode(Errors.TRANSACTIONAL_ID_NOT_FOUND.code) } else { @@ -345,7 +345,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, .setErrorCode(Errors.NONE.code) .setProducerId(txnMetadata.producerId) .setProducerEpoch(txnMetadata.producerEpoch) - .setTransactionState(txnMetadata.state.name) + .setTransactionState(txnMetadata.state.stateName) .setTransactionTimeoutMs(txnMetadata.txnTimeoutMs) .setTransactionStartTimeMs(txnMetadata.txnStartTimestamp) } @@ -378,7 +378,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, Left(Errors.INVALID_PRODUCER_ID_MAPPING) } else if (txnMetadata.producerEpoch != producerEpoch) { Left(Errors.PRODUCER_FENCED) - } else if (txnMetadata.state == PrepareCommit || txnMetadata.state == PrepareAbort) { + } else if (txnMetadata.state == TransactionState.PREPARE_COMMIT || txnMetadata.state == TransactionState.PREPARE_ABORT) { Left(Errors.CONCURRENT_TRANSACTIONS) } else { Right(partitions.map { part => @@ -435,9 +435,9 @@ class TransactionCoordinator(txnConfig: TransactionConfig, Left(Errors.INVALID_PRODUCER_ID_MAPPING) } else if (txnMetadata.producerEpoch != producerEpoch) { Left(Errors.PRODUCER_FENCED) - } else if (txnMetadata.state == PrepareCommit || txnMetadata.state == PrepareAbort) { + } else if (txnMetadata.state == TransactionState.PREPARE_COMMIT || txnMetadata.state == TransactionState.PREPARE_ABORT) { Left(Errors.CONCURRENT_TRANSACTIONS) - } else if (txnMetadata.state == Ongoing && partitions.subsetOf(txnMetadata.topicPartitions)) { + } else if (txnMetadata.state == TransactionState.ONGOING && partitions.subsetOf(txnMetadata.topicPartitions)) { // this is an optimization: if the partitions are already in the metadata reply OK immediately Left(Errors.NONE) } else { @@ -555,16 +555,16 @@ class TransactionCoordinator(txnConfig: TransactionConfig, // Strict equality is enforced on the client side requests, as they shouldn't bump the producer epoch. else if ((isFromClient && producerEpoch != txnMetadata.producerEpoch) || producerEpoch < txnMetadata.producerEpoch) Left(Errors.PRODUCER_FENCED) - else if (txnMetadata.pendingTransitionInProgress && txnMetadata.pendingState.get != PrepareEpochFence) + else if (txnMetadata.pendingTransitionInProgress && txnMetadata.pendingState.get != TransactionState.PREPARE_EPOCH_FENCE) Left(Errors.CONCURRENT_TRANSACTIONS) else txnMetadata.state match { - case Ongoing => + case TransactionState.ONGOING => val nextState = if (txnMarkerResult == TransactionResult.COMMIT) - PrepareCommit + TransactionState.PREPARE_COMMIT else - PrepareAbort + TransactionState.PREPARE_ABORT - if (nextState == PrepareAbort && txnMetadata.pendingState.contains(PrepareEpochFence)) { + if (nextState == TransactionState.PREPARE_ABORT && txnMetadata.pendingState.contains(TransactionState.PREPARE_EPOCH_FENCE)) { // We should clear the pending state to make way for the transition to PrepareAbort and also bump // the epoch in the transaction metadata we are about to append. isEpochFence = true @@ -574,29 +574,29 @@ class TransactionCoordinator(txnConfig: TransactionConfig, } Right(coordinatorEpoch, txnMetadata.prepareAbortOrCommit(nextState, TransactionVersion.fromFeatureLevel(0), RecordBatch.NO_PRODUCER_ID, time.milliseconds(), false)) - case CompleteCommit => + case TransactionState.COMPLETE_COMMIT => if (txnMarkerResult == TransactionResult.COMMIT) Left(Errors.NONE) else logInvalidStateTransitionAndReturnError(transactionalId, txnMetadata.state, txnMarkerResult) - case CompleteAbort => + case TransactionState.COMPLETE_ABORT => if (txnMarkerResult == TransactionResult.ABORT) Left(Errors.NONE) else logInvalidStateTransitionAndReturnError(transactionalId, txnMetadata.state, txnMarkerResult) - case PrepareCommit => + case TransactionState.PREPARE_COMMIT => if (txnMarkerResult == TransactionResult.COMMIT) Left(Errors.CONCURRENT_TRANSACTIONS) else logInvalidStateTransitionAndReturnError(transactionalId, txnMetadata.state, txnMarkerResult) - case PrepareAbort => + case TransactionState.PREPARE_ABORT => if (txnMarkerResult == TransactionResult.ABORT) Left(Errors.CONCURRENT_TRANSACTIONS) else logInvalidStateTransitionAndReturnError(transactionalId, txnMetadata.state, txnMarkerResult) - case Empty => + case TransactionState.EMPTY => logInvalidStateTransitionAndReturnError(transactionalId, txnMetadata.state, txnMarkerResult) - case Dead | PrepareEpochFence => + case TransactionState.DEAD | TransactionState.PREPARE_EPOCH_FENCE => val errorMsg = s"Found transactionalId $transactionalId with state ${txnMetadata.state}. " + s"This is illegal as we should never have transitioned to this state." fatal(errorMsg) @@ -631,19 +631,19 @@ class TransactionCoordinator(txnConfig: TransactionConfig, else if (txnMetadata.pendingTransitionInProgress) Left(Errors.CONCURRENT_TRANSACTIONS) else txnMetadata.state match { - case Empty| Ongoing | CompleteCommit | CompleteAbort => + case TransactionState.EMPTY| TransactionState.ONGOING | TransactionState.COMPLETE_COMMIT | TransactionState.COMPLETE_ABORT => logInvalidStateTransitionAndReturnError(transactionalId, txnMetadata.state, txnMarkerResult) - case PrepareCommit => + case TransactionState.PREPARE_COMMIT => if (txnMarkerResult != TransactionResult.COMMIT) logInvalidStateTransitionAndReturnError(transactionalId, txnMetadata.state, txnMarkerResult) else Right(txnMetadata, txnMetadata.prepareComplete(time.milliseconds())) - case PrepareAbort => + case TransactionState.PREPARE_ABORT => if (txnMarkerResult != TransactionResult.ABORT) logInvalidStateTransitionAndReturnError(transactionalId, txnMetadata.state, txnMarkerResult) else Right(txnMetadata, txnMetadata.prepareComplete(time.milliseconds())) - case Dead | PrepareEpochFence => + case TransactionState.DEAD | TransactionState.PREPARE_EPOCH_FENCE => val errorMsg = s"Found transactionalId $transactionalId with state ${txnMetadata.state}. " + s"This is illegal as we should never have transitioned to this state." fatal(errorMsg) @@ -776,7 +776,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, producerEpochCopy = txnMetadata.producerEpoch // PrepareEpochFence has slightly different epoch bumping logic so don't include it here. // Note that, it can only happen when the current state is Ongoing. - isEpochFence = txnMetadata.pendingState.contains(PrepareEpochFence) + isEpochFence = txnMetadata.pendingState.contains(TransactionState.PREPARE_EPOCH_FENCE) // True if the client retried a request that had overflowed the epoch, and a new producer ID is stored in the txnMetadata val retryOnOverflow = !isEpochFence && txnMetadata.prevProducerId == producerId && producerEpoch == Short.MaxValue - 1 && txnMetadata.producerEpoch == 0 @@ -790,11 +790,11 @@ class TransactionCoordinator(txnConfig: TransactionConfig, // Return producer fenced even in the cases where the epoch is higher and could indicate an invalid state transition. // Use the following criteria to determine if a v2 retry is valid: txnMetadata.state match { - case Ongoing | Empty | Dead | PrepareEpochFence => + case TransactionState.ONGOING | TransactionState.EMPTY | TransactionState.DEAD | TransactionState.PREPARE_EPOCH_FENCE => producerEpoch == txnMetadata.producerEpoch - case PrepareCommit | PrepareAbort => + case TransactionState.PREPARE_COMMIT | TransactionState.PREPARE_ABORT => retryOnEpochBump - case CompleteCommit | CompleteAbort => + case TransactionState.COMPLETE_COMMIT | TransactionState.COMPLETE_ABORT => retryOnEpochBump || retryOnOverflow || producerEpoch == txnMetadata.producerEpoch } } else { @@ -818,7 +818,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, Right(RecordBatch.NO_PRODUCER_ID) } - if (nextState == PrepareAbort && isEpochFence) { + if (nextState == TransactionState.PREPARE_ABORT && isEpochFence) { // We should clear the pending state to make way for the transition to PrepareAbort and also bump // the epoch in the transaction metadata we are about to append. txnMetadata.pendingState = None @@ -832,7 +832,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, } } - if (txnMetadata.pendingTransitionInProgress && txnMetadata.pendingState.get != PrepareEpochFence) { + if (txnMetadata.pendingTransitionInProgress && txnMetadata.pendingState.get != TransactionState.PREPARE_EPOCH_FENCE) { // This check is performed first so that the pending transition can complete before the next checks. // With TV2, we may be transitioning over a producer epoch overflow, and the producer may be using the // new producer ID that is still only in pending state. @@ -842,14 +842,14 @@ class TransactionCoordinator(txnConfig: TransactionConfig, else if (!isValidEpoch) Left(Errors.PRODUCER_FENCED) else txnMetadata.state match { - case Ongoing => + case TransactionState.ONGOING => val nextState = if (txnMarkerResult == TransactionResult.COMMIT) - PrepareCommit + TransactionState.PREPARE_COMMIT else - PrepareAbort + TransactionState.PREPARE_ABORT generateTxnTransitMetadataForTxnCompletion(nextState, false) - case CompleteCommit => + case TransactionState.COMPLETE_COMMIT => if (txnMarkerResult == TransactionResult.COMMIT) { if (isRetry) Left(Errors.NONE) @@ -860,35 +860,35 @@ class TransactionCoordinator(txnConfig: TransactionConfig, if (isRetry) logInvalidStateTransitionAndReturnError(transactionalId, txnMetadata.state, txnMarkerResult) else - generateTxnTransitMetadataForTxnCompletion(PrepareAbort, true) + generateTxnTransitMetadataForTxnCompletion(TransactionState.PREPARE_ABORT, true) } - case CompleteAbort => + case TransactionState.COMPLETE_ABORT => if (txnMarkerResult == TransactionResult.ABORT) { if (isRetry) Left(Errors.NONE) else - generateTxnTransitMetadataForTxnCompletion(PrepareAbort, true) + generateTxnTransitMetadataForTxnCompletion(TransactionState.PREPARE_ABORT, true) } else { // Commit. logInvalidStateTransitionAndReturnError(transactionalId, txnMetadata.state, txnMarkerResult) } - case PrepareCommit => + case TransactionState.PREPARE_COMMIT => if (txnMarkerResult == TransactionResult.COMMIT) Left(Errors.CONCURRENT_TRANSACTIONS) else logInvalidStateTransitionAndReturnError(transactionalId, txnMetadata.state, txnMarkerResult) - case PrepareAbort => + case TransactionState.PREPARE_ABORT => if (txnMarkerResult == TransactionResult.ABORT) Left(Errors.CONCURRENT_TRANSACTIONS) else logInvalidStateTransitionAndReturnError(transactionalId, txnMetadata.state, txnMarkerResult) - case Empty => + case TransactionState.EMPTY => if (txnMarkerResult == TransactionResult.ABORT) { - generateTxnTransitMetadataForTxnCompletion(PrepareAbort, true) + generateTxnTransitMetadataForTxnCompletion(TransactionState.PREPARE_ABORT, true) } else { logInvalidStateTransitionAndReturnError(transactionalId, txnMetadata.state, txnMarkerResult) } - case Dead | PrepareEpochFence => + case TransactionState.DEAD | TransactionState.PREPARE_EPOCH_FENCE => val errorMsg = s"Found transactionalId $transactionalId with state ${txnMetadata.state}. " + s"This is illegal as we should never have transitioned to this state." fatal(errorMsg) @@ -928,19 +928,19 @@ class TransactionCoordinator(txnConfig: TransactionConfig, else if (txnMetadata.pendingTransitionInProgress) Left(Errors.CONCURRENT_TRANSACTIONS) else txnMetadata.state match { - case Empty| Ongoing | CompleteCommit | CompleteAbort => + case TransactionState.EMPTY | TransactionState.ONGOING | TransactionState.COMPLETE_COMMIT | TransactionState.COMPLETE_ABORT => logInvalidStateTransitionAndReturnError(transactionalId, txnMetadata.state, txnMarkerResult) - case PrepareCommit => + case TransactionState.PREPARE_COMMIT => if (txnMarkerResult != TransactionResult.COMMIT) logInvalidStateTransitionAndReturnError(transactionalId, txnMetadata.state, txnMarkerResult) else Right(txnMetadata, txnMetadata.prepareComplete(time.milliseconds())) - case PrepareAbort => + case TransactionState.PREPARE_ABORT => if (txnMarkerResult != TransactionResult.ABORT) logInvalidStateTransitionAndReturnError(transactionalId, txnMetadata.state, txnMarkerResult) else Right(txnMetadata, txnMetadata.prepareComplete(time.milliseconds())) - case Dead | PrepareEpochFence => + case TransactionState.DEAD | TransactionState.PREPARE_EPOCH_FENCE => val errorMsg = s"Found transactionalId $transactionalId with state ${txnMetadata.state}. " + s"This is illegal as we should never have transitioned to this state." fatal(errorMsg) diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala index 5972418d0c1..631a432d3be 100644 --- a/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala +++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala @@ -21,6 +21,7 @@ import org.apache.kafka.common.compress.Compression import org.apache.kafka.common.protocol.{ByteBufferAccessor, MessageUtil} import org.apache.kafka.common.record.RecordBatch import org.apache.kafka.common.TopicPartition +import org.apache.kafka.coordinator.transaction.TransactionState import org.apache.kafka.coordinator.transaction.generated.{CoordinatorRecordType, TransactionLogKey, TransactionLogValue} import org.apache.kafka.server.common.TransactionVersion @@ -61,10 +62,10 @@ object TransactionLog { */ private[transaction] def valueToBytes(txnMetadata: TxnTransitMetadata, transactionVersionLevel: TransactionVersion): Array[Byte] = { - if (txnMetadata.txnState == Empty && txnMetadata.topicPartitions.nonEmpty) + if (txnMetadata.txnState == TransactionState.EMPTY && txnMetadata.topicPartitions.nonEmpty) throw new IllegalStateException(s"Transaction is not expected to have any partitions since its state is ${txnMetadata.txnState}: $txnMetadata") - val transactionPartitions = if (txnMetadata.txnState == Empty) null + val transactionPartitions = if (txnMetadata.txnState == TransactionState.EMPTY) null else txnMetadata.topicPartitions .groupBy(_.topic) .map { case (topic, partitions) => @@ -127,7 +128,7 @@ object TransactionLog { txnLastUpdateTimestamp = value.transactionLastUpdateTimestampMs, clientTransactionVersion = TransactionVersion.fromFeatureLevel(value.clientTransactionVersion)) - if (!transactionMetadata.state.equals(Empty)) + if (!transactionMetadata.state.equals(TransactionState.EMPTY)) value.transactionPartitions.forEach(partitionsSchema => transactionMetadata.addPartitions(partitionsSchema.partitionIds .asScala diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala index aff68749513..aa8c871b7de 100644 --- a/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala +++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala @@ -21,151 +21,11 @@ import kafka.utils.{CoreUtils, Logging, nonthreadsafe} import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.protocol.Errors import org.apache.kafka.common.record.RecordBatch +import org.apache.kafka.coordinator.transaction.TransactionState import org.apache.kafka.server.common.TransactionVersion import scala.collection.{immutable, mutable} - -object TransactionState { - val AllStates: Set[TransactionState] = Set( - Empty, - Ongoing, - PrepareCommit, - PrepareAbort, - CompleteCommit, - CompleteAbort, - Dead, - PrepareEpochFence - ) - - def fromName(name: String): Option[TransactionState] = { - AllStates.find(_.name == name) - } - - def fromId(id: Byte): TransactionState = { - id match { - case 0 => Empty - case 1 => Ongoing - case 2 => PrepareCommit - case 3 => PrepareAbort - case 4 => CompleteCommit - case 5 => CompleteAbort - case 6 => Dead - case 7 => PrepareEpochFence - case _ => throw new IllegalStateException(s"Unknown transaction state id $id from the transaction status message") - } - } -} - -private[transaction] sealed trait TransactionState { - def id: Byte - - /** - * Get the name of this state. This is exposed through the `DescribeTransactions` API. - */ - def name: String - - def validPreviousStates: Set[TransactionState] - - def isExpirationAllowed: Boolean = false -} - -/** - * Transaction has not existed yet - * - * transition: received AddPartitionsToTxnRequest => Ongoing - * received AddOffsetsToTxnRequest => Ongoing - * received EndTxnRequest with abort and TransactionV2 enabled => PrepareAbort - */ -private[transaction] case object Empty extends TransactionState { - val id: Byte = 0 - val name: String = "Empty" - val validPreviousStates: Set[TransactionState] = Set(Empty, CompleteCommit, CompleteAbort) - override def isExpirationAllowed: Boolean = true -} - -/** - * Transaction has started and ongoing - * - * transition: received EndTxnRequest with commit => PrepareCommit - * received EndTxnRequest with abort => PrepareAbort - * received AddPartitionsToTxnRequest => Ongoing - * received AddOffsetsToTxnRequest => Ongoing - */ -private[transaction] case object Ongoing extends TransactionState { - val id: Byte = 1 - val name: String = "Ongoing" - val validPreviousStates: Set[TransactionState] = Set(Ongoing, Empty, CompleteCommit, CompleteAbort) -} - -/** - * Group is preparing to commit - * - * transition: received acks from all partitions => CompleteCommit - */ -private[transaction] case object PrepareCommit extends TransactionState { - val id: Byte = 2 - val name: String = "PrepareCommit" - val validPreviousStates: Set[TransactionState] = Set(Ongoing) -} - -/** - * Group is preparing to abort - * - * transition: received acks from all partitions => CompleteAbort - * - * Note, In transaction v2, we allow Empty, CompleteCommit, CompleteAbort to transition to PrepareAbort. because the - * client may not know the txn state on the server side, it needs to send endTxn request when uncertain. - */ -private[transaction] case object PrepareAbort extends TransactionState { - val id: Byte = 3 - val name: String = "PrepareAbort" - val validPreviousStates: Set[TransactionState] = Set(Ongoing, PrepareEpochFence, Empty, CompleteCommit, CompleteAbort) -} - -/** - * Group has completed commit - * - * Will soon be removed from the ongoing transaction cache - */ -private[transaction] case object CompleteCommit extends TransactionState { - val id: Byte = 4 - val name: String = "CompleteCommit" - val validPreviousStates: Set[TransactionState] = Set(PrepareCommit) - override def isExpirationAllowed: Boolean = true -} - -/** - * Group has completed abort - * - * Will soon be removed from the ongoing transaction cache - */ -private[transaction] case object CompleteAbort extends TransactionState { - val id: Byte = 5 - val name: String = "CompleteAbort" - val validPreviousStates: Set[TransactionState] = Set(PrepareAbort) - override def isExpirationAllowed: Boolean = true -} - -/** - * TransactionalId has expired and is about to be removed from the transaction cache - */ -private[transaction] case object Dead extends TransactionState { - val id: Byte = 6 - val name: String = "Dead" - val validPreviousStates: Set[TransactionState] = Set(Empty, CompleteAbort, CompleteCommit) -} - -/** - * We are in the middle of bumping the epoch and fencing out older producers. - */ - -private[transaction] case object PrepareEpochFence extends TransactionState { - val id: Byte = 7 - val name: String = "PrepareEpochFence" - val validPreviousStates: Set[TransactionState] = Set(Ongoing) -} - private[transaction] object TransactionMetadata { def isEpochExhausted(producerEpoch: Short): Boolean = producerEpoch >= Short.MaxValue - 1 } @@ -244,7 +104,7 @@ private[transaction] class TransactionMetadata(val transactionalId: String, } def removePartition(topicPartition: TopicPartition): Unit = { - if (state != PrepareCommit && state != PrepareAbort) + if (state != TransactionState.PREPARE_COMMIT && state != TransactionState.PREPARE_ABORT) throw new IllegalStateException(s"Transaction metadata's current state is $state, and its pending state is $pendingState " + s"while trying to remove partitions whose txn marker has been sent, this is not expected") @@ -267,7 +127,7 @@ private[transaction] class TransactionMetadata(val transactionalId: String, val bumpedEpoch = if (hasFailedEpochFence) producerEpoch else (producerEpoch + 1).toShort prepareTransitionTo( - state = PrepareEpochFence, + state = TransactionState.PREPARE_EPOCH_FENCE, producerEpoch = bumpedEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH ) @@ -309,7 +169,7 @@ private[transaction] class TransactionMetadata(val transactionalId: String, epochBumpResult match { case Right((nextEpoch, lastEpoch)) => Right(prepareTransitionTo( - state = Empty, + state = TransactionState.EMPTY, producerEpoch = nextEpoch, lastProducerEpoch = lastEpoch, txnTimeoutMs = newTxnTimeoutMs, @@ -330,7 +190,7 @@ private[transaction] class TransactionMetadata(val transactionalId: String, throw new IllegalStateException("Cannot rotate producer ids while a transaction is still pending") prepareTransitionTo( - state = Empty, + state = TransactionState.EMPTY, producerId = newProducerId, producerEpoch = 0, lastProducerEpoch = if (recordLastEpoch) producerEpoch else RecordBatch.NO_PRODUCER_EPOCH, @@ -343,12 +203,12 @@ private[transaction] class TransactionMetadata(val transactionalId: String, def prepareAddPartitions(addedTopicPartitions: immutable.Set[TopicPartition], updateTimestamp: Long, clientTransactionVersion: TransactionVersion): TxnTransitMetadata = { val newTxnStartTimestamp = state match { - case Empty | CompleteAbort | CompleteCommit => updateTimestamp + case TransactionState.EMPTY | TransactionState.COMPLETE_ABORT | TransactionState.COMPLETE_COMMIT => updateTimestamp case _ => txnStartTimestamp } prepareTransitionTo( - state = Ongoing, + state = TransactionState.ONGOING, topicPartitions = (topicPartitions ++ addedTopicPartitions), txnStartTimestamp = newTxnStartTimestamp, txnLastUpdateTimestamp = updateTimestamp, @@ -379,7 +239,7 @@ private[transaction] class TransactionMetadata(val transactionalId: String, } def prepareComplete(updateTimestamp: Long): TxnTransitMetadata = { - val newState = if (state == PrepareCommit) CompleteCommit else CompleteAbort + val newState = if (state == TransactionState.PREPARE_COMMIT) TransactionState.COMPLETE_COMMIT else TransactionState.COMPLETE_ABORT // Since the state change was successfully written to the log, unset the flag for a failed epoch fence hasFailedEpochFence = false @@ -408,7 +268,7 @@ private[transaction] class TransactionMetadata(val transactionalId: String, def prepareDead(): TxnTransitMetadata = { prepareTransitionTo( - state = Dead, + state = TransactionState.DEAD, topicPartitions = mutable.Set.empty[TopicPartition] ) } @@ -427,7 +287,7 @@ private[transaction] class TransactionMetadata(val transactionalId: String, private def hasPendingTransaction: Boolean = { state match { - case Ongoing | PrepareAbort | PrepareCommit => true + case TransactionState.ONGOING | TransactionState.PREPARE_ABORT | TransactionState.PREPARE_COMMIT => true case _ => false } } @@ -452,7 +312,7 @@ private[transaction] class TransactionMetadata(val transactionalId: String, // The epoch is initialized to NO_PRODUCER_EPOCH when the TransactionMetadata // is created for the first time and it could stay like this until transitioning // to Dead. - if (state != Dead && producerEpoch < 0) + if (state != TransactionState.DEAD && producerEpoch < 0) throw new IllegalArgumentException(s"Illegal new producer epoch $producerEpoch") // check that the new state transition is valid and update the pending state if necessary @@ -492,7 +352,7 @@ private[transaction] class TransactionMetadata(val transactionalId: String, throwStateTransitionFailure(transitMetadata) } else { toState match { - case Empty => // from initPid + case TransactionState.EMPTY => // from initPid if ((producerEpoch != transitMetadata.producerEpoch && !validProducerEpochBump(transitMetadata)) || transitMetadata.topicPartitions.nonEmpty || transitMetadata.txnStartTimestamp != -1) { @@ -500,7 +360,7 @@ private[transaction] class TransactionMetadata(val transactionalId: String, throwStateTransitionFailure(transitMetadata) } - case Ongoing => // from addPartitions + case TransactionState.ONGOING => // from addPartitions if (!validProducerEpoch(transitMetadata) || !topicPartitions.subsetOf(transitMetadata.topicPartitions) || txnTimeoutMs != transitMetadata.txnTimeoutMs) { @@ -508,11 +368,11 @@ private[transaction] class TransactionMetadata(val transactionalId: String, throwStateTransitionFailure(transitMetadata) } - case PrepareAbort | PrepareCommit => // from endTxn + case TransactionState.PREPARE_ABORT | TransactionState.PREPARE_COMMIT => // from endTxn // In V2, we allow state transits from Empty, CompleteCommit and CompleteAbort to PrepareAbort. It is possible // their updated start time is not equal to the current start time. - val allowedEmptyAbort = toState == PrepareAbort && transitMetadata.clientTransactionVersion.supportsEpochBump() && - (state == Empty || state == CompleteCommit || state == CompleteAbort) + val allowedEmptyAbort = toState == TransactionState.PREPARE_ABORT && transitMetadata.clientTransactionVersion.supportsEpochBump() && + (state == TransactionState.EMPTY || state == TransactionState.COMPLETE_COMMIT || state == TransactionState.COMPLETE_ABORT) val validTimestamp = txnStartTimestamp == transitMetadata.txnStartTimestamp || allowedEmptyAbort if (!validProducerEpoch(transitMetadata) || !topicPartitions.equals(transitMetadata.topicPartitions) || @@ -521,14 +381,14 @@ private[transaction] class TransactionMetadata(val transactionalId: String, throwStateTransitionFailure(transitMetadata) } - case CompleteAbort | CompleteCommit => // from write markers + case TransactionState.COMPLETE_ABORT | TransactionState.COMPLETE_COMMIT => // from write markers if (!validProducerEpoch(transitMetadata) || txnTimeoutMs != transitMetadata.txnTimeoutMs || transitMetadata.txnStartTimestamp == -1) { throwStateTransitionFailure(transitMetadata) } - case PrepareEpochFence => + case TransactionState.PREPARE_EPOCH_FENCE => // We should never get here, since once we prepare to fence the epoch, we immediately set the pending state // to PrepareAbort, and then consequently to CompleteAbort after the markers are written.. So we should never // ever try to complete a transition to PrepareEpochFence, as it is not a valid previous state for any other state, and hence @@ -536,7 +396,7 @@ private[transaction] class TransactionMetadata(val transactionalId: String, throwStateTransitionFailure(transitMetadata) - case Dead => + case TransactionState.DEAD => // The transactionalId was being expired. The completion of the operation should result in removal of the // the metadata from the cache, so we should never realistically transition to the dead state. throw new IllegalStateException(s"TransactionalId $transactionalId is trying to complete a transition to " + @@ -590,11 +450,11 @@ private[transaction] class TransactionMetadata(val transactionalId: String, val transitLastProducerEpoch = transitMetadata.lastProducerEpoch (isAtLeastTransactionsV2, txnState, transitProducerEpoch) match { - case (true, CompleteCommit | CompleteAbort, epoch) if epoch == 0.toShort => + case (true, TransactionState.COMPLETE_COMMIT | TransactionState.COMPLETE_ABORT, epoch) if epoch == 0.toShort => transitLastProducerEpoch == lastProducerEpoch && transitMetadata.prevProducerId == producerId - case (true, PrepareCommit | PrepareAbort, _) => + case (true, TransactionState.PREPARE_COMMIT | TransactionState.PREPARE_ABORT, _) => transitLastProducerEpoch == producerEpoch && transitProducerId == producerId diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala index 994fb170973..b859ed003ad 100644 --- a/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala +++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala @@ -35,7 +35,7 @@ import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse import org.apache.kafka.common.requests.TransactionResult import org.apache.kafka.common.utils.{Time, Utils} import org.apache.kafka.common.{KafkaException, TopicIdPartition, TopicPartition} -import org.apache.kafka.coordinator.transaction.{TransactionLogConfig, TransactionStateManagerConfig} +import org.apache.kafka.coordinator.transaction.{TransactionLogConfig, TransactionState, TransactionStateManagerConfig} import org.apache.kafka.metadata.MetadataCache import org.apache.kafka.server.common.{RequestLocal, TransactionVersion} import org.apache.kafka.server.config.ServerConfigs @@ -134,7 +134,7 @@ class TransactionStateManager(brokerId: Int, false } else { txnMetadata.state match { - case Ongoing => + case TransactionState.ONGOING => // Do not apply timeout to distributed two phase commit transactions. (!txnMetadata.isDistributedTwoPhaseCommitTxn) && (txnMetadata.txnStartTimestamp + txnMetadata.txnTimeoutMs < now) @@ -265,7 +265,7 @@ class TransactionStateManager(brokerId: Int, val txnMetadata = txnMetadataCacheEntry.metadataPerTransactionalId.get(transactionalId) txnMetadata.inLock { if (txnMetadataCacheEntry.coordinatorEpoch == idCoordinatorEpochAndMetadata.coordinatorEpoch - && txnMetadata.pendingState.contains(Dead) + && txnMetadata.pendingState.contains(TransactionState.DEAD) && txnMetadata.producerEpoch == idCoordinatorEpochAndMetadata.transitMetadata.producerEpoch && response.error == Errors.NONE) { txnMetadataCacheEntry.metadataPerTransactionalId.remove(transactionalId) @@ -328,15 +328,15 @@ class TransactionStateManager(brokerId: Int, } else { val filterStates = mutable.Set.empty[TransactionState] filterStateNames.foreach { stateName => - TransactionState.fromName(stateName) match { - case Some(state) => filterStates += state - case None => response.unknownStateFilters.add(stateName) - } + TransactionState.fromName(stateName).ifPresentOrElse( + state => filterStates += state, + () => response.unknownStateFilters.add(stateName) + ) } val now : Long = time.milliseconds() def shouldInclude(txnMetadata: TransactionMetadata, pattern: Pattern): Boolean = { - if (txnMetadata.state == Dead) { + if (txnMetadata.state == TransactionState.DEAD) { // We filter the `Dead` state since it is a transient state which // indicates that the transactionalId and its metadata are in the // process of expiration and removal. @@ -371,7 +371,7 @@ class TransactionStateManager(brokerId: Int, states.add(new ListTransactionsResponseData.TransactionState() .setTransactionalId(txnMetadata.transactionalId) .setProducerId(txnMetadata.producerId) - .setTransactionState(txnMetadata.state.name) + .setTransactionState(txnMetadata.state.stateName) ) } } @@ -568,10 +568,10 @@ class TransactionStateManager(brokerId: Int, txnMetadata.inLock { // if state is PrepareCommit or PrepareAbort we need to complete the transaction txnMetadata.state match { - case PrepareAbort => + case TransactionState.PREPARE_ABORT => transactionsPendingForCompletion += TransactionalIdCoordinatorEpochAndTransitMetadata(transactionalId, coordinatorEpoch, TransactionResult.ABORT, txnMetadata, txnMetadata.prepareComplete(time.milliseconds())) - case PrepareCommit => + case TransactionState.PREPARE_COMMIT => transactionsPendingForCompletion += TransactionalIdCoordinatorEpochAndTransitMetadata(transactionalId, coordinatorEpoch, TransactionResult.COMMIT, txnMetadata, txnMetadata.prepareComplete(time.milliseconds())) case _ => diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala index 0bfd67d31c3..5db59dd51fe 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala @@ -28,8 +28,8 @@ import kafka.server.KafkaConfig import kafka.utils.TestUtils import org.apache.kafka.clients.{ClientResponse, NetworkClient} import org.apache.kafka.common.internals.Topic -import org.apache.kafka.common.compress.Compression import org.apache.kafka.common.internals.Topic.TRANSACTION_STATE_TOPIC_NAME +import org.apache.kafka.common.compress.Compression import org.apache.kafka.common.metrics.Metrics import org.apache.kafka.common.network.ListenerName import org.apache.kafka.common.protocol.{ApiKeys, Errors} @@ -37,7 +37,7 @@ import org.apache.kafka.common.record.{FileRecords, MemoryRecords, RecordBatch, import org.apache.kafka.common.requests._ import org.apache.kafka.common.utils.{LogContext, MockTime, ProducerIdAndEpoch} import org.apache.kafka.common.{Node, TopicPartition, Uuid} -import org.apache.kafka.coordinator.transaction.ProducerIdManager +import org.apache.kafka.coordinator.transaction.{ProducerIdManager, TransactionState} import org.apache.kafka.metadata.MetadataCache import org.apache.kafka.server.common.{FinalizedFeatures, MetadataVersion, RequestLocal, TransactionVersion} import org.apache.kafka.server.storage.log.FetchIsolation @@ -468,7 +468,7 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren val txnMetadata = transactionMetadata(txn).getOrElse(throw new IllegalStateException(s"Transaction not found $txn")) txnRecords += new SimpleRecord(txn.txnMessageKeyBytes, TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), TransactionVersion.TV_2)) - txnMetadata.state = PrepareCommit + txnMetadata.state = TransactionState.PREPARE_COMMIT txnRecords += new SimpleRecord(txn.txnMessageKeyBytes, TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), TransactionVersion.TV_2)) prepareTxnLog(partitionId) @@ -513,7 +513,7 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren producerEpoch = (Short.MaxValue - 1).toShort, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 60000, - state = Empty, + state = TransactionState.EMPTY, topicPartitions = collection.mutable.Set.empty[TopicPartition], txnLastUpdateTimestamp = time.milliseconds(), clientTransactionVersion = TransactionVersion.TV_0) @@ -544,7 +544,7 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren override def awaitAndVerify(txn: Transaction): Unit = { val initPidResult = result.getOrElse(throw new IllegalStateException("InitProducerId has not completed")) assertEquals(Errors.NONE, initPidResult.error) - verifyTransaction(txn, Empty) + verifyTransaction(txn, TransactionState.EMPTY) } } @@ -564,7 +564,7 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren override def awaitAndVerify(txn: Transaction): Unit = { val error = result.getOrElse(throw new IllegalStateException("AddPartitionsToTransaction has not completed")) assertEquals(Errors.NONE, error) - verifyTransaction(txn, Ongoing) + verifyTransaction(txn, TransactionState.ONGOING) } } @@ -585,7 +585,7 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren if (!txn.ended) { txn.ended = true assertEquals(Errors.NONE, error) - val expectedState = if (transactionResult(txn) == TransactionResult.COMMIT) CompleteCommit else CompleteAbort + val expectedState = if (transactionResult(txn) == TransactionResult.COMMIT) TransactionState.COMPLETE_COMMIT else TransactionState.COMPLETE_ABORT verifyTransaction(txn, expectedState) } else assertEquals(Errors.INVALID_TXN_STATE, error) @@ -606,7 +606,7 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren override def await(): Unit = { allTransactions.foreach { txn => if (txnStateManager.partitionFor(txn.transactionalId) == txnTopicPartitionId) { - verifyTransaction(txn, CompleteCommit) + verifyTransaction(txn, TransactionState.COMPLETE_COMMIT) } } } diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala index 94ccd6dc03d..bad1d6e91f6 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala @@ -22,7 +22,7 @@ import org.apache.kafka.common.protocol.Errors import org.apache.kafka.common.record.RecordBatch import org.apache.kafka.common.requests.{AddPartitionsToTxnResponse, TransactionResult} import org.apache.kafka.common.utils.{LogContext, MockTime, ProducerIdAndEpoch} -import org.apache.kafka.coordinator.transaction.{ProducerIdManager, TransactionStateManagerConfig} +import org.apache.kafka.coordinator.transaction.{ProducerIdManager, TransactionState, TransactionStateManagerConfig} import org.apache.kafka.server.common.TransactionVersion import org.apache.kafka.server.common.TransactionVersion.{TV_0, TV_2} import org.apache.kafka.server.util.MockScheduler @@ -197,7 +197,7 @@ class TransactionCoordinatorTest { initPidGenericMocks(transactionalId) val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, (Short.MaxValue - 1).toShort, - (Short.MaxValue - 2).toShort, txnTimeoutMs, Empty, mutable.Set.empty, time.milliseconds(), time.milliseconds(), TV_0) + (Short.MaxValue - 2).toShort, txnTimeoutMs, TransactionState.EMPTY, mutable.Set.empty, time.milliseconds(), time.milliseconds(), TV_0) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) @@ -229,11 +229,11 @@ class TransactionCoordinatorTest { initPidGenericMocks(transactionalId) val txnMetadata1 = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, (Short.MaxValue - 1).toShort, - (Short.MaxValue - 2).toShort, txnTimeoutMs, Ongoing, mutable.Set.empty, time.milliseconds(), time.milliseconds(), TV_2) - // We start with txnMetadata1 so we can transform the metadata to PrepareCommit. + (Short.MaxValue - 2).toShort, txnTimeoutMs, TransactionState.ONGOING, mutable.Set.empty, time.milliseconds(), time.milliseconds(), TV_2) + // We start with txnMetadata1 so we can transform the metadata to TransactionState.PREPARE_COMMIT. val txnMetadata2 = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, (Short.MaxValue - 1).toShort, - (Short.MaxValue - 2).toShort, txnTimeoutMs, Ongoing, mutable.Set.empty, time.milliseconds(), time.milliseconds(), TV_2) - val transitMetadata = txnMetadata2.prepareAbortOrCommit(PrepareCommit, TV_2, producerId2, time.milliseconds(), false) + (Short.MaxValue - 2).toShort, txnTimeoutMs, TransactionState.ONGOING, mutable.Set.empty, time.milliseconds(), time.milliseconds(), TV_2) + val transitMetadata = txnMetadata2.prepareAbortOrCommit(TransactionState.PREPARE_COMMIT, TV_2, producerId2, time.milliseconds(), false) txnMetadata2.completeTransitionTo(transitMetadata) assertEquals(producerId, txnMetadata2.producerId) @@ -342,7 +342,7 @@ class TransactionCoordinatorTest { } // If producer ID is not the same, return INVALID_PRODUCER_ID_MAPPING val wrongPidTxnMetadata = new TransactionMetadata(transactionalId, 1, 0, RecordBatch.NO_PRODUCER_ID, - 0, RecordBatch.NO_PRODUCER_EPOCH, 0, PrepareCommit, partitions, 0, 0, TV_0) + 0, RecordBatch.NO_PRODUCER_EPOCH, 0, TransactionState.PREPARE_COMMIT, partitions, 0, 0, TV_0) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(new CoordinatorEpochAndTxnMetadata(coordinatorEpoch, wrongPidTxnMetadata)))) @@ -353,7 +353,7 @@ class TransactionCoordinatorTest { // If producer epoch is not equal, return PRODUCER_FENCED val oldEpochTxnMetadata = new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID, - 0, RecordBatch.NO_PRODUCER_EPOCH, 0, PrepareCommit, partitions, 0, 0, TV_0) + 0, RecordBatch.NO_PRODUCER_EPOCH, 0, TransactionState.PREPARE_COMMIT, partitions, 0, 0, TV_0) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(new CoordinatorEpochAndTxnMetadata(coordinatorEpoch, oldEpochTxnMetadata)))) @@ -364,7 +364,7 @@ class TransactionCoordinatorTest { // If the txn state is Prepare or AbortCommit, we return CONCURRENT_TRANSACTIONS val emptyTxnMetadata = new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID, - 0, RecordBatch.NO_PRODUCER_EPOCH, 0, PrepareCommit, partitions, 0, 0, TV_0) + 0, RecordBatch.NO_PRODUCER_EPOCH, 0, TransactionState.PREPARE_COMMIT, partitions, 0, 0, TV_0) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(new CoordinatorEpochAndTxnMetadata(coordinatorEpoch, emptyTxnMetadata)))) @@ -375,8 +375,8 @@ class TransactionCoordinatorTest { // Pending state does not matter, we will just check if the partitions are in the txnMetadata. val ongoingTxnMetadata = new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID, - 0, RecordBatch.NO_PRODUCER_EPOCH, 0, Ongoing, mutable.Set.empty, 0, 0, TV_0) - ongoingTxnMetadata.pendingState = Some(CompleteCommit) + 0, RecordBatch.NO_PRODUCER_EPOCH, 0, TransactionState.ONGOING, mutable.Set.empty, 0, 0, TV_0) + ongoingTxnMetadata.pendingState = Some(TransactionState.COMPLETE_COMMIT) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(new CoordinatorEpochAndTxnMetadata(coordinatorEpoch, ongoingTxnMetadata)))) @@ -388,16 +388,16 @@ class TransactionCoordinatorTest { @Test def shouldRespondWithConcurrentTransactionsOnAddPartitionsWhenStateIsPrepareCommit(): Unit = { - validateConcurrentTransactions(PrepareCommit) + validateConcurrentTransactions(TransactionState.PREPARE_COMMIT) } @Test def shouldRespondWithConcurrentTransactionOnAddPartitionsWhenStateIsPrepareAbort(): Unit = { - validateConcurrentTransactions(PrepareAbort) + validateConcurrentTransactions(TransactionState.PREPARE_ABORT) } def validateConcurrentTransactions(state: TransactionState): Unit = { - // Since the clientTransactionVersion doesn't matter, use 2 since the states are PrepareCommit and PrepareAbort. + // Since the clientTransactionVersion doesn't matter, use 2 since the states are TransactionState.PREPARE_COMMIT and TransactionState.PREPARE_ABORT. when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID, @@ -409,11 +409,11 @@ class TransactionCoordinatorTest { @Test def shouldRespondWithProducerFencedOnAddPartitionsWhenEpochsAreDifferent(): Unit = { - // Since the clientTransactionVersion doesn't matter, use 2 since the state is PrepareCommit. + // Since the clientTransactionVersion doesn't matter, use 2 since the state is TransactionState.PREPARE_COMMIT. when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID, - 10, 9, 0, PrepareCommit, mutable.Set.empty, 0, 0, TV_2))))) + 10, 9, 0, TransactionState.PREPARE_COMMIT, mutable.Set.empty, 0, 0, TV_2))))) coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 0, partitions, errorsCallback, TV_2) assertEquals(Errors.PRODUCER_FENCED, error) @@ -421,24 +421,24 @@ class TransactionCoordinatorTest { @Test def shouldAppendNewMetadataToLogOnAddPartitionsWhenPartitionsAdded(): Unit = { - validateSuccessfulAddPartitions(Empty, 0) + validateSuccessfulAddPartitions(TransactionState.EMPTY, 0) } @Test def shouldRespondWithSuccessOnAddPartitionsWhenStateIsOngoing(): Unit = { - validateSuccessfulAddPartitions(Ongoing, 0) + validateSuccessfulAddPartitions(TransactionState.ONGOING, 0) } @ParameterizedTest @ValueSource(shorts = Array(0, 2)) def shouldRespondWithSuccessOnAddPartitionsWhenStateIsCompleteCommit(clientTransactionVersion: Short): Unit = { - validateSuccessfulAddPartitions(CompleteCommit, clientTransactionVersion) + validateSuccessfulAddPartitions(TransactionState.COMPLETE_COMMIT, clientTransactionVersion) } @ParameterizedTest @ValueSource(shorts = Array(0, 2)) def shouldRespondWithSuccessOnAddPartitionsWhenStateIsCompleteAbort(clientTransactionVersion: Short): Unit = { - validateSuccessfulAddPartitions(CompleteAbort, clientTransactionVersion) + validateSuccessfulAddPartitions(TransactionState.COMPLETE_ABORT, clientTransactionVersion) } def validateSuccessfulAddPartitions(previousState: TransactionState, transactionVersion: Short): Unit = { @@ -467,7 +467,7 @@ class TransactionCoordinatorTest { when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID, - 0, RecordBatch.NO_PRODUCER_EPOCH, 0, Empty, partitions, 0, 0, TV_0))))) + 0, RecordBatch.NO_PRODUCER_EPOCH, 0, TransactionState.EMPTY, partitions, 0, 0, TV_0))))) coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 0, partitions, errorsCallback, TV_0) assertEquals(Errors.NONE, error) @@ -484,7 +484,7 @@ class TransactionCoordinatorTest { when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID, - 0, RecordBatch.NO_PRODUCER_EPOCH, 0, Ongoing, partitions, 0, 0, TV_0))))) + 0, RecordBatch.NO_PRODUCER_EPOCH, 0, TransactionState.ONGOING, partitions, 0, 0, TV_0))))) coordinator.handleVerifyPartitionsInTransaction(transactionalId, 0L, 0, partitions, verifyPartitionsInTxnCallback) errors.foreach { case (_, error) => @@ -503,7 +503,7 @@ class TransactionCoordinatorTest { when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID, - 0, RecordBatch.NO_PRODUCER_EPOCH, 0, Empty, partitions, 0, 0, TV_0))))) + 0, RecordBatch.NO_PRODUCER_EPOCH, 0, TransactionState.EMPTY, partitions, 0, 0, TV_0))))) val extraPartitions = partitions ++ Set(new TopicPartition("topic2", 0)) @@ -532,7 +532,7 @@ class TransactionCoordinatorTest { when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, 10, 10, RecordBatch.NO_PRODUCER_ID, - 0, RecordBatch.NO_PRODUCER_EPOCH, 0, Ongoing, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_0))))) + 0, RecordBatch.NO_PRODUCER_EPOCH, 0, TransactionState.ONGOING, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_0))))) coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback) assertEquals(Errors.INVALID_PRODUCER_ID_MAPPING, error) @@ -546,7 +546,7 @@ class TransactionCoordinatorTest { when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, producerEpoch, - (producerEpoch - 1).toShort, 1, Ongoing, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_0))))) + (producerEpoch - 1).toShort, 1, TransactionState.ONGOING, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_0))))) coordinator.handleEndTransaction(transactionalId, producerId, 0, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback) assertEquals(Errors.PRODUCER_FENCED, error) @@ -560,7 +560,7 @@ class TransactionCoordinatorTest { when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, producerEpoch, - (producerEpoch - 1).toShort, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion))))) + (producerEpoch - 1).toShort, 1, TransactionState.COMPLETE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion))))) val epoch = if (isRetry) producerEpoch - 1 else producerEpoch coordinator.handleEndTransaction(transactionalId, producerId, epoch.toShort, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback) @@ -587,7 +587,7 @@ class TransactionCoordinatorTest { when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, producerEpoch, - (producerEpoch - 1).toShort, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion))))) + (producerEpoch - 1).toShort, 1, TransactionState.COMPLETE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion))))) val epoch = if (isRetry) producerEpoch - 1 else producerEpoch coordinator.handleEndTransaction(transactionalId, producerId, epoch.toShort, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback) @@ -604,7 +604,7 @@ class TransactionCoordinatorTest { def testEndTxnWhenStatusIsCompleteAbortAndResultIsAbortInV1(isRetry: Boolean): Unit = { val clientTransactionVersion = TransactionVersion.fromFeatureLevel(0) val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, - producerEpoch, (producerEpoch - 1).toShort, 1, CompleteAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion) + producerEpoch, (producerEpoch - 1).toShort, 1, TransactionState.COMPLETE_ABORT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) @@ -623,7 +623,7 @@ class TransactionCoordinatorTest { def shouldReturnOkOnEndTxnWhenStatusIsCompleteAbortAndResultIsAbortInV2(isRetry: Boolean): Unit = { val clientTransactionVersion = TransactionVersion.fromFeatureLevel(2) val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, - producerEpoch, (producerEpoch - 1).toShort, 1, CompleteAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion) + producerEpoch, (producerEpoch - 1).toShort, 1, TransactionState.COMPLETE_ABORT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) @@ -660,7 +660,7 @@ class TransactionCoordinatorTest { def shouldReturnInvalidTxnRequestOnEndTxnRequestWhenStatusIsCompleteAbortAndResultIsNotAbort(transactionVersion: Short): Unit = { val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, - producerEpoch, (producerEpoch - 1).toShort, 1, CompleteAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion) + producerEpoch, (producerEpoch - 1).toShort, 1, TransactionState.COMPLETE_ABORT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) @@ -673,7 +673,7 @@ class TransactionCoordinatorTest { def shouldReturnInvalidTxnRequestOnEndTxnRequestWhenStatusIsCompleteCommitAndResultIsNotCommit(): Unit = { val clientTransactionVersion = TransactionVersion.fromFeatureLevel(0) val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, - producerEpoch, (producerEpoch - 1).toShort,1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion) + producerEpoch, (producerEpoch - 1).toShort,1, TransactionState.COMPLETE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) @@ -687,7 +687,7 @@ class TransactionCoordinatorTest { def testEndTxnRequestWhenStatusIsCompleteCommitAndResultIsAbortInV1(isRetry: Boolean): Unit = { val clientTransactionVersion = TransactionVersion.fromFeatureLevel(0) val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, - producerEpoch, (producerEpoch - 1).toShort, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion) + producerEpoch, (producerEpoch - 1).toShort, 1, TransactionState.COMPLETE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) @@ -706,7 +706,7 @@ class TransactionCoordinatorTest { def testEndTxnRequestWhenStatusIsCompleteCommitAndResultIsAbortInV2(isRetry: Boolean): Unit = { val clientTransactionVersion = TransactionVersion.fromFeatureLevel(2) val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, - producerEpoch, (producerEpoch - 1).toShort, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion) + producerEpoch, (producerEpoch - 1).toShort, 1, TransactionState.COMPLETE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) @@ -737,7 +737,7 @@ class TransactionCoordinatorTest { val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, - RecordBatch.NO_PRODUCER_ID, producerEpoch, (producerEpoch - 1).toShort, 1, PrepareCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion))))) + RecordBatch.NO_PRODUCER_ID, producerEpoch, (producerEpoch - 1).toShort, 1, TransactionState.PREPARE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion))))) coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch(clientTransactionVersion), TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback) assertEquals(Errors.CONCURRENT_TRANSACTIONS, error) @@ -750,7 +750,7 @@ class TransactionCoordinatorTest { val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, - RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, PrepareAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion))))) + RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, TransactionState.PREPARE_ABORT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion))))) coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch(clientTransactionVersion), TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback) assertEquals(Errors.INVALID_TXN_STATE, error) @@ -762,7 +762,7 @@ class TransactionCoordinatorTest { val clientTransactionVersion = TransactionVersion.fromFeatureLevel(0) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, - RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, Empty, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion))))) + RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, TransactionState.EMPTY, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion))))) coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.ABORT, clientTransactionVersion, endTxnCallback) assertEquals(Errors.INVALID_TXN_STATE, error) @@ -775,7 +775,7 @@ class TransactionCoordinatorTest { val clientTransactionVersion = TransactionVersion.fromFeatureLevel(2) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, - RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, Empty, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion))))) + RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, TransactionState.EMPTY, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion))))) val epoch = if (isRetry) producerEpoch - 1 else producerEpoch coordinator.handleEndTransaction(transactionalId, producerId, epoch.toShort, TransactionResult.ABORT, clientTransactionVersion, endTxnCallback) @@ -804,7 +804,7 @@ class TransactionCoordinatorTest { val clientTransactionVersion = TransactionVersion.fromFeatureLevel(2) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, - RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, Empty, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion))))) + RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, TransactionState.EMPTY, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion))))) val epoch = if (isRetry) producerEpoch - 1 else producerEpoch coordinator.handleEndTransaction(transactionalId, producerId, epoch.toShort, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback) @@ -820,7 +820,7 @@ class TransactionCoordinatorTest { def shouldReturnInvalidTxnRequestOnEndTxnV2IfNotEndTxnV2Retry(): Unit = { when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, - RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, PrepareCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2))))) + RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, TransactionState.PREPARE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2))))) // If producerEpoch is the same, this is not a retry of the EndTxnRequest, but the next EndTxnRequest. Return PRODUCER_FENCED. coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.COMMIT, TV_2, endTxnCallback) @@ -829,7 +829,7 @@ class TransactionCoordinatorTest { when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, - RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2))))) + RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, TransactionState.COMPLETE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2))))) // If producerEpoch is the same, this is not a retry of the EndTxnRequest, but the next EndTxnRequest. Return INVALID_TXN_STATE. coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.COMMIT, TV_2, endTxnCallback) @@ -841,7 +841,7 @@ class TransactionCoordinatorTest { def shouldReturnOkOnEndTxnV2IfEndTxnV2RetryEpochOverflow(): Unit = { when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, - producerId2, Short.MaxValue, (Short.MaxValue - 1).toShort, 1, PrepareCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2))))) + producerId2, Short.MaxValue, (Short.MaxValue - 1).toShort, 1, TransactionState.PREPARE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2))))) // Return CONCURRENT_TRANSACTIONS while transaction is still completing coordinator.handleEndTransaction(transactionalId, producerId, (Short.MaxValue - 1).toShort, TransactionResult.COMMIT, TV_2, endTxnCallback) @@ -850,7 +850,7 @@ class TransactionCoordinatorTest { when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId2, producerId, - RecordBatch.NO_PRODUCER_ID, 0, RecordBatch.NO_PRODUCER_EPOCH, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2))))) + RecordBatch.NO_PRODUCER_ID, 0, RecordBatch.NO_PRODUCER_EPOCH, 1, TransactionState.COMPLETE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2))))) coordinator.handleEndTransaction(transactionalId, producerId, (Short.MaxValue - 1).toShort, TransactionResult.COMMIT, TV_2, endTxnCallback) assertEquals(Errors.NONE, error) @@ -863,7 +863,7 @@ class TransactionCoordinatorTest { @Test def shouldReturnConcurrentTxnOnAddPartitionsIfEndTxnV2EpochOverflowAndNotComplete(): Unit = { val prepareWithPending = new TransactionMetadata(transactionalId, producerId, producerId, - producerId2, Short.MaxValue, (Short.MaxValue - 1).toShort, 1, PrepareCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2) + producerId2, Short.MaxValue, (Short.MaxValue - 1).toShort, 1, TransactionState.PREPARE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2) val txnTransitMetadata = prepareWithPending.prepareComplete(time.milliseconds()) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) @@ -875,7 +875,7 @@ class TransactionCoordinatorTest { verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) prepareWithPending.completeTransitionTo(txnTransitMetadata) - assertEquals(CompleteCommit, prepareWithPending.state) + assertEquals(TransactionState.COMPLETE_COMMIT, prepareWithPending.state) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, prepareWithPending)))) when(transactionManager.appendTransactionToLog( @@ -897,7 +897,7 @@ class TransactionCoordinatorTest { @ValueSource(shorts = Array(0, 2)) def shouldAppendPrepareCommitToLogOnEndTxnWhenStatusIsOngoingAndResultIsCommit(transactionVersion: Short): Unit = { val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) - mockPrepare(PrepareCommit, clientTransactionVersion) + mockPrepare(TransactionState.PREPARE_COMMIT, clientTransactionVersion) coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback) verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) @@ -914,7 +914,7 @@ class TransactionCoordinatorTest { @ValueSource(shorts = Array(0, 2)) def shouldAppendPrepareAbortToLogOnEndTxnWhenStatusIsOngoingAndResultIsAbort(transactionVersion: Short): Unit = { val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) - mockPrepare(PrepareAbort, clientTransactionVersion) + mockPrepare(TransactionState.PREPARE_ABORT, clientTransactionVersion) coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.ABORT, clientTransactionVersion, endTxnCallback) verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) @@ -989,7 +989,7 @@ class TransactionCoordinatorTest { when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, metadataEpoch, 1, - 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion))))) + 1, TransactionState.COMPLETE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion))))) coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback) assertEquals(Errors.PRODUCER_FENCED, error) @@ -998,29 +998,29 @@ class TransactionCoordinatorTest { @Test def shouldIncrementEpochAndUpdateMetadataOnHandleInitPidWhenExistingEmptyTransaction(): Unit = { - validateIncrementEpochAndUpdateMetadata(Empty, 0) + validateIncrementEpochAndUpdateMetadata(TransactionState.EMPTY, 0) } @ParameterizedTest @ValueSource(shorts = Array(0, 2)) def shouldIncrementEpochAndUpdateMetadataOnHandleInitPidWhenExistingCompleteTransaction(clientTransactionVersion: Short): Unit = { - validateIncrementEpochAndUpdateMetadata(CompleteAbort, clientTransactionVersion) + validateIncrementEpochAndUpdateMetadata(TransactionState.COMPLETE_ABORT, clientTransactionVersion) } @ParameterizedTest @ValueSource(shorts = Array(0, 2)) def shouldIncrementEpochAndUpdateMetadataOnHandleInitPidWhenExistingCompleteCommitTransaction(clientTransactionVersion: Short): Unit = { - validateIncrementEpochAndUpdateMetadata(CompleteCommit, clientTransactionVersion) + validateIncrementEpochAndUpdateMetadata(TransactionState.COMPLETE_COMMIT, clientTransactionVersion) } @Test def shouldWaitForCommitToCompleteOnHandleInitPidAndExistingTransactionInPrepareCommitState(): Unit = { - validateRespondsWithConcurrentTransactionsOnInitPidWhenInPrepareState(PrepareCommit) + validateRespondsWithConcurrentTransactionsOnInitPidWhenInPrepareState(TransactionState.PREPARE_COMMIT) } @Test def shouldWaitForCommitToCompleteOnHandleInitPidAndExistingTransactionInPrepareAbortState(): Unit = { - validateRespondsWithConcurrentTransactionsOnInitPidWhenInPrepareState(PrepareAbort) + validateRespondsWithConcurrentTransactionsOnInitPidWhenInPrepareState(TransactionState.PREPARE_ABORT) } @ParameterizedTest(name = "enableTwoPCFlag={0}, keepPreparedTxn={1}") @@ -1030,7 +1030,7 @@ class TransactionCoordinatorTest { keepPreparedTxn: Boolean ): Unit = { val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, - producerEpoch, (producerEpoch - 1).toShort, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0) + producerEpoch, (producerEpoch - 1).toShort, txnTimeoutMs, TransactionState.ONGOING, partitions, time.milliseconds(), time.milliseconds(), TV_0) when(transactionManager.validateTransactionTimeoutMs(anyBoolean(), anyInt())) .thenReturn(true) @@ -1041,7 +1041,7 @@ class TransactionCoordinatorTest { when(transactionManager.transactionVersionLevel()).thenReturn(TV_0) val originalMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, - (producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0) + (producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.ONGOING, partitions, time.milliseconds(), time.milliseconds(), TV_0) when(transactionManager.appendTransactionToLog( ArgumentMatchers.eq(transactionalId), ArgumentMatchers.eq(coordinatorEpoch), @@ -1066,7 +1066,7 @@ class TransactionCoordinatorTest { verify(transactionManager).appendTransactionToLog( ArgumentMatchers.eq(transactionalId), ArgumentMatchers.eq(coordinatorEpoch), - ArgumentMatchers.eq(originalMetadata.prepareAbortOrCommit(PrepareAbort, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds(), false)), + ArgumentMatchers.eq(originalMetadata.prepareAbortOrCommit(TransactionState.PREPARE_ABORT, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds(), false)), any(), any(), any()) @@ -1075,13 +1075,13 @@ class TransactionCoordinatorTest { @Test def shouldFailToAbortTransactionOnHandleInitPidWhenProducerEpochIsSmaller(): Unit = { val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, - producerEpoch, (producerEpoch - 1).toShort, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0) + producerEpoch, (producerEpoch - 1).toShort, txnTimeoutMs, TransactionState.ONGOING, partitions, time.milliseconds(), time.milliseconds(), TV_0) when(transactionManager.validateTransactionTimeoutMs(anyBoolean(), anyInt())) .thenReturn(true) val bumpedTxnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, - (producerEpoch + 2).toShort, (producerEpoch - 1).toShort, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0) + (producerEpoch + 2).toShort, (producerEpoch - 1).toShort, txnTimeoutMs, TransactionState.ONGOING, partitions, time.milliseconds(), time.milliseconds(), TV_0) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) @@ -1106,7 +1106,7 @@ class TransactionCoordinatorTest { @Test def shouldNotRepeatedlyBumpEpochDueToInitPidDuringOngoingTxnIfAppendToLogFails(): Unit = { val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, - producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0) + producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.ONGOING, partitions, time.milliseconds(), time.milliseconds(), TV_0) when(transactionManager.validateTransactionTimeoutMs(anyBoolean(), anyInt())) .thenReturn(true) @@ -1120,8 +1120,8 @@ class TransactionCoordinatorTest { when(transactionManager.transactionVersionLevel()).thenReturn(TV_0) val originalMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, - (producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0) - val txnTransitMetadata = originalMetadata.prepareAbortOrCommit(PrepareAbort, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds(), false) + (producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.ONGOING, partitions, time.milliseconds(), time.milliseconds(), TV_0) + val txnTransitMetadata = originalMetadata.prepareAbortOrCommit(TransactionState.PREPARE_ABORT, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds(), false) when(transactionManager.appendTransactionToLog( ArgumentMatchers.eq(transactionalId), ArgumentMatchers.eq(coordinatorEpoch), @@ -1198,14 +1198,14 @@ class TransactionCoordinatorTest { @Test def shouldUseLastEpochToFenceWhenEpochsAreExhausted(): Unit = { val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, - (Short.MaxValue - 1).toShort, (Short.MaxValue - 2).toShort, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0) + (Short.MaxValue - 1).toShort, (Short.MaxValue - 2).toShort, txnTimeoutMs, TransactionState.ONGOING, partitions, time.milliseconds(), time.milliseconds(), TV_0) assertTrue(txnMetadata.isProducerEpochExhausted) when(transactionManager.validateTransactionTimeoutMs(anyBoolean(), anyInt())) .thenReturn(true) val postFenceTxnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, - Short.MaxValue, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, PrepareAbort, partitions, time.milliseconds(), time.milliseconds(), TV_0) + Short.MaxValue, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.PREPARE_ABORT, partitions, time.milliseconds(), time.milliseconds(), TV_0) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) @@ -1224,7 +1224,7 @@ class TransactionCoordinatorTest { producerEpoch = Short.MaxValue, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = txnTimeoutMs, - txnState = PrepareAbort, + txnState = TransactionState.PREPARE_ABORT, topicPartitions = partitions.clone, txnStartTimestamp = time.milliseconds(), txnLastUpdateTimestamp = time.milliseconds(), @@ -1257,7 +1257,7 @@ class TransactionCoordinatorTest { producerEpoch = Short.MaxValue, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = txnTimeoutMs, - txnState = PrepareAbort, + txnState = TransactionState.PREPARE_ABORT, topicPartitions = partitions.clone, txnStartTimestamp = time.milliseconds(), txnLastUpdateTimestamp = time.milliseconds(), @@ -1272,7 +1272,7 @@ class TransactionCoordinatorTest { // If the metadata doesn't include the previous producer data (for example, if it was written to the log by a broker // on an old version), the retry case should fail val txnMetadata = new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, - RecordBatch.NO_PRODUCER_EPOCH, (producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds, TV_0) + RecordBatch.NO_PRODUCER_EPOCH, (producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.EMPTY, partitions, time.milliseconds, time.milliseconds, TV_0) when(transactionManager.validateTransactionTimeoutMs(anyBoolean(), anyInt())) .thenReturn(true) @@ -1295,7 +1295,7 @@ class TransactionCoordinatorTest { def testFenceProducerWhenMappingExistsWithDifferentProducerId(): Unit = { // Existing transaction ID maps to new producer ID val txnMetadata = new TransactionMetadata(transactionalId, producerId + 1, producerId, - RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, (producerEpoch - 1).toShort, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds, TV_0) + RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, (producerEpoch - 1).toShort, txnTimeoutMs, TransactionState.EMPTY, partitions, time.milliseconds, time.milliseconds, TV_0) when(transactionManager.validateTransactionTimeoutMs(anyBoolean(), anyInt())) .thenReturn(true) @@ -1319,7 +1319,7 @@ class TransactionCoordinatorTest { mockPidGenerator() val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, - RecordBatch.NO_PRODUCER_EPOCH, 10, 9, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds, TV_0) + RecordBatch.NO_PRODUCER_EPOCH, 10, 9, txnTimeoutMs, TransactionState.EMPTY, partitions, time.milliseconds, time.milliseconds, TV_0) when(transactionManager.validateTransactionTimeoutMs(anyBoolean(), anyInt())) .thenReturn(true) @@ -1366,7 +1366,7 @@ class TransactionCoordinatorTest { mockPidGenerator() val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, - RecordBatch.NO_PRODUCER_EPOCH, 10, 9, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds, TV_0) + RecordBatch.NO_PRODUCER_EPOCH, 10, 9, txnTimeoutMs, TransactionState.EMPTY, partitions, time.milliseconds, time.milliseconds, TV_0) when(transactionManager.validateTransactionTimeoutMs(anyBoolean(), anyInt())) .thenReturn(true) @@ -1415,7 +1415,7 @@ class TransactionCoordinatorTest { def testRetryInitProducerIdAfterProducerIdRotation(): Unit = { // Existing transaction ID maps to new producer ID val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, - RecordBatch.NO_PRODUCER_EPOCH, (Short.MaxValue - 1).toShort, (Short.MaxValue - 2).toShort, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds, TV_0) + RecordBatch.NO_PRODUCER_EPOCH, (Short.MaxValue - 1).toShort, (Short.MaxValue - 2).toShort, txnTimeoutMs, TransactionState.EMPTY, partitions, time.milliseconds, time.milliseconds, TV_0) when(pidGenerator.generateProducerId()) .thenReturn(producerId + 1) @@ -1468,7 +1468,7 @@ class TransactionCoordinatorTest { def testInitProducerIdWithInvalidEpochAfterProducerIdRotation(): Unit = { // Existing transaction ID maps to new producer ID val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, - RecordBatch.NO_PRODUCER_EPOCH, (Short.MaxValue - 1).toShort, (Short.MaxValue - 2).toShort, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds, TV_0) + RecordBatch.NO_PRODUCER_EPOCH, (Short.MaxValue - 1).toShort, (Short.MaxValue - 2).toShort, txnTimeoutMs, TransactionState.EMPTY, partitions, time.milliseconds, time.milliseconds, TV_0) when(pidGenerator.generateProducerId()) .thenReturn(producerId + 1) @@ -1528,7 +1528,7 @@ class TransactionCoordinatorTest { def shouldAbortExpiredTransactionsInOngoingStateAndBumpEpoch(): Unit = { val now = time.milliseconds() val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, - RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now, TV_0) + RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.ONGOING, partitions, now, now, TV_0) when(transactionManager.timedOutTransactions()) .thenReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, producerId, producerEpoch))) @@ -1537,7 +1537,7 @@ class TransactionCoordinatorTest { // Transaction timeouts use FenceProducerEpoch so clientTransactionVersion is 0. val expectedTransition = TxnTransitMetadata(producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, (producerEpoch + 1).toShort, - RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, PrepareAbort, partitions.clone, now, + RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.PREPARE_ABORT, partitions.clone, now, now + TransactionStateManagerConfig.TRANSACTIONS_ABORT_TIMED_OUT_TRANSACTION_CLEANUP_INTERVAL_MS_DEFAULT, TV_0) when(transactionManager.transactionVersionLevel()).thenReturn(TV_0) @@ -1567,7 +1567,7 @@ class TransactionCoordinatorTest { def shouldNotAcceptSmallerEpochDuringTransactionExpiration(): Unit = { val now = time.milliseconds() val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, - RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now, TV_0) + RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.ONGOING, partitions, now, now, TV_0) when(transactionManager.timedOutTransactions()) .thenReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, producerId, producerEpoch))) @@ -1577,7 +1577,7 @@ class TransactionCoordinatorTest { when(transactionManager.transactionVersionLevel()).thenReturn(TV_0) val bumpedTxnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, - RecordBatch.NO_PRODUCER_EPOCH, (producerEpoch + 2).toShort, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now, TV_0) + RecordBatch.NO_PRODUCER_EPOCH, (producerEpoch + 2).toShort, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.ONGOING, partitions, now, now, TV_0) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, bumpedTxnMetadata)))) @@ -1593,8 +1593,8 @@ class TransactionCoordinatorTest { @Test def shouldNotAbortExpiredTransactionsThatHaveAPendingStateTransition(): Unit = { val metadata = new TransactionMetadata(transactionalId, producerId, producerId, - RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0) - metadata.prepareAbortOrCommit(PrepareCommit, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds(), false) + RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.ONGOING, partitions, time.milliseconds(), time.milliseconds(), TV_0) + metadata.prepareAbortOrCommit(TransactionState.PREPARE_COMMIT, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds(), false) when(transactionManager.timedOutTransactions()) .thenReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, producerId, producerEpoch))) @@ -1612,13 +1612,13 @@ class TransactionCoordinatorTest { def shouldNotBumpEpochWhenAbortingExpiredTransactionIfAppendToLogFails(): Unit = { val now = time.milliseconds() val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, - RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now, TV_0) + RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.ONGOING, partitions, now, now, TV_0) when(transactionManager.timedOutTransactions()) .thenReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, producerId, producerEpoch))) val txnMetadataAfterAppendFailure = new TransactionMetadata(transactionalId, producerId, producerId, - RecordBatch.NO_PRODUCER_EPOCH, (producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now, TV_0) + RecordBatch.NO_PRODUCER_EPOCH, (producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.ONGOING, partitions, now, now, TV_0) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) @@ -1627,7 +1627,7 @@ class TransactionCoordinatorTest { // Transaction timeouts use FenceProducerEpoch so clientTransactionVersion is 0. val bumpedEpoch = (producerEpoch + 1).toShort val expectedTransition = TxnTransitMetadata(producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, bumpedEpoch, - RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, PrepareAbort, partitions.clone, now, + RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.PREPARE_ABORT, partitions.clone, now, now + TransactionStateManagerConfig.TRANSACTIONS_ABORT_TIMED_OUT_TRANSACTION_CLEANUP_INTERVAL_MS_DEFAULT, TV_0) when(transactionManager.transactionVersionLevel()).thenReturn(TV_0) @@ -1660,8 +1660,8 @@ class TransactionCoordinatorTest { @Test def shouldNotBumpEpochWithPendingTransaction(): Unit = { val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, - RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0) - txnMetadata.prepareAbortOrCommit(PrepareCommit, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds(), false) + RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.ONGOING, partitions, time.milliseconds(), time.milliseconds(), TV_0) + txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_COMMIT, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds(), false) when(transactionManager.validateTransactionTimeoutMs(anyBoolean(), anyInt())) .thenReturn(true) @@ -1695,7 +1695,7 @@ class TransactionCoordinatorTest { coordinator.startup(() => transactionStatePartitionCount, enableTransactionalIdExpiration = false) val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, - RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Dead, mutable.Set.empty, time.milliseconds(), + RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.DEAD, mutable.Set.empty, time.milliseconds(), time.milliseconds(), TV_0) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) @@ -1721,7 +1721,7 @@ class TransactionCoordinatorTest { @Test def testDescribeTransactions(): Unit = { val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, - RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0) + RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.ONGOING, partitions, time.milliseconds(), time.milliseconds(), TV_0) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) @@ -1747,7 +1747,7 @@ class TransactionCoordinatorTest { when(transactionManager.validateTransactionTimeoutMs(anyBoolean(), anyInt())) .thenReturn(true) - // Since the clientTransactionVersion doesn't matter, use 2 since the states are PrepareCommit and PrepareAbort. + // Since the clientTransactionVersion doesn't matter, use 2 since the states are TransactionState.PREPARE_COMMIT and TransactionState.PREPARE_ABORT. val metadata = new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, state, mutable.Set[TopicPartition](new TopicPartition("topic", 1)), 0, 0, TV_2) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) @@ -1799,7 +1799,7 @@ class TransactionCoordinatorTest { private def mockPrepare(transactionState: TransactionState, clientTransactionVersion: TransactionVersion, runCallback: Boolean = false): TransactionMetadata = { val now = time.milliseconds() val originalMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, - producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now, TV_0) + producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.ONGOING, partitions, now, now, TV_0) val transition = TxnTransitMetadata(producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, transactionState, partitions.clone, now, now, clientTransactionVersion) diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala index 8a852d70cbe..d139f1d3b7f 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala @@ -23,6 +23,7 @@ import org.apache.kafka.common.protocol.{ByteBufferAccessor, MessageUtil} import org.apache.kafka.common.protocol.types.Field.TaggedFieldsSection import org.apache.kafka.common.protocol.types.{CompactArrayOf, Field, Schema, Struct, Type} import org.apache.kafka.common.record.{MemoryRecords, RecordBatch, SimpleRecord} +import org.apache.kafka.coordinator.transaction.TransactionState import org.apache.kafka.coordinator.transaction.generated.{TransactionLogKey, TransactionLogValue} import org.apache.kafka.server.common.TransactionVersion.{TV_0, TV_2} import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows, assertTrue, fail} @@ -49,7 +50,7 @@ class TransactionLogTest { val producerId = 23423L val txnMetadata = new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_ID, producerEpoch, - RecordBatch.NO_PRODUCER_EPOCH, transactionTimeoutMs, Empty, collection.mutable.Set.empty[TopicPartition], 0, 0, TV_0) + RecordBatch.NO_PRODUCER_EPOCH, transactionTimeoutMs, TransactionState.EMPTY, collection.mutable.Set.empty[TopicPartition], 0, 0, TV_0) txnMetadata.addPartitions(topicPartitions) assertThrows(classOf[IllegalStateException], () => TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), TV_2)) @@ -64,19 +65,19 @@ class TransactionLogTest { "four" -> 4L, "five" -> 5L) - val transactionStates = Map[Long, TransactionState](0L -> Empty, - 1L -> Ongoing, - 2L -> PrepareCommit, - 3L -> CompleteCommit, - 4L -> PrepareAbort, - 5L -> CompleteAbort) + val transactionStates = Map[Long, TransactionState](0L -> TransactionState.EMPTY, + 1L -> TransactionState.ONGOING, + 2L -> TransactionState.PREPARE_COMMIT, + 3L -> TransactionState.COMPLETE_COMMIT, + 4L -> TransactionState.PREPARE_ABORT, + 5L -> TransactionState.COMPLETE_ABORT) // generate transaction log messages val txnRecords = pidMappings.map { case (transactionalId, producerId) => val txnMetadata = new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, transactionTimeoutMs, transactionStates(producerId), collection.mutable.Set.empty[TopicPartition], 0, 0, TV_0) - if (!txnMetadata.state.equals(Empty)) + if (!txnMetadata.state.equals(TransactionState.EMPTY)) txnMetadata.addPartitions(topicPartitions) val keyBytes = TransactionLog.keyToBytes(transactionalId) @@ -99,7 +100,7 @@ class TransactionLogTest { assertEquals(transactionTimeoutMs, txnMetadata.txnTimeoutMs) assertEquals(transactionStates(txnMetadata.producerId), txnMetadata.state) - if (txnMetadata.state.equals(Empty)) + if (txnMetadata.state.equals(TransactionState.EMPTY)) assertEquals(Set.empty[TopicPartition], txnMetadata.topicPartitions) else assertEquals(topicPartitions, txnMetadata.topicPartitions) @@ -113,14 +114,14 @@ class TransactionLogTest { @Test def testSerializeTransactionLogValueToHighestNonFlexibleVersion(): Unit = { - val txnTransitMetadata = TxnTransitMetadata(1, 1, 1, 1, 1, 1000, CompleteCommit, mutable.Set.empty, 500, 500, TV_0) + val txnTransitMetadata = TxnTransitMetadata(1, 1, 1, 1, 1, 1000, TransactionState.COMPLETE_COMMIT, mutable.Set.empty, 500, 500, TV_0) val txnLogValueBuffer = ByteBuffer.wrap(TransactionLog.valueToBytes(txnTransitMetadata, TV_0)) assertEquals(0, txnLogValueBuffer.getShort) } @Test def testSerializeTransactionLogValueToFlexibleVersion(): Unit = { - val txnTransitMetadata = TxnTransitMetadata(1, 1, 1, 1, 1, 1000, CompleteCommit, mutable.Set.empty, 500, 500, TV_2) + val txnTransitMetadata = TxnTransitMetadata(1, 1, 1, 1, 1, 1000, TransactionState.COMPLETE_COMMIT, mutable.Set.empty, 500, 500, TV_2) val txnLogValueBuffer = ByteBuffer.wrap(TransactionLog.valueToBytes(txnTransitMetadata, TV_2)) assertEquals(TransactionLogValue.HIGHEST_SUPPORTED_VERSION, txnLogValueBuffer.getShort) } @@ -134,7 +135,7 @@ class TransactionLogTest { val txnLogValue = new TransactionLogValue() .setProducerId(100) .setProducerEpoch(50.toShort) - .setTransactionStatus(CompleteCommit.id) + .setTransactionStatus(TransactionState.COMPLETE_COMMIT.id) .setTransactionStartTimestampMs(750L) .setTransactionLastUpdateTimestampMs(1000L) .setTransactionTimeoutMs(500) @@ -145,7 +146,7 @@ class TransactionLogTest { assertEquals(100, deserialized.producerId) assertEquals(50, deserialized.producerEpoch) - assertEquals(CompleteCommit, deserialized.state) + assertEquals(TransactionState.COMPLETE_COMMIT, deserialized.state) assertEquals(750L, deserialized.txnStartTimestamp) assertEquals(1000L, deserialized.txnLastUpdateTimestamp) assertEquals(500, deserialized.txnTimeoutMs) @@ -198,7 +199,7 @@ class TransactionLogTest { transactionLogValue.set("producer_id", 1000L) transactionLogValue.set("producer_epoch", 100.toShort) transactionLogValue.set("transaction_timeout_ms", 1000) - transactionLogValue.set("transaction_status", CompleteCommit.id) + transactionLogValue.set("transaction_status", TransactionState.COMPLETE_COMMIT.id) transactionLogValue.set("transaction_partitions", Array(txnPartitions)) transactionLogValue.set("transaction_last_update_timestamp_ms", 2000L) transactionLogValue.set("transaction_start_timestamp_ms", 3000L) @@ -227,7 +228,7 @@ class TransactionLogTest { assertEquals(1000L, txnMetadata.producerId) assertEquals(100, txnMetadata.producerEpoch) assertEquals(1000L, txnMetadata.txnTimeoutMs) - assertEquals(CompleteCommit, txnMetadata.state) + assertEquals(TransactionState.COMPLETE_COMMIT, txnMetadata.state) assertEquals(Set(new TopicPartition("topic", 1)), txnMetadata.topicPartitions) assertEquals(2000L, txnMetadata.txnLastUpdateTimestamp) assertEquals(3000L, txnMetadata.txnStartTimestamp) diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala index 321e6e793f4..4131f564d63 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala @@ -29,6 +29,7 @@ import org.apache.kafka.common.record.RecordBatch import org.apache.kafka.common.requests.{RequestHeader, TransactionResult, WriteTxnMarkersRequest, WriteTxnMarkersResponse} import org.apache.kafka.common.utils.MockTime import org.apache.kafka.common.{Node, TopicPartition} +import org.apache.kafka.coordinator.transaction.TransactionState import org.apache.kafka.metadata.MetadataCache import org.apache.kafka.server.common.{MetadataVersion, TransactionVersion} import org.apache.kafka.server.metrics.{KafkaMetricsGroup, KafkaYammerMetrics} @@ -68,9 +69,9 @@ class TransactionMarkerChannelManagerTest { private val txnTimeoutMs = 0 private val txnResult = TransactionResult.COMMIT private val txnMetadata1 = new TransactionMetadata(transactionalId1, producerId1, producerId1, RecordBatch.NO_PRODUCER_ID, - producerEpoch, lastProducerEpoch, txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](partition1, partition2), 0L, 0L, TransactionVersion.TV_2) + producerEpoch, lastProducerEpoch, txnTimeoutMs, TransactionState.PREPARE_COMMIT, mutable.Set[TopicPartition](partition1, partition2), 0L, 0L, TransactionVersion.TV_2) private val txnMetadata2 = new TransactionMetadata(transactionalId2, producerId2, producerId2, RecordBatch.NO_PRODUCER_ID, - producerEpoch, lastProducerEpoch, txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](partition1), 0L, 0L, TransactionVersion.TV_2) + producerEpoch, lastProducerEpoch, txnTimeoutMs, TransactionState.PREPARE_COMMIT, mutable.Set[TopicPartition](partition1), 0L, 0L, TransactionVersion.TV_2) private val capturedErrorsCallback: ArgumentCaptor[Errors => Unit] = ArgumentCaptor.forClass(classOf[Errors => Unit]) private val time = new MockTime @@ -480,7 +481,7 @@ class TransactionMarkerChannelManagerTest { assertEquals(0, channelManager.numTxnsWithPendingMarkers) assertEquals(0, channelManager.queueForBroker(broker1.id).get.totalNumMarkers) assertEquals(None, txnMetadata2.pendingState) - assertEquals(CompleteCommit, txnMetadata2.state) + assertEquals(TransactionState.COMPLETE_COMMIT, txnMetadata2.state) } @Test @@ -533,7 +534,7 @@ class TransactionMarkerChannelManagerTest { assertEquals(0, channelManager.numTxnsWithPendingMarkers) assertEquals(0, channelManager.queueForBroker(broker1.id).get.totalNumMarkers) assertEquals(None, txnMetadata2.pendingState) - assertEquals(PrepareCommit, txnMetadata2.state) + assertEquals(TransactionState.PREPARE_COMMIT, txnMetadata2.state) } @ParameterizedTest @@ -594,7 +595,7 @@ class TransactionMarkerChannelManagerTest { assertEquals(0, channelManager.numTxnsWithPendingMarkers) assertEquals(0, channelManager.queueForBroker(broker1.id).get.totalNumMarkers) assertEquals(None, txnMetadata2.pendingState) - assertEquals(CompleteCommit, txnMetadata2.state) + assertEquals(TransactionState.COMPLETE_COMMIT, txnMetadata2.state) } private def createPidErrorMap(errors: Errors): util.HashMap[java.lang.Long, util.Map[TopicPartition, Errors]] = { diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala index 72ffa5629c0..b34b0725020 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala @@ -23,6 +23,7 @@ import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.protocol.{ApiKeys, Errors} import org.apache.kafka.common.record.RecordBatch import org.apache.kafka.common.requests.{RequestHeader, TransactionResult, WriteTxnMarkersRequest, WriteTxnMarkersResponse} +import org.apache.kafka.coordinator.transaction.TransactionState import org.apache.kafka.server.common.TransactionVersion import org.junit.jupiter.api.Assertions._ import org.junit.jupiter.api.Test @@ -44,7 +45,7 @@ class TransactionMarkerRequestCompletionHandlerTest { private val txnResult = TransactionResult.COMMIT private val topicPartition = new TopicPartition("topic1", 0) private val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, - producerEpoch, lastProducerEpoch, txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](topicPartition), 0L, 0L, TransactionVersion.TV_2) + producerEpoch, lastProducerEpoch, txnTimeoutMs, TransactionState.PREPARE_COMMIT, mutable.Set[TopicPartition](topicPartition), 0L, 0L, TransactionVersion.TV_2) private val pendingCompleteTxnAndMarkers = asList( PendingCompleteTxnAndMarkerEntry( PendingCompleteTxn(transactionalId, coordinatorEpoch, txnMetadata, txnMetadata.prepareComplete(42)), diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMetadataTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMetadataTest.scala index 12536cddff7..56d595d24ed 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMetadataTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMetadataTest.scala @@ -19,6 +19,7 @@ package kafka.coordinator.transaction import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.protocol.Errors import org.apache.kafka.common.record.RecordBatch +import org.apache.kafka.coordinator.transaction.TransactionState import org.apache.kafka.server.common.TransactionVersion import org.apache.kafka.server.common.TransactionVersion.{TV_0, TV_2} import org.apache.kafka.server.util.MockTime @@ -27,7 +28,10 @@ import org.junit.jupiter.api.Test import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.ValueSource +import java.util.Optional + import scala.collection.mutable +import scala.jdk.CollectionConverters._ class TransactionMetadataTest { @@ -47,7 +51,7 @@ class TransactionMetadataTest { producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, - state = Empty, + state = TransactionState.EMPTY, topicPartitions = mutable.Set.empty, txnLastUpdateTimestamp = time.milliseconds(), clientTransactionVersion = TV_0) @@ -71,7 +75,7 @@ class TransactionMetadataTest { producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, - state = Empty, + state = TransactionState.EMPTY, topicPartitions = mutable.Set.empty, txnLastUpdateTimestamp = time.milliseconds(), clientTransactionVersion = TV_0) @@ -95,7 +99,7 @@ class TransactionMetadataTest { producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, - state = Empty, + state = TransactionState.EMPTY, topicPartitions = mutable.Set.empty, txnLastUpdateTimestamp = time.milliseconds(), clientTransactionVersion = TV_0) @@ -117,13 +121,13 @@ class TransactionMetadataTest { producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, - state = Empty, + state = TransactionState.EMPTY, topicPartitions = mutable.Set.empty, txnStartTimestamp = -1, txnLastUpdateTimestamp = time.milliseconds(), clientTransactionVersion = TV_2) - val transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareAbort, TV_2, RecordBatch.NO_PRODUCER_ID, time.milliseconds() + 1, true) + val transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_ABORT, TV_2, RecordBatch.NO_PRODUCER_ID, time.milliseconds() + 1, true) txnMetadata.completeTransitionTo(transitMetadata) assertEquals(producerId, txnMetadata.producerId) assertEquals(producerEpoch + 1, txnMetadata.producerEpoch) @@ -142,13 +146,13 @@ class TransactionMetadataTest { producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, - state = CompleteAbort, + state = TransactionState.COMPLETE_ABORT, topicPartitions = mutable.Set.empty, txnStartTimestamp = time.milliseconds() - 1, txnLastUpdateTimestamp = time.milliseconds(), clientTransactionVersion = TV_2) - val transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareAbort, TV_2, RecordBatch.NO_PRODUCER_ID, time.milliseconds() + 1, true) + val transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_ABORT, TV_2, RecordBatch.NO_PRODUCER_ID, time.milliseconds() + 1, true) txnMetadata.completeTransitionTo(transitMetadata) assertEquals(producerId, txnMetadata.producerId) assertEquals(producerEpoch + 1, txnMetadata.producerEpoch) @@ -167,13 +171,13 @@ class TransactionMetadataTest { producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, - state = CompleteCommit, + state = TransactionState.COMPLETE_COMMIT, topicPartitions = mutable.Set.empty, txnStartTimestamp = time.milliseconds() - 1, txnLastUpdateTimestamp = time.milliseconds(), clientTransactionVersion = TV_2) - val transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareAbort, TV_2, RecordBatch.NO_PRODUCER_ID, time.milliseconds() + 1, true) + val transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_ABORT, TV_2, RecordBatch.NO_PRODUCER_ID, time.milliseconds() + 1, true) txnMetadata.completeTransitionTo(transitMetadata) assertEquals(producerId, txnMetadata.producerId) assertEquals(producerEpoch + 1, txnMetadata.producerEpoch) @@ -191,7 +195,7 @@ class TransactionMetadataTest { producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, - state = Empty, + state = TransactionState.EMPTY, topicPartitions = mutable.Set.empty, txnStartTimestamp = 1L, txnLastUpdateTimestamp = time.milliseconds(), @@ -219,7 +223,7 @@ class TransactionMetadataTest { producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, - state = Empty, + state = TransactionState.EMPTY, topicPartitions = mutable.Set.empty, txnStartTimestamp = 1L, txnLastUpdateTimestamp = time.milliseconds(), @@ -246,13 +250,13 @@ class TransactionMetadataTest { producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, - state = Empty, + state = TransactionState.EMPTY, topicPartitions = mutable.Set.empty, txnStartTimestamp = time.milliseconds(), txnLastUpdateTimestamp = time.milliseconds(), clientTransactionVersion = TV_0) - // let new time be smaller; when transiting from Empty the start time would be updated to the update-time + // let new time be smaller; when transiting from TransactionState.EMPTY the start time would be updated to the update-time var transitMetadata = txnMetadata.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic1", 0)), time.milliseconds() - 1, TV_0) txnMetadata.completeTransitionTo(transitMetadata) assertEquals(Set[TopicPartition](new TopicPartition("topic1", 0)), txnMetadata.topicPartitions) @@ -262,7 +266,7 @@ class TransactionMetadataTest { assertEquals(time.milliseconds() - 1, txnMetadata.txnStartTimestamp) assertEquals(time.milliseconds() - 1, txnMetadata.txnLastUpdateTimestamp) - // add another partition, check that in Ongoing state the start timestamp would not change to update time + // add another partition, check that in TransactionState.ONGOING state the start timestamp would not change to update time transitMetadata = txnMetadata.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds() - 2, TV_0) txnMetadata.completeTransitionTo(transitMetadata) assertEquals(Set[TopicPartition](new TopicPartition("topic1", 0), new TopicPartition("topic2", 0)), txnMetadata.topicPartitions) @@ -284,16 +288,16 @@ class TransactionMetadataTest { producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, - state = Ongoing, + state = TransactionState.ONGOING, topicPartitions = mutable.Set.empty, txnStartTimestamp = 1L, txnLastUpdateTimestamp = time.milliseconds(), clientTransactionVersion = TV_0) // let new time be smaller - val transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareCommit, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds() - 1, false) + val transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_COMMIT, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds() - 1, false) txnMetadata.completeTransitionTo(transitMetadata) - assertEquals(PrepareCommit, txnMetadata.state) + assertEquals(TransactionState.PREPARE_COMMIT, txnMetadata.state) assertEquals(producerId, txnMetadata.producerId) assertEquals(RecordBatch.NO_PRODUCER_EPOCH, txnMetadata.lastProducerEpoch) assertEquals(producerEpoch, txnMetadata.producerEpoch) @@ -312,16 +316,16 @@ class TransactionMetadataTest { producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, - state = Ongoing, + state = TransactionState.ONGOING, topicPartitions = mutable.Set.empty, txnStartTimestamp = 1L, txnLastUpdateTimestamp = time.milliseconds(), clientTransactionVersion = TV_0) // let new time be smaller - val transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareAbort, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds() - 1, false) + val transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_ABORT, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds() - 1, false) txnMetadata.completeTransitionTo(transitMetadata) - assertEquals(PrepareAbort, txnMetadata.state) + assertEquals(TransactionState.PREPARE_ABORT, txnMetadata.state) assertEquals(producerId, txnMetadata.producerId) assertEquals(RecordBatch.NO_PRODUCER_EPOCH, txnMetadata.lastProducerEpoch) assertEquals(producerEpoch, txnMetadata.producerEpoch) @@ -343,7 +347,7 @@ class TransactionMetadataTest { producerEpoch = producerEpoch, lastProducerEpoch = lastProducerEpoch, txnTimeoutMs = 30000, - state = PrepareCommit, + state = TransactionState.PREPARE_COMMIT, topicPartitions = mutable.Set.empty, txnStartTimestamp = 1L, txnLastUpdateTimestamp = time.milliseconds(), @@ -354,7 +358,7 @@ class TransactionMetadataTest { val transitMetadata = txnMetadata.prepareComplete(time.milliseconds() - 1) txnMetadata.completeTransitionTo(transitMetadata) - assertEquals(CompleteCommit, txnMetadata.state) + assertEquals(TransactionState.COMPLETE_COMMIT, txnMetadata.state) assertEquals(producerId, txnMetadata.producerId) assertEquals(lastProducerEpoch, txnMetadata.lastProducerEpoch) assertEquals(producerEpoch, txnMetadata.producerEpoch) @@ -376,7 +380,7 @@ class TransactionMetadataTest { producerEpoch = producerEpoch, lastProducerEpoch = lastProducerEpoch, txnTimeoutMs = 30000, - state = PrepareAbort, + state = TransactionState.PREPARE_ABORT, topicPartitions = mutable.Set.empty, txnStartTimestamp = 1L, txnLastUpdateTimestamp = time.milliseconds(), @@ -387,7 +391,7 @@ class TransactionMetadataTest { val transitMetadata = txnMetadata.prepareComplete(time.milliseconds() - 1) txnMetadata.completeTransitionTo(transitMetadata) - assertEquals(CompleteAbort, txnMetadata.state) + assertEquals(TransactionState.COMPLETE_ABORT, txnMetadata.state) assertEquals(producerId, txnMetadata.producerId) assertEquals(lastProducerEpoch, txnMetadata.lastProducerEpoch) assertEquals(producerEpoch, txnMetadata.producerEpoch) @@ -407,7 +411,7 @@ class TransactionMetadataTest { producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, - state = Ongoing, + state = TransactionState.ONGOING, topicPartitions = mutable.Set.empty, txnLastUpdateTimestamp = time.milliseconds(), clientTransactionVersion = TV_0) @@ -416,12 +420,12 @@ class TransactionMetadataTest { val fencingTransitMetadata = txnMetadata.prepareFenceProducerEpoch() assertEquals(Short.MaxValue, fencingTransitMetadata.producerEpoch) assertEquals(RecordBatch.NO_PRODUCER_EPOCH, fencingTransitMetadata.lastProducerEpoch) - assertEquals(Some(PrepareEpochFence), txnMetadata.pendingState) + assertEquals(Some(TransactionState.PREPARE_EPOCH_FENCE), txnMetadata.pendingState) // We should reset the pending state to make way for the abort transition. txnMetadata.pendingState = None - val transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareAbort, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds(), false) + val transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_ABORT, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds(), false) txnMetadata.completeTransitionTo(transitMetadata) assertEquals(producerId, transitMetadata.producerId) } @@ -438,7 +442,7 @@ class TransactionMetadataTest { producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, - state = CompleteCommit, + state = TransactionState.COMPLETE_COMMIT, topicPartitions = mutable.Set.empty, txnLastUpdateTimestamp = time.milliseconds(), clientTransactionVersion = TV_0) @@ -459,7 +463,7 @@ class TransactionMetadataTest { producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, - state = CompleteAbort, + state = TransactionState.COMPLETE_ABORT, topicPartitions = mutable.Set.empty, txnLastUpdateTimestamp = time.milliseconds(), clientTransactionVersion = TV_0) @@ -480,7 +484,7 @@ class TransactionMetadataTest { producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, - state = Ongoing, + state = TransactionState.ONGOING, topicPartitions = mutable.Set.empty, txnLastUpdateTimestamp = time.milliseconds(), clientTransactionVersion = TV_0) @@ -500,7 +504,7 @@ class TransactionMetadataTest { producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, - state = Empty, + state = TransactionState.EMPTY, topicPartitions = mutable.Set.empty, txnLastUpdateTimestamp = time.milliseconds(), clientTransactionVersion = TV_0) @@ -527,13 +531,13 @@ class TransactionMetadataTest { producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, - state = Ongoing, + state = TransactionState.ONGOING, topicPartitions = mutable.Set.empty, txnStartTimestamp = time.milliseconds(), txnLastUpdateTimestamp = time.milliseconds(), clientTransactionVersion = TV_2) - var transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareCommit, TV_2, RecordBatch.NO_PRODUCER_ID, time.milliseconds() - 1, false) + var transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_COMMIT, TV_2, RecordBatch.NO_PRODUCER_ID, time.milliseconds() - 1, false) txnMetadata.completeTransitionTo(transitMetadata) assertEquals(producerId, txnMetadata.producerId) assertEquals((producerEpoch + 1).toShort, txnMetadata.producerEpoch) @@ -559,7 +563,7 @@ class TransactionMetadataTest { producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, - state = Ongoing, + state = TransactionState.ONGOING, topicPartitions = mutable.Set.empty, txnStartTimestamp = time.milliseconds(), txnLastUpdateTimestamp = time.milliseconds(), @@ -567,7 +571,7 @@ class TransactionMetadataTest { assertTrue(txnMetadata.isProducerEpochExhausted) val newProducerId = 9893L - var transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareCommit, TV_2, newProducerId, time.milliseconds() - 1, false) + var transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_COMMIT, TV_2, newProducerId, time.milliseconds() - 1, false) txnMetadata.completeTransitionTo(transitMetadata) assertEquals(producerId, txnMetadata.producerId) assertEquals(Short.MaxValue, txnMetadata.producerEpoch) @@ -584,21 +588,21 @@ class TransactionMetadataTest { @Test def testRotateProducerIdInOngoingState(): Unit = { - assertThrows(classOf[IllegalStateException], () => testRotateProducerIdInOngoingState(Ongoing, TV_0)) + assertThrows(classOf[IllegalStateException], () => testRotateProducerIdInOngoingState(TransactionState.ONGOING, TV_0)) } @ParameterizedTest @ValueSource(shorts = Array(0, 2)) def testRotateProducerIdInPrepareAbortState(transactionVersion: Short): Unit = { val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) - assertThrows(classOf[IllegalStateException], () => testRotateProducerIdInOngoingState(PrepareAbort, clientTransactionVersion)) + assertThrows(classOf[IllegalStateException], () => testRotateProducerIdInOngoingState(TransactionState.PREPARE_ABORT, clientTransactionVersion)) } @ParameterizedTest @ValueSource(shorts = Array(0, 2)) def testRotateProducerIdInPrepareCommitState(transactionVersion: Short): Unit = { val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) - assertThrows(classOf[IllegalStateException], () => testRotateProducerIdInOngoingState(PrepareCommit, clientTransactionVersion)) + assertThrows(classOf[IllegalStateException], () => testRotateProducerIdInOngoingState(TransactionState.PREPARE_COMMIT, clientTransactionVersion)) } @Test @@ -613,7 +617,7 @@ class TransactionMetadataTest { producerEpoch = RecordBatch.NO_PRODUCER_EPOCH, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, - state = Empty, + state = TransactionState.EMPTY, topicPartitions = mutable.Set.empty, txnLastUpdateTimestamp = time.milliseconds(), clientTransactionVersion = TV_0) @@ -637,7 +641,7 @@ class TransactionMetadataTest { producerEpoch = producerEpoch, lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs = 30000, - state = Empty, + state = TransactionState.EMPTY, topicPartitions = mutable.Set.empty, txnLastUpdateTimestamp = time.milliseconds(), clientTransactionVersion = TV_0) @@ -662,7 +666,7 @@ class TransactionMetadataTest { producerEpoch = producerEpoch, lastProducerEpoch = lastProducerEpoch, txnTimeoutMs = 30000, - state = Empty, + state = TransactionState.EMPTY, topicPartitions = mutable.Set.empty, txnLastUpdateTimestamp = time.milliseconds(), clientTransactionVersion = TV_0) @@ -687,7 +691,7 @@ class TransactionMetadataTest { producerEpoch = producerEpoch, lastProducerEpoch = lastProducerEpoch, txnTimeoutMs = 30000, - state = Empty, + state = TransactionState.EMPTY, topicPartitions = mutable.Set.empty, txnLastUpdateTimestamp = time.milliseconds(), clientTransactionVersion = TV_0) @@ -699,13 +703,13 @@ class TransactionMetadataTest { @Test def testTransactionStateIdAndNameMapping(): Unit = { - for (state <- TransactionState.AllStates) { + for (state <- TransactionState.ALL_STATES.asScala) { assertEquals(state, TransactionState.fromId(state.id)) - assertEquals(Some(state), TransactionState.fromName(state.name)) + assertEquals(Optional.of(state), TransactionState.fromName(state.stateName)) - if (state != Dead) { - val clientTransactionState = org.apache.kafka.clients.admin.TransactionState.parse(state.name) - assertEquals(state.name, clientTransactionState.toString) + if (state != TransactionState.DEAD) { + val clientTransactionState = org.apache.kafka.clients.admin.TransactionState.parse(state.stateName) + assertEquals(state.stateName, clientTransactionState.toString) assertNotEquals(org.apache.kafka.clients.admin.TransactionState.UNKNOWN, clientTransactionState) } } @@ -714,27 +718,27 @@ class TransactionMetadataTest { @Test def testAllTransactionStatesAreMapped(): Unit = { val unmatchedStates = mutable.Set( - Empty, - Ongoing, - PrepareCommit, - PrepareAbort, - CompleteCommit, - CompleteAbort, - PrepareEpochFence, - Dead + TransactionState.EMPTY, + TransactionState.ONGOING, + TransactionState.PREPARE_COMMIT, + TransactionState.PREPARE_ABORT, + TransactionState.COMPLETE_COMMIT, + TransactionState.COMPLETE_ABORT, + TransactionState.PREPARE_EPOCH_FENCE, + TransactionState.DEAD ) // The exhaustive match is intentional here to ensure that we are // forced to update the test case if a new state is added. - TransactionState.AllStates.foreach { - case Empty => assertTrue(unmatchedStates.remove(Empty)) - case Ongoing => assertTrue(unmatchedStates.remove(Ongoing)) - case PrepareCommit => assertTrue(unmatchedStates.remove(PrepareCommit)) - case PrepareAbort => assertTrue(unmatchedStates.remove(PrepareAbort)) - case CompleteCommit => assertTrue(unmatchedStates.remove(CompleteCommit)) - case CompleteAbort => assertTrue(unmatchedStates.remove(CompleteAbort)) - case PrepareEpochFence => assertTrue(unmatchedStates.remove(PrepareEpochFence)) - case Dead => assertTrue(unmatchedStates.remove(Dead)) + TransactionState.ALL_STATES.asScala.foreach { + case TransactionState.EMPTY => assertTrue(unmatchedStates.remove(TransactionState.EMPTY)) + case TransactionState.ONGOING => assertTrue(unmatchedStates.remove(TransactionState.ONGOING)) + case TransactionState.PREPARE_COMMIT => assertTrue(unmatchedStates.remove(TransactionState.PREPARE_COMMIT)) + case TransactionState.PREPARE_ABORT => assertTrue(unmatchedStates.remove(TransactionState.PREPARE_ABORT)) + case TransactionState.COMPLETE_COMMIT => assertTrue(unmatchedStates.remove(TransactionState.COMPLETE_COMMIT)) + case TransactionState.COMPLETE_ABORT => assertTrue(unmatchedStates.remove(TransactionState.COMPLETE_ABORT)) + case TransactionState.PREPARE_EPOCH_FENCE => assertTrue(unmatchedStates.remove(TransactionState.PREPARE_EPOCH_FENCE)) + case TransactionState.DEAD => assertTrue(unmatchedStates.remove(TransactionState.DEAD)) } assertEquals(Set.empty, unmatchedStates) diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala index 629e22e2c64..2eb0cbf630e 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala @@ -33,6 +33,7 @@ import org.apache.kafka.common.record._ import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse import org.apache.kafka.common.requests.TransactionResult import org.apache.kafka.common.utils.MockTime +import org.apache.kafka.coordinator.transaction.TransactionState import org.apache.kafka.metadata.MetadataCache import org.apache.kafka.server.common.{FinalizedFeatures, MetadataVersion, RequestLocal, TransactionVersion} import org.apache.kafka.server.common.TransactionVersion.{TV_0, TV_2} @@ -181,7 +182,7 @@ class TransactionStateManagerTest { ).thenReturn(new FetchDataInfo(new LogOffsetMetadata(startOffset), fileRecordsMock)) when(replicaManager.getLogEndOffset(topicPartition)).thenReturn(Some(endOffset)) - txnMetadata1.state = PrepareCommit + txnMetadata1.state = TransactionState.PREPARE_COMMIT txnMetadata1.addPartitions(Set[TopicPartition]( new TopicPartition("topic1", 0), new TopicPartition("topic1", 1))) @@ -240,7 +241,7 @@ class TransactionStateManagerTest { ).thenReturn(new FetchDataInfo(new LogOffsetMetadata(startOffset), fileRecordsMock)) when(replicaManager.getLogEndOffset(topicPartition)).thenReturn(Some(endOffset)) - txnMetadata1.state = PrepareCommit + txnMetadata1.state = TransactionState.PREPARE_COMMIT txnMetadata1.addPartitions(Set[TopicPartition]( new TopicPartition("topic1", 0), new TopicPartition("topic1", 1))) @@ -285,7 +286,7 @@ class TransactionStateManagerTest { // generate transaction log messages for two pids traces: // pid1's transaction started with two partitions - txnMetadata1.state = Ongoing + txnMetadata1.state = TransactionState.ONGOING txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0), new TopicPartition("topic1", 1))) @@ -299,12 +300,12 @@ class TransactionStateManagerTest { txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2)) // pid1's transaction is preparing to commit - txnMetadata1.state = PrepareCommit + txnMetadata1.state = TransactionState.PREPARE_COMMIT txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2)) // pid2's transaction started with three partitions - txnMetadata2.state = Ongoing + txnMetadata2.state = TransactionState.ONGOING txnMetadata2.addPartitions(Set[TopicPartition](new TopicPartition("topic3", 0), new TopicPartition("topic3", 1), new TopicPartition("topic3", 2))) @@ -312,17 +313,17 @@ class TransactionStateManagerTest { txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), TV_2)) // pid2's transaction is preparing to abort - txnMetadata2.state = PrepareAbort + txnMetadata2.state = TransactionState.PREPARE_ABORT txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), TV_2)) // pid2's transaction has aborted - txnMetadata2.state = CompleteAbort + txnMetadata2.state = TransactionState.COMPLETE_ABORT txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), TV_2)) // pid2's epoch has advanced, with no ongoing transaction yet - txnMetadata2.state = Empty + txnMetadata2.state = TransactionState.EMPTY txnMetadata2.topicPartitions.clear() txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), TV_2)) @@ -511,7 +512,7 @@ class TransactionStateManagerTest { prepareForTxnMessageAppend(Errors.UNKNOWN_TOPIC_OR_PARTITION) transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback, _ => true, RequestLocal.withThreadConfinedCaching) assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))), transactionManager.getTransactionState(transactionalId1)) - assertEquals(Some(Ongoing), txnMetadata1.pendingState) + assertEquals(Some(TransactionState.ONGOING), txnMetadata1.pendingState) } @Test @@ -591,25 +592,25 @@ class TransactionStateManagerTest { } } - putTransaction(transactionalId = "t0", producerId = 0, state = Ongoing) - putTransaction(transactionalId = "t1", producerId = 1, state = Ongoing) - putTransaction(transactionalId = "my-special-0", producerId = 0, state = Ongoing) + putTransaction(transactionalId = "t0", producerId = 0, state = TransactionState.ONGOING) + putTransaction(transactionalId = "t1", producerId = 1, state = TransactionState.ONGOING) + putTransaction(transactionalId = "my-special-0", producerId = 0, state = TransactionState.ONGOING) // update time to create transactions with various durations time.sleep(1000) - putTransaction(transactionalId = "t2", producerId = 2, state = PrepareCommit) - putTransaction(transactionalId = "t3", producerId = 3, state = PrepareAbort) - putTransaction(transactionalId = "your-special-1", producerId = 0, state = PrepareAbort) + putTransaction(transactionalId = "t2", producerId = 2, state = TransactionState.PREPARE_COMMIT) + putTransaction(transactionalId = "t3", producerId = 3, state = TransactionState.PREPARE_ABORT) + putTransaction(transactionalId = "your-special-1", producerId = 0, state = TransactionState.PREPARE_ABORT) time.sleep(1000) - putTransaction(transactionalId = "t4", producerId = 4, state = CompleteCommit) - putTransaction(transactionalId = "t5", producerId = 5, state = CompleteAbort) - putTransaction(transactionalId = "t6", producerId = 6, state = CompleteAbort) - putTransaction(transactionalId = "t7", producerId = 7, state = PrepareEpochFence) - putTransaction(transactionalId = "their-special-2", producerId = 7, state = CompleteAbort) + putTransaction(transactionalId = "t4", producerId = 4, state = TransactionState.COMPLETE_COMMIT) + putTransaction(transactionalId = "t5", producerId = 5, state = TransactionState.COMPLETE_ABORT) + putTransaction(transactionalId = "t6", producerId = 6, state = TransactionState.COMPLETE_ABORT) + putTransaction(transactionalId = "t7", producerId = 7, state = TransactionState.PREPARE_EPOCH_FENCE) + putTransaction(transactionalId = "their-special-2", producerId = 7, state = TransactionState.COMPLETE_ABORT) time.sleep(1000) - // Note that `Dead` transactions are never returned. This is a transient state + // Note that `TransactionState.DEAD` transactions are never returned. This is a transient state // which is used when the transaction state is in the process of being deleted // (whether though expiration or coordinator unloading). - putTransaction(transactionalId = "t8", producerId = 8, state = Dead) + putTransaction(transactionalId = "t8", producerId = 8, state = TransactionState.DEAD) def assertListTransactions( expectedTransactionalIds: Set[String], @@ -657,12 +658,12 @@ class TransactionStateManagerTest { transactionManager.addLoadedTransactionsToCache(partitionId, 0, new ConcurrentHashMap[String, TransactionMetadata]()) } - transactionManager.putTransactionStateIfNotExists(transactionMetadata("ongoing", producerId = 0, state = Ongoing)) - transactionManager.putTransactionStateIfNotExists(transactionMetadata("not-expiring", producerId = 1, state = Ongoing, txnTimeout = 10000)) - transactionManager.putTransactionStateIfNotExists(transactionMetadata("prepare-commit", producerId = 2, state = PrepareCommit)) - transactionManager.putTransactionStateIfNotExists(transactionMetadata("prepare-abort", producerId = 3, state = PrepareAbort)) - transactionManager.putTransactionStateIfNotExists(transactionMetadata("complete-commit", producerId = 4, state = CompleteCommit)) - transactionManager.putTransactionStateIfNotExists(transactionMetadata("complete-abort", producerId = 5, state = CompleteAbort)) + transactionManager.putTransactionStateIfNotExists(transactionMetadata("ongoing", producerId = 0, state = TransactionState.ONGOING)) + transactionManager.putTransactionStateIfNotExists(transactionMetadata("not-expiring", producerId = 1, state = TransactionState.ONGOING, txnTimeout = 10000)) + transactionManager.putTransactionStateIfNotExists(transactionMetadata("prepare-commit", producerId = 2, state = TransactionState.PREPARE_COMMIT)) + transactionManager.putTransactionStateIfNotExists(transactionMetadata("prepare-abort", producerId = 3, state = TransactionState.PREPARE_ABORT)) + transactionManager.putTransactionStateIfNotExists(transactionMetadata("complete-commit", producerId = 4, state = TransactionState.COMPLETE_COMMIT)) + transactionManager.putTransactionStateIfNotExists(transactionMetadata("complete-abort", producerId = 5, state = TransactionState.COMPLETE_ABORT)) time.sleep(2000) val expiring = transactionManager.timedOutTransactions() @@ -671,59 +672,59 @@ class TransactionStateManagerTest { @Test def shouldWriteTxnMarkersForTransactionInPreparedCommitState(): Unit = { - verifyWritesTxnMarkersInPrepareState(PrepareCommit) + verifyWritesTxnMarkersInPrepareState(TransactionState.PREPARE_COMMIT) } @Test def shouldWriteTxnMarkersForTransactionInPreparedAbortState(): Unit = { - verifyWritesTxnMarkersInPrepareState(PrepareAbort) + verifyWritesTxnMarkersInPrepareState(TransactionState.PREPARE_ABORT) } @Test def shouldRemoveCompleteCommitExpiredTransactionalIds(): Unit = { - setupAndRunTransactionalIdExpiration(Errors.NONE, CompleteCommit) + setupAndRunTransactionalIdExpiration(Errors.NONE, TransactionState.COMPLETE_COMMIT) verifyMetadataDoesntExist(transactionalId1) verifyMetadataDoesExistAndIsUsable(transactionalId2) } @Test def shouldRemoveCompleteAbortExpiredTransactionalIds(): Unit = { - setupAndRunTransactionalIdExpiration(Errors.NONE, CompleteAbort) + setupAndRunTransactionalIdExpiration(Errors.NONE, TransactionState.COMPLETE_ABORT) verifyMetadataDoesntExist(transactionalId1) verifyMetadataDoesExistAndIsUsable(transactionalId2) } @Test def shouldRemoveEmptyExpiredTransactionalIds(): Unit = { - setupAndRunTransactionalIdExpiration(Errors.NONE, Empty) + setupAndRunTransactionalIdExpiration(Errors.NONE, TransactionState.EMPTY) verifyMetadataDoesntExist(transactionalId1) verifyMetadataDoesExistAndIsUsable(transactionalId2) } @Test def shouldNotRemoveExpiredTransactionalIdsIfLogAppendFails(): Unit = { - setupAndRunTransactionalIdExpiration(Errors.NOT_ENOUGH_REPLICAS, CompleteAbort) + setupAndRunTransactionalIdExpiration(Errors.NOT_ENOUGH_REPLICAS, TransactionState.COMPLETE_ABORT) verifyMetadataDoesExistAndIsUsable(transactionalId1) verifyMetadataDoesExistAndIsUsable(transactionalId2) } @Test def shouldNotRemoveOngoingTransactionalIds(): Unit = { - setupAndRunTransactionalIdExpiration(Errors.NONE, Ongoing) + setupAndRunTransactionalIdExpiration(Errors.NONE, TransactionState.ONGOING) verifyMetadataDoesExistAndIsUsable(transactionalId1) verifyMetadataDoesExistAndIsUsable(transactionalId2) } @Test def shouldNotRemovePrepareAbortTransactionalIds(): Unit = { - setupAndRunTransactionalIdExpiration(Errors.NONE, PrepareAbort) + setupAndRunTransactionalIdExpiration(Errors.NONE, TransactionState.PREPARE_ABORT) verifyMetadataDoesExistAndIsUsable(transactionalId1) verifyMetadataDoesExistAndIsUsable(transactionalId2) } @Test def shouldNotRemovePrepareCommitTransactionalIds(): Unit = { - setupAndRunTransactionalIdExpiration(Errors.NONE, PrepareCommit) + setupAndRunTransactionalIdExpiration(Errors.NONE, TransactionState.PREPARE_COMMIT) verifyMetadataDoesExistAndIsUsable(transactionalId1) verifyMetadataDoesExistAndIsUsable(transactionalId2) } @@ -879,7 +880,7 @@ class TransactionStateManagerTest { // will be expired and it should succeed. val timestamp = time.milliseconds() val txnMetadata = new TransactionMetadata(transactionalId, 1, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, - RecordBatch.NO_PRODUCER_EPOCH, transactionTimeoutMs, Empty, collection.mutable.Set.empty[TopicPartition], timestamp, timestamp, TV_0) + RecordBatch.NO_PRODUCER_EPOCH, transactionTimeoutMs, TransactionState.EMPTY, collection.mutable.Set.empty[TopicPartition], timestamp, timestamp, TV_0) transactionManager.putTransactionStateIfNotExists(txnMetadata) time.sleep(txnConfig.transactionalIdExpirationMs + 1) @@ -966,7 +967,7 @@ class TransactionStateManagerTest { @Test def testSuccessfulReimmigration(): Unit = { - txnMetadata1.state = PrepareCommit + txnMetadata1.state = TransactionState.PREPARE_COMMIT txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0), new TopicPartition("topic1", 1))) @@ -1033,9 +1034,9 @@ class TransactionStateManagerTest { @Test def testLoadTransactionMetadataContainingSegmentEndingWithEmptyBatch(): Unit = { // Simulate a case where a log contains two segments and the first segment ending with an empty batch. - txnMetadata1.state = PrepareCommit + txnMetadata1.state = TransactionState.PREPARE_COMMIT txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0))) - txnMetadata2.state = Ongoing + txnMetadata2.state = TransactionState.ONGOING txnMetadata2.addPartitions(Set[TopicPartition](new TopicPartition("topic2", 0))) // Create the first segment which contains two batches. @@ -1176,7 +1177,7 @@ class TransactionStateManagerTest { transactionManager.removeExpiredTransactionalIds() val stateAllowsExpiration = txnState match { - case Empty | CompleteCommit | CompleteAbort => true + case TransactionState.EMPTY | TransactionState.COMPLETE_COMMIT | TransactionState.COMPLETE_ABORT => true case _ => false } @@ -1223,7 +1224,7 @@ class TransactionStateManagerTest { private def transactionMetadata(transactionalId: String, producerId: Long, - state: TransactionState = Empty, + state: TransactionState = TransactionState.EMPTY, txnTimeout: Int = transactionTimeoutMs): TransactionMetadata = { val timestamp = time.milliseconds() new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_ID, 0.toShort, @@ -1300,7 +1301,7 @@ class TransactionStateManagerTest { assertEquals(Double.NaN, partitionLoadTime("partition-load-time-avg"), 0) assertTrue(reporter.containsMbean(mBeanName)) - txnMetadata1.state = Ongoing + txnMetadata1.state = TransactionState.ONGOING txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 1), new TopicPartition("topic1", 1))) @@ -1319,7 +1320,7 @@ class TransactionStateManagerTest { @Test def testIgnoreUnknownRecordType(): Unit = { - txnMetadata1.state = PrepareCommit + txnMetadata1.state = TransactionState.PREPARE_COMMIT txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0), new TopicPartition("topic1", 1))) diff --git a/transaction-coordinator/src/main/java/org/apache/kafka/coordinator/transaction/TransactionState.java b/transaction-coordinator/src/main/java/org/apache/kafka/coordinator/transaction/TransactionState.java new file mode 100644 index 00000000000..c03066945be --- /dev/null +++ b/transaction-coordinator/src/main/java/org/apache/kafka/coordinator/transaction/TransactionState.java @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.coordinator.transaction; + +import java.util.Arrays; +import java.util.EnumSet; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; + +/** + * Represents the states of a transaction in the transaction coordinator. + * This enum corresponds to the Scala sealed trait TransactionState in kafka.coordinator.transaction. + */ +public enum TransactionState { + /** + * Transaction has not existed yet + *

+ * transition: received AddPartitionsToTxnRequest => Ongoing + * received AddOffsetsToTxnRequest => Ongoing + * received EndTxnRequest with abort and TransactionV2 enabled => PrepareAbort + */ + EMPTY((byte) 0, org.apache.kafka.clients.admin.TransactionState.EMPTY.toString(), true), + /** + * Transaction has started and ongoing + *

+ * transition: received EndTxnRequest with commit => PrepareCommit + * received EndTxnRequest with abort => PrepareAbort + * received AddPartitionsToTxnRequest => Ongoing + * received AddOffsetsToTxnRequest => Ongoing + */ + ONGOING((byte) 1, org.apache.kafka.clients.admin.TransactionState.ONGOING.toString(), false), + /** + * Group is preparing to commit + * transition: received acks from all partitions => CompleteCommit + */ + PREPARE_COMMIT((byte) 2, org.apache.kafka.clients.admin.TransactionState.PREPARE_COMMIT.toString(), false), + /** + * Group is preparing to abort + *

+ * transition: received acks from all partitions => CompleteAbort + *

+ * Note, In transaction v2, we allow Empty, CompleteCommit, CompleteAbort to transition to PrepareAbort. because the + * client may not know the txn state on the server side, it needs to send endTxn request when uncertain. + */ + PREPARE_ABORT((byte) 3, org.apache.kafka.clients.admin.TransactionState.PREPARE_ABORT.toString(), false), + /** + * Group has completed commit + *

+ * Will soon be removed from the ongoing transaction cache + */ + COMPLETE_COMMIT((byte) 4, org.apache.kafka.clients.admin.TransactionState.COMPLETE_COMMIT.toString(), true), + /** + * Group has completed abort + *

+ * Will soon be removed from the ongoing transaction cache + */ + COMPLETE_ABORT((byte) 5, org.apache.kafka.clients.admin.TransactionState.COMPLETE_ABORT.toString(), true), + /** + * TransactionalId has expired and is about to be removed from the transaction cache + */ + DEAD((byte) 6, "Dead", false), + /** + * We are in the middle of bumping the epoch and fencing out older producers. + */ + PREPARE_EPOCH_FENCE((byte) 7, org.apache.kafka.clients.admin.TransactionState.PREPARE_EPOCH_FENCE.toString(), false); + + private static final Map NAME_TO_ENUM = Arrays.stream(values()) + .collect(Collectors.toUnmodifiableMap(TransactionState::stateName, Function.identity())); + + private static final Map ID_TO_ENUM = Arrays.stream(values()) + .collect(Collectors.toUnmodifiableMap(TransactionState::id, Function.identity())); + + public static final Set ALL_STATES = Set.copyOf(EnumSet.allOf(TransactionState.class)); + + private final byte id; + private final String stateName; + public static final Map> VALID_PREVIOUS_STATES = Map.of( + EMPTY, Set.of(EMPTY, COMPLETE_COMMIT, COMPLETE_ABORT), + ONGOING, Set.of(ONGOING, EMPTY, COMPLETE_COMMIT, COMPLETE_ABORT), + PREPARE_COMMIT, Set.of(ONGOING), + PREPARE_ABORT, Set.of(ONGOING, PREPARE_EPOCH_FENCE, EMPTY, COMPLETE_COMMIT, COMPLETE_ABORT), + COMPLETE_COMMIT, Set.of(PREPARE_COMMIT), + COMPLETE_ABORT, Set.of(PREPARE_ABORT), + DEAD, Set.of(EMPTY, COMPLETE_ABORT, COMPLETE_COMMIT), + PREPARE_EPOCH_FENCE, Set.of(ONGOING) + ); + + private final boolean expirationAllowed; + + TransactionState(byte id, String name, boolean expirationAllowed) { + this.id = id; + this.stateName = name; + this.expirationAllowed = expirationAllowed; + } + + /** + * @return The state id byte. + */ + public byte id() { + return id; + } + + /** + * Get the name of this state. This is exposed through the `DescribeTransactions` API. + * @return The state name string. + */ + public String stateName() { + return stateName; + } + + /** + * @return The set of states from which it is valid to transition into this state. + */ + public Set validPreviousStates() { + return VALID_PREVIOUS_STATES.getOrDefault(this, Set.of()); + } + + /** + * @return True if expiration is allowed in this state, false otherwise. + */ + public boolean isExpirationAllowed() { + return expirationAllowed; + } + + /** + * Finds a TransactionState by its name. + * @param name The name of the state. + * @return An Optional containing the TransactionState if found, otherwise empty. + */ + public static Optional fromName(String name) { + return Optional.ofNullable(NAME_TO_ENUM.get(name)); + } + + /** + * Finds a TransactionState by its ID. + * @param id The byte ID of the state. + * @return The TransactionState corresponding to the ID. + * @throws IllegalStateException if the ID is unknown. + */ + public static TransactionState fromId(byte id) { + TransactionState state = ID_TO_ENUM.get(id); + if (state == null) { + throw new IllegalStateException("Unknown transaction state id " + id + " from the transaction status message"); + } + return state; + } +}