KAFKA-18884 Move TransactionMetadata to transaction-coordinator module (#19699)

1. Move TransactionMetadata to transaction-coordinator module.
2. Rewrite TransactionMetadata in Java.
3. The `topicPartitions` field uses `HashSet` instead of `Set`, because
it's mutable field.
4. In Scala, when calling `prepare*` methods, they can use current value
as default input in `prepareTransitionTo`. However, in Java, it doesn't
support function default input value. To avoid a lot of duplicated code
or assign value to wrong field, we add a private class `TransitionData`.
It can get current `TransactionMetadata` value as default value and
`prepare*` methods just need to assign updated value.

Reviewers: Justine Olshan <jolshan@confluent.io>, Artem Livshits
 <alivshits@confluent.io>, Chia-Ping Tsai <chia7712@gmail.com>
This commit is contained in:
PoAn Yang 2025-08-16 02:10:52 +08:00 committed by GitHub
parent 27647c7c7c
commit 990cb5c06c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 1316 additions and 1119 deletions

View File

@ -27,14 +27,15 @@ import org.apache.kafka.common.protocol.Errors
import org.apache.kafka.common.record.RecordBatch
import org.apache.kafka.common.requests.{AddPartitionsToTxnResponse, TransactionResult}
import org.apache.kafka.common.utils.{LogContext, ProducerIdAndEpoch, Time}
import org.apache.kafka.coordinator.transaction.{ProducerIdManager, TransactionLogConfig, TransactionState, TransactionStateManagerConfig, TxnTransitMetadata}
import org.apache.kafka.coordinator.transaction.{ProducerIdManager, TransactionLogConfig, TransactionMetadata, TransactionState, TransactionStateManagerConfig, TxnTransitMetadata}
import org.apache.kafka.metadata.MetadataCache
import org.apache.kafka.server.common.{RequestLocal, TransactionVersion}
import org.apache.kafka.server.util.Scheduler
import java.util
import java.util.Properties
import java.util.concurrent.atomic.AtomicBoolean
import scala.jdk.CollectionConverters._
import scala.jdk.OptionConverters._
object TransactionCoordinator {
@ -147,17 +148,18 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
val coordinatorEpochAndMetadata = txnManager.getTransactionState(transactionalId).flatMap {
case None =>
try {
val createdMetadata = new TransactionMetadata(transactionalId = transactionalId,
producerId = producerIdManager.generateProducerId(),
prevProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = resolvedTxnTimeoutMs,
state = TransactionState.EMPTY,
topicPartitions = collection.mutable.Set.empty[TopicPartition],
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TransactionVersion.TV_0)
val createdMetadata = new TransactionMetadata(transactionalId,
producerIdManager.generateProducerId(),
RecordBatch.NO_PRODUCER_ID,
RecordBatch.NO_PRODUCER_ID,
RecordBatch.NO_PRODUCER_EPOCH,
RecordBatch.NO_PRODUCER_EPOCH,
resolvedTxnTimeoutMs,
TransactionState.EMPTY,
util.Set.of(),
-1,
time.milliseconds(),
TransactionVersion.TV_0)
txnManager.putTransactionStateIfNotExists(createdMetadata)
} catch {
case e: Exception => Left(Errors.forException(e))
@ -171,10 +173,10 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
val coordinatorEpoch = existingEpochAndMetadata.coordinatorEpoch
val txnMetadata = existingEpochAndMetadata.transactionMetadata
txnMetadata.inLock {
txnMetadata.inLock(() =>
prepareInitProducerIdTransit(transactionalId, resolvedTxnTimeoutMs, coordinatorEpoch, txnMetadata,
expectedProducerIdAndEpoch)
}
)
}
result match {
@ -256,17 +258,16 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
case TransactionState.COMPLETE_ABORT | TransactionState.COMPLETE_COMMIT | TransactionState.EMPTY =>
val transitMetadataResult =
// If the epoch is exhausted and the expected epoch (if provided) matches it, generate a new producer ID
if (txnMetadata.isProducerEpochExhausted &&
expectedProducerIdAndEpoch.forall(_.epoch == txnMetadata.producerEpoch)) {
try {
try {
if (txnMetadata.isProducerEpochExhausted &&
expectedProducerIdAndEpoch.forall(_.epoch == txnMetadata.producerEpoch))
Right(txnMetadata.prepareProducerIdRotation(producerIdManager.generateProducerId(), transactionTimeoutMs, time.milliseconds(),
expectedProducerIdAndEpoch.isDefined))
} catch {
case e: Exception => Left(Errors.forException(e))
}
} else {
txnMetadata.prepareIncrementProducerEpoch(transactionTimeoutMs, expectedProducerIdAndEpoch.map(_.epoch),
time.milliseconds())
else
Right(txnMetadata.prepareIncrementProducerEpoch(transactionTimeoutMs, expectedProducerIdAndEpoch.map(e => Short.box(e.epoch)).toJava,
time.milliseconds()))
} catch {
case e: Exception => Left(Errors.forException(e))
}
transitMetadataResult match {
@ -326,12 +327,12 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
transactionState.setErrorCode(Errors.TRANSACTIONAL_ID_NOT_FOUND.code)
case Right(Some(coordinatorEpochAndMetadata)) =>
val txnMetadata = coordinatorEpochAndMetadata.transactionMetadata
txnMetadata.inLock {
txnMetadata.inLock(() => {
if (txnMetadata.state == TransactionState.DEAD) {
// The transaction state is being expired, so ignore it
transactionState.setErrorCode(Errors.TRANSACTIONAL_ID_NOT_FOUND.code)
} else {
txnMetadata.topicPartitions.foreach { topicPartition =>
txnMetadata.topicPartitions.forEach(topicPartition => {
var topicData = transactionState.topics.find(topicPartition.topic)
if (topicData == null) {
topicData = new DescribeTransactionsResponseData.TopicData()
@ -339,7 +340,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
transactionState.topics.add(topicData)
}
topicData.partitions.add(topicPartition.partition)
}
})
transactionState
.setErrorCode(Errors.NONE.code)
@ -349,7 +350,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
.setTransactionTimeoutMs(txnMetadata.txnTimeoutMs)
.setTransactionStartTimeMs(txnMetadata.txnStartTimestamp)
}
}
})
}
}
}
@ -357,13 +358,15 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
def handleVerifyPartitionsInTransaction(transactionalId: String,
producerId: Long,
producerEpoch: Short,
partitions: collection.Set[TopicPartition],
partitions: util.Set[TopicPartition],
responseCallback: VerifyPartitionsCallback): Unit = {
if (transactionalId == null || transactionalId.isEmpty) {
debug(s"Returning ${Errors.INVALID_REQUEST} error code to client for $transactionalId's AddPartitions request for verification")
responseCallback(AddPartitionsToTxnResponse.resultForTransaction(transactionalId, partitions.map(_ -> Errors.INVALID_REQUEST).toMap.asJava))
val errors = new util.HashMap[TopicPartition, Errors]()
partitions.forEach(partition => errors.put(partition, Errors.INVALID_REQUEST))
responseCallback(AddPartitionsToTxnResponse.resultForTransaction(transactionalId, errors))
} else {
val result: ApiResult[Map[TopicPartition, Errors]] =
val result: ApiResult[util.Map[TopicPartition, Errors]] =
txnManager.getTransactionState(transactionalId).flatMap {
case None => Left(Errors.INVALID_PRODUCER_ID_MAPPING)
@ -373,7 +376,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
// Given the txnMetadata is valid, we check if the partitions are in the transaction.
// Pending state is not checked since there is a final validation on the append to the log.
// Partitions are added to metadata when the add partitions state is persisted, and removed when the end marker is persisted.
txnMetadata.inLock {
txnMetadata.inLock(() => {
if (txnMetadata.producerId != producerId) {
Left(Errors.INVALID_PRODUCER_ID_MAPPING)
} else if (txnMetadata.producerEpoch != producerEpoch) {
@ -381,23 +384,27 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
} else if (txnMetadata.state == TransactionState.PREPARE_COMMIT || txnMetadata.state == TransactionState.PREPARE_ABORT) {
Left(Errors.CONCURRENT_TRANSACTIONS)
} else {
Right(partitions.map { part =>
val errors = new util.HashMap[TopicPartition, Errors]()
partitions.forEach(part => {
if (txnMetadata.topicPartitions.contains(part))
(part, Errors.NONE)
errors.put(part, Errors.NONE)
else
(part, Errors.TRANSACTION_ABORTABLE)
}.toMap)
errors.put(part, Errors.TRANSACTION_ABORTABLE)
})
Right(errors)
}
}
})
}
result match {
case Left(err) =>
debug(s"Returning $err error code to client for $transactionalId's AddPartitions request for verification")
responseCallback(AddPartitionsToTxnResponse.resultForTransaction(transactionalId, partitions.map(_ -> err).toMap.asJava))
val errors = new util.HashMap[TopicPartition, Errors]()
partitions.forEach(partition => errors.put(partition, err))
responseCallback(AddPartitionsToTxnResponse.resultForTransaction(transactionalId, errors))
case Right(errors) =>
responseCallback(AddPartitionsToTxnResponse.resultForTransaction(transactionalId, errors.asJava))
responseCallback(AddPartitionsToTxnResponse.resultForTransaction(transactionalId, errors))
}
}
@ -406,7 +413,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
def handleAddPartitionsToTransaction(transactionalId: String,
producerId: Long,
producerEpoch: Short,
partitions: collection.Set[TopicPartition],
partitions: util.Set[TopicPartition],
responseCallback: AddPartitionsCallback,
clientTransactionVersion: TransactionVersion,
requestLocal: RequestLocal = RequestLocal.noCaching): Unit = {
@ -424,7 +431,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
val txnMetadata = epochAndMetadata.transactionMetadata
// generate the new transaction metadata with added partitions
txnMetadata.inLock {
txnMetadata.inLock(() => {
if (txnMetadata.pendingTransitionInProgress) {
// return a retriable exception to let the client backoff and retry
// This check is performed first so that the pending transition can complete before subsequent checks.
@ -437,13 +444,13 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
Left(Errors.PRODUCER_FENCED)
} else if (txnMetadata.state == TransactionState.PREPARE_COMMIT || txnMetadata.state == TransactionState.PREPARE_ABORT) {
Left(Errors.CONCURRENT_TRANSACTIONS)
} else if (txnMetadata.state == TransactionState.ONGOING && partitions.subsetOf(txnMetadata.topicPartitions)) {
} else if (txnMetadata.state == TransactionState.ONGOING && txnMetadata.topicPartitions.containsAll(partitions)) {
// this is an optimization: if the partitions are already in the metadata reply OK immediately
Left(Errors.NONE)
} else {
Right(coordinatorEpoch, txnMetadata.prepareAddPartitions(partitions.toSet, time.milliseconds(), clientTransactionVersion))
Right(coordinatorEpoch, txnMetadata.prepareAddPartitions(partitions, time.milliseconds(), clientTransactionVersion))
}
}
})
}
result match {
@ -549,7 +556,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
val txnMetadata = epochAndTxnMetadata.transactionMetadata
val coordinatorEpoch = epochAndTxnMetadata.coordinatorEpoch
txnMetadata.inLock {
txnMetadata.inLock(() => {
if (txnMetadata.producerId != producerId)
Left(Errors.INVALID_PRODUCER_ID_MAPPING)
// Strict equality is enforced on the client side requests, as they shouldn't bump the producer epoch.
@ -564,13 +571,13 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
else
TransactionState.PREPARE_ABORT
if (nextState == TransactionState.PREPARE_ABORT && txnMetadata.pendingState.contains(TransactionState.PREPARE_EPOCH_FENCE)) {
if (nextState == TransactionState.PREPARE_ABORT && txnMetadata.pendingState.filter(s => s == TransactionState.PREPARE_EPOCH_FENCE).isPresent) {
// We should clear the pending state to make way for the transition to PrepareAbort and also bump
// the epoch in the transaction metadata we are about to append.
isEpochFence = true
txnMetadata.pendingState = None
txnMetadata.producerEpoch = producerEpoch
txnMetadata.lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH
txnMetadata.pendingState(util.Optional.empty())
txnMetadata.setProducerEpoch(producerEpoch)
txnMetadata.setLastProducerEpoch(RecordBatch.NO_PRODUCER_EPOCH)
}
Right(coordinatorEpoch, txnMetadata.prepareAbortOrCommit(nextState, TransactionVersion.fromFeatureLevel(0), RecordBatch.NO_PRODUCER_ID, time.milliseconds(), false))
@ -602,7 +609,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
fatal(errorMsg)
throw new IllegalStateException(errorMsg)
}
}
})
}
preAppendResult match {
@ -623,7 +630,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
case Some(epochAndMetadata) =>
if (epochAndMetadata.coordinatorEpoch == coordinatorEpoch) {
val txnMetadata = epochAndMetadata.transactionMetadata
txnMetadata.inLock {
txnMetadata.inLock(() => {
if (txnMetadata.producerId != producerId)
Left(Errors.INVALID_PRODUCER_ID_MAPPING)
else if (txnMetadata.producerEpoch != producerEpoch)
@ -649,7 +656,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
fatal(errorMsg)
throw new IllegalStateException(errorMsg)
}
}
})
} else {
debug(s"The transaction coordinator epoch has changed to ${epochAndMetadata.coordinatorEpoch} after $txnMarkerResult was " +
s"successfully appended to the log for $transactionalId with old epoch $coordinatorEpoch")
@ -682,7 +689,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
case Some(epochAndMetadata) =>
if (epochAndMetadata.coordinatorEpoch == coordinatorEpoch) {
// This was attempted epoch fence that failed, so mark this state on the metadata
epochAndMetadata.transactionMetadata.hasFailedEpochFence = true
epochAndMetadata.transactionMetadata.hasFailedEpochFence(true)
warn(s"The coordinator failed to write an epoch fence transition for producer $transactionalId to the transaction log " +
s"with error $error. The epoch was increased to ${newMetadata.producerEpoch} but not returned to the client")
}
@ -771,12 +778,12 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
val txnMetadata = epochAndTxnMetadata.transactionMetadata
val coordinatorEpoch = epochAndTxnMetadata.coordinatorEpoch
txnMetadata.inLock {
txnMetadata.inLock(() => {
producerIdCopy = txnMetadata.producerId
producerEpochCopy = txnMetadata.producerEpoch
// PrepareEpochFence has slightly different epoch bumping logic so don't include it here.
// Note that, it can only happen when the current state is Ongoing.
isEpochFence = txnMetadata.pendingState.contains(TransactionState.PREPARE_EPOCH_FENCE)
isEpochFence = txnMetadata.pendingState.filter(s => s == TransactionState.PREPARE_EPOCH_FENCE).isPresent
// True if the client retried a request that had overflowed the epoch, and a new producer ID is stored in the txnMetadata
val retryOnOverflow = !isEpochFence && txnMetadata.prevProducerId == producerId &&
producerEpoch == Short.MaxValue - 1 && txnMetadata.producerEpoch == 0
@ -820,7 +827,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
if (nextState == TransactionState.PREPARE_ABORT && isEpochFence) {
// We should clear the pending state to make way for the transition to PrepareAbort
txnMetadata.pendingState = None
txnMetadata.pendingState(util.Optional.empty())
// For TV2+, don't manually set the epoch - let prepareAbortOrCommit handle it naturally.
}
@ -893,7 +900,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
throw new IllegalStateException(errorMsg)
}
}
})
}
preAppendResult match {
@ -918,7 +925,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
case Some(epochAndMetadata) =>
if (epochAndMetadata.coordinatorEpoch == coordinatorEpoch) {
val txnMetadata = epochAndMetadata.transactionMetadata
txnMetadata.inLock {
txnMetadata.inLock(() => {
if (txnMetadata.producerId != producerId)
Left(Errors.INVALID_PRODUCER_ID_MAPPING)
else if (txnMetadata.producerEpoch != producerEpoch && txnMetadata.producerEpoch != producerEpoch + 1)
@ -945,7 +952,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
throw new IllegalStateException(errorMsg)
}
}
})
} else {
debug(s"The transaction coordinator epoch has changed to ${epochAndMetadata.coordinatorEpoch} after $txnMarkerResult was " +
s"successfully appended to the log for $transactionalId with old epoch $coordinatorEpoch")
@ -1026,7 +1033,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
case Some(epochAndTxnMetadata) =>
val txnMetadata = epochAndTxnMetadata.transactionMetadata
val transitMetadataOpt = txnMetadata.inLock {
val transitMetadataOpt = txnMetadata.inLock(() => {
if (txnMetadata.producerId != txnIdAndPidEpoch.producerId) {
error(s"Found incorrect producerId when expiring transactionalId: ${txnIdAndPidEpoch.transactionalId}. " +
s"Expected producerId: ${txnIdAndPidEpoch.producerId}. Found producerId: " +
@ -1039,7 +1046,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
} else {
Some(txnMetadata.prepareFenceProducerEpoch())
}
}
})
transitMetadataOpt.foreach { txnTransitMetadata =>
endTransaction(txnMetadata.transactionalId,

View File

@ -21,11 +21,12 @@ import org.apache.kafka.common.compress.Compression
import org.apache.kafka.common.protocol.{ByteBufferAccessor, MessageUtil}
import org.apache.kafka.common.record.RecordBatch
import org.apache.kafka.common.TopicPartition
import org.apache.kafka.coordinator.transaction.{TransactionState, TxnTransitMetadata}
import org.apache.kafka.coordinator.transaction.{TransactionMetadata, TransactionState, TxnTransitMetadata}
import org.apache.kafka.coordinator.transaction.generated.{CoordinatorRecordType, TransactionLogKey, TransactionLogValue}
import org.apache.kafka.server.common.TransactionVersion
import scala.collection.mutable
import java.util
import scala.jdk.CollectionConverters._
/**
@ -115,26 +116,26 @@ object TransactionLog {
if (version >= TransactionLogValue.LOWEST_SUPPORTED_VERSION && version <= TransactionLogValue.HIGHEST_SUPPORTED_VERSION) {
val value = new TransactionLogValue(new ByteBufferAccessor(buffer), version)
val transactionMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = value.producerId,
prevProducerId = value.previousProducerId,
nextProducerId = value.nextProducerId,
producerEpoch = value.producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = value.transactionTimeoutMs,
state = TransactionState.fromId(value.transactionStatus),
topicPartitions = mutable.Set.empty[TopicPartition],
txnStartTimestamp = value.transactionStartTimestampMs,
txnLastUpdateTimestamp = value.transactionLastUpdateTimestampMs,
clientTransactionVersion = TransactionVersion.fromFeatureLevel(value.clientTransactionVersion))
transactionalId,
value.producerId,
value.previousProducerId,
value.nextProducerId,
value.producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH,
value.transactionTimeoutMs,
TransactionState.fromId(value.transactionStatus),
util.Set.of(),
value.transactionStartTimestampMs,
value.transactionLastUpdateTimestampMs,
TransactionVersion.fromFeatureLevel(value.clientTransactionVersion))
if (!transactionMetadata.state.equals(TransactionState.EMPTY))
value.transactionPartitions.forEach(partitionsSchema =>
value.transactionPartitions.forEach(partitionsSchema => {
transactionMetadata.addPartitions(partitionsSchema.partitionIds
.asScala
.map(partitionId => new TopicPartition(partitionsSchema.topic, partitionId))
.toSet)
)
.stream
.map(partitionId => new TopicPartition(partitionsSchema.topic, partitionId.intValue()))
.toList)
})
Some(transactionMetadata)
} else throw new IllegalStateException(s"Unknown version $version from the transaction log message value")
}

View File

@ -32,7 +32,7 @@ import org.apache.kafka.common.requests.{TransactionResult, WriteTxnMarkersReque
import org.apache.kafka.common.security.JaasContext
import org.apache.kafka.common.utils.{LogContext, Time}
import org.apache.kafka.common.{Node, Reconfigurable, TopicPartition}
import org.apache.kafka.coordinator.transaction.TxnTransitMetadata
import org.apache.kafka.coordinator.transaction.{TransactionMetadata, TxnTransitMetadata}
import org.apache.kafka.metadata.MetadataCache
import org.apache.kafka.server.common.RequestLocal
import org.apache.kafka.server.metrics.KafkaMetricsGroup
@ -326,16 +326,16 @@ class TransactionMarkerChannelManager(
info(s"Replaced an existing pending complete txn $prev with $pendingCompleteTxn while adding markers to send.")
}
addTxnMarkersToBrokerQueue(txnMetadata.producerId,
txnMetadata.producerEpoch, txnResult, pendingCompleteTxn, txnMetadata.topicPartitions.toSet)
txnMetadata.producerEpoch, txnResult, pendingCompleteTxn, txnMetadata.topicPartitions.asScala.toSet)
maybeWriteTxnCompletion(transactionalId)
}
def numTxnsWithPendingMarkers: Int = transactionsWithPendingMarkers.size
private def hasPendingMarkersToWrite(txnMetadata: TransactionMetadata): Boolean = {
txnMetadata.inLock {
txnMetadata.topicPartitions.nonEmpty
}
txnMetadata.inLock(() =>
!txnMetadata.topicPartitions.isEmpty
)
}
def maybeWriteTxnCompletion(transactionalId: String): Unit = {
@ -422,9 +422,9 @@ class TransactionMarkerChannelManager(
val txnMetadata = epochAndMetadata.transactionMetadata
txnMetadata.inLock {
txnMetadata.inLock(() =>
topicPartitions.foreach(txnMetadata.removePartition)
}
)
maybeWriteTxnCompletion(transactionalId)
}

View File

@ -131,7 +131,7 @@ class TransactionMarkerRequestCompletionHandler(brokerId: Int,
txnMarkerChannelManager.removeMarkersForTxn(pendingCompleteTxn)
abortSending = true
} else {
txnMetadata.inLock {
txnMetadata.inLock(() => {
for ((topicPartition, error) <- errors.asScala) {
error match {
case Errors.NONE =>
@ -178,7 +178,7 @@ class TransactionMarkerRequestCompletionHandler(brokerId: Int,
throw new IllegalStateException(s"Unexpected error ${other.exceptionName} while sending txn marker for $transactionalId")
}
}
}
})
}
if (!abortSending) {

View File

@ -1,492 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package kafka.coordinator.transaction
import java.util.concurrent.locks.ReentrantLock
import kafka.utils.{CoreUtils, Logging, nonthreadsafe}
import org.apache.kafka.common.TopicPartition
import org.apache.kafka.common.protocol.Errors
import org.apache.kafka.common.record.RecordBatch
import org.apache.kafka.coordinator.transaction.{TransactionState, TxnTransitMetadata}
import org.apache.kafka.server.common.TransactionVersion
import scala.collection.{immutable, mutable}
import scala.jdk.CollectionConverters._
private[transaction] object TransactionMetadata {
def isEpochExhausted(producerEpoch: Short): Boolean = producerEpoch >= Short.MaxValue - 1
}
/**
*
* @param producerId producer id
* @param prevProducerId producer id for the last committed transaction with this transactional ID
* @param nextProducerId Latest producer ID sent to the producer for the given transactional ID
* @param producerEpoch current epoch of the producer
* @param lastProducerEpoch last epoch of the producer
* @param txnTimeoutMs timeout to be used to abort long running transactions
* @param state current state of the transaction
* @param topicPartitions current set of partitions that are part of this transaction
* @param txnStartTimestamp time the transaction was started, i.e., when first partition is added
* @param txnLastUpdateTimestamp updated when any operation updates the TransactionMetadata. To be used for expiration
* @param clientTransactionVersion TransactionVersion used by the client when the state was transitioned
*/
@nonthreadsafe
private[transaction] class TransactionMetadata(val transactionalId: String,
var producerId: Long,
var prevProducerId: Long,
var nextProducerId: Long,
var producerEpoch: Short,
var lastProducerEpoch: Short,
var txnTimeoutMs: Int,
var state: TransactionState,
var topicPartitions: mutable.Set[TopicPartition],
@volatile var txnStartTimestamp: Long = -1,
@volatile var txnLastUpdateTimestamp: Long,
var clientTransactionVersion: TransactionVersion) extends Logging {
// pending state is used to indicate the state that this transaction is going to
// transit to, and for blocking future attempts to transit it again if it is not legal;
// initialized as the same as the current state
var pendingState: Option[TransactionState] = None
// Indicates that during a previous attempt to fence a producer, the bumped epoch may not have been
// successfully written to the log. If this is true, we will not bump the epoch again when fencing
var hasFailedEpochFence: Boolean = false
private[transaction] val lock = new ReentrantLock
def inLock[T](fun: => T): T = CoreUtils.inLock(lock)(fun)
def addPartitions(partitions: collection.Set[TopicPartition]): Unit = {
topicPartitions ++= partitions
}
def removePartition(topicPartition: TopicPartition): Unit = {
if (state != TransactionState.PREPARE_COMMIT && state != TransactionState.PREPARE_ABORT)
throw new IllegalStateException(s"Transaction metadata's current state is $state, and its pending state is $pendingState " +
s"while trying to remove partitions whose txn marker has been sent, this is not expected")
topicPartitions -= topicPartition
}
// this is visible for test only
def prepareNoTransit(): TxnTransitMetadata = {
// do not call transitTo as it will set the pending state, a follow-up call to abort the transaction will set its pending state
new TxnTransitMetadata(producerId, prevProducerId, nextProducerId, producerEpoch, lastProducerEpoch, txnTimeoutMs, state, topicPartitions.clone().asJava,
txnStartTimestamp, txnLastUpdateTimestamp, clientTransactionVersion)
}
def prepareFenceProducerEpoch(): TxnTransitMetadata = {
if (producerEpoch == Short.MaxValue)
throw new IllegalStateException(s"Cannot fence producer with epoch equal to Short.MaxValue since this would overflow")
// If we've already failed to fence an epoch (because the write to the log failed), we don't increase it again.
// This is safe because we never return the epoch to client if we fail to fence the epoch
val bumpedEpoch = if (hasFailedEpochFence) producerEpoch else (producerEpoch + 1).toShort
prepareTransitionTo(
state = TransactionState.PREPARE_EPOCH_FENCE,
producerEpoch = bumpedEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH
)
}
def prepareIncrementProducerEpoch(newTxnTimeoutMs: Int,
expectedProducerEpoch: Option[Short],
updateTimestamp: Long): Either[Errors, TxnTransitMetadata] = {
if (isProducerEpochExhausted)
throw new IllegalStateException(s"Cannot allocate any more producer epochs for producerId $producerId")
val bumpedEpoch = (producerEpoch + 1).toShort
val epochBumpResult: Either[Errors, (Short, Short)] = expectedProducerEpoch match {
case None =>
// If no expected epoch was provided by the producer, bump the current epoch and set the last epoch to -1
// In the case of a new producer, producerEpoch will be -1 and bumpedEpoch will be 0
Right(bumpedEpoch, RecordBatch.NO_PRODUCER_EPOCH)
case Some(expectedEpoch) =>
if (producerEpoch == RecordBatch.NO_PRODUCER_EPOCH || expectedEpoch == producerEpoch)
// If the expected epoch matches the current epoch, or if there is no current epoch, the producer is attempting
// to continue after an error and no other producer has been initialized. Bump the current and last epochs.
// The no current epoch case means this is a new producer; producerEpoch will be -1 and bumpedEpoch will be 0
Right(bumpedEpoch, producerEpoch)
else if (expectedEpoch == lastProducerEpoch)
// If the expected epoch matches the previous epoch, it is a retry of a successful call, so just return the
// current epoch without bumping. There is no danger of this producer being fenced, because a new producer
// calling InitProducerId would have caused the last epoch to be set to -1.
// Note that if the IBP is prior to 2.4.IV1, the lastProducerId and lastProducerEpoch will not be written to
// the transaction log, so a retry that spans a coordinator change will fail. We expect this to be a rare case.
Right(producerEpoch, lastProducerEpoch)
else {
// Otherwise, the producer has a fenced epoch and should receive an PRODUCER_FENCED error
info(s"Expected producer epoch $expectedEpoch does not match current " +
s"producer epoch $producerEpoch or previous producer epoch $lastProducerEpoch")
Left(Errors.PRODUCER_FENCED)
}
}
epochBumpResult match {
case Right((nextEpoch, lastEpoch)) => Right(prepareTransitionTo(
state = TransactionState.EMPTY,
producerEpoch = nextEpoch,
lastProducerEpoch = lastEpoch,
txnTimeoutMs = newTxnTimeoutMs,
topicPartitions = mutable.Set.empty[TopicPartition],
txnStartTimestamp = -1,
txnLastUpdateTimestamp = updateTimestamp
))
case Left(err) => Left(err)
}
}
def prepareProducerIdRotation(newProducerId: Long,
newTxnTimeoutMs: Int,
updateTimestamp: Long,
recordLastEpoch: Boolean): TxnTransitMetadata = {
if (hasPendingTransaction)
throw new IllegalStateException("Cannot rotate producer ids while a transaction is still pending")
prepareTransitionTo(
state = TransactionState.EMPTY,
producerId = newProducerId,
producerEpoch = 0,
lastProducerEpoch = if (recordLastEpoch) producerEpoch else RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = newTxnTimeoutMs,
topicPartitions = mutable.Set.empty[TopicPartition],
txnStartTimestamp = -1,
txnLastUpdateTimestamp = updateTimestamp
)
}
def prepareAddPartitions(addedTopicPartitions: immutable.Set[TopicPartition], updateTimestamp: Long, clientTransactionVersion: TransactionVersion): TxnTransitMetadata = {
val newTxnStartTimestamp = state match {
case TransactionState.EMPTY | TransactionState.COMPLETE_ABORT | TransactionState.COMPLETE_COMMIT => updateTimestamp
case _ => txnStartTimestamp
}
prepareTransitionTo(
state = TransactionState.ONGOING,
topicPartitions = (topicPartitions ++ addedTopicPartitions),
txnStartTimestamp = newTxnStartTimestamp,
txnLastUpdateTimestamp = updateTimestamp,
clientTransactionVersion = clientTransactionVersion
)
}
def prepareAbortOrCommit(newState: TransactionState, clientTransactionVersion: TransactionVersion, nextProducerId: Long, updateTimestamp: Long, noPartitionAdded: Boolean): TxnTransitMetadata = {
val (updatedProducerEpoch, updatedLastProducerEpoch) = if (clientTransactionVersion.supportsEpochBump()) {
// We already ensured that we do not overflow here. MAX_SHORT is the highest possible value.
((producerEpoch + 1).toShort, producerEpoch)
} else {
(producerEpoch, lastProducerEpoch)
}
// With transaction V2, it is allowed to abort the transaction without adding any partitions. Then, the transaction
// start time is uncertain but it is still required. So we can use the update time as the transaction start time.
val newTxnStartTimestamp = if (noPartitionAdded) updateTimestamp else txnStartTimestamp
prepareTransitionTo(
state = newState,
nextProducerId = nextProducerId,
producerEpoch = updatedProducerEpoch,
lastProducerEpoch = updatedLastProducerEpoch,
txnStartTimestamp = newTxnStartTimestamp,
txnLastUpdateTimestamp = updateTimestamp,
clientTransactionVersion = clientTransactionVersion
)
}
def prepareComplete(updateTimestamp: Long): TxnTransitMetadata = {
val newState = if (state == TransactionState.PREPARE_COMMIT) TransactionState.COMPLETE_COMMIT else TransactionState.COMPLETE_ABORT
// Since the state change was successfully written to the log, unset the flag for a failed epoch fence
hasFailedEpochFence = false
val (updatedProducerId, updatedProducerEpoch) =
// In the prepareComplete transition for the overflow case, the lastProducerEpoch is kept at MAX-1,
// which is the last epoch visible to the client.
// Internally, however, during the transition between prepareAbort/prepareCommit and prepareComplete, the producer epoch
// reaches MAX but the client only sees the transition as MAX-1 followed by 0.
// When an epoch overflow occurs, we set the producerId to nextProducerId and reset the epoch to 0,
// but lastProducerEpoch remains MAX-1 to maintain consistency with what the client last saw.
if (clientTransactionVersion.supportsEpochBump() && nextProducerId != RecordBatch.NO_PRODUCER_ID) {
(nextProducerId, 0.toShort)
} else {
(producerId, producerEpoch)
}
prepareTransitionTo(
state = newState,
producerId = updatedProducerId,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = updatedProducerEpoch,
topicPartitions = mutable.Set.empty[TopicPartition],
txnLastUpdateTimestamp = updateTimestamp
)
}
def prepareDead(): TxnTransitMetadata = {
prepareTransitionTo(
state = TransactionState.DEAD,
topicPartitions = mutable.Set.empty[TopicPartition]
)
}
/**
* Check if the epochs have been exhausted for the current producerId. We do not allow the client to use an
* epoch equal to Short.MaxValue to ensure that the coordinator will always be able to fence an existing producer.
*/
def isProducerEpochExhausted: Boolean = TransactionMetadata.isEpochExhausted(producerEpoch)
/**
* Check if this is a distributed two phase commit transaction.
* Such transactions have no timeout (identified by maximum value for timeout).
*/
def isDistributedTwoPhaseCommitTxn: Boolean = txnTimeoutMs == Int.MaxValue
private def hasPendingTransaction: Boolean = {
state match {
case TransactionState.ONGOING | TransactionState.PREPARE_ABORT | TransactionState.PREPARE_COMMIT => true
case _ => false
}
}
private def prepareTransitionTo(state: TransactionState,
producerId: Long = this.producerId,
nextProducerId: Long = this.nextProducerId,
producerEpoch: Short = this.producerEpoch,
lastProducerEpoch: Short = this.lastProducerEpoch,
txnTimeoutMs: Int = this.txnTimeoutMs,
topicPartitions: mutable.Set[TopicPartition] = this.topicPartitions,
txnStartTimestamp: Long = this.txnStartTimestamp,
txnLastUpdateTimestamp: Long = this.txnLastUpdateTimestamp,
clientTransactionVersion: TransactionVersion = this.clientTransactionVersion): TxnTransitMetadata = {
if (pendingState.isDefined)
throw new IllegalStateException(s"Preparing transaction state transition to $state " +
s"while it already a pending state ${pendingState.get}")
if (producerId < 0)
throw new IllegalArgumentException(s"Illegal new producer id $producerId")
// The epoch is initialized to NO_PRODUCER_EPOCH when the TransactionMetadata
// is created for the first time and it could stay like this until transitioning
// to Dead.
if (state != TransactionState.DEAD && producerEpoch < 0)
throw new IllegalArgumentException(s"Illegal new producer epoch $producerEpoch")
// check that the new state transition is valid and update the pending state if necessary
if (state.validPreviousStates.contains(this.state)) {
val transitMetadata = new TxnTransitMetadata(producerId, this.producerId, nextProducerId, producerEpoch, lastProducerEpoch, txnTimeoutMs, state,
topicPartitions.asJava, txnStartTimestamp, txnLastUpdateTimestamp, clientTransactionVersion)
debug(s"TransactionalId ${this.transactionalId} prepare transition from ${this.state} to $transitMetadata")
pendingState = Some(state)
transitMetadata
} else {
throw new IllegalStateException(s"Preparing transaction state transition to $state failed since the target state" +
s" $state is not a valid previous state of the current state ${this.state}")
}
}
def completeTransitionTo(transitMetadata: TxnTransitMetadata): Unit = {
// metadata transition is valid only if all the following conditions are met:
//
// 1. the new state is already indicated in the pending state.
// 2. the epoch should be either the same value, the old value + 1, or 0 if we have a new producerId.
// 3. the last update time is no smaller than the old value.
// 4. the old partitions set is a subset of the new partitions set.
//
// plus, we should only try to update the metadata after the corresponding log entry has been successfully
// written and replicated (see TransactionStateManager#appendTransactionToLog)
//
// if valid, transition is done via overwriting the whole object to ensure synchronization
val toState = pendingState.getOrElse {
fatal(s"$this's transition to $transitMetadata failed since pendingState is not defined: this should not happen")
throw new IllegalStateException(s"TransactionalId $transactionalId " +
"completing transaction state transition while it does not have a pending state")
}
if (toState != transitMetadata.txnState) {
throwStateTransitionFailure(transitMetadata)
} else {
toState match {
case TransactionState.EMPTY => // from initPid
if ((producerEpoch != transitMetadata.producerEpoch && !validProducerEpochBump(transitMetadata)) ||
!transitMetadata.topicPartitions.isEmpty ||
transitMetadata.txnStartTimestamp != -1) {
throwStateTransitionFailure(transitMetadata)
}
case TransactionState.ONGOING => // from addPartitions
if (!validProducerEpoch(transitMetadata) ||
!topicPartitions.subsetOf(transitMetadata.topicPartitions.asScala) ||
txnTimeoutMs != transitMetadata.txnTimeoutMs) {
throwStateTransitionFailure(transitMetadata)
}
case TransactionState.PREPARE_ABORT | TransactionState.PREPARE_COMMIT => // from endTxn
// In V2, we allow state transits from Empty, CompleteCommit and CompleteAbort to PrepareAbort. It is possible
// their updated start time is not equal to the current start time.
val allowedEmptyAbort = toState == TransactionState.PREPARE_ABORT && transitMetadata.clientTransactionVersion.supportsEpochBump() &&
(state == TransactionState.EMPTY || state == TransactionState.COMPLETE_COMMIT || state == TransactionState.COMPLETE_ABORT)
val validTimestamp = txnStartTimestamp == transitMetadata.txnStartTimestamp || allowedEmptyAbort
if (!validProducerEpoch(transitMetadata) ||
!topicPartitions.equals(transitMetadata.topicPartitions.asScala) ||
txnTimeoutMs != transitMetadata.txnTimeoutMs || !validTimestamp) {
throwStateTransitionFailure(transitMetadata)
}
case TransactionState.COMPLETE_ABORT | TransactionState.COMPLETE_COMMIT => // from write markers
if (!validProducerEpoch(transitMetadata) ||
txnTimeoutMs != transitMetadata.txnTimeoutMs ||
transitMetadata.txnStartTimestamp == -1) {
throwStateTransitionFailure(transitMetadata)
}
case TransactionState.PREPARE_EPOCH_FENCE =>
// We should never get here, since once we prepare to fence the epoch, we immediately set the pending state
// to PrepareAbort, and then consequently to CompleteAbort after the markers are written.. So we should never
// ever try to complete a transition to PrepareEpochFence, as it is not a valid previous state for any other state, and hence
// can never be transitioned out of.
throwStateTransitionFailure(transitMetadata)
case TransactionState.DEAD =>
// The transactionalId was being expired. The completion of the operation should result in removal of the
// the metadata from the cache, so we should never realistically transition to the dead state.
throw new IllegalStateException(s"TransactionalId $transactionalId is trying to complete a transition to " +
s"$toState. This means that the transactionalId was being expired, and the only acceptable completion of " +
s"this operation is to remove the transaction metadata from the cache, not to persist the $toState in the log.")
}
debug(s"TransactionalId $transactionalId complete transition from $state to $transitMetadata")
producerId = transitMetadata.producerId
prevProducerId = transitMetadata.prevProducerId
nextProducerId = transitMetadata.nextProducerId
producerEpoch = transitMetadata.producerEpoch
lastProducerEpoch = transitMetadata.lastProducerEpoch
txnTimeoutMs = transitMetadata.txnTimeoutMs
topicPartitions = transitMetadata.topicPartitions.asScala
txnStartTimestamp = transitMetadata.txnStartTimestamp
txnLastUpdateTimestamp = transitMetadata.txnLastUpdateTimestamp
clientTransactionVersion = transitMetadata.clientTransactionVersion
pendingState = None
state = toState
}
}
/**
* Validates the producer epoch and ID based on transaction state and version.
*
* Logic:
* * 1. **Overflow Case in Transactions V2:**
* * - During overflow (epoch reset to 0), we compare both `lastProducerEpoch` values since it
* * does not change during completion.
* * - For PrepareComplete, the producer ID has been updated. We ensure that the `prevProducerID`
* * in the transit metadata matches the current producer ID, confirming the change.
* *
* * 2. **Epoch Bump Case in Transactions V2:**
* * - For PrepareCommit or PrepareAbort, the producer epoch has been bumped. We ensure the `lastProducerEpoch`
* * in transit metadata matches the current producer epoch, confirming the bump.
* * - We also verify that the producer ID remains the same.
* *
* * 3. **Other Cases:**
* * - For other states and versions, check if the producer epoch and ID match the current values.
*
* @param transitMetadata The transaction transition metadata containing state, producer epoch, and ID.
* @return true if the producer epoch and ID are valid; false otherwise.
*/
private def validProducerEpoch(transitMetadata: TxnTransitMetadata): Boolean = {
val isAtLeastTransactionsV2 = transitMetadata.clientTransactionVersion.supportsEpochBump()
val txnState = transitMetadata.txnState
val transitProducerEpoch = transitMetadata.producerEpoch
val transitProducerId = transitMetadata.producerId
val transitLastProducerEpoch = transitMetadata.lastProducerEpoch
(isAtLeastTransactionsV2, txnState, transitProducerEpoch) match {
case (true, TransactionState.COMPLETE_COMMIT | TransactionState.COMPLETE_ABORT, epoch) if epoch == 0.toShort =>
transitLastProducerEpoch == lastProducerEpoch &&
transitMetadata.prevProducerId == producerId
case (true, TransactionState.PREPARE_COMMIT | TransactionState.PREPARE_ABORT, _) =>
transitLastProducerEpoch == producerEpoch &&
transitProducerId == producerId
case _ =>
transitProducerEpoch == producerEpoch &&
transitProducerId == producerId
}
}
private def validProducerEpochBump(transitMetadata: TxnTransitMetadata): Boolean = {
val transitEpoch = transitMetadata.producerEpoch
val transitProducerId = transitMetadata.producerId
transitEpoch == producerEpoch + 1 || (transitEpoch == 0 && transitProducerId != producerId)
}
private def throwStateTransitionFailure(txnTransitMetadata: TxnTransitMetadata): Unit = {
fatal(s"${this.toString}'s transition to $txnTransitMetadata failed: this should not happen")
throw new IllegalStateException(s"TransactionalId $transactionalId failed transition to state $txnTransitMetadata " +
"due to unexpected metadata")
}
def pendingTransitionInProgress: Boolean = pendingState.isDefined
override def toString: String = {
"TransactionMetadata(" +
s"transactionalId=$transactionalId, " +
s"producerId=$producerId, " +
s"prevProducerId=$prevProducerId, " +
s"nextProducerId=$nextProducerId, " +
s"producerEpoch=$producerEpoch, " +
s"lastProducerEpoch=$lastProducerEpoch, " +
s"txnTimeoutMs=$txnTimeoutMs, " +
s"state=$state, " +
s"pendingState=$pendingState, " +
s"topicPartitions=$topicPartitions, " +
s"txnStartTimestamp=$txnStartTimestamp, " +
s"txnLastUpdateTimestamp=$txnLastUpdateTimestamp, " +
s"clientTransactionVersion=$clientTransactionVersion)"
}
override def equals(that: Any): Boolean = that match {
case other: TransactionMetadata =>
transactionalId == other.transactionalId &&
producerId == other.producerId &&
producerEpoch == other.producerEpoch &&
lastProducerEpoch == other.lastProducerEpoch &&
txnTimeoutMs == other.txnTimeoutMs &&
state.equals(other.state) &&
topicPartitions.equals(other.topicPartitions) &&
txnStartTimestamp == other.txnStartTimestamp &&
txnLastUpdateTimestamp == other.txnLastUpdateTimestamp &&
clientTransactionVersion == other.clientTransactionVersion
case _ => false
}
override def hashCode(): Int = {
val fields = Seq(transactionalId, producerId, producerEpoch, txnTimeoutMs, state, topicPartitions,
txnStartTimestamp, txnLastUpdateTimestamp, clientTransactionVersion)
fields.map(_.hashCode()).foldLeft(0)((a, b) => 31 * a + b)
}
}

View File

@ -35,7 +35,7 @@ import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse
import org.apache.kafka.common.requests.TransactionResult
import org.apache.kafka.common.utils.{Time, Utils}
import org.apache.kafka.common.{KafkaException, TopicIdPartition, TopicPartition}
import org.apache.kafka.coordinator.transaction.{TransactionLogConfig, TransactionState, TransactionStateManagerConfig, TxnTransitMetadata}
import org.apache.kafka.coordinator.transaction.{TransactionLogConfig, TransactionMetadata, TransactionState, TransactionStateManagerConfig, TxnTransitMetadata}
import org.apache.kafka.metadata.MetadataCache
import org.apache.kafka.server.common.{RequestLocal, TransactionVersion}
import org.apache.kafka.server.config.ServerConfigs
@ -46,6 +46,8 @@ import org.apache.kafka.storage.internals.log.AppendOrigin
import com.google.re2j.{Pattern, PatternSyntaxException}
import org.apache.kafka.common.errors.InvalidRegularExpression
import java.util.Optional
import scala.jdk.CollectionConverters._
import scala.collection.mutable
@ -176,7 +178,7 @@ class TransactionStateManager(brokerId: Int,
val transactionalId = txnMetadata.transactionalId
var fullBatch = false
txnMetadata.inLock {
txnMetadata.inLock(() => {
if (txnMetadata.pendingState.isEmpty && shouldExpire(txnMetadata, currentTimeMs)) {
if (recordsBuilder == null) {
recordsBuilder = MemoryRecords.builder(
@ -199,7 +201,7 @@ class TransactionStateManager(brokerId: Int,
fullBatch = true
}
}
}
})
if (fullBatch) {
flushRecordsBuilder()
@ -263,9 +265,9 @@ class TransactionStateManager(brokerId: Int,
expiredForPartition.foreach { idCoordinatorEpochAndMetadata =>
val transactionalId = idCoordinatorEpochAndMetadata.transactionalId
val txnMetadata = txnMetadataCacheEntry.metadataPerTransactionalId.get(transactionalId)
txnMetadata.inLock {
txnMetadata.inLock(() => {
if (txnMetadataCacheEntry.coordinatorEpoch == idCoordinatorEpochAndMetadata.coordinatorEpoch
&& txnMetadata.pendingState.contains(TransactionState.DEAD)
&& txnMetadata.pendingState.filter(s => s == TransactionState.DEAD).isPresent
&& txnMetadata.producerEpoch == idCoordinatorEpochAndMetadata.transitMetadata.producerEpoch
&& response.error == Errors.NONE) {
txnMetadataCacheEntry.metadataPerTransactionalId.remove(transactionalId)
@ -276,9 +278,9 @@ class TransactionStateManager(brokerId: Int,
s" expected producerEpoch: ${idCoordinatorEpochAndMetadata.transitMetadata.producerEpoch}," +
s" coordinatorEpoch: ${txnMetadataCacheEntry.coordinatorEpoch}, expected coordinatorEpoch: " +
s"${idCoordinatorEpochAndMetadata.coordinatorEpoch}")
txnMetadata.pendingState = None
txnMetadata.pendingState(Optional.empty())
}
}
})
}
}
}
@ -366,7 +368,7 @@ class TransactionStateManager(brokerId: Int,
} else null
transactionMetadataCache.foreachEntry { (_, cache) =>
cache.metadataPerTransactionalId.forEach { (_, txnMetadata) =>
txnMetadata.inLock {
txnMetadata.inLock(() => {
if (shouldInclude(txnMetadata, pattern)) {
states.add(new ListTransactionsResponseData.TransactionState()
.setTransactionalId(txnMetadata.transactionalId)
@ -374,7 +376,7 @@ class TransactionStateManager(brokerId: Int,
.setTransactionState(txnMetadata.state.stateName)
)
}
}
})
}
}
response.setErrorCode(Errors.NONE.code)
@ -565,7 +567,7 @@ class TransactionStateManager(brokerId: Int,
val transactionsPendingForCompletion = new mutable.ListBuffer[TransactionalIdCoordinatorEpochAndTransitMetadata]
loadedTransactions.forEach((transactionalId, txnMetadata) => {
txnMetadata.inLock {
txnMetadata.inLock(() => {
// if state is PrepareCommit or PrepareAbort we need to complete the transaction
txnMetadata.state match {
case TransactionState.PREPARE_ABORT =>
@ -577,7 +579,7 @@ class TransactionStateManager(brokerId: Int,
case _ =>
// nothing needs to be done
}
}
})
})
// we first remove the partition from loading partition then send out the markers for those pending to be
@ -713,7 +715,7 @@ class TransactionStateManager(brokerId: Int,
case Right(Some(epochAndMetadata)) =>
val metadata = epochAndMetadata.transactionMetadata
metadata.inLock {
metadata.inLock(() => {
if (epochAndMetadata.coordinatorEpoch != coordinatorEpoch) {
// the cache may have been changed due to txn topic partition emigration and immigration,
// in this case directly return NOT_COORDINATOR to client and let it to re-discover the transaction coordinator
@ -725,7 +727,7 @@ class TransactionStateManager(brokerId: Int,
metadata.completeTransitionTo(newMetadata)
debug(s"Updating $transactionalId's transaction state to $newMetadata with coordinator epoch $coordinatorEpoch for $transactionalId succeeded")
}
}
})
case Right(None) =>
// this transactional id no longer exists, maybe the corresponding partition has already been migrated out.
@ -740,7 +742,7 @@ class TransactionStateManager(brokerId: Int,
getTransactionState(transactionalId) match {
case Right(Some(epochAndTxnMetadata)) =>
val metadata = epochAndTxnMetadata.transactionMetadata
metadata.inLock {
metadata.inLock(() => {
if (epochAndTxnMetadata.coordinatorEpoch == coordinatorEpoch) {
if (retryOnError(responseError)) {
info(s"TransactionalId ${metadata.transactionalId} append transaction log for $newMetadata transition failed due to $responseError, " +
@ -749,13 +751,13 @@ class TransactionStateManager(brokerId: Int,
info(s"TransactionalId ${metadata.transactionalId} append transaction log for $newMetadata transition failed due to $responseError, " +
s"resetting pending state from ${metadata.pendingState}, aborting state transition and returning $responseError in the callback")
metadata.pendingState = None
metadata.pendingState(Optional.empty())
}
} else {
info(s"TransactionalId ${metadata.transactionalId} append transaction log for $newMetadata transition failed due to $responseError, " +
s"aborting state transition and returning the error in the callback since the coordinator epoch has changed from ${epochAndTxnMetadata.coordinatorEpoch} to $coordinatorEpoch")
}
}
})
case Right(None) =>
// Do nothing here, since we want to return the original append error to the user.
@ -790,7 +792,7 @@ class TransactionStateManager(brokerId: Int,
case Right(Some(epochAndMetadata)) =>
val metadata = epochAndMetadata.transactionMetadata
val append: Boolean = metadata.inLock {
val append: Boolean = metadata.inLock(() => {
if (epochAndMetadata.coordinatorEpoch != coordinatorEpoch) {
// the coordinator epoch has changed, reply to client immediately with NOT_COORDINATOR
responseCallback(Errors.NOT_COORDINATOR)
@ -800,7 +802,7 @@ class TransactionStateManager(brokerId: Int,
// under the same coordinator epoch, so directly append to txn log now
true
}
}
})
if (append) {
replicaManager.appendRecords(
timeout = newMetadata.txnTimeoutMs.toLong,

View File

@ -1970,7 +1970,7 @@ class KafkaApis(val requestChannel: RequestChannel,
} else {
val unauthorizedTopicErrors = mutable.Map[TopicPartition, Errors]()
val nonExistingTopicErrors = mutable.Map[TopicPartition, Errors]()
val authorizedPartitions = mutable.Set[TopicPartition]()
val authorizedPartitions = new util.HashSet[TopicPartition]()
// Only request versions less than 4 need write authorization since they come from clients.
val authorizedTopics =
@ -1992,7 +1992,7 @@ class KafkaApis(val requestChannel: RequestChannel,
// partitions which failed, and an 'OPERATION_NOT_ATTEMPTED' error code for the partitions which succeeded
// the authorization check to indicate that they were not added to the transaction.
val partitionErrors = unauthorizedTopicErrors ++ nonExistingTopicErrors ++
authorizedPartitions.map(_ -> Errors.OPERATION_NOT_ATTEMPTED)
authorizedPartitions.asScala.map(_ -> Errors.OPERATION_NOT_ATTEMPTED)
addResultAndMaybeSendResponse(AddPartitionsToTxnResponse.resultForTransaction(transactionalId, partitionErrors.asJava))
} else {
def sendResponseCallback(error: Errors): Unit = {
@ -2071,7 +2071,7 @@ class KafkaApis(val requestChannel: RequestChannel,
txnCoordinator.handleAddPartitionsToTransaction(transactionalId,
addOffsetsToTxnRequest.data.producerId,
addOffsetsToTxnRequest.data.producerEpoch,
Set(offsetTopicPartition),
util.Set.of(offsetTopicPartition),
sendResponseCallback,
TransactionVersion.TV_0, // This request will always come from the client not using TV 2.
requestLocal)

View File

@ -37,7 +37,7 @@ import org.apache.kafka.common.record.{FileRecords, MemoryRecords, RecordBatch,
import org.apache.kafka.common.requests._
import org.apache.kafka.common.utils.{LogContext, MockTime, ProducerIdAndEpoch}
import org.apache.kafka.common.{Node, TopicPartition, Uuid}
import org.apache.kafka.coordinator.transaction.{ProducerIdManager, TransactionState}
import org.apache.kafka.coordinator.transaction.{ProducerIdManager, TransactionMetadata, TransactionState}
import org.apache.kafka.metadata.MetadataCache
import org.apache.kafka.server.common.{FinalizedFeatures, MetadataVersion, RequestLocal, TransactionVersion}
import org.apache.kafka.server.storage.log.FetchIsolation
@ -63,7 +63,7 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren
private val allOperations = Seq(
new InitProducerIdOperation,
new AddPartitionsToTxnOperation(Set(new TopicPartition("topic", 0))),
new AddPartitionsToTxnOperation(util.Set.of(new TopicPartition("topic", 0))),
new EndTxnOperation)
private val allTransactions = mutable.Set[Transaction]()
@ -459,7 +459,7 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren
val partitionId = txnStateManager.partitionFor(txn.transactionalId)
val txnRecords = txnRecordsByPartition(partitionId)
val initPidOp = new InitProducerIdOperation()
val addPartitionsOp = new AddPartitionsToTxnOperation(Set(new TopicPartition("topic", 0)))
val addPartitionsOp = new AddPartitionsToTxnOperation(util.Set.of(new TopicPartition("topic", 0)))
initPidOp.run(txn)
initPidOp.awaitAndVerify(txn)
addPartitionsOp.run(txn)
@ -468,7 +468,7 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren
val txnMetadata = transactionMetadata(txn).getOrElse(throw new IllegalStateException(s"Transaction not found $txn"))
txnRecords += new SimpleRecord(txn.txnMessageKeyBytes, TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), TransactionVersion.TV_2))
txnMetadata.state = TransactionState.PREPARE_COMMIT
txnMetadata.state(TransactionState.PREPARE_COMMIT)
txnRecords += new SimpleRecord(txn.txnMessageKeyBytes, TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), TransactionVersion.TV_2))
prepareTxnLog(partitionId)
@ -506,17 +506,18 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren
}
private def prepareExhaustedEpochTxnMetadata(txn: Transaction): TransactionMetadata = {
new TransactionMetadata(transactionalId = txn.transactionalId,
producerId = producerId,
prevProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = (Short.MaxValue - 1).toShort,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 60000,
state = TransactionState.EMPTY,
topicPartitions = collection.mutable.Set.empty[TopicPartition],
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TransactionVersion.TV_0)
new TransactionMetadata(txn.transactionalId,
producerId,
RecordBatch.NO_PRODUCER_ID,
RecordBatch.NO_PRODUCER_ID,
(Short.MaxValue - 1).toShort,
RecordBatch.NO_PRODUCER_EPOCH,
60000,
TransactionState.EMPTY,
new util.HashSet[TopicPartition](),
-1,
time.milliseconds(),
TransactionVersion.TV_0)
}
abstract class TxnOperation[R] extends Operation {
@ -548,7 +549,7 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren
}
}
class AddPartitionsToTxnOperation(partitions: Set[TopicPartition]) extends TxnOperation[Errors] {
class AddPartitionsToTxnOperation(partitions: util.Set[TopicPartition]) extends TxnOperation[Errors] {
override def run(txn: Transaction): Unit = {
transactionMetadata(txn).foreach { txnMetadata =>
transactionCoordinator.handleAddPartitionsToTransaction(txn.transactionalId,
@ -629,7 +630,7 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren
override def run(): Unit = {
transactions.foreach { txn =>
transactionMetadata(txn).foreach { txnMetadata =>
txnMetadata.txnLastUpdateTimestamp = time.milliseconds() - txnConfig.transactionalIdExpirationMs
txnMetadata.txnLastUpdateTimestamp(time.milliseconds() - txnConfig.transactionalIdExpirationMs)
}
}
txnStateManager.enableTransactionalIdExpiration()

View File

@ -22,7 +22,7 @@ import org.apache.kafka.common.protocol.Errors
import org.apache.kafka.common.record.RecordBatch
import org.apache.kafka.common.requests.{AddPartitionsToTxnResponse, TransactionResult}
import org.apache.kafka.common.utils.{LogContext, MockTime, ProducerIdAndEpoch}
import org.apache.kafka.coordinator.transaction.{ProducerIdManager, TransactionState, TransactionStateManagerConfig, TxnTransitMetadata}
import org.apache.kafka.coordinator.transaction.{ProducerIdManager, TransactionMetadata, TransactionState, TransactionStateManagerConfig, TxnTransitMetadata}
import org.apache.kafka.server.common.TransactionVersion
import org.apache.kafka.server.common.TransactionVersion.{TV_0, TV_2}
import org.apache.kafka.server.util.MockScheduler
@ -33,9 +33,9 @@ import org.junit.jupiter.params.provider.{CsvSource, ValueSource}
import org.mockito.ArgumentMatchers.{any, anyBoolean, anyInt}
import org.mockito.Mockito._
import org.mockito.{ArgumentCaptor, ArgumentMatchers}
import org.mockito.Mockito.doAnswer
import scala.collection.mutable
import java.util
import scala.jdk.CollectionConverters._
class TransactionCoordinatorTest {
@ -57,7 +57,8 @@ class TransactionCoordinatorTest {
private val txnTimeoutMs = 1
private val producerId2 = 11L
private val partitions = mutable.Set[TopicPartition](new TopicPartition("topic1", 0))
private val partitions = new util.HashSet[TopicPartition]()
partitions.add(new TopicPartition("topic1", 0))
private val scheduler = new MockScheduler(time)
val coordinator = new TransactionCoordinator(
@ -198,7 +199,7 @@ class TransactionCoordinatorTest {
initPidGenericMocks(transactionalId)
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, (Short.MaxValue - 1).toShort,
(Short.MaxValue - 2).toShort, txnTimeoutMs, TransactionState.EMPTY, mutable.Set.empty, time.milliseconds(), time.milliseconds(), TV_0)
(Short.MaxValue - 2).toShort, txnTimeoutMs, TransactionState.EMPTY, util.Set.of, time.milliseconds(), time.milliseconds(), TV_0)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
@ -230,10 +231,10 @@ class TransactionCoordinatorTest {
initPidGenericMocks(transactionalId)
val txnMetadata1 = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, (Short.MaxValue - 1).toShort,
(Short.MaxValue - 2).toShort, txnTimeoutMs, TransactionState.ONGOING, mutable.Set.empty, time.milliseconds(), time.milliseconds(), TV_2)
(Short.MaxValue - 2).toShort, txnTimeoutMs, TransactionState.ONGOING, util.Set.of, time.milliseconds(), time.milliseconds(), TV_2)
// We start with txnMetadata1 so we can transform the metadata to TransactionState.PREPARE_COMMIT.
val txnMetadata2 = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, (Short.MaxValue - 1).toShort,
(Short.MaxValue - 2).toShort, txnTimeoutMs, TransactionState.ONGOING, mutable.Set.empty, time.milliseconds(), time.milliseconds(), TV_2)
(Short.MaxValue - 2).toShort, txnTimeoutMs, TransactionState.ONGOING, util.Set.of, time.milliseconds(), time.milliseconds(), TV_2)
val transitMetadata = txnMetadata2.prepareAbortOrCommit(TransactionState.PREPARE_COMMIT, TV_2, producerId2, time.milliseconds(), false)
txnMetadata2.completeTransitionTo(transitMetadata)
@ -376,8 +377,8 @@ class TransactionCoordinatorTest {
// Pending state does not matter, we will just check if the partitions are in the txnMetadata.
val ongoingTxnMetadata = new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID,
0, RecordBatch.NO_PRODUCER_EPOCH, 0, TransactionState.ONGOING, mutable.Set.empty, 0, 0, TV_0)
ongoingTxnMetadata.pendingState = Some(TransactionState.COMPLETE_COMMIT)
0, RecordBatch.NO_PRODUCER_EPOCH, 0, TransactionState.ONGOING, util.Set.of, 0, 0, TV_0)
ongoingTxnMetadata.pendingState(util.Optional.of(TransactionState.COMPLETE_COMMIT))
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(new CoordinatorEpochAndTxnMetadata(coordinatorEpoch, ongoingTxnMetadata))))
@ -402,7 +403,7 @@ class TransactionCoordinatorTest {
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch,
new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID,
0, RecordBatch.NO_PRODUCER_EPOCH, 0, state, mutable.Set.empty, 0, 0, TV_2)))))
0, RecordBatch.NO_PRODUCER_EPOCH, 0, state, util.Set.of, 0, 0, TV_2)))))
coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 0, partitions, errorsCallback, TV_2)
assertEquals(Errors.CONCURRENT_TRANSACTIONS, error)
@ -414,7 +415,7 @@ class TransactionCoordinatorTest {
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch,
new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID,
10, 9, 0, TransactionState.PREPARE_COMMIT, mutable.Set.empty, 0, 0, TV_2)))))
10, 9, 0, TransactionState.PREPARE_COMMIT, util.Set.of, 0, 0, TV_2)))))
coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 0, partitions, errorsCallback, TV_2)
assertEquals(Errors.PRODUCER_FENCED, error)
@ -445,7 +446,7 @@ class TransactionCoordinatorTest {
def validateSuccessfulAddPartitions(previousState: TransactionState, transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
producerEpoch, (producerEpoch - 1).toShort, txnTimeoutMs, previousState, mutable.Set.empty, time.milliseconds(), time.milliseconds(), clientTransactionVersion)
producerEpoch, (producerEpoch - 1).toShort, txnTimeoutMs, previousState, util.Set.of, time.milliseconds(), time.milliseconds(), clientTransactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
@ -506,7 +507,8 @@ class TransactionCoordinatorTest {
new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID,
0, RecordBatch.NO_PRODUCER_EPOCH, 0, TransactionState.EMPTY, partitions, 0, 0, TV_0)))))
val extraPartitions = partitions ++ Set(new TopicPartition("topic2", 0))
val extraPartitions = new util.HashSet[TopicPartition](partitions)
extraPartitions.add(new TopicPartition("topic2", 0))
coordinator.handleVerifyPartitionsInTransaction(transactionalId, 0L, 0, extraPartitions, verifyPartitionsInTxnCallback)
assertEquals(Errors.TRANSACTION_ABORTABLE, errors(new TopicPartition("topic2", 0)))
@ -533,7 +535,7 @@ class TransactionCoordinatorTest {
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch,
new TransactionMetadata(transactionalId, 10, 10, RecordBatch.NO_PRODUCER_ID,
0, RecordBatch.NO_PRODUCER_EPOCH, 0, TransactionState.ONGOING, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_0)))))
0, RecordBatch.NO_PRODUCER_EPOCH, 0, TransactionState.ONGOING, util.Set.of, 0, time.milliseconds(), TV_0)))))
coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback)
assertEquals(Errors.INVALID_PRODUCER_ID_MAPPING, error)
@ -547,7 +549,7 @@ class TransactionCoordinatorTest {
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch,
new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, producerEpoch,
(producerEpoch - 1).toShort, 1, TransactionState.ONGOING, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_0)))))
(producerEpoch - 1).toShort, 1, TransactionState.ONGOING, util.Set.of, 0, time.milliseconds(), TV_0)))))
coordinator.handleEndTransaction(transactionalId, producerId, 0, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback)
assertEquals(Errors.PRODUCER_FENCED, error)
@ -561,7 +563,7 @@ class TransactionCoordinatorTest {
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch,
new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, producerEpoch,
(producerEpoch - 1).toShort, 1, TransactionState.COMPLETE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)))))
(producerEpoch - 1).toShort, 1, TransactionState.COMPLETE_COMMIT, util.Set.of, 0, time.milliseconds(), clientTransactionVersion)))))
val epoch = if (isRetry) producerEpoch - 1 else producerEpoch
coordinator.handleEndTransaction(transactionalId, producerId, epoch.toShort, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback)
@ -588,7 +590,7 @@ class TransactionCoordinatorTest {
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch,
new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, producerEpoch,
(producerEpoch - 1).toShort, 1, TransactionState.COMPLETE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)))))
(producerEpoch - 1).toShort, 1, TransactionState.COMPLETE_COMMIT, util.Set.of, 0, time.milliseconds(), clientTransactionVersion)))))
val epoch = if (isRetry) producerEpoch - 1 else producerEpoch
coordinator.handleEndTransaction(transactionalId, producerId, epoch.toShort, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback)
@ -605,7 +607,7 @@ class TransactionCoordinatorTest {
def testEndTxnWhenStatusIsCompleteAbortAndResultIsAbortInV1(isRetry: Boolean): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(0)
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
producerEpoch, (producerEpoch - 1).toShort, 1, TransactionState.COMPLETE_ABORT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)
producerEpoch, (producerEpoch - 1).toShort, 1, TransactionState.COMPLETE_ABORT, util.Set.of, 0, time.milliseconds(), clientTransactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
@ -624,7 +626,7 @@ class TransactionCoordinatorTest {
def shouldReturnOkOnEndTxnWhenStatusIsCompleteAbortAndResultIsAbortInV2(isRetry: Boolean): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(2)
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
producerEpoch, (producerEpoch - 1).toShort, 1, TransactionState.COMPLETE_ABORT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)
producerEpoch, (producerEpoch - 1).toShort, 1, TransactionState.COMPLETE_ABORT, util.Set.of, 0, time.milliseconds(), clientTransactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
@ -661,7 +663,7 @@ class TransactionCoordinatorTest {
def shouldReturnInvalidTxnRequestOnEndTxnRequestWhenStatusIsCompleteAbortAndResultIsNotAbort(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
producerEpoch, (producerEpoch - 1).toShort, 1, TransactionState.COMPLETE_ABORT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)
producerEpoch, (producerEpoch - 1).toShort, 1, TransactionState.COMPLETE_ABORT, util.Set.of, 0, time.milliseconds(), clientTransactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
@ -674,7 +676,7 @@ class TransactionCoordinatorTest {
def shouldReturnInvalidTxnRequestOnEndTxnRequestWhenStatusIsCompleteCommitAndResultIsNotCommit(): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(0)
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
producerEpoch, (producerEpoch - 1).toShort,1, TransactionState.COMPLETE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)
producerEpoch, (producerEpoch - 1).toShort,1, TransactionState.COMPLETE_COMMIT, util.Set.of, 0, time.milliseconds(), clientTransactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
@ -688,7 +690,7 @@ class TransactionCoordinatorTest {
def testEndTxnRequestWhenStatusIsCompleteCommitAndResultIsAbortInV1(isRetry: Boolean): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(0)
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
producerEpoch, (producerEpoch - 1).toShort, 1, TransactionState.COMPLETE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)
producerEpoch, (producerEpoch - 1).toShort, 1, TransactionState.COMPLETE_COMMIT, util.Set.of, 0, time.milliseconds(), clientTransactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
@ -707,7 +709,7 @@ class TransactionCoordinatorTest {
def testEndTxnRequestWhenStatusIsCompleteCommitAndResultIsAbortInV2(isRetry: Boolean): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(2)
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
producerEpoch, (producerEpoch - 1).toShort, 1, TransactionState.COMPLETE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)
producerEpoch, (producerEpoch - 1).toShort, 1, TransactionState.COMPLETE_COMMIT, util.Set.of, 0, time.milliseconds(), clientTransactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
@ -738,7 +740,7 @@ class TransactionCoordinatorTest {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_ID, producerEpoch, (producerEpoch - 1).toShort, 1, TransactionState.PREPARE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)))))
RecordBatch.NO_PRODUCER_ID, producerEpoch, (producerEpoch - 1).toShort, 1, TransactionState.PREPARE_COMMIT, util.Set.of, 0, time.milliseconds(), clientTransactionVersion)))))
coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch(clientTransactionVersion), TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback)
assertEquals(Errors.CONCURRENT_TRANSACTIONS, error)
@ -751,7 +753,7 @@ class TransactionCoordinatorTest {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, TransactionState.PREPARE_ABORT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)))))
RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, TransactionState.PREPARE_ABORT, util.Set.of, 0, time.milliseconds(), clientTransactionVersion)))))
coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch(clientTransactionVersion), TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback)
assertEquals(Errors.INVALID_TXN_STATE, error)
@ -763,7 +765,7 @@ class TransactionCoordinatorTest {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(0)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, TransactionState.EMPTY, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)))))
RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, TransactionState.EMPTY, util.Set.of, 0, time.milliseconds(), clientTransactionVersion)))))
coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.ABORT, clientTransactionVersion, endTxnCallback)
assertEquals(Errors.INVALID_TXN_STATE, error)
@ -776,7 +778,7 @@ class TransactionCoordinatorTest {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(2)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, TransactionState.EMPTY, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)))))
RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, TransactionState.EMPTY, util.Set.of, 0, time.milliseconds(), clientTransactionVersion)))))
val epoch = if (isRetry) producerEpoch - 1 else producerEpoch
coordinator.handleEndTransaction(transactionalId, producerId, epoch.toShort, TransactionResult.ABORT, clientTransactionVersion, endTxnCallback)
@ -805,7 +807,7 @@ class TransactionCoordinatorTest {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(2)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, TransactionState.EMPTY, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)))))
RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, TransactionState.EMPTY, util.Set.of, 0, time.milliseconds(), clientTransactionVersion)))))
val epoch = if (isRetry) producerEpoch - 1 else producerEpoch
coordinator.handleEndTransaction(transactionalId, producerId, epoch.toShort, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback)
@ -821,7 +823,7 @@ class TransactionCoordinatorTest {
def shouldReturnInvalidTxnRequestOnEndTxnV2IfNotEndTxnV2Retry(): Unit = {
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, TransactionState.PREPARE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2)))))
RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, TransactionState.PREPARE_COMMIT, util.Set.of, 0, time.milliseconds(), TV_2)))))
// If producerEpoch is the same, this is not a retry of the EndTxnRequest, but the next EndTxnRequest. Return PRODUCER_FENCED.
coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.COMMIT, TV_2, endTxnCallback)
@ -830,7 +832,7 @@ class TransactionCoordinatorTest {
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, TransactionState.COMPLETE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2)))))
RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, TransactionState.COMPLETE_COMMIT, util.Set.of, 0, time.milliseconds(), TV_2)))))
// If producerEpoch is the same, this is not a retry of the EndTxnRequest, but the next EndTxnRequest. Return INVALID_TXN_STATE.
coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.COMMIT, TV_2, endTxnCallback)
@ -842,7 +844,7 @@ class TransactionCoordinatorTest {
def shouldReturnOkOnEndTxnV2IfEndTxnV2RetryEpochOverflow(): Unit = {
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId,
producerId2, Short.MaxValue, (Short.MaxValue - 1).toShort, 1, TransactionState.PREPARE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2)))))
producerId2, Short.MaxValue, (Short.MaxValue - 1).toShort, 1, TransactionState.PREPARE_COMMIT, util.Set.of, 0, time.milliseconds(), TV_2)))))
// Return CONCURRENT_TRANSACTIONS while transaction is still completing
coordinator.handleEndTransaction(transactionalId, producerId, (Short.MaxValue - 1).toShort, TransactionResult.COMMIT, TV_2, endTxnCallback)
@ -851,7 +853,7 @@ class TransactionCoordinatorTest {
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId2, producerId,
RecordBatch.NO_PRODUCER_ID, 0, RecordBatch.NO_PRODUCER_EPOCH, 1, TransactionState.COMPLETE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2)))))
RecordBatch.NO_PRODUCER_ID, 0, RecordBatch.NO_PRODUCER_EPOCH, 1, TransactionState.COMPLETE_COMMIT, util.Set.of, 0, time.milliseconds(), TV_2)))))
coordinator.handleEndTransaction(transactionalId, producerId, (Short.MaxValue - 1).toShort, TransactionResult.COMMIT, TV_2, endTxnCallback)
assertEquals(Errors.NONE, error)
@ -864,7 +866,7 @@ class TransactionCoordinatorTest {
@Test
def shouldReturnConcurrentTxnOnAddPartitionsIfEndTxnV2EpochOverflowAndNotComplete(): Unit = {
val prepareWithPending = new TransactionMetadata(transactionalId, producerId, producerId,
producerId2, Short.MaxValue, (Short.MaxValue - 1).toShort, 1, TransactionState.PREPARE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2)
producerId2, Short.MaxValue, (Short.MaxValue - 1).toShort, 1, TransactionState.PREPARE_COMMIT, util.Set.of, 0, time.milliseconds(), TV_2)
val txnTransitMetadata = prepareWithPending.prepareComplete(time.milliseconds())
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
@ -990,7 +992,7 @@ class TransactionCoordinatorTest {
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch,
new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, metadataEpoch, 1,
1, TransactionState.COMPLETE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)))))
1, TransactionState.COMPLETE_COMMIT, util.Set.of, 0, time.milliseconds(), clientTransactionVersion)))))
coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback)
assertEquals(Errors.PRODUCER_FENCED, error)
@ -1132,10 +1134,10 @@ class TransactionCoordinatorTest {
any())
).thenAnswer(_ => {
capturedErrorsCallback.getValue.apply(Errors.NOT_ENOUGH_REPLICAS)
txnMetadata.pendingState = None
txnMetadata.pendingState(util.Optional.empty())
}).thenAnswer(_ => {
capturedErrorsCallback.getValue.apply(Errors.NOT_ENOUGH_REPLICAS)
txnMetadata.pendingState = None
txnMetadata.pendingState(util.Optional.empty())
}).thenAnswer(_ => {
capturedErrorsCallback.getValue.apply(Errors.NONE)
@ -1226,7 +1228,7 @@ class TransactionCoordinatorTest {
RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs,
TransactionState.PREPARE_ABORT,
partitions.clone.asJava,
partitions,
time.milliseconds(),
time.milliseconds(),
TV_0)),
@ -1259,7 +1261,7 @@ class TransactionCoordinatorTest {
RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs,
TransactionState.PREPARE_ABORT,
partitions.clone.asJava,
partitions,
time.milliseconds(),
time.milliseconds(),
TV_0)),
@ -1334,18 +1336,18 @@ class TransactionCoordinatorTest {
// Create transaction metadata at the epoch boundary that would cause overflow IFF double-incremented
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
prevProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = epochAtMaxBoundary,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = txnTimeoutMs,
state = TransactionState.ONGOING,
topicPartitions = partitions,
txnStartTimestamp = now,
txnLastUpdateTimestamp = now,
clientTransactionVersion = TV_2
transactionalId,
producerId,
RecordBatch.NO_PRODUCER_ID,
RecordBatch.NO_PRODUCER_ID,
epochAtMaxBoundary,
RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs,
TransactionState.ONGOING,
partitions,
now,
now,
TV_2
)
assertTrue(txnMetadata.isProducerEpochExhausted)
@ -1472,7 +1474,7 @@ class TransactionCoordinatorTest {
any())
).thenAnswer(_ => {
capturedErrorsCallback.getValue.apply(Errors.NONE)
txnMetadata.pendingState = None
txnMetadata.pendingState(util.Optional.empty())
})
// Re-initialization should succeed and bump the producer epoch
@ -1520,9 +1522,9 @@ class TransactionCoordinatorTest {
any())
).thenAnswer(_ => {
capturedErrorsCallback.getValue.apply(Errors.NONE)
txnMetadata.pendingState = None
txnMetadata.producerEpoch = capturedTxnTransitMetadata.getValue.producerEpoch
txnMetadata.lastProducerEpoch = capturedTxnTransitMetadata.getValue.lastProducerEpoch
txnMetadata.pendingState(util.Optional.empty())
txnMetadata.setProducerEpoch(capturedTxnTransitMetadata.getValue.producerEpoch)
txnMetadata.setLastProducerEpoch(capturedTxnTransitMetadata.getValue.lastProducerEpoch)
})
// With producer epoch at 10, new producer calls InitProducerId and should get epoch 11
@ -1571,11 +1573,11 @@ class TransactionCoordinatorTest {
any())
).thenAnswer(_ => {
capturedErrorsCallback.getValue.apply(Errors.NONE)
txnMetadata.pendingState = None
txnMetadata.producerId = capturedTxnTransitMetadata.getValue.producerId
txnMetadata.prevProducerId = capturedTxnTransitMetadata.getValue.prevProducerId
txnMetadata.producerEpoch = capturedTxnTransitMetadata.getValue.producerEpoch
txnMetadata.lastProducerEpoch = capturedTxnTransitMetadata.getValue.lastProducerEpoch
txnMetadata.pendingState(util.Optional.empty())
txnMetadata.setProducerId(capturedTxnTransitMetadata.getValue.producerId)
txnMetadata.setPrevProducerId(capturedTxnTransitMetadata.getValue.prevProducerId)
txnMetadata.setProducerEpoch(capturedTxnTransitMetadata.getValue.producerEpoch)
txnMetadata.setLastProducerEpoch(capturedTxnTransitMetadata.getValue.lastProducerEpoch)
})
// Bump epoch and cause producer ID to be rotated
@ -1624,11 +1626,11 @@ class TransactionCoordinatorTest {
any())
).thenAnswer(_ => {
capturedErrorsCallback.getValue.apply(Errors.NONE)
txnMetadata.pendingState = None
txnMetadata.producerId = capturedTxnTransitMetadata.getValue.producerId
txnMetadata.prevProducerId = capturedTxnTransitMetadata.getValue.prevProducerId
txnMetadata.producerEpoch = capturedTxnTransitMetadata.getValue.producerEpoch
txnMetadata.lastProducerEpoch = capturedTxnTransitMetadata.getValue.lastProducerEpoch
txnMetadata.pendingState(util.Optional.empty())
txnMetadata.setProducerId(capturedTxnTransitMetadata.getValue.producerId)
txnMetadata.setPrevProducerId(capturedTxnTransitMetadata.getValue.prevProducerId)
txnMetadata.setProducerEpoch(capturedTxnTransitMetadata.getValue.producerEpoch)
txnMetadata.setLastProducerEpoch(capturedTxnTransitMetadata.getValue.lastProducerEpoch)
})
// Bump epoch and cause producer ID to be rotated
@ -1674,7 +1676,7 @@ class TransactionCoordinatorTest {
// Transaction timeouts use FenceProducerEpoch so clientTransactionVersion is 0.
val expectedTransition = new TxnTransitMetadata(producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, (producerEpoch + 1).toShort,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.PREPARE_ABORT, partitions.clone.asJava, now,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.PREPARE_ABORT, partitions, now,
now + TransactionStateManagerConfig.TRANSACTIONS_ABORT_TIMED_OUT_TRANSACTION_CLEANUP_INTERVAL_MS_DEFAULT, TV_0)
when(transactionManager.transactionVersionLevel()).thenReturn(TV_0)
@ -1764,7 +1766,7 @@ class TransactionCoordinatorTest {
// Transaction timeouts use FenceProducerEpoch so clientTransactionVersion is 0.
val bumpedEpoch = (producerEpoch + 1).toShort
val expectedTransition = new TxnTransitMetadata(producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, bumpedEpoch,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.PREPARE_ABORT, partitions.clone.asJava, now,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.PREPARE_ABORT, partitions, now,
now + TransactionStateManagerConfig.TRANSACTIONS_ABORT_TIMED_OUT_TRANSACTION_CLEANUP_INTERVAL_MS_DEFAULT, TV_0)
when(transactionManager.transactionVersionLevel()).thenReturn(TV_0)
@ -1832,7 +1834,7 @@ class TransactionCoordinatorTest {
coordinator.startup(() => transactionStatePartitionCount, enableTransactionalIdExpiration = false)
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.DEAD, mutable.Set.empty, time.milliseconds(),
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.DEAD, util.Set.of, time.milliseconds(),
time.milliseconds(), TV_0)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
@ -1872,9 +1874,11 @@ class TransactionCoordinatorTest {
assertEquals(txnTimeoutMs, result.transactionTimeoutMs)
assertEquals(time.milliseconds(), result.transactionStartTimeMs)
val addedPartitions = result.topics.asScala.flatMap { topicData =>
topicData.partitions.asScala.map(partition => new TopicPartition(topicData.topic, partition))
}.toSet
val addedPartitions = result.topics.stream.flatMap(topicData =>
topicData.partitions.stream
.map(partition => new TopicPartition(topicData.topic, partition))
)
.collect(util.stream.Collectors.toSet());
assertEquals(partitions, addedPartitions)
verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId))
@ -1886,7 +1890,7 @@ class TransactionCoordinatorTest {
// Since the clientTransactionVersion doesn't matter, use 2 since the states are TransactionState.PREPARE_COMMIT and TransactionState.PREPARE_ABORT.
val metadata = new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_EPOCH,
0, RecordBatch.NO_PRODUCER_EPOCH, 0, state, mutable.Set[TopicPartition](new TopicPartition("topic", 1)), 0, 0, TV_2)
0, RecordBatch.NO_PRODUCER_EPOCH, 0, state, util.Set.of[TopicPartition](new TopicPartition("topic", 1)), 0, 0, TV_2)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, metadata))))
@ -1905,7 +1909,7 @@ class TransactionCoordinatorTest {
.thenReturn(true)
val metadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH,
producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, state, mutable.Set.empty[TopicPartition], time.milliseconds(), time.milliseconds(), clientTransactionVersion)
producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, state, util.Set.of, time.milliseconds(), time.milliseconds(), clientTransactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, metadata))))
@ -1939,7 +1943,7 @@ class TransactionCoordinatorTest {
producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.ONGOING, partitions, now, now, TV_0)
val transition = new TxnTransitMetadata(producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, transactionState, partitions.clone.asJava, now, now, clientTransactionVersion)
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, transactionState, partitions, now, now, clientTransactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, originalMetadata))))
@ -2007,7 +2011,7 @@ class TransactionCoordinatorTest {
// Simulate the real TransactionStateManager behavior: reset pendingState on failure
// since handleInitProducerId doesn't provide a custom retryOnError function
txnMetadata.pendingState = None
txnMetadata.pendingState(util.Optional.empty())
// For TV2, hasFailedEpochFence is NOT set to true, allowing epoch bumps on retry
// The epoch remains at its original value (1) since completeTransitionTo was never called
@ -2062,7 +2066,7 @@ class TransactionCoordinatorTest {
// Simulate the completion of transaction markers and the second write
// This would normally happen asynchronously after markers are sent
txnMetadata.completeTransitionTo(newMetadata) // This transitions to COMPLETE_ABORT
txnMetadata.pendingState = None
txnMetadata.pendingState(util.Optional.empty())
null
}).when(transactionMarkerChannelManager).addTxnMarkersToSend(

View File

@ -22,7 +22,7 @@ import org.apache.kafka.common.protocol.{ByteBufferAccessor, MessageUtil}
import org.apache.kafka.common.protocol.types.Field.TaggedFieldsSection
import org.apache.kafka.common.protocol.types.{CompactArrayOf, Field, Schema, Struct, Type}
import org.apache.kafka.common.record.{MemoryRecords, RecordBatch, SimpleRecord}
import org.apache.kafka.coordinator.transaction.{TransactionState, TxnTransitMetadata}
import org.apache.kafka.coordinator.transaction.{TransactionMetadata, TransactionState, TxnTransitMetadata}
import org.apache.kafka.coordinator.transaction.generated.{TransactionLogKey, TransactionLogValue}
import org.apache.kafka.server.common.TransactionVersion.{TV_0, TV_2}
import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows, assertTrue, fail}
@ -38,7 +38,7 @@ class TransactionLogTest {
val producerEpoch: Short = 0
val transactionTimeoutMs: Int = 1000
val topicPartitions: Set[TopicPartition] = Set[TopicPartition](new TopicPartition("topic1", 0),
val topicPartitions = util.Set.of(new TopicPartition("topic1", 0),
new TopicPartition("topic1", 1),
new TopicPartition("topic2", 0),
new TopicPartition("topic2", 1),
@ -50,7 +50,7 @@ class TransactionLogTest {
val producerId = 23423L
val txnMetadata = new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_ID, producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH, transactionTimeoutMs, TransactionState.EMPTY, collection.mutable.Set.empty[TopicPartition], 0, 0, TV_0)
RecordBatch.NO_PRODUCER_EPOCH, transactionTimeoutMs, TransactionState.EMPTY, util.Set.of, 0, 0, TV_0)
txnMetadata.addPartitions(topicPartitions)
assertThrows(classOf[IllegalStateException], () => TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), TV_2))
@ -75,7 +75,7 @@ class TransactionLogTest {
// generate transaction log messages
val txnRecords = pidMappings.map { case (transactionalId, producerId) =>
val txnMetadata = new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_ID, producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH, transactionTimeoutMs, transactionStates(producerId), collection.mutable.Set.empty[TopicPartition], 0, 0, TV_0)
RecordBatch.NO_PRODUCER_EPOCH, transactionTimeoutMs, transactionStates(producerId), util.Set.of, 0, 0, TV_0)
if (!txnMetadata.state.equals(TransactionState.EMPTY))
txnMetadata.addPartitions(topicPartitions)
@ -101,7 +101,7 @@ class TransactionLogTest {
assertEquals(transactionStates(txnMetadata.producerId), txnMetadata.state)
if (txnMetadata.state.equals(TransactionState.EMPTY))
assertEquals(Set.empty[TopicPartition], txnMetadata.topicPartitions)
assertEquals(util.Set.of, txnMetadata.topicPartitions)
else
assertEquals(topicPartitions, txnMetadata.topicPartitions)
@ -114,14 +114,14 @@ class TransactionLogTest {
@Test
def testSerializeTransactionLogValueToHighestNonFlexibleVersion(): Unit = {
val txnTransitMetadata = new TxnTransitMetadata(1, 1, 1, 1, 1, 1000, TransactionState.COMPLETE_COMMIT, util.Set.of, 500, 500, TV_0)
val txnTransitMetadata = new TxnTransitMetadata(1, 1, 1, 1, 1, 1000, TransactionState.COMPLETE_COMMIT, new util.HashSet(), 500, 500, TV_0)
val txnLogValueBuffer = ByteBuffer.wrap(TransactionLog.valueToBytes(txnTransitMetadata, TV_0))
assertEquals(0, txnLogValueBuffer.getShort)
}
@Test
def testSerializeTransactionLogValueToFlexibleVersion(): Unit = {
val txnTransitMetadata = new TxnTransitMetadata(1, 1, 1, 1, 1, 1000, TransactionState.COMPLETE_COMMIT, util.Set.of, 500, 500, TV_2)
val txnTransitMetadata = new TxnTransitMetadata(1, 1, 1, 1, 1, 1000, TransactionState.COMPLETE_COMMIT, new util.HashSet(), 500, 500, TV_2)
val txnLogValueBuffer = ByteBuffer.wrap(TransactionLog.valueToBytes(txnTransitMetadata, TV_2))
assertEquals(TransactionLogValue.HIGHEST_SUPPORTED_VERSION, txnLogValueBuffer.getShort)
}
@ -229,7 +229,7 @@ class TransactionLogTest {
assertEquals(100, txnMetadata.producerEpoch)
assertEquals(1000L, txnMetadata.txnTimeoutMs)
assertEquals(TransactionState.COMPLETE_COMMIT, txnMetadata.state)
assertEquals(Set(new TopicPartition("topic", 1)), txnMetadata.topicPartitions)
assertEquals(util.Set.of(new TopicPartition("topic", 1)), txnMetadata.topicPartitions)
assertEquals(2000L, txnMetadata.txnLastUpdateTimestamp)
assertEquals(3000L, txnMetadata.txnStartTimestamp)
}

View File

@ -27,7 +27,7 @@ import org.apache.kafka.common.record.RecordBatch
import org.apache.kafka.common.requests.{RequestHeader, TransactionResult, WriteTxnMarkersRequest, WriteTxnMarkersResponse}
import org.apache.kafka.common.utils.MockTime
import org.apache.kafka.common.{Node, TopicPartition}
import org.apache.kafka.coordinator.transaction.TransactionState
import org.apache.kafka.coordinator.transaction.{TransactionMetadata, TransactionState}
import org.apache.kafka.metadata.MetadataCache
import org.apache.kafka.server.common.{MetadataVersion, TransactionVersion}
import org.apache.kafka.server.metrics.{KafkaMetricsGroup, KafkaYammerMetrics}
@ -41,7 +41,6 @@ import org.mockito.{ArgumentCaptor, ArgumentMatchers}
import org.mockito.Mockito.{clearInvocations, mock, mockConstruction, times, verify, verifyNoMoreInteractions, when}
import scala.jdk.CollectionConverters._
import scala.collection.mutable
import scala.util.Try
class TransactionMarkerChannelManagerTest {
@ -67,9 +66,9 @@ class TransactionMarkerChannelManagerTest {
private val txnTimeoutMs = 0
private val txnResult = TransactionResult.COMMIT
private val txnMetadata1 = new TransactionMetadata(transactionalId1, producerId1, producerId1, RecordBatch.NO_PRODUCER_ID,
producerEpoch, lastProducerEpoch, txnTimeoutMs, TransactionState.PREPARE_COMMIT, mutable.Set[TopicPartition](partition1, partition2), 0L, 0L, TransactionVersion.TV_2)
producerEpoch, lastProducerEpoch, txnTimeoutMs, TransactionState.PREPARE_COMMIT, util.Set.of(partition1, partition2), 0L, 0L, TransactionVersion.TV_2)
private val txnMetadata2 = new TransactionMetadata(transactionalId2, producerId2, producerId2, RecordBatch.NO_PRODUCER_ID,
producerEpoch, lastProducerEpoch, txnTimeoutMs, TransactionState.PREPARE_COMMIT, mutable.Set[TopicPartition](partition1), 0L, 0L, TransactionVersion.TV_2)
producerEpoch, lastProducerEpoch, txnTimeoutMs, TransactionState.PREPARE_COMMIT, util.Set.of(partition1), 0L, 0L, TransactionVersion.TV_2)
private val capturedErrorsCallback: ArgumentCaptor[Errors => Unit] = ArgumentCaptor.forClass(classOf[Errors => Unit])
private val time = new MockTime
@ -145,33 +144,33 @@ class TransactionMarkerChannelManagerTest {
var addMarkerFuture: Future[Try[Unit]] = null
val executor = Executors.newFixedThreadPool(1)
txnMetadata2.lock.lock()
try {
addMarkerFuture = executor.submit((() => {
Try(channelManager.addTxnMarkersToSend(coordinatorEpoch, txnResult,
txnMetadata2.inLock(() => {
addMarkerFuture = executor.submit((() => {
Try(channelManager.addTxnMarkersToSend(coordinatorEpoch, txnResult,
txnMetadata2, expectedTransition))
}): Callable[Try[Unit]])
}): Callable[Try[Unit]])
val header = new RequestHeader(ApiKeys.WRITE_TXN_MARKERS, 0, "client", 1)
val response = new WriteTxnMarkersResponse(
util.Map.of(producerId2: java.lang.Long, util.Map.of(partition1, Errors.NONE)))
val clientResponse = new ClientResponse(header, null, null,
time.milliseconds(), time.milliseconds(), false, null, null,
response)
val header = new RequestHeader(ApiKeys.WRITE_TXN_MARKERS, 0, "client", 1)
val response = new WriteTxnMarkersResponse(
util.Map.of(producerId2: java.lang.Long, util.Map.of(partition1, Errors.NONE)))
val clientResponse = new ClientResponse(header, null, null,
time.milliseconds(), time.milliseconds(), false, null, null,
response)
TestUtils.waitUntilTrue(() => {
val requests = channelManager.generateRequests().asScala
if (requests.nonEmpty) {
assertEquals(1, requests.size)
val request = requests.head
request.handler.onComplete(clientResponse)
true
} else {
false
}
}, "Timed out waiting for expected WriteTxnMarkers request")
TestUtils.waitUntilTrue(() => {
val requests = channelManager.generateRequests().asScala
if (requests.nonEmpty) {
assertEquals(1, requests.size)
val request = requests.head
request.handler.onComplete(clientResponse)
true
} else {
false
}
}, "Timed out waiting for expected WriteTxnMarkers request")
})
} finally {
txnMetadata2.lock.unlock()
executor.shutdown()
}
@ -478,7 +477,7 @@ class TransactionMarkerChannelManagerTest {
assertEquals(0, channelManager.numTxnsWithPendingMarkers)
assertEquals(0, channelManager.queueForBroker(broker1.id).get.totalNumMarkers)
assertEquals(None, txnMetadata2.pendingState)
assertEquals(Optional.empty(), txnMetadata2.pendingState)
assertEquals(TransactionState.COMPLETE_COMMIT, txnMetadata2.state)
}
@ -507,7 +506,7 @@ class TransactionMarkerChannelManagerTest {
any(),
any()))
.thenAnswer(_ => {
txnMetadata2.pendingState = None
txnMetadata2.pendingState(util.Optional.empty())
capturedErrorsCallback.getValue.apply(Errors.NOT_COORDINATOR)
})
@ -531,7 +530,7 @@ class TransactionMarkerChannelManagerTest {
assertEquals(0, channelManager.numTxnsWithPendingMarkers)
assertEquals(0, channelManager.queueForBroker(broker1.id).get.totalNumMarkers)
assertEquals(None, txnMetadata2.pendingState)
assertEquals(Optional.empty(), txnMetadata2.pendingState)
assertEquals(TransactionState.PREPARE_COMMIT, txnMetadata2.state)
}
@ -592,7 +591,7 @@ class TransactionMarkerChannelManagerTest {
assertEquals(0, channelManager.numTxnsWithPendingMarkers)
assertEquals(0, channelManager.queueForBroker(broker1.id).get.totalNumMarkers)
assertEquals(None, txnMetadata2.pendingState)
assertEquals(Optional.empty(), txnMetadata2.pendingState)
assertEquals(TransactionState.COMPLETE_COMMIT, txnMetadata2.state)
}
@ -632,11 +631,11 @@ class TransactionMarkerChannelManagerTest {
txnMetadata: TransactionMetadata
): Unit = {
if (isTransactionV2Enabled) {
txnMetadata.clientTransactionVersion = TransactionVersion.TV_2
txnMetadata.producerEpoch = (producerEpoch + 1).toShort
txnMetadata.lastProducerEpoch = producerEpoch
txnMetadata.clientTransactionVersion(TransactionVersion.TV_2)
txnMetadata.setProducerEpoch((producerEpoch + 1).toShort)
txnMetadata.setLastProducerEpoch(producerEpoch)
} else {
txnMetadata.clientTransactionVersion = TransactionVersion.TV_1
txnMetadata.clientTransactionVersion(TransactionVersion.TV_1)
}
}
}

View File

@ -22,15 +22,13 @@ import org.apache.kafka.common.TopicPartition
import org.apache.kafka.common.protocol.{ApiKeys, Errors}
import org.apache.kafka.common.record.RecordBatch
import org.apache.kafka.common.requests.{RequestHeader, TransactionResult, WriteTxnMarkersRequest, WriteTxnMarkersResponse}
import org.apache.kafka.coordinator.transaction.TransactionState
import org.apache.kafka.coordinator.transaction.{TransactionMetadata, TransactionState}
import org.apache.kafka.server.common.TransactionVersion
import org.junit.jupiter.api.Assertions._
import org.junit.jupiter.api.Test
import org.mockito.ArgumentMatchers
import org.mockito.Mockito.{mock, verify, when}
import scala.collection.mutable
class TransactionMarkerRequestCompletionHandlerTest {
private val brokerId = 0
@ -44,7 +42,7 @@ class TransactionMarkerRequestCompletionHandlerTest {
private val txnResult = TransactionResult.COMMIT
private val topicPartition = new TopicPartition("topic1", 0)
private val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
producerEpoch, lastProducerEpoch, txnTimeoutMs, TransactionState.PREPARE_COMMIT, mutable.Set[TopicPartition](topicPartition), 0L, 0L, TransactionVersion.TV_2)
producerEpoch, lastProducerEpoch, txnTimeoutMs, TransactionState.PREPARE_COMMIT, util.Set.of(topicPartition), 0L, 0L, TransactionVersion.TV_2)
private val pendingCompleteTxnAndMarkers = util.List.of(
PendingCompleteTxnAndMarkerEntry(
PendingCompleteTxn(transactionalId, coordinatorEpoch, txnMetadata, txnMetadata.prepareComplete(42)),
@ -194,7 +192,7 @@ class TransactionMarkerRequestCompletionHandlerTest {
handler.onComplete(new ClientResponse(new RequestHeader(ApiKeys.PRODUCE, 0, "client", 1),
null, null, 0, 0, false, null, null, response))
assertEquals(txnMetadata.topicPartitions, mutable.Set[TopicPartition](topicPartition))
assertEquals(txnMetadata.topicPartitions, util.Set.of(topicPartition))
verify(markerChannelManager).addTxnMarkersToBrokerQueue(producerId,
producerEpoch, txnResult, pendingCompleteTxnAndMarkers.get(0).pendingCompleteTxn,
Set[TopicPartition](topicPartition))

View File

@ -19,7 +19,7 @@ package kafka.coordinator.transaction
import org.apache.kafka.common.TopicPartition
import org.apache.kafka.common.protocol.Errors
import org.apache.kafka.common.record.RecordBatch
import org.apache.kafka.coordinator.transaction.{TransactionState, TxnTransitMetadata}
import org.apache.kafka.coordinator.transaction.{TransactionMetadata, TransactionState, TxnTransitMetadata}
import org.apache.kafka.server.common.TransactionVersion
import org.apache.kafka.server.common.TransactionVersion.{TV_0, TV_2}
import org.apache.kafka.server.util.MockTime
@ -28,6 +28,7 @@ import org.junit.jupiter.api.Test
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.ValueSource
import java.util
import java.util.Optional
import scala.collection.mutable
@ -44,19 +45,20 @@ class TransactionMetadataTest {
val producerEpoch = RecordBatch.NO_PRODUCER_EPOCH
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
prevProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = TransactionState.EMPTY,
topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
transactionalId,
producerId,
RecordBatch.NO_PRODUCER_ID,
RecordBatch.NO_PRODUCER_ID,
producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH,
30000,
TransactionState.EMPTY,
util.Set.of,
-1,
time.milliseconds(),
TV_0)
val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, None)
val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, Optional.empty())
txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(producerId, txnMetadata.producerId)
assertEquals(0, txnMetadata.producerEpoch)
@ -68,19 +70,20 @@ class TransactionMetadataTest {
val producerEpoch = 735.toShort
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
prevProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = TransactionState.EMPTY,
topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
transactionalId,
producerId,
RecordBatch.NO_PRODUCER_ID,
RecordBatch.NO_PRODUCER_ID,
producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH,
30000,
TransactionState.EMPTY,
util.Set.of,
-1,
time.milliseconds(),
TV_0)
val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, None)
val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, Optional.empty())
txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(producerId, txnMetadata.producerId)
assertEquals(producerEpoch + 1, txnMetadata.producerEpoch)
@ -92,21 +95,22 @@ class TransactionMetadataTest {
val producerEpoch = (Short.MaxValue - 1).toShort
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
prevProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = TransactionState.EMPTY,
topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
transactionalId,
producerId,
RecordBatch.NO_PRODUCER_ID,
RecordBatch.NO_PRODUCER_ID,
producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH,
30000,
TransactionState.EMPTY,
util.Set.of,
-1,
time.milliseconds(),
TV_0)
assertTrue(txnMetadata.isProducerEpochExhausted)
assertThrows(classOf[IllegalStateException], () => txnMetadata.prepareIncrementProducerEpoch(30000,
None, time.milliseconds()))
Optional.empty, time.milliseconds()))
}
@Test
@ -114,20 +118,20 @@ class TransactionMetadataTest {
val producerEpoch = 735.toShort
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
prevProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = TransactionState.EMPTY,
topicPartitions = mutable.Set.empty,
txnStartTimestamp = -1,
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_2)
transactionalId,
producerId,
RecordBatch.NO_PRODUCER_ID,
RecordBatch.NO_PRODUCER_ID,
producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH,
30000,
TransactionState.EMPTY,
util.Set.of,
-1,
time.milliseconds(),
TV_2)
val transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_ABORT, TV_2, RecordBatch.NO_PRODUCER_ID, time.milliseconds() + 1, noPartitionAdded = true)
val transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_ABORT, TV_2, RecordBatch.NO_PRODUCER_ID, time.milliseconds() + 1, true)
txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(producerId, txnMetadata.producerId)
assertEquals(producerEpoch + 1, txnMetadata.producerEpoch)
@ -139,20 +143,20 @@ class TransactionMetadataTest {
val producerEpoch = 735.toShort
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
prevProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = TransactionState.COMPLETE_ABORT,
topicPartitions = mutable.Set.empty,
txnStartTimestamp = time.milliseconds() - 1,
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_2)
transactionalId,
producerId,
RecordBatch.NO_PRODUCER_ID,
RecordBatch.NO_PRODUCER_ID,
producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH,
30000,
TransactionState.COMPLETE_ABORT,
util.Set.of,
time.milliseconds() - 1,
time.milliseconds(),
TV_2)
val transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_ABORT, TV_2, RecordBatch.NO_PRODUCER_ID, time.milliseconds() + 1, noPartitionAdded = true)
val transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_ABORT, TV_2, RecordBatch.NO_PRODUCER_ID, time.milliseconds() + 1, true)
txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(producerId, txnMetadata.producerId)
assertEquals(producerEpoch + 1, txnMetadata.producerEpoch)
@ -164,20 +168,20 @@ class TransactionMetadataTest {
val producerEpoch = 735.toShort
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
prevProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = TransactionState.COMPLETE_COMMIT,
topicPartitions = mutable.Set.empty,
txnStartTimestamp = time.milliseconds() - 1,
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_2)
transactionalId,
producerId,
RecordBatch.NO_PRODUCER_ID,
RecordBatch.NO_PRODUCER_ID,
producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH,
30000,
TransactionState.COMPLETE_COMMIT,
util.Set.of,
time.milliseconds() - 1,
time.milliseconds(),
TV_2)
val transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_ABORT, TV_2, RecordBatch.NO_PRODUCER_ID, time.milliseconds() + 1, noPartitionAdded = true)
val transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_ABORT, TV_2, RecordBatch.NO_PRODUCER_ID, time.milliseconds() + 1, true)
txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(producerId, txnMetadata.producerId)
assertEquals(producerEpoch + 1, txnMetadata.producerEpoch)
@ -188,21 +192,21 @@ class TransactionMetadataTest {
def testTolerateUpdateTimeShiftDuringEpochBump(): Unit = {
val producerEpoch: Short = 1
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
prevProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = TransactionState.EMPTY,
topicPartitions = mutable.Set.empty,
txnStartTimestamp = 1L,
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
transactionalId,
producerId,
RecordBatch.NO_PRODUCER_ID,
RecordBatch.NO_PRODUCER_ID,
producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH,
30000,
TransactionState.EMPTY,
util.Set.of,
1L,
time.milliseconds(),
TV_0)
// let new time be smaller
val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, Option(producerEpoch),
val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, Optional.of(producerEpoch),
Some(time.milliseconds() - 1))
txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(producerId, txnMetadata.producerId)
@ -216,21 +220,21 @@ class TransactionMetadataTest {
def testTolerateUpdateTimeResetDuringProducerIdRotation(): Unit = {
val producerEpoch: Short = 1
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
prevProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = TransactionState.EMPTY,
topicPartitions = mutable.Set.empty,
txnStartTimestamp = 1L,
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
transactionalId,
producerId,
RecordBatch.NO_PRODUCER_ID,
RecordBatch.NO_PRODUCER_ID,
producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH,
30000,
TransactionState.EMPTY,
util.Set.of,
1L,
time.milliseconds(),
TV_0)
// let new time be smaller
val transitMetadata = txnMetadata.prepareProducerIdRotation(producerId + 1, 30000, time.milliseconds() - 1, recordLastEpoch = true)
val transitMetadata = txnMetadata.prepareProducerIdRotation(producerId + 1, 30000, time.milliseconds() - 1, true)
txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(producerId + 1, txnMetadata.producerId)
assertEquals(producerEpoch, txnMetadata.lastProducerEpoch)
@ -243,23 +247,23 @@ class TransactionMetadataTest {
def testTolerateTimeShiftDuringAddPartitions(): Unit = {
val producerEpoch: Short = 1
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
prevProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = TransactionState.EMPTY,
topicPartitions = mutable.Set.empty,
txnStartTimestamp = time.milliseconds(),
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
transactionalId,
producerId,
RecordBatch.NO_PRODUCER_ID,
RecordBatch.NO_PRODUCER_ID,
producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH,
30000,
TransactionState.EMPTY,
util.Set.of,
time.milliseconds(),
time.milliseconds(),
TV_0)
// let new time be smaller; when transiting from TransactionState.EMPTY the start time would be updated to the update-time
var transitMetadata = txnMetadata.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic1", 0)), time.milliseconds() - 1, TV_0)
var transitMetadata = txnMetadata.prepareAddPartitions(util.Set.of(new TopicPartition("topic1", 0)), time.milliseconds() - 1, TV_0)
txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(Set[TopicPartition](new TopicPartition("topic1", 0)), txnMetadata.topicPartitions)
assertEquals(util.Set.of(new TopicPartition("topic1", 0)), txnMetadata.topicPartitions)
assertEquals(producerId, txnMetadata.producerId)
assertEquals(RecordBatch.NO_PRODUCER_EPOCH, txnMetadata.lastProducerEpoch)
assertEquals(producerEpoch, txnMetadata.producerEpoch)
@ -267,9 +271,9 @@ class TransactionMetadataTest {
assertEquals(time.milliseconds() - 1, txnMetadata.txnLastUpdateTimestamp)
// add another partition, check that in TransactionState.ONGOING state the start timestamp would not change to update time
transitMetadata = txnMetadata.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds() - 2, TV_0)
transitMetadata = txnMetadata.prepareAddPartitions(util.Set.of(new TopicPartition("topic2", 0)), time.milliseconds() - 2, TV_0)
txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(Set[TopicPartition](new TopicPartition("topic1", 0), new TopicPartition("topic2", 0)), txnMetadata.topicPartitions)
assertEquals(util.Set.of(new TopicPartition("topic1", 0), new TopicPartition("topic2", 0)), txnMetadata.topicPartitions)
assertEquals(producerId, txnMetadata.producerId)
assertEquals(RecordBatch.NO_PRODUCER_EPOCH, txnMetadata.lastProducerEpoch)
assertEquals(producerEpoch, txnMetadata.producerEpoch)
@ -281,21 +285,21 @@ class TransactionMetadataTest {
def testTolerateTimeShiftDuringPrepareCommit(): Unit = {
val producerEpoch: Short = 1
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
prevProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = TransactionState.ONGOING,
topicPartitions = mutable.Set.empty,
txnStartTimestamp = 1L,
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
transactionalId,
producerId,
RecordBatch.NO_PRODUCER_ID,
RecordBatch.NO_PRODUCER_ID,
producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH,
30000,
TransactionState.ONGOING,
util.Set.of,
1L,
time.milliseconds(),
TV_0)
// let new time be smaller
val transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_COMMIT, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds() - 1, noPartitionAdded = false)
val transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_COMMIT, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds() - 1, false)
txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(TransactionState.PREPARE_COMMIT, txnMetadata.state)
assertEquals(producerId, txnMetadata.producerId)
@ -309,21 +313,21 @@ class TransactionMetadataTest {
def testTolerateTimeShiftDuringPrepareAbort(): Unit = {
val producerEpoch: Short = 1
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
prevProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = TransactionState.ONGOING,
topicPartitions = mutable.Set.empty,
txnStartTimestamp = 1L,
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
transactionalId,
producerId,
RecordBatch.NO_PRODUCER_ID,
RecordBatch.NO_PRODUCER_ID,
producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH,
30000,
TransactionState.ONGOING,
util.Set.of,
1L,
time.milliseconds(),
TV_0)
// let new time be smaller
val transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_ABORT, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds() - 1, noPartitionAdded = false)
val transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_ABORT, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds() - 1, false)
txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(TransactionState.PREPARE_ABORT, txnMetadata.state)
assertEquals(producerId, txnMetadata.producerId)
@ -340,18 +344,18 @@ class TransactionMetadataTest {
val producerEpoch: Short = 1
val lastProducerEpoch: Short = 0
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
prevProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = lastProducerEpoch,
txnTimeoutMs = 30000,
state = TransactionState.PREPARE_COMMIT,
topicPartitions = mutable.Set.empty,
txnStartTimestamp = 1L,
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = clientTransactionVersion
transactionalId,
producerId,
RecordBatch.NO_PRODUCER_ID,
RecordBatch.NO_PRODUCER_ID,
producerEpoch,
lastProducerEpoch,
30000,
TransactionState.PREPARE_COMMIT,
util.Set.of(),
1L,
time.milliseconds(),
clientTransactionVersion
)
// let new time be smaller
@ -373,18 +377,18 @@ class TransactionMetadataTest {
val producerEpoch: Short = 1
val lastProducerEpoch: Short = 0
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
prevProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = lastProducerEpoch,
txnTimeoutMs = 30000,
state = TransactionState.PREPARE_ABORT,
topicPartitions = mutable.Set.empty,
txnStartTimestamp = 1L,
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = clientTransactionVersion
transactionalId,
producerId,
RecordBatch.NO_PRODUCER_ID,
RecordBatch.NO_PRODUCER_ID,
producerEpoch,
lastProducerEpoch,
30000,
TransactionState.PREPARE_ABORT,
util.Set.of,
1L,
time.milliseconds(),
clientTransactionVersion
)
// let new time be smaller
@ -404,28 +408,29 @@ class TransactionMetadataTest {
val producerEpoch = (Short.MaxValue - 1).toShort
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
prevProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = TransactionState.ONGOING,
topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
transactionalId,
producerId,
RecordBatch.NO_PRODUCER_ID,
RecordBatch.NO_PRODUCER_ID,
producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH,
30000,
TransactionState.ONGOING,
util.Set.of,
-1,
time.milliseconds(),
TV_0)
assertTrue(txnMetadata.isProducerEpochExhausted)
val fencingTransitMetadata = txnMetadata.prepareFenceProducerEpoch()
assertEquals(Short.MaxValue, fencingTransitMetadata.producerEpoch)
assertEquals(RecordBatch.NO_PRODUCER_EPOCH, fencingTransitMetadata.lastProducerEpoch)
assertEquals(Some(TransactionState.PREPARE_EPOCH_FENCE), txnMetadata.pendingState)
assertEquals(Optional.of(TransactionState.PREPARE_EPOCH_FENCE), txnMetadata.pendingState)
// We should reset the pending state to make way for the abort transition.
txnMetadata.pendingState = None
txnMetadata.pendingState(Optional.empty())
val transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_ABORT, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds(), noPartitionAdded = false)
val transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_ABORT, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds(), false)
txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(producerId, transitMetadata.producerId)
}
@ -435,17 +440,18 @@ class TransactionMetadataTest {
val producerEpoch = (Short.MaxValue - 1).toShort
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
prevProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = TransactionState.COMPLETE_COMMIT,
topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
transactionalId,
producerId,
RecordBatch.NO_PRODUCER_ID,
RecordBatch.NO_PRODUCER_ID,
producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH,
30000,
TransactionState.COMPLETE_COMMIT,
util.Set.of,
-1,
time.milliseconds(),
TV_0)
assertTrue(txnMetadata.isProducerEpochExhausted)
assertThrows(classOf[IllegalStateException], () => txnMetadata.prepareFenceProducerEpoch())
@ -456,17 +462,18 @@ class TransactionMetadataTest {
val producerEpoch = (Short.MaxValue - 1).toShort
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
prevProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = TransactionState.COMPLETE_ABORT,
topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
transactionalId,
producerId,
RecordBatch.NO_PRODUCER_ID,
RecordBatch.NO_PRODUCER_ID,
producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH,
30000,
TransactionState.COMPLETE_ABORT,
util.Set.of,
-1,
time.milliseconds(),
TV_0)
assertTrue(txnMetadata.isProducerEpochExhausted)
assertThrows(classOf[IllegalStateException], () => txnMetadata.prepareFenceProducerEpoch())
@ -477,17 +484,18 @@ class TransactionMetadataTest {
val producerEpoch = Short.MaxValue
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
prevProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = TransactionState.ONGOING,
topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
transactionalId,
producerId,
RecordBatch.NO_PRODUCER_ID,
RecordBatch.NO_PRODUCER_ID,
producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH,
30000,
TransactionState.ONGOING,
util.Set.of,
-1,
time.milliseconds(),
TV_0)
assertTrue(txnMetadata.isProducerEpochExhausted)
assertThrows(classOf[IllegalStateException], () => txnMetadata.prepareFenceProducerEpoch())
}
@ -497,20 +505,21 @@ class TransactionMetadataTest {
val producerEpoch = (Short.MaxValue - 1).toShort
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
prevProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = TransactionState.EMPTY,
topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
transactionalId,
producerId,
RecordBatch.NO_PRODUCER_ID,
RecordBatch.NO_PRODUCER_ID,
producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH,
30000,
TransactionState.EMPTY,
util.Set.of,
-1,
time.milliseconds(),
TV_0)
val newProducerId = 9893L
val transitMetadata = txnMetadata.prepareProducerIdRotation(newProducerId, 30000, time.milliseconds(), recordLastEpoch = true)
val transitMetadata = txnMetadata.prepareProducerIdRotation(newProducerId, 30000, time.milliseconds(), true)
txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(newProducerId, txnMetadata.producerId)
assertEquals(producerId, txnMetadata.prevProducerId)
@ -524,20 +533,20 @@ class TransactionMetadataTest {
val producerEpoch = 10.toShort
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
prevProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = TransactionState.ONGOING,
topicPartitions = mutable.Set.empty,
txnStartTimestamp = time.milliseconds(),
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_2)
transactionalId,
producerId,
RecordBatch.NO_PRODUCER_ID,
RecordBatch.NO_PRODUCER_ID,
producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH,
30000,
TransactionState.ONGOING,
util.Set.of,
time.milliseconds(),
time.milliseconds(),
TV_2)
var transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_COMMIT, TV_2, RecordBatch.NO_PRODUCER_ID, time.milliseconds() - 1, noPartitionAdded = false)
var transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_COMMIT, TV_2, RecordBatch.NO_PRODUCER_ID, time.milliseconds() - 1, false)
txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(producerId, txnMetadata.producerId)
assertEquals((producerEpoch + 1).toShort, txnMetadata.producerEpoch)
@ -556,22 +565,22 @@ class TransactionMetadataTest {
val producerEpoch = (Short.MaxValue - 1).toShort
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
prevProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = TransactionState.ONGOING,
topicPartitions = mutable.Set.empty,
txnStartTimestamp = time.milliseconds(),
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_2)
transactionalId,
producerId,
RecordBatch.NO_PRODUCER_ID,
RecordBatch.NO_PRODUCER_ID,
producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH,
30000,
TransactionState.ONGOING,
util.Set.of,
time.milliseconds(),
time.milliseconds(),
TV_2)
assertTrue(txnMetadata.isProducerEpochExhausted)
val newProducerId = 9893L
var transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_COMMIT, TV_2, newProducerId, time.milliseconds() - 1, noPartitionAdded = false)
var transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_COMMIT, TV_2, newProducerId, time.milliseconds() - 1, false)
txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(producerId, txnMetadata.producerId)
assertEquals(Short.MaxValue, txnMetadata.producerEpoch)
@ -610,19 +619,20 @@ class TransactionMetadataTest {
val producerEpoch = 735.toShort
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
prevProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = TransactionState.EMPTY,
topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
transactionalId,
producerId,
RecordBatch.NO_PRODUCER_ID,
RecordBatch.NO_PRODUCER_ID,
RecordBatch.NO_PRODUCER_EPOCH,
RecordBatch.NO_PRODUCER_EPOCH,
30000,
TransactionState.EMPTY,
util.Set.of,
-1,
time.milliseconds(),
TV_0)
val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, Some(producerEpoch))
val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, Optional.of(producerEpoch))
txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(producerId, txnMetadata.producerId)
assertEquals(0, txnMetadata.producerEpoch)
@ -634,19 +644,20 @@ class TransactionMetadataTest {
val producerEpoch = 735.toShort
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
prevProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = TransactionState.EMPTY,
topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
transactionalId,
producerId,
RecordBatch.NO_PRODUCER_ID,
RecordBatch.NO_PRODUCER_ID,
producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH,
30000,
TransactionState.EMPTY,
util.Set.of,
-1,
time.milliseconds(),
TV_0)
val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, Some(producerEpoch))
val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, Optional.of(producerEpoch))
txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(producerId, txnMetadata.producerId)
assertEquals(producerEpoch + 1, txnMetadata.producerEpoch)
@ -659,19 +670,20 @@ class TransactionMetadataTest {
val lastProducerEpoch = (producerEpoch - 1).toShort
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
prevProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = lastProducerEpoch,
txnTimeoutMs = 30000,
state = TransactionState.EMPTY,
topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
transactionalId,
producerId,
RecordBatch.NO_PRODUCER_ID,
RecordBatch.NO_PRODUCER_ID,
producerEpoch,
lastProducerEpoch,
30000,
TransactionState.EMPTY,
util.Set.of,
-1,
time.milliseconds(),
TV_0)
val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, Some(lastProducerEpoch))
val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, Optional.of(lastProducerEpoch))
txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(producerId, txnMetadata.producerId)
assertEquals(producerEpoch, txnMetadata.producerEpoch)
@ -684,21 +696,23 @@ class TransactionMetadataTest {
val lastProducerEpoch = (producerEpoch - 1).toShort
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
prevProducerId = producerId,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = lastProducerEpoch,
txnTimeoutMs = 30000,
state = TransactionState.EMPTY,
topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
transactionalId,
producerId,
producerId,
RecordBatch.NO_PRODUCER_ID,
producerEpoch,
lastProducerEpoch,
30000,
TransactionState.EMPTY,
util.Set.of,
-1,
time.milliseconds(),
TV_0)
val result = txnMetadata.prepareIncrementProducerEpoch(30000, Some((lastProducerEpoch - 1).toShort),
time.milliseconds())
assertEquals(Left(Errors.PRODUCER_FENCED), result)
assertThrows(Errors.PRODUCER_FENCED.exception().getClass, () =>
txnMetadata.prepareIncrementProducerEpoch(30000, Optional.of((lastProducerEpoch - 1).toShort),
time.milliseconds())
)
}
@Test
@ -748,27 +762,26 @@ class TransactionMetadataTest {
val producerEpoch = (Short.MaxValue - 1).toShort
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
prevProducerId = producerId,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = state,
topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = clientTransactionVersion)
transactionalId,
producerId,
producerId,
RecordBatch.NO_PRODUCER_ID,
producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH,
30000,
state,
util.Set.of,
-1,
time.milliseconds(),
clientTransactionVersion)
val newProducerId = 9893L
txnMetadata.prepareProducerIdRotation(newProducerId, 30000, time.milliseconds(), recordLastEpoch = false)
txnMetadata.prepareProducerIdRotation(newProducerId, 30000, time.milliseconds(), false)
}
private def prepareSuccessfulIncrementProducerEpoch(txnMetadata: TransactionMetadata,
expectedProducerEpoch: Option[Short],
expectedProducerEpoch: Optional[java.lang.Short],
now: Option[Long] = None): TxnTransitMetadata = {
val result = txnMetadata.prepareIncrementProducerEpoch(30000, expectedProducerEpoch,
now.getOrElse(time.milliseconds()))
result.getOrElse(throw new AssertionError(s"prepareIncrementProducerEpoch failed with $result"))
txnMetadata.prepareIncrementProducerEpoch(30000, expectedProducerEpoch, now.getOrElse(time.milliseconds()))
}
}

View File

@ -18,7 +18,6 @@ package kafka.coordinator.transaction
import java.lang.management.ManagementFactory
import java.nio.ByteBuffer
import java.util
import java.util.concurrent.{ConcurrentHashMap, CountDownLatch}
import javax.management.ObjectName
import kafka.server.ReplicaManager
@ -33,7 +32,7 @@ import org.apache.kafka.common.record._
import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse
import org.apache.kafka.common.requests.TransactionResult
import org.apache.kafka.common.utils.MockTime
import org.apache.kafka.coordinator.transaction.{TransactionState, TxnTransitMetadata}
import org.apache.kafka.coordinator.transaction.{TransactionMetadata, TransactionState, TxnTransitMetadata}
import org.apache.kafka.metadata.MetadataCache
import org.apache.kafka.server.common.{FinalizedFeatures, MetadataVersion, RequestLocal, TransactionVersion}
import org.apache.kafka.server.common.TransactionVersion.{TV_0, TV_2}
@ -50,6 +49,7 @@ import org.mockito.{ArgumentCaptor, ArgumentMatchers}
import org.mockito.ArgumentMatchers.{any, anyInt, anyLong, anyShort}
import org.mockito.Mockito.{atLeastOnce, mock, reset, times, verify, when}
import java.util
import scala.collection.{Map, mutable}
import scala.jdk.CollectionConverters._
@ -181,8 +181,8 @@ class TransactionStateManagerTest {
).thenReturn(new FetchDataInfo(new LogOffsetMetadata(startOffset), fileRecordsMock))
when(replicaManager.getLogEndOffset(topicPartition)).thenReturn(Some(endOffset))
txnMetadata1.state = TransactionState.PREPARE_COMMIT
txnMetadata1.addPartitions(Set[TopicPartition](
txnMetadata1.state(TransactionState.PREPARE_COMMIT)
txnMetadata1.addPartitions(util.Set.of(
new TopicPartition("topic1", 0),
new TopicPartition("topic1", 1)))
val records = MemoryRecords.withRecords(startOffset, Compression.NONE,
@ -240,8 +240,8 @@ class TransactionStateManagerTest {
).thenReturn(new FetchDataInfo(new LogOffsetMetadata(startOffset), fileRecordsMock))
when(replicaManager.getLogEndOffset(topicPartition)).thenReturn(Some(endOffset))
txnMetadata1.state = TransactionState.PREPARE_COMMIT
txnMetadata1.addPartitions(Set[TopicPartition](
txnMetadata1.state(TransactionState.PREPARE_COMMIT)
txnMetadata1.addPartitions(util.Set.of(
new TopicPartition("topic1", 0),
new TopicPartition("topic1", 1)))
val records = MemoryRecords.withRecords(startOffset, Compression.NONE,
@ -285,44 +285,44 @@ class TransactionStateManagerTest {
// generate transaction log messages for two pids traces:
// pid1's transaction started with two partitions
txnMetadata1.state = TransactionState.ONGOING
txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0),
txnMetadata1.state(TransactionState.ONGOING)
txnMetadata1.addPartitions(util.Set.of(new TopicPartition("topic1", 0),
new TopicPartition("topic1", 1)))
txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2))
// pid1's transaction adds three more partitions
txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic2", 0),
txnMetadata1.addPartitions(util.Set.of(new TopicPartition("topic2", 0),
new TopicPartition("topic2", 1),
new TopicPartition("topic2", 2)))
txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2))
// pid1's transaction is preparing to commit
txnMetadata1.state = TransactionState.PREPARE_COMMIT
txnMetadata1.state(TransactionState.PREPARE_COMMIT)
txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2))
// pid2's transaction started with three partitions
txnMetadata2.state = TransactionState.ONGOING
txnMetadata2.addPartitions(Set[TopicPartition](new TopicPartition("topic3", 0),
txnMetadata2.state(TransactionState.ONGOING)
txnMetadata2.addPartitions(util.Set.of(new TopicPartition("topic3", 0),
new TopicPartition("topic3", 1),
new TopicPartition("topic3", 2)))
txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), TV_2))
// pid2's transaction is preparing to abort
txnMetadata2.state = TransactionState.PREPARE_ABORT
txnMetadata2.state(TransactionState.PREPARE_ABORT)
txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), TV_2))
// pid2's transaction has aborted
txnMetadata2.state = TransactionState.COMPLETE_ABORT
txnMetadata2.state(TransactionState.COMPLETE_ABORT)
txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), TV_2))
// pid2's epoch has advanced, with no ongoing transaction yet
txnMetadata2.state = TransactionState.EMPTY
txnMetadata2.state(TransactionState.EMPTY)
txnMetadata2.topicPartitions.clear()
txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), TV_2))
@ -391,7 +391,7 @@ class TransactionStateManagerTest {
expectedError = Errors.NONE
// update the metadata to ongoing with two partitions
val newMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic1", 0),
val newMetadata = txnMetadata1.prepareAddPartitions(util.Set.of(new TopicPartition("topic1", 0),
new TopicPartition("topic1", 1)), time.milliseconds(), TV_0)
// append the new metadata into log
@ -407,7 +407,7 @@ class TransactionStateManagerTest {
transactionManager.putTransactionStateIfNotExists(txnMetadata1)
expectedError = Errors.COORDINATOR_NOT_AVAILABLE
var failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds(), TV_0)
var failedMetadata = txnMetadata1.prepareAddPartitions(util.Set.of(new TopicPartition("topic2", 0)), time.milliseconds(), TV_0)
prepareForTxnMessageAppend(Errors.UNKNOWN_TOPIC_OR_PARTITION)
val requestLocal = RequestLocal.withThreadConfinedCaching
@ -415,19 +415,19 @@ class TransactionStateManagerTest {
assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))), transactionManager.getTransactionState(transactionalId1))
assertTrue(txnMetadata1.pendingState.isEmpty)
failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds(), TV_0)
failedMetadata = txnMetadata1.prepareAddPartitions(util.Set.of(new TopicPartition("topic2", 0)), time.milliseconds(), TV_0)
prepareForTxnMessageAppend(Errors.NOT_ENOUGH_REPLICAS)
transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback, requestLocal = requestLocal)
assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))), transactionManager.getTransactionState(transactionalId1))
assertTrue(txnMetadata1.pendingState.isEmpty)
failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds(), TV_0)
failedMetadata = txnMetadata1.prepareAddPartitions(util.Set.of(new TopicPartition("topic2", 0)), time.milliseconds(), TV_0)
prepareForTxnMessageAppend(Errors.NOT_ENOUGH_REPLICAS_AFTER_APPEND)
transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback, requestLocal = requestLocal)
assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))), transactionManager.getTransactionState(transactionalId1))
assertTrue(txnMetadata1.pendingState.isEmpty)
failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds(), TV_0)
failedMetadata = txnMetadata1.prepareAddPartitions(util.Set.of(new TopicPartition("topic2", 0)), time.milliseconds(), TV_0)
prepareForTxnMessageAppend(Errors.REQUEST_TIMED_OUT)
transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback, requestLocal = requestLocal)
assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))), transactionManager.getTransactionState(transactionalId1))
@ -440,7 +440,7 @@ class TransactionStateManagerTest {
transactionManager.putTransactionStateIfNotExists(txnMetadata1)
expectedError = Errors.NOT_COORDINATOR
var failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds(), TV_0)
var failedMetadata = txnMetadata1.prepareAddPartitions(util.Set.of(new TopicPartition("topic2", 0)), time.milliseconds(), TV_0)
prepareForTxnMessageAppend(Errors.NOT_LEADER_OR_FOLLOWER)
val requestLocal = RequestLocal.withThreadConfinedCaching
@ -448,7 +448,7 @@ class TransactionStateManagerTest {
assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))), transactionManager.getTransactionState(transactionalId1))
assertTrue(txnMetadata1.pendingState.isEmpty)
failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds(), TV_0)
failedMetadata = txnMetadata1.prepareAddPartitions(util.Set.of(new TopicPartition("topic2", 0)), time.milliseconds(), TV_0)
prepareForTxnMessageAppend(Errors.NONE)
transactionManager.removeTransactionsForTxnTopicPartition(partitionId, coordinatorEpoch)
transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback, requestLocal = requestLocal)
@ -471,7 +471,7 @@ class TransactionStateManagerTest {
transactionManager.putTransactionStateIfNotExists(txnMetadata1)
expectedError = Errors.COORDINATOR_LOAD_IN_PROGRESS
val failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds(), TV_0)
val failedMetadata = txnMetadata1.prepareAddPartitions(util.Set.of(new TopicPartition("topic2", 0)), time.milliseconds(), TV_0)
prepareForTxnMessageAppend(Errors.NONE)
transactionManager.removeTransactionsForTxnTopicPartition(partitionId, coordinatorEpoch)
@ -485,7 +485,7 @@ class TransactionStateManagerTest {
transactionManager.putTransactionStateIfNotExists(txnMetadata1)
expectedError = Errors.UNKNOWN_SERVER_ERROR
var failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds(), TV_0)
var failedMetadata = txnMetadata1.prepareAddPartitions(util.Set.of(new TopicPartition("topic2", 0)), time.milliseconds(), TV_0)
prepareForTxnMessageAppend(Errors.MESSAGE_TOO_LARGE)
val requestLocal = RequestLocal.withThreadConfinedCaching
@ -493,7 +493,7 @@ class TransactionStateManagerTest {
assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))), transactionManager.getTransactionState(transactionalId1))
assertTrue(txnMetadata1.pendingState.isEmpty)
failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds(), TV_0)
failedMetadata = txnMetadata1.prepareAddPartitions(util.Set.of(new TopicPartition("topic2", 0)), time.milliseconds(), TV_0)
prepareForTxnMessageAppend(Errors.RECORD_LIST_TOO_LARGE)
transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback, requestLocal = requestLocal)
assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))), transactionManager.getTransactionState(transactionalId1))
@ -506,12 +506,12 @@ class TransactionStateManagerTest {
transactionManager.putTransactionStateIfNotExists(txnMetadata1)
expectedError = Errors.COORDINATOR_NOT_AVAILABLE
val failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds(), TV_0)
val failedMetadata = txnMetadata1.prepareAddPartitions(util.Set.of(new TopicPartition("topic2", 0)), time.milliseconds(), TV_0)
prepareForTxnMessageAppend(Errors.UNKNOWN_TOPIC_OR_PARTITION)
transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback, _ => true, RequestLocal.withThreadConfinedCaching)
assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))), transactionManager.getTransactionState(transactionalId1))
assertEquals(Some(TransactionState.ONGOING), txnMetadata1.pendingState)
assertEquals(util.Optional.of(TransactionState.ONGOING), txnMetadata1.pendingState)
}
@Test
@ -524,11 +524,11 @@ class TransactionStateManagerTest {
prepareForTxnMessageAppend(Errors.NONE)
expectedError = Errors.NOT_COORDINATOR
val newMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic1", 0),
val newMetadata = txnMetadata1.prepareAddPartitions(util.Set.of(new TopicPartition("topic1", 0),
new TopicPartition("topic1", 1)), time.milliseconds(), TV_0)
// modify the cache while trying to append the new metadata
txnMetadata1.producerEpoch = (txnMetadata1.producerEpoch + 1).toShort
txnMetadata1.setProducerEpoch((txnMetadata1.producerEpoch + 1).toShort)
// append the new metadata into log
transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, newMetadata, assertCallback, requestLocal = RequestLocal.withThreadConfinedCaching)
@ -543,11 +543,11 @@ class TransactionStateManagerTest {
prepareForTxnMessageAppend(Errors.NONE)
expectedError = Errors.INVALID_PRODUCER_EPOCH
val newMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic1", 0),
val newMetadata = txnMetadata1.prepareAddPartitions(util.Set.of(new TopicPartition("topic1", 0),
new TopicPartition("topic1", 1)), time.milliseconds(), TV_0)
// modify the cache while trying to append the new metadata
txnMetadata1.pendingState = None
txnMetadata1.pendingState(util.Optional.empty())
// append the new metadata into log
assertThrows(classOf[IllegalStateException], () => transactionManager.appendTransactionToLog(transactionalId1,
@ -876,7 +876,7 @@ class TransactionStateManagerTest {
// will be expired and it should succeed.
val timestamp = time.milliseconds()
val txnMetadata = new TransactionMetadata(transactionalId, 1, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH,
RecordBatch.NO_PRODUCER_EPOCH, transactionTimeoutMs, TransactionState.EMPTY, collection.mutable.Set.empty[TopicPartition], timestamp, timestamp, TV_0)
RecordBatch.NO_PRODUCER_EPOCH, transactionTimeoutMs, TransactionState.EMPTY, util.Set.of, timestamp, timestamp, TV_0)
transactionManager.putTransactionStateIfNotExists(txnMetadata)
time.sleep(txnConfig.transactionalIdExpirationMs + 1)
@ -934,7 +934,7 @@ class TransactionStateManagerTest {
val txnlId = s"id_$i"
val producerId = i
val txnMetadata = transactionMetadata(txnlId, producerId)
txnMetadata.txnLastUpdateTimestamp = time.milliseconds() - txnConfig.transactionalIdExpirationMs
txnMetadata.txnLastUpdateTimestamp(time.milliseconds() - txnConfig.transactionalIdExpirationMs)
transactionManager.putTransactionStateIfNotExists(txnMetadata)
allTransactionalIds += txnlId
}
@ -962,8 +962,8 @@ class TransactionStateManagerTest {
@Test
def testSuccessfulReimmigration(): Unit = {
txnMetadata1.state = TransactionState.PREPARE_COMMIT
txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0),
txnMetadata1.state(TransactionState.PREPARE_COMMIT)
txnMetadata1.addPartitions(util.Set.of(new TopicPartition("topic1", 0),
new TopicPartition("topic1", 1)))
txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2))
@ -1029,10 +1029,10 @@ class TransactionStateManagerTest {
@Test
def testLoadTransactionMetadataContainingSegmentEndingWithEmptyBatch(): Unit = {
// Simulate a case where a log contains two segments and the first segment ending with an empty batch.
txnMetadata1.state = TransactionState.PREPARE_COMMIT
txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0)))
txnMetadata2.state = TransactionState.ONGOING
txnMetadata2.addPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)))
txnMetadata1.state(TransactionState.PREPARE_COMMIT)
txnMetadata1.addPartitions(util.Set.of(new TopicPartition("topic1", 0)))
txnMetadata2.state(TransactionState.ONGOING)
txnMetadata2.addPartitions(util.Set.of(new TopicPartition("topic2", 0)))
// Create the first segment which contains two batches.
// The first batch has one transactional record
@ -1158,11 +1158,11 @@ class TransactionStateManagerTest {
loadTransactionsForPartitions(partitionIds)
expectLogConfig(partitionIds, ServerLogConfigs.MAX_MESSAGE_BYTES_DEFAULT)
txnMetadata1.txnLastUpdateTimestamp = time.milliseconds() - txnConfig.transactionalIdExpirationMs
txnMetadata1.state = txnState
txnMetadata1.txnLastUpdateTimestamp(time.milliseconds() - txnConfig.transactionalIdExpirationMs)
txnMetadata1.state(txnState)
transactionManager.putTransactionStateIfNotExists(txnMetadata1)
txnMetadata2.txnLastUpdateTimestamp = time.milliseconds()
txnMetadata2.txnLastUpdateTimestamp(time.milliseconds())
transactionManager.putTransactionStateIfNotExists(txnMetadata2)
val appendedRecords = mutable.Map.empty[TopicIdPartition, mutable.Buffer[MemoryRecords]]
@ -1188,8 +1188,8 @@ class TransactionStateManagerTest {
}
private def verifyWritesTxnMarkersInPrepareState(state: TransactionState): Unit = {
txnMetadata1.state = state
txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0),
txnMetadata1.state(state)
txnMetadata1.addPartitions(util.Set.of(new TopicPartition("topic1", 0),
new TopicPartition("topic1", 1)))
txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2))
@ -1222,7 +1222,7 @@ class TransactionStateManagerTest {
txnTimeout: Int = transactionTimeoutMs): TransactionMetadata = {
val timestamp = time.milliseconds()
new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_ID, 0.toShort,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeout, state, collection.mutable.Set.empty[TopicPartition], timestamp, timestamp, TV_0)
RecordBatch.NO_PRODUCER_EPOCH, txnTimeout, state, util.Set.of, timestamp, timestamp, TV_0)
}
private def prepareTxnLog(topicPartition: TopicPartition,
@ -1294,8 +1294,8 @@ class TransactionStateManagerTest {
assertEquals(Double.NaN, partitionLoadTime("partition-load-time-avg"), 0)
assertTrue(reporter.containsMbean(mBeanName))
txnMetadata1.state = TransactionState.ONGOING
txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 1),
txnMetadata1.state(TransactionState.ONGOING)
txnMetadata1.addPartitions(util.List.of(new TopicPartition("topic1", 1),
new TopicPartition("topic1", 1)))
txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2))
@ -1313,8 +1313,8 @@ class TransactionStateManagerTest {
@Test
def testIgnoreUnknownRecordType(): Unit = {
txnMetadata1.state = TransactionState.PREPARE_COMMIT
txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0),
txnMetadata1.state(TransactionState.PREPARE_COMMIT)
txnMetadata1.addPartitions(util.Set.of(new TopicPartition("topic1", 0),
new TopicPartition("topic1", 1)))
txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2))

View File

@ -1892,7 +1892,7 @@ class KafkaApisTest extends Logging {
ArgumentMatchers.eq(transactionalId),
ArgumentMatchers.eq(producerId),
ArgumentMatchers.eq(epoch),
ArgumentMatchers.eq(Set(new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, partition))),
ArgumentMatchers.eq(util.Set.of(new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, partition))),
responseCallback.capture(),
ArgumentMatchers.eq(TransactionVersion.TV_0),
ArgumentMatchers.eq(requestLocal)
@ -1951,7 +1951,7 @@ class KafkaApisTest extends Logging {
ArgumentMatchers.eq(transactionalId),
ArgumentMatchers.eq(producerId),
ArgumentMatchers.eq(epoch),
ArgumentMatchers.eq(Set(topicPartition)),
ArgumentMatchers.eq(util.Set.of(topicPartition)),
responseCallback.capture(),
ArgumentMatchers.eq(TransactionVersion.TV_0),
ArgumentMatchers.eq(requestLocal)
@ -2157,7 +2157,7 @@ class KafkaApisTest extends Logging {
ArgumentMatchers.eq(transactionalId1),
ArgumentMatchers.eq(producerId),
ArgumentMatchers.eq(epoch),
ArgumentMatchers.eq(Set(tp0)),
ArgumentMatchers.eq(util.Set.of(tp0)),
responseCallback.capture(),
any[TransactionVersion],
ArgumentMatchers.eq(requestLocal)
@ -2167,7 +2167,7 @@ class KafkaApisTest extends Logging {
ArgumentMatchers.eq(transactionalId2),
ArgumentMatchers.eq(producerId),
ArgumentMatchers.eq(epoch),
ArgumentMatchers.eq(Set(tp1)),
ArgumentMatchers.eq(util.Set.of(tp1)),
verifyPartitionsCallback.capture(),
)).thenAnswer(_ => verifyPartitionsCallback.getValue.apply(AddPartitionsToTxnResponse.resultForTransaction(transactionalId2, util.Map.of(tp1, Errors.PRODUCER_FENCED))))
kafkaApis = createKafkaApis()

View File

@ -573,7 +573,7 @@ For a detailed description of spotbugs bug categories, see https://spotbugs.read
<Class name="org.apache.kafka.streams.state.internals.ThreadCache"/>
<Class name="org.apache.kafka.connect.runtime.WorkerSinkTask"/>
<Class name="org.apache.kafka.tools.VerifiableProducer"/>
<Class name="kafka.coordinator.transaction.TransactionMetadata"/>
<Class name="org.apache.kafka.coordinator.transaction.TransactionMetadata"/>
<Class name="org.apache.kafka.tools.VerifiableShareConsumer"/>
<Class name="org.apache.kafka.server.quota.ClientQuotaManager"/>
<Class name="kafka.log.LogManager"/>
@ -606,7 +606,7 @@ For a detailed description of spotbugs bug categories, see https://spotbugs.read
<Class name="org.apache.kafka.streams.state.internals.InMemoryTimeOrderedKeyValueChangeBuffer"/>
<Class name="org.apache.kafka.connect.runtime.WorkerSinkTask"/>
<Class name="org.apache.kafka.tools.ConsumerPerformance$ConsumerPerfRebListener"/>
<Class name="kafka.coordinator.transaction.TransactionMetadata"/>
<Class name="org.apache.kafka.coordinator.transaction.TransactionMetadata"/>
<Class name="kafka.server.BrokerLifecycleManager"/>
<Class name="kafka.server.CachedPartition"/>
<Class name="kafka.server.ControllerRegistrationManager"/>
@ -670,7 +670,7 @@ For a detailed description of spotbugs bug categories, see https://spotbugs.read
<Class name="org.apache.kafka.connect.runtime.distributed.WorkerGroupMember"/>
<Class name="org.apache.kafka.connect.util.KafkaBasedLog"/>
<Class name="org.apache.kafka.tools.VerifiableProducer"/>
<Class name="kafka.coordinator.transaction.TransactionMetadata"/>
<Class name="org.apache.kafka.coordinator.transaction.TransactionMetadata"/>
<Class name="kafka.network.Acceptor"/>
<Class name="kafka.network.Processor"/>
<Class name="kafka.server.BrokerLifecycleManager"/>

View File

@ -0,0 +1,662 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.coordinator.transaction;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.config.LogLevelConfig;
import org.apache.kafka.common.protocol.Errors;
import org.apache.kafka.common.record.RecordBatch;
import org.apache.kafka.server.common.TransactionVersion;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.MarkerFactory;
import java.util.Collection;
import java.util.HashSet;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Supplier;
public class TransactionMetadata {
private static final Logger LOGGER = LoggerFactory.getLogger(TransactionMetadata.class);
private final String transactionalId;
private long producerId;
private long prevProducerId;
private long nextProducerId;
private short producerEpoch;
private short lastProducerEpoch;
private int txnTimeoutMs;
private TransactionState state;
// The topicPartitions is mutable, so using HashSet, instead of Set.
private HashSet<TopicPartition> topicPartitions;
private volatile long txnStartTimestamp;
private volatile long txnLastUpdateTimestamp;
private TransactionVersion clientTransactionVersion;
// pending state is used to indicate the state that this transaction is going to
// transit to, and for blocking future attempts to transit it again if it is not legal;
// initialized as the same as the current state
private Optional<TransactionState> pendingState;
// Indicates that during a previous attempt to fence a producer, the bumped epoch may not have been
// successfully written to the log. If this is true, we will not bump the epoch again when fencing
private boolean hasFailedEpochFence;
private final ReentrantLock lock;
public static boolean isEpochExhausted(short producerEpoch) {
return producerEpoch >= Short.MAX_VALUE - 1;
}
/**
* @param transactionalId transactional id
* @param producerId producer id
* @param prevProducerId producer id for the last committed transaction with this transactional ID
* @param nextProducerId Latest producer ID sent to the producer for the given transactional ID
* @param producerEpoch current epoch of the producer
* @param lastProducerEpoch last epoch of the producer
* @param txnTimeoutMs timeout to be used to abort long running transactions
* @param state current state of the transaction
* @param topicPartitions current set of partitions that are part of this transaction
* @param txnStartTimestamp time the transaction was started, i.e., when first partition is added
* @param txnLastUpdateTimestamp updated when any operation updates the TransactionMetadata. To be used for expiration
* @param clientTransactionVersion TransactionVersion used by the client when the state was transitioned
*/
public TransactionMetadata(String transactionalId,
long producerId,
long prevProducerId,
long nextProducerId,
short producerEpoch,
short lastProducerEpoch,
int txnTimeoutMs,
TransactionState state,
Set<TopicPartition> topicPartitions,
long txnStartTimestamp,
long txnLastUpdateTimestamp,
TransactionVersion clientTransactionVersion) {
this.transactionalId = transactionalId;
this.producerId = producerId;
this.prevProducerId = prevProducerId;
this.nextProducerId = nextProducerId;
this.producerEpoch = producerEpoch;
this.lastProducerEpoch = lastProducerEpoch;
this.txnTimeoutMs = txnTimeoutMs;
this.state = state;
this.topicPartitions = new HashSet<>(topicPartitions);
this.txnStartTimestamp = txnStartTimestamp;
this.txnLastUpdateTimestamp = txnLastUpdateTimestamp;
this.clientTransactionVersion = clientTransactionVersion;
this.pendingState = Optional.empty();
this.hasFailedEpochFence = false;
this.lock = new ReentrantLock();
}
public <T> T inLock(Supplier<T> function) {
lock.lock();
try {
return function.get();
} finally {
lock.unlock();
}
}
public void addPartitions(Collection<TopicPartition> partitions) {
topicPartitions.addAll(partitions);
}
public void removePartition(TopicPartition topicPartition) {
if (state != TransactionState.PREPARE_COMMIT && state != TransactionState.PREPARE_ABORT)
throw new IllegalStateException("Transaction metadata's current state is " + state + ", and its pending state is " +
pendingState + " while trying to remove partitions whose txn marker has been sent, this is not expected");
topicPartitions.remove(topicPartition);
}
// this is visible for test only
public TxnTransitMetadata prepareNoTransit() {
// do not call transitTo as it will set the pending state, a follow-up call to abort the transaction will set its pending state
return new TxnTransitMetadata(producerId, prevProducerId, nextProducerId, producerEpoch, lastProducerEpoch, txnTimeoutMs,
state, new HashSet<>(topicPartitions), txnStartTimestamp, txnLastUpdateTimestamp, clientTransactionVersion);
}
public TxnTransitMetadata prepareFenceProducerEpoch() {
if (producerEpoch == Short.MAX_VALUE)
throw new IllegalStateException("Cannot fence producer with epoch equal to Short.MaxValue since this would overflow");
// If we've already failed to fence an epoch (because the write to the log failed), we don't increase it again.
// This is safe because we never return the epoch to client if we fail to fence the epoch
short bumpedEpoch = hasFailedEpochFence ? producerEpoch : (short) (producerEpoch + 1);
TransitionData data = new TransitionData(TransactionState.PREPARE_EPOCH_FENCE);
data.producerEpoch = bumpedEpoch;
return prepareTransitionTo(data);
}
public TxnTransitMetadata prepareIncrementProducerEpoch(
int newTxnTimeoutMs,
Optional<Short> expectedProducerEpoch,
long updateTimestamp) {
if (isProducerEpochExhausted())
throw new IllegalStateException("Cannot allocate any more producer epochs for producerId " + producerId);
TransitionData data = new TransitionData(TransactionState.EMPTY);
short bumpedEpoch = (short) (producerEpoch + 1);
if (expectedProducerEpoch.isEmpty()) {
// If no expected epoch was provided by the producer, bump the current epoch and set the last epoch to -1
// In the case of a new producer, producerEpoch will be -1 and bumpedEpoch will be 0
data.producerEpoch = bumpedEpoch;
data.lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH;
} else if (producerEpoch == RecordBatch.NO_PRODUCER_EPOCH || expectedProducerEpoch.get() == producerEpoch) {
// If the expected epoch matches the current epoch, or if there is no current epoch, the producer is attempting
// to continue after an error and no other producer has been initialized. Bump the current and last epochs.
// The no current epoch case means this is a new producer; producerEpoch will be -1 and bumpedEpoch will be 0
data.producerEpoch = bumpedEpoch;
data.lastProducerEpoch = producerEpoch;
} else if (expectedProducerEpoch.get() == lastProducerEpoch) {
// If the expected epoch matches the previous epoch, it is a retry of a successful call, so just return the
// current epoch without bumping. There is no danger of this producer being fenced, because a new producer
// calling InitProducerId would have caused the last epoch to be set to -1.
// Note that if the IBP is prior to 2.4.IV1, the lastProducerId and lastProducerEpoch will not be written to
// the transaction log, so a retry that spans a coordinator change will fail. We expect this to be a rare case.
data.producerEpoch = producerEpoch;
data.lastProducerEpoch = lastProducerEpoch;
} else {
// Otherwise, the producer has a fenced epoch and should receive an PRODUCER_FENCED error
LOGGER.info("Expected producer epoch {} does not match current producer epoch {} or previous producer epoch {}",
expectedProducerEpoch.get(), producerEpoch, lastProducerEpoch);
throw Errors.PRODUCER_FENCED.exception();
}
data.txnTimeoutMs = newTxnTimeoutMs;
data.topicPartitions = new HashSet<>();
data.txnStartTimestamp = -1L;
data.txnLastUpdateTimestamp = updateTimestamp;
return prepareTransitionTo(data);
}
public TxnTransitMetadata prepareProducerIdRotation(long newProducerId,
int newTxnTimeoutMs,
long updateTimestamp,
boolean recordLastEpoch) {
if (hasPendingTransaction())
throw new IllegalStateException("Cannot rotate producer ids while a transaction is still pending");
TransitionData data = new TransitionData(TransactionState.EMPTY);
data.producerId = newProducerId;
data.producerEpoch = 0;
data.lastProducerEpoch = recordLastEpoch ? producerEpoch : RecordBatch.NO_PRODUCER_EPOCH;
data.txnTimeoutMs = newTxnTimeoutMs;
data.topicPartitions = new HashSet<>();
data.txnStartTimestamp = -1L;
data.txnLastUpdateTimestamp = updateTimestamp;
return prepareTransitionTo(data);
}
public TxnTransitMetadata prepareAddPartitions(Set<TopicPartition> addedTopicPartitions,
long updateTimestamp,
TransactionVersion clientTransactionVersion) {
long newTxnStartTimestamp;
if (state == TransactionState.EMPTY || state == TransactionState.COMPLETE_ABORT || state == TransactionState.COMPLETE_COMMIT) {
newTxnStartTimestamp = updateTimestamp;
} else {
newTxnStartTimestamp = txnStartTimestamp;
}
HashSet<TopicPartition> newTopicPartitions = new HashSet<>(topicPartitions);
newTopicPartitions.addAll(addedTopicPartitions);
TransitionData data = new TransitionData(TransactionState.ONGOING);
data.topicPartitions = newTopicPartitions;
data.txnStartTimestamp = newTxnStartTimestamp;
data.txnLastUpdateTimestamp = updateTimestamp;
data.clientTransactionVersion = clientTransactionVersion;
return prepareTransitionTo(data);
}
public TxnTransitMetadata prepareAbortOrCommit(TransactionState newState,
TransactionVersion clientTransactionVersion,
long nextProducerId,
long updateTimestamp,
boolean noPartitionAdded) {
TransitionData data = new TransitionData(newState);
if (clientTransactionVersion.supportsEpochBump()) {
// We already ensured that we do not overflow here. MAX_SHORT is the highest possible value.
data.producerEpoch = (short) (producerEpoch + 1);
data.lastProducerEpoch = producerEpoch;
} else {
data.producerEpoch = producerEpoch;
data.lastProducerEpoch = lastProducerEpoch;
}
// With transaction V2, it is allowed to abort the transaction without adding any partitions. Then, the transaction
// start time is uncertain but it is still required. So we can use the update time as the transaction start time.
data.txnStartTimestamp = noPartitionAdded ? updateTimestamp : txnStartTimestamp;
data.nextProducerId = nextProducerId;
data.txnLastUpdateTimestamp = updateTimestamp;
data.clientTransactionVersion = clientTransactionVersion;
return prepareTransitionTo(data);
}
public TxnTransitMetadata prepareComplete(long updateTimestamp) {
// Since the state change was successfully written to the log, unset the flag for a failed epoch fence
hasFailedEpochFence = false;
TransitionData data = new TransitionData(state == TransactionState.PREPARE_COMMIT ?
TransactionState.COMPLETE_COMMIT : TransactionState.COMPLETE_ABORT);
// In the prepareComplete transition for the overflow case, the lastProducerEpoch is kept at MAX-1,
// which is the last epoch visible to the client.
// Internally, however, during the transition between prepareAbort/prepareCommit and prepareComplete, the producer epoch
// reaches MAX but the client only sees the transition as MAX-1 followed by 0.
// When an epoch overflow occurs, we set the producerId to nextProducerId and reset the epoch to 0,
// but lastProducerEpoch remains MAX-1 to maintain consistency with what the client last saw.
if (clientTransactionVersion.supportsEpochBump() && nextProducerId != RecordBatch.NO_PRODUCER_ID) {
data.producerId = nextProducerId;
data.producerEpoch = 0;
} else {
data.producerId = producerId;
data.producerEpoch = producerEpoch;
}
data.nextProducerId = RecordBatch.NO_PRODUCER_ID;
data.topicPartitions = new HashSet<>();
data.txnLastUpdateTimestamp = updateTimestamp;
return prepareTransitionTo(data);
}
public TxnTransitMetadata prepareDead() {
TransitionData data = new TransitionData(TransactionState.DEAD);
data.topicPartitions = new HashSet<>();
return prepareTransitionTo(data);
}
/**
* Check if the epochs have been exhausted for the current producerId. We do not allow the client to use an
* epoch equal to Short.MaxValue to ensure that the coordinator will always be able to fence an existing producer.
*/
public boolean isProducerEpochExhausted() {
return isEpochExhausted(producerEpoch);
}
/**
* Check if this is a distributed two phase commit transaction.
* Such transactions have no timeout (identified by maximum value for timeout).
*/
public boolean isDistributedTwoPhaseCommitTxn() {
return txnTimeoutMs == Integer.MAX_VALUE;
}
private boolean hasPendingTransaction() {
return state == TransactionState.ONGOING ||
state == TransactionState.PREPARE_ABORT ||
state == TransactionState.PREPARE_COMMIT;
}
private TxnTransitMetadata prepareTransitionTo(TransitionData data) {
if (pendingState.isPresent())
throw new IllegalStateException("Preparing transaction state transition to " + state +
" while it already a pending state " + pendingState.get());
if (data.producerId < 0)
throw new IllegalArgumentException("Illegal new producer id " + data.producerId);
// The epoch is initialized to NO_PRODUCER_EPOCH when the TransactionMetadata
// is created for the first time and it could stay like this until transitioning
// to Dead.
if (data.state != TransactionState.DEAD && data.producerEpoch < 0)
throw new IllegalArgumentException("Illegal new producer epoch " + data.producerEpoch);
// check that the new state transition is valid and update the pending state if necessary
if (data.state.validPreviousStates().contains(this.state)) {
TxnTransitMetadata transitMetadata = new TxnTransitMetadata(
data.producerId, this.producerId, data.nextProducerId, data.producerEpoch, data.lastProducerEpoch,
data.txnTimeoutMs, data.state, data.topicPartitions,
data.txnStartTimestamp, data.txnLastUpdateTimestamp, data.clientTransactionVersion
);
LOGGER.debug("TransactionalId {} prepare transition from {} to {}", transactionalId, this.state, data.state);
pendingState = Optional.of(data.state);
return transitMetadata;
}
throw new IllegalStateException("Preparing transaction state transition to " + data.state + " failed since the target state " +
data.state + " is not a valid previous state of the current state " + this.state);
}
@SuppressWarnings("CyclomaticComplexity")
public void completeTransitionTo(TxnTransitMetadata transitMetadata) {
// metadata transition is valid only if all the following conditions are met:
//
// 1. the new state is already indicated in the pending state.
// 2. the epoch should be either the same value, the old value + 1, or 0 if we have a new producerId.
// 3. the last update time is no smaller than the old value.
// 4. the old partitions set is a subset of the new partitions set.
//
// plus, we should only try to update the metadata after the corresponding log entry has been successfully
// written and replicated (see TransactionStateManager#appendTransactionToLog)
//
// if valid, transition is done via overwriting the whole object to ensure synchronization
TransactionState toState = pendingState.orElseThrow(() -> {
LOGGER.error(MarkerFactory.getMarker(LogLevelConfig.FATAL_LOG_LEVEL),
"{}'s transition to {} failed since pendingState is not defined: this should not happen", this, transitMetadata);
return new IllegalStateException("TransactionalId " + transactionalId +
" completing transaction state transition while it does not have a pending state");
});
if (!toState.equals(transitMetadata.txnState())) throwStateTransitionFailure(transitMetadata);
switch (toState) {
case EMPTY: // from initPid
if ((producerEpoch != transitMetadata.producerEpoch() && !validProducerEpochBump(transitMetadata)) ||
!transitMetadata.topicPartitions().isEmpty() ||
transitMetadata.txnStartTimestamp() != -1) {
throwStateTransitionFailure(transitMetadata);
}
break;
case ONGOING: // from addPartitions
if (!validProducerEpoch(transitMetadata) ||
!transitMetadata.topicPartitions().containsAll(topicPartitions) ||
txnTimeoutMs != transitMetadata.txnTimeoutMs()) {
throwStateTransitionFailure(transitMetadata);
}
break;
case PREPARE_ABORT: // from endTxn
case PREPARE_COMMIT:
// In V2, we allow state transits from Empty, CompleteCommit and CompleteAbort to PrepareAbort. It is possible
// their updated start time is not equal to the current start time.
boolean allowedEmptyAbort = toState == TransactionState.PREPARE_ABORT && transitMetadata.clientTransactionVersion().supportsEpochBump() &&
(state == TransactionState.EMPTY || state == TransactionState.COMPLETE_COMMIT || state == TransactionState.COMPLETE_ABORT);
boolean validTimestamp = txnStartTimestamp == transitMetadata.txnStartTimestamp() || allowedEmptyAbort;
if (!validProducerEpoch(transitMetadata) ||
!topicPartitions.equals(transitMetadata.topicPartitions()) ||
txnTimeoutMs != transitMetadata.txnTimeoutMs() ||
!validTimestamp) {
throwStateTransitionFailure(transitMetadata);
}
break;
case COMPLETE_ABORT: // from write markers
case COMPLETE_COMMIT:
if (!validProducerEpoch(transitMetadata) ||
txnTimeoutMs != transitMetadata.txnTimeoutMs() ||
transitMetadata.txnStartTimestamp() == -1) {
throwStateTransitionFailure(transitMetadata);
}
break;
case PREPARE_EPOCH_FENCE:
// We should never get here, since once we prepare to fence the epoch, we immediately set the pending state
// to PrepareAbort, and then consequently to CompleteAbort after the markers are written.. So we should never
// ever try to complete a transition to PrepareEpochFence, as it is not a valid previous state for any other state, and hence
// can never be transitioned out of.
throwStateTransitionFailure(transitMetadata);
break;
case DEAD:
// The transactionalId was being expired. The completion of the operation should result in removal of the
// the metadata from the cache, so we should never realistically transition to the dead state.
throw new IllegalStateException("TransactionalId " + transactionalId + " is trying to complete a transition to " +
toState + ". This means that the transactionalId was being expired, and the only acceptable completion of " +
"this operation is to remove the transaction metadata from the cache, not to persist the " + toState + " in the log.");
default:
break;
}
LOGGER.debug("TransactionalId {} complete transition from {} to {}", transactionalId, state, transitMetadata);
producerId = transitMetadata.producerId();
prevProducerId = transitMetadata.prevProducerId();
nextProducerId = transitMetadata.nextProducerId();
producerEpoch = transitMetadata.producerEpoch();
lastProducerEpoch = transitMetadata.lastProducerEpoch();
txnTimeoutMs = transitMetadata.txnTimeoutMs();
topicPartitions = transitMetadata.topicPartitions();
txnStartTimestamp = transitMetadata.txnStartTimestamp();
txnLastUpdateTimestamp = transitMetadata.txnLastUpdateTimestamp();
clientTransactionVersion = transitMetadata.clientTransactionVersion();
pendingState = Optional.empty();
state = toState;
}
/**
* Validates the producer epoch and ID based on transaction state and version.
* <p>
* Logic:
* * 1. **Overflow Case in Transactions V2:**
* * - During overflow (epoch reset to 0), we compare both `lastProducerEpoch` values since it
* * does not change during completion.
* * - For PrepareComplete, the producer ID has been updated. We ensure that the `prevProducerID`
* * in the transit metadata matches the current producer ID, confirming the change.
* *
* * 2. **Epoch Bump Case in Transactions V2:**
* * - For PrepareCommit or PrepareAbort, the producer epoch has been bumped. We ensure the `lastProducerEpoch`
* * in transit metadata matches the current producer epoch, confirming the bump.
* * - We also verify that the producer ID remains the same.
* *
* * 3. **Other Cases:**
* * - For other states and versions, check if the producer epoch and ID match the current values.
*
* @param transitMetadata The transaction transition metadata containing state, producer epoch, and ID.
* @return true if the producer epoch and ID are valid; false otherwise.
*/
private boolean validProducerEpoch(TxnTransitMetadata transitMetadata) {
boolean isAtLeastTransactionsV2 = transitMetadata.clientTransactionVersion().supportsEpochBump();
TransactionState txnState = transitMetadata.txnState();
short transitProducerEpoch = transitMetadata.producerEpoch();
long transitProducerId = transitMetadata.producerId();
short transitLastProducerEpoch = transitMetadata.lastProducerEpoch();
if (isAtLeastTransactionsV2 &&
(txnState == TransactionState.COMPLETE_COMMIT || txnState == TransactionState.COMPLETE_ABORT) &&
transitProducerEpoch == 0) {
return transitLastProducerEpoch == lastProducerEpoch && transitMetadata.prevProducerId() == producerId;
}
if (isAtLeastTransactionsV2 &&
(txnState == TransactionState.PREPARE_COMMIT || txnState == TransactionState.PREPARE_ABORT)) {
return transitLastProducerEpoch == producerEpoch && transitProducerId == producerId;
}
return transitProducerEpoch == producerEpoch && transitProducerId == producerId;
}
private boolean validProducerEpochBump(TxnTransitMetadata transitMetadata) {
short transitEpoch = transitMetadata.producerEpoch();
long transitProducerId = transitMetadata.producerId();
return transitEpoch == (short) (producerEpoch + 1) || (transitEpoch == 0 && transitProducerId != producerId);
}
private void throwStateTransitionFailure(TxnTransitMetadata txnTransitMetadata) {
LOGGER.error(MarkerFactory.getMarker(LogLevelConfig.FATAL_LOG_LEVEL),
"{}'s transition to {} failed: this should not happen", this, txnTransitMetadata);
throw new IllegalStateException("TransactionalId " + transactionalId + " failed transition to state " + txnTransitMetadata +
" due to unexpected metadata");
}
public boolean pendingTransitionInProgress() {
return pendingState.isPresent();
}
public String transactionalId() {
return transactionalId;
}
public void setProducerId(long producerId) {
this.producerId = producerId;
}
public long producerId() {
return producerId;
}
public void setPrevProducerId(long prevProducerId) {
this.prevProducerId = prevProducerId;
}
public long prevProducerId() {
return prevProducerId;
}
public void setProducerEpoch(short producerEpoch) {
this.producerEpoch = producerEpoch;
}
public short producerEpoch() {
return producerEpoch;
}
public void setLastProducerEpoch(short lastProducerEpoch) {
this.lastProducerEpoch = lastProducerEpoch;
}
public short lastProducerEpoch() {
return lastProducerEpoch;
}
public int txnTimeoutMs() {
return txnTimeoutMs;
}
public void state(TransactionState state) {
this.state = state;
}
public TransactionState state() {
return state;
}
public Set<TopicPartition> topicPartitions() {
return topicPartitions;
}
public long txnStartTimestamp() {
return txnStartTimestamp;
}
public void txnLastUpdateTimestamp(long txnLastUpdateTimestamp) {
this.txnLastUpdateTimestamp = txnLastUpdateTimestamp;
}
public long txnLastUpdateTimestamp() {
return txnLastUpdateTimestamp;
}
public void clientTransactionVersion(TransactionVersion clientTransactionVersion) {
this.clientTransactionVersion = clientTransactionVersion;
}
public TransactionVersion clientTransactionVersion() {
return clientTransactionVersion;
}
public void pendingState(Optional<TransactionState> pendingState) {
this.pendingState = pendingState;
}
public Optional<TransactionState> pendingState() {
return pendingState;
}
public void hasFailedEpochFence(boolean hasFailedEpochFence) {
this.hasFailedEpochFence = hasFailedEpochFence;
}
public boolean hasFailedEpochFence() {
return hasFailedEpochFence;
}
@Override
public String toString() {
return "TransactionMetadata(" +
"transactionalId=" + transactionalId +
", producerId=" + producerId +
", prevProducerId=" + prevProducerId +
", nextProducerId=" + nextProducerId +
", producerEpoch=" + producerEpoch +
", lastProducerEpoch=" + lastProducerEpoch +
", txnTimeoutMs=" + txnTimeoutMs +
", state=" + state +
", pendingState=" + pendingState +
", topicPartitions=" + topicPartitions +
", txnStartTimestamp=" + txnStartTimestamp +
", txnLastUpdateTimestamp=" + txnLastUpdateTimestamp +
", clientTransactionVersion=" + clientTransactionVersion +
")";
}
@Override
public boolean equals(Object obj) {
if (this == obj) return true;
if (obj == null || getClass() != obj.getClass()) return false;
TransactionMetadata other = (TransactionMetadata) obj;
return transactionalId.equals(other.transactionalId) &&
producerId == other.producerId &&
prevProducerId == other.prevProducerId &&
nextProducerId == other.nextProducerId &&
producerEpoch == other.producerEpoch &&
lastProducerEpoch == other.lastProducerEpoch &&
txnTimeoutMs == other.txnTimeoutMs &&
state.equals(other.state) &&
topicPartitions.equals(other.topicPartitions) &&
txnStartTimestamp == other.txnStartTimestamp &&
txnLastUpdateTimestamp == other.txnLastUpdateTimestamp &&
clientTransactionVersion.equals(other.clientTransactionVersion);
}
@Override
public int hashCode() {
return Objects.hash(
transactionalId,
producerId,
prevProducerId,
nextProducerId,
producerEpoch,
lastProducerEpoch,
txnTimeoutMs,
state,
topicPartitions,
txnStartTimestamp,
txnLastUpdateTimestamp,
clientTransactionVersion
);
}
/**
* This class is used to hold the data that is needed to transition the transaction metadata to a new state.
* The data is copied from the current transaction metadata to avoid a lot of duplicated code in the prepare methods.
*/
private class TransitionData {
final TransactionState state;
long producerId = TransactionMetadata.this.producerId;
long nextProducerId = TransactionMetadata.this.nextProducerId;
short producerEpoch = TransactionMetadata.this.producerEpoch;
short lastProducerEpoch = TransactionMetadata.this.lastProducerEpoch;
int txnTimeoutMs = TransactionMetadata.this.txnTimeoutMs;
HashSet<TopicPartition> topicPartitions = TransactionMetadata.this.topicPartitions;
long txnStartTimestamp = TransactionMetadata.this.txnStartTimestamp;
long txnLastUpdateTimestamp = TransactionMetadata.this.txnLastUpdateTimestamp;
TransactionVersion clientTransactionVersion = TransactionMetadata.this.clientTransactionVersion;
private TransitionData(TransactionState state) {
this.state = state;
}
}
}

View File

@ -19,7 +19,7 @@ package org.apache.kafka.coordinator.transaction;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.server.common.TransactionVersion;
import java.util.Set;
import java.util.HashSet;
/**
* Immutable object representing the target transition of the transaction metadata
@ -32,7 +32,9 @@ public record TxnTransitMetadata(
short lastProducerEpoch,
int txnTimeoutMs,
TransactionState txnState,
Set<TopicPartition> topicPartitions,
// The TransactionMetadata#topicPartitions field is mutable.
// To avoid deep copy when assigning value from TxnTransitMetadata to TransactionMetadata, use HashSet here.
HashSet<TopicPartition> topicPartitions,
long txnStartTimestamp,
long txnLastUpdateTimestamp,
TransactionVersion clientTransactionVersion