KAFKA-14562 [1/2]: Implement epoch bump after every transaction (#16719)

Implement server side changes for epoch bump but keep EndTxn as an unstable API until the client side changes are implemented. EndTxnResponse will return the producer ID and epoch for the transaction. Introduces new tagged fields to the TransactionLogValue to persist the clientTransactionVersion, previousProducerId, and nextProducerId to the log so that the state can be reloaded. See KIP-890 for more details.

Small updates to naming of lastProducerId -> PreviousProducerId. Also cleans up the many TransactionMetadata constructors.

Reviewers: Artem Livshits <alivshits@confluent.io>, David Jacot <djacot@confluent.io>
This commit is contained in:
Justine Olshan 2024-09-26 09:37:11 -07:00 committed by GitHub
parent 3a1465e14c
commit ede0c94aaa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 850 additions and 380 deletions

View File

@ -32,7 +32,11 @@ public class EndTxnRequest extends AbstractRequest {
public final EndTxnRequestData data;
public Builder(EndTxnRequestData data) {
super(ApiKeys.END_TXN);
this(data, false);
}
public Builder(EndTxnRequestData data, boolean enableUnstableLastVersion) {
super(ApiKeys.END_TXN, enableUnstableLastVersion);
this.data = data;
}

View File

@ -25,7 +25,10 @@
// Version 3 enables flexible versions.
//
// Version 4 adds support for new error code TRANSACTION_ABORTABLE (KIP-890).
"validVersions": "0-4",
//
// Version 5 enables bumping epoch on every transaction (KIP-890 Part 2)
"latestVersionUnstable": true,
"validVersions": "0-5",
"flexibleVersions": "3+",
"fields": [
{ "name": "TransactionalId", "type": "string", "versions": "0+", "entityType": "transactionalId",

View File

@ -24,12 +24,18 @@
// Version 3 enables flexible versions.
//
// Version 4 adds support for new error code TRANSACTION_ABORTABLE (KIP-890).
"validVersions": "0-4",
//
// Version 5 enables bumping epoch on every transaction (KIP-890 Part 2), so producer ID and epoch are included in the response.
"validVersions": "0-5",
"flexibleVersions": "3+",
"fields": [
{ "name": "ThrottleTimeMs", "type": "int32", "versions": "0+",
"about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." },
{ "name": "ErrorCode", "type": "int16", "versions": "0+",
"about": "The error code, or 0 if there was no error." }
"about": "The error code, or 0 if there was no error." },
{ "name": "ProducerId", "type": "int64", "versions": "5+", "entityType": "producerId", "default": "-1", "ignorable": "true",
"about": "The producer ID." },
{ "name": "ProducerEpoch", "type": "int16", "versions": "5+", "default": "-1", "ignorable": "true",
"about": "The current epoch associated with the producer." }
]
}

View File

@ -29,7 +29,7 @@ import org.apache.kafka.common.protocol.Errors
import org.apache.kafka.common.record.RecordBatch
import org.apache.kafka.common.requests.{AddPartitionsToTxnResponse, TransactionResult}
import org.apache.kafka.common.utils.{LogContext, ProducerIdAndEpoch, Time}
import org.apache.kafka.server.common.RequestLocal
import org.apache.kafka.server.common.{RequestLocal, TransactionVersion}
import org.apache.kafka.server.util.Scheduler
import scala.jdk.CollectionConverters._
@ -98,7 +98,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
private type InitProducerIdCallback = InitProducerIdResult => Unit
private type AddPartitionsCallback = Errors => Unit
private type VerifyPartitionsCallback = AddPartitionsToTxnResult => Unit
private type EndTxnCallback = Errors => Unit
private type EndTxnCallback = (Errors, Long, Short) => Unit
private type ApiResult[T] = Either[Errors, T]
/* Active flag of the coordinator */
@ -135,13 +135,15 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
case Success(producerId) =>
val createdMetadata = new TransactionMetadata(transactionalId = transactionalId,
producerId = producerId,
lastProducerId = RecordBatch.NO_PRODUCER_ID,
previousProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = transactionTimeoutMs,
state = Empty,
topicPartitions = collection.mutable.Set.empty[TopicPartition],
txnLastUpdateTimestamp = time.milliseconds())
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TransactionVersion.TV_0)
txnManager.putTransactionStateIfNotExists(createdMetadata)
case Failure(exception) =>
@ -169,7 +171,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
case Right((coordinatorEpoch, newMetadata)) =>
if (newMetadata.txnState == PrepareEpochFence) {
// abort the ongoing transaction and then return CONCURRENT_TRANSACTIONS to let client wait and retry
def sendRetriableErrorCallback(error: Errors): Unit = {
def sendRetriableErrorCallback(error: Errors, newProducerId: Long, newProducerEpoch: Short): Unit = {
if (error != Errors.NONE) {
responseCallback(initTransactionError(error))
} else {
@ -182,6 +184,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
newMetadata.producerEpoch,
TransactionResult.ABORT,
isFromClient = false,
clientTransactionVersion = txnManager.transactionVersionLevel(), // Since this is not from client, use server TV
sendRetriableErrorCallback,
requestLocal)
} else {
@ -221,7 +224,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
// could be a retry after a valid epoch bump that the producer never received the response for
txnMetadata.producerEpoch == RecordBatch.NO_PRODUCER_EPOCH ||
producerIdAndEpoch.producerId == txnMetadata.producerId ||
(producerIdAndEpoch.producerId == txnMetadata.lastProducerId && TransactionMetadata.isEpochExhausted(producerIdAndEpoch.epoch))
(producerIdAndEpoch.producerId == txnMetadata.previousProducerId && TransactionMetadata.isEpochExhausted(producerIdAndEpoch.epoch))
}
if (txnMetadata.pendingTransitionInProgress) {
@ -487,6 +490,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
producerId: Long,
producerEpoch: Short,
txnMarkerResult: TransactionResult,
clientTransactionVersion: TransactionVersion,
responseCallback: EndTxnCallback,
requestLocal: RequestLocal = RequestLocal.noCaching): Unit = {
endTransaction(transactionalId,
@ -494,6 +498,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
producerEpoch,
txnMarkerResult,
isFromClient = true,
clientTransactionVersion,
responseCallback,
requestLocal)
}
@ -503,12 +508,15 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
producerEpoch: Short,
txnMarkerResult: TransactionResult,
isFromClient: Boolean,
clientTransactionVersion: TransactionVersion,
responseCallback: EndTxnCallback,
requestLocal: RequestLocal): Unit = {
var isEpochFence = false
if (transactionalId == null || transactionalId.isEmpty)
responseCallback(Errors.INVALID_REQUEST)
responseCallback(Errors.INVALID_REQUEST, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH)
else {
var producerIdCopy = RecordBatch.NO_PRODUCER_ID
var producerEpochCopy = RecordBatch.NO_PRODUCER_EPOCH
val preAppendResult: ApiResult[(Int, TxnTransitMetadata)] = txnManager.getTransactionState(transactionalId).flatMap {
case None =>
Left(Errors.INVALID_PRODUCER_ID_MAPPING)
@ -518,10 +526,39 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
val coordinatorEpoch = epochAndTxnMetadata.coordinatorEpoch
txnMetadata.inLock {
if (txnMetadata.producerId != producerId)
producerIdCopy = txnMetadata.producerId
producerEpochCopy = txnMetadata.producerEpoch
// PrepareEpochFence has slightly different epoch bumping logic so don't include it here.
val currentTxnMetadataIsAtLeastTransactionsV2 = !txnMetadata.pendingState.contains(PrepareEpochFence) && txnMetadata.clientTransactionVersion.supportsEpochBump()
// True if the client used TV_2 and retried a request that had overflowed the epoch, and a new producer ID is stored in the txnMetadata
val retryOnOverflow = currentTxnMetadataIsAtLeastTransactionsV2 &&
txnMetadata.previousProducerId == producerId && producerEpoch == Short.MaxValue - 1 && txnMetadata.producerEpoch == 0
// True if the client used TV_2 and retried an endTxn request, and the bumped producer epoch is stored in the txnMetadata.
val retryOnEpochBump = endTxnEpochBumped(txnMetadata, producerEpoch)
val isValidEpoch = {
if (currentTxnMetadataIsAtLeastTransactionsV2) {
// With transactions V2, state + same epoch is not sufficient to determine if a retry transition is valid. If the epoch is the
// same it actually indicates the next endTransaction call. Instead, we want to check the epoch matches with the epoch in the retry conditions.
// Return producer fenced even in the cases where the epoch is higher and could indicate an invalid state transition.
// Use the following criteria to determine if a v2 retry is valid:
txnMetadata.state match {
case Ongoing | Empty | Dead | PrepareEpochFence =>
producerEpoch == txnMetadata.producerEpoch
case PrepareCommit | PrepareAbort =>
retryOnEpochBump
case CompleteCommit | CompleteAbort =>
retryOnEpochBump || retryOnOverflow
}
} else {
// For transactions V1 strict equality is enforced on the client side requests, as they shouldn't bump the producer epoch without server knowledge.
(!isFromClient || producerEpoch == txnMetadata.producerEpoch) && producerEpoch >= txnMetadata.producerEpoch
}
}
if (txnMetadata.producerId != producerId && !retryOnOverflow)
Left(Errors.INVALID_PRODUCER_ID_MAPPING)
// Strict equality is enforced on the client side requests, as they shouldn't bump the producer epoch.
else if ((isFromClient && producerEpoch != txnMetadata.producerEpoch) || producerEpoch < txnMetadata.producerEpoch)
else if (!isValidEpoch)
Left(Errors.PRODUCER_FENCED)
else if (txnMetadata.pendingTransitionInProgress && txnMetadata.pendingState.get != PrepareEpochFence)
Left(Errors.CONCURRENT_TRANSACTIONS)
@ -532,6 +569,19 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
else
PrepareAbort
// Maybe allocate new producer ID if we are bumping epoch and epoch is exhausted
val nextProducerIdOrErrors =
if (clientTransactionVersion.supportsEpochBump() && !txnMetadata.pendingState.contains(PrepareEpochFence) && txnMetadata.isProducerEpochExhausted) {
producerIdManager.generateProducerId() match {
case Success(newProducerId) =>
Right(newProducerId)
case Failure(exception) =>
Left(Errors.forException(exception))
}
} else {
Right(RecordBatch.NO_PRODUCER_ID)
}
if (nextState == PrepareAbort && txnMetadata.pendingState.contains(PrepareEpochFence)) {
// We should clear the pending state to make way for the transition to PrepareAbort and also bump
// the epoch in the transaction metadata we are about to append.
@ -541,7 +591,10 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
txnMetadata.lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH
}
Right(coordinatorEpoch, txnMetadata.prepareAbortOrCommit(nextState, time.milliseconds()))
nextProducerIdOrErrors.flatMap {
nextProducerId =>
Right(coordinatorEpoch, txnMetadata.prepareAbortOrCommit(nextState, clientTransactionVersion, nextProducerId, time.milliseconds()))
}
case CompleteCommit =>
if (txnMarkerResult == TransactionResult.COMMIT)
Left(Errors.NONE)
@ -576,8 +629,12 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
preAppendResult match {
case Left(err) =>
if (err == Errors.NONE) {
responseCallback(err, producerIdCopy, producerEpochCopy)
} else {
debug(s"Aborting append of $txnMarkerResult to transaction log with coordinator and returning $err error to client for $transactionalId's EndTransaction request")
responseCallback(err)
responseCallback(err, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH)
}
case Right((coordinatorEpoch, newMetadata)) =>
def sendTxnMarkersCallback(error: Errors): Unit = {
@ -595,7 +652,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
txnMetadata.inLock {
if (txnMetadata.producerId != producerId)
Left(Errors.INVALID_PRODUCER_ID_MAPPING)
else if (txnMetadata.producerEpoch != producerEpoch)
else if (txnMetadata.producerEpoch != producerEpoch && !endTxnEpochBumped(txnMetadata, producerEpoch))
Left(Errors.PRODUCER_FENCED)
else if (txnMetadata.pendingTransitionInProgress)
Left(Errors.CONCURRENT_TRANSACTIONS)
@ -630,12 +687,12 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
preSendResult match {
case Left(err) =>
info(s"Aborting sending of transaction markers after appended $txnMarkerResult to transaction log and returning $err error to client for $transactionalId's EndTransaction request")
responseCallback(err)
responseCallback(err, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH)
case Right((txnMetadata, newPreSendMetadata)) =>
// we can respond to the client immediately and continue to write the txn markers if
// the log append was successful
responseCallback(Errors.NONE)
responseCallback(Errors.NONE, txnMetadata.producerId, txnMetadata.producerEpoch)
txnMarkerChannelManager.addTxnMarkersToSend(coordinatorEpoch, txnMarkerResult, txnMetadata, newPreSendMetadata)
}
@ -659,7 +716,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
}
}
responseCallback(error)
responseCallback(error, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH)
}
}
@ -669,11 +726,19 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
}
}
// When a client and server support V2, every endTransaction call bumps the producer epoch. When checking epoch, we want to
// check epoch + 1. Epoch bumps from PrepareEpochFence state are handled separately, so this method should not be used to check that case.
// Returns true if the transaction state epoch is the specified producer epoch + 1 and epoch bump on every transaction is expected.
private def endTxnEpochBumped(txnMetadata: TransactionMetadata, producerEpoch: Short): Boolean = {
!txnMetadata.pendingState.contains(PrepareEpochFence) && txnMetadata.clientTransactionVersion.supportsEpochBump() &&
txnMetadata.producerEpoch == producerEpoch + 1
}
def transactionTopicConfigs: Properties = txnManager.transactionTopicConfigs
def partitionFor(transactionalId: String): Int = txnManager.partitionFor(transactionalId)
private def onEndTransactionComplete(txnIdAndPidEpoch: TransactionalIdAndProducerIdEpoch)(error: Errors): Unit = {
private def onEndTransactionComplete(txnIdAndPidEpoch: TransactionalIdAndProducerIdEpoch)(error: Errors, newProducerId: Long, newProducerEpoch: Short): Unit = {
error match {
case Errors.NONE =>
info("Completed rollback of ongoing transaction for transactionalId " +
@ -721,6 +786,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
txnTransitMetadata.producerEpoch,
TransactionResult.ABORT,
isFromClient = false,
clientTransactionVersion = txnManager.transactionVersionLevel(), // Since this is not from client, use server TV
onComplete(txnIdAndPidEpoch),
RequestLocal.noCaching)
}

View File

@ -25,6 +25,7 @@ import org.apache.kafka.common.protocol.{ByteBufferAccessor, MessageUtil}
import org.apache.kafka.common.record.{Record, RecordBatch}
import org.apache.kafka.common.{MessageFormatter, TopicPartition}
import org.apache.kafka.coordinator.transaction.generated.{TransactionLogKey, TransactionLogValue}
import org.apache.kafka.server.common.TransactionVersion
import scala.collection.mutable
import scala.jdk.CollectionConverters._
@ -63,7 +64,7 @@ object TransactionLog {
* @return value payload bytes
*/
private[transaction] def valueToBytes(txnMetadata: TxnTransitMetadata,
usesFlexibleRecords: Boolean): Array[Byte] = {
transactionVersionLevel: TransactionVersion): Array[Byte] = {
if (txnMetadata.txnState == Empty && txnMetadata.topicPartitions.nonEmpty)
throw new IllegalStateException(s"Transaction is not expected to have any partitions since its state is ${txnMetadata.txnState}: $txnMetadata")
@ -78,9 +79,7 @@ object TransactionLog {
// Serialize with version 0 (highest non-flexible version) until transaction.version 1 is enabled
// which enables flexible fields in records.
val version: Short =
if (usesFlexibleRecords) 1 else 0
MessageUtil.toVersionPrefixedBytes(version,
MessageUtil.toVersionPrefixedBytes(transactionVersionLevel.transactionLogValueVersion(),
new TransactionLogValue()
.setProducerId(txnMetadata.producerId)
.setProducerEpoch(txnMetadata.producerEpoch)
@ -88,7 +87,8 @@ object TransactionLog {
.setTransactionStatus(txnMetadata.txnState.id)
.setTransactionLastUpdateTimestampMs(txnMetadata.txnLastUpdateTimestamp)
.setTransactionStartTimestampMs(txnMetadata.txnStartTimestamp)
.setTransactionPartitions(transactionPartitions))
.setTransactionPartitions(transactionPartitions)
.setClientTransactionVersion(txnMetadata.clientTransactionVersion.featureLevel()))
}
/**
@ -124,14 +124,16 @@ object TransactionLog {
val transactionMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = value.producerId,
lastProducerId = RecordBatch.NO_PRODUCER_ID,
previousProducerId = value.previousProducerId,
nextProducerId = value.nextProducerId,
producerEpoch = value.producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = value.transactionTimeoutMs,
state = TransactionState.fromId(value.transactionStatus),
topicPartitions = mutable.Set.empty[TopicPartition],
txnStartTimestamp = value.transactionStartTimestampMs,
txnLastUpdateTimestamp = value.transactionLastUpdateTimestampMs)
txnLastUpdateTimestamp = value.transactionLastUpdateTimestampMs,
clientTransactionVersion = TransactionVersion.fromFeatureLevel(value.clientTransactionVersion))
if (!transactionMetadata.state.equals(Empty))
value.transactionPartitions.forEach(partitionsSchema =>

View File

@ -17,11 +17,11 @@
package kafka.coordinator.transaction
import java.util.concurrent.locks.ReentrantLock
import kafka.utils.{CoreUtils, Logging, nonthreadsafe}
import org.apache.kafka.common.TopicPartition
import org.apache.kafka.common.protocol.Errors
import org.apache.kafka.common.record.RecordBatch
import org.apache.kafka.server.common.TransactionVersion
import scala.collection.{immutable, mutable}
@ -163,51 +163,42 @@ private[transaction] case object PrepareEpochFence extends TransactionState {
}
private[transaction] object TransactionMetadata {
def apply(transactionalId: String, producerId: Long, producerEpoch: Short, txnTimeoutMs: Int, timestamp: Long) =
new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Empty, collection.mutable.Set.empty[TopicPartition], timestamp, timestamp)
def apply(transactionalId: String, producerId: Long, producerEpoch: Short, txnTimeoutMs: Int,
state: TransactionState, timestamp: Long) =
new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, state, collection.mutable.Set.empty[TopicPartition], timestamp, timestamp)
def apply(transactionalId: String, producerId: Long, lastProducerId: Long, producerEpoch: Short,
lastProducerEpoch: Short, txnTimeoutMs: Int, state: TransactionState, timestamp: Long) =
new TransactionMetadata(transactionalId, producerId, lastProducerId, producerEpoch, lastProducerEpoch,
txnTimeoutMs, state, collection.mutable.Set.empty[TopicPartition], timestamp, timestamp)
def isEpochExhausted(producerEpoch: Short): Boolean = producerEpoch >= Short.MaxValue - 1
}
// this is a immutable object representing the target transition of the transaction metadata
private[transaction] case class TxnTransitMetadata(producerId: Long,
lastProducerId: Long,
prevProducerId: Long,
nextProducerId: Long,
producerEpoch: Short,
lastProducerEpoch: Short,
txnTimeoutMs: Int,
txnState: TransactionState,
topicPartitions: immutable.Set[TopicPartition],
txnStartTimestamp: Long,
txnLastUpdateTimestamp: Long) {
txnLastUpdateTimestamp: Long,
clientTransactionVersion: TransactionVersion) {
override def toString: String = {
"TxnTransitMetadata(" +
s"producerId=$producerId, " +
s"lastProducerId=$lastProducerId, " +
s"previousProducerId=$prevProducerId, " +
s"nextProducerId=$nextProducerId, " +
s"producerEpoch=$producerEpoch, " +
s"lastProducerEpoch=$lastProducerEpoch, " +
s"txnTimeoutMs=$txnTimeoutMs, " +
s"txnState=$txnState, " +
s"topicPartitions=$topicPartitions, " +
s"txnStartTimestamp=$txnStartTimestamp, " +
s"txnLastUpdateTimestamp=$txnLastUpdateTimestamp)"
s"txnLastUpdateTimestamp=$txnLastUpdateTimestamp, " +
s"clientTransactionVersion=$clientTransactionVersion)"
}
}
/**
*
* @param producerId producer id
* @param lastProducerId last producer id assigned to the producer
* @param previousProducerId producer id for the last committed transaction with this transactional ID
* @param nextProducerId Latest producer ID sent to the producer for the given transactional ID
* @param producerEpoch current epoch of the producer
* @param lastProducerEpoch last epoch of the producer
* @param txnTimeoutMs timeout to be used to abort long running transactions
@ -215,18 +206,21 @@ private[transaction] case class TxnTransitMetadata(producerId: Long,
* @param topicPartitions current set of partitions that are part of this transaction
* @param txnStartTimestamp time the transaction was started, i.e., when first partition is added
* @param txnLastUpdateTimestamp updated when any operation updates the TransactionMetadata. To be used for expiration
* @param clientTransactionVersion TransactionVersion used by the client when the state was transitioned
*/
@nonthreadsafe
private[transaction] class TransactionMetadata(val transactionalId: String,
var producerId: Long,
var lastProducerId: Long,
var previousProducerId: Long,
var nextProducerId: Long,
var producerEpoch: Short,
var lastProducerEpoch: Short,
var txnTimeoutMs: Int,
var state: TransactionState,
val topicPartitions: mutable.Set[TopicPartition],
@volatile var txnStartTimestamp: Long = -1,
@volatile var txnLastUpdateTimestamp: Long) extends Logging {
@volatile var txnLastUpdateTimestamp: Long,
var clientTransactionVersion: TransactionVersion) extends Logging {
// pending state is used to indicate the state that this transaction is going to
// transit to, and for blocking future attempts to transit it again if it is not legal;
@ -256,8 +250,8 @@ private[transaction] class TransactionMetadata(val transactionalId: String,
// this is visible for test only
def prepareNoTransit(): TxnTransitMetadata = {
// do not call transitTo as it will set the pending state, a follow-up call to abort the transaction will set its pending state
TxnTransitMetadata(producerId, lastProducerId, producerEpoch, lastProducerEpoch, txnTimeoutMs, state, topicPartitions.toSet,
txnStartTimestamp, txnLastUpdateTimestamp)
TxnTransitMetadata(producerId, previousProducerId, nextProducerId, producerEpoch, lastProducerEpoch, txnTimeoutMs, state, topicPartitions.toSet,
txnStartTimestamp, txnLastUpdateTimestamp, TransactionVersion.TV_0)
}
def prepareFenceProducerEpoch(): TxnTransitMetadata = {
@ -335,9 +329,16 @@ private[transaction] class TransactionMetadata(val transactionalId: String,
(topicPartitions ++ addedTopicPartitions).toSet, newTxnStartTimestamp, updateTimestamp)
}
def prepareAbortOrCommit(newState: TransactionState, updateTimestamp: Long): TxnTransitMetadata = {
prepareTransitionTo(newState, producerId, producerEpoch, lastProducerEpoch, txnTimeoutMs, topicPartitions.toSet,
txnStartTimestamp, updateTimestamp)
def prepareAbortOrCommit(newState: TransactionState, clientTransactionVersion: TransactionVersion, nextProducerId: Long, updateTimestamp: Long): TxnTransitMetadata = {
val (updatedProducerEpoch, updatedLastProducerEpoch) = if (clientTransactionVersion.supportsEpochBump()) {
// We already ensured that we do not overflow here. MAX_SHORT is the highest possible value.
((producerEpoch + 1).toShort, producerEpoch)
} else {
(producerEpoch, lastProducerEpoch)
}
prepareTransitionTo(newState, producerId, nextProducerId, updatedProducerEpoch, updatedLastProducerEpoch, txnTimeoutMs, topicPartitions.toSet,
txnStartTimestamp, updateTimestamp, clientTransactionVersion)
}
def prepareComplete(updateTimestamp: Long): TxnTransitMetadata = {
@ -345,8 +346,15 @@ private[transaction] class TransactionMetadata(val transactionalId: String,
// Since the state change was successfully written to the log, unset the flag for a failed epoch fence
hasFailedEpochFence = false
prepareTransitionTo(newState, producerId, producerEpoch, lastProducerEpoch, txnTimeoutMs, Set.empty[TopicPartition],
txnStartTimestamp, updateTimestamp)
val (updatedProducerId, updatedProducerEpoch) =
// If we overflowed on epoch bump, we have to set it as the producer ID now the marker has been written.
if (clientTransactionVersion.supportsEpochBump() && nextProducerId != RecordBatch.NO_PRODUCER_ID) {
(nextProducerId, 0.toShort)
} else {
(producerId, producerEpoch)
}
prepareTransitionTo(newState, updatedProducerId, RecordBatch.NO_PRODUCER_ID, updatedProducerEpoch, producerEpoch, txnTimeoutMs, Set.empty[TopicPartition],
txnStartTimestamp, updateTimestamp, clientTransactionVersion)
}
def prepareDead(): TxnTransitMetadata = {
@ -367,37 +375,50 @@ private[transaction] class TransactionMetadata(val transactionalId: String,
}
}
private def prepareTransitionTo(newState: TransactionState,
newProducerId: Long,
newEpoch: Short,
newLastEpoch: Short,
newTxnTimeoutMs: Int,
newTopicPartitions: immutable.Set[TopicPartition],
newTxnStartTimestamp: Long,
private def prepareTransitionTo(updatedState: TransactionState,
updatedProducerId: Long,
updatedEpoch: Short,
updatedLastEpoch: Short,
updatedTxnTimeoutMs: Int,
updatedTopicPartitions: immutable.Set[TopicPartition],
updatedTxnStartTimestamp: Long,
updateTimestamp: Long): TxnTransitMetadata = {
prepareTransitionTo(updatedState, updatedProducerId, RecordBatch.NO_PRODUCER_ID, updatedEpoch, updatedLastEpoch, updatedTxnTimeoutMs, updatedTopicPartitions, updatedTxnStartTimestamp, updateTimestamp, TransactionVersion.TV_0)
}
private def prepareTransitionTo(updatedState: TransactionState,
updatedProducerId: Long,
nextProducerId: Long,
updatedEpoch: Short,
updatedLastEpoch: Short,
updatedTxnTimeoutMs: Int,
updatedTopicPartitions: immutable.Set[TopicPartition],
updatedTxnStartTimestamp: Long,
updateTimestamp: Long,
clientTransactionVersion: TransactionVersion): TxnTransitMetadata = {
if (pendingState.isDefined)
throw new IllegalStateException(s"Preparing transaction state transition to $newState " +
throw new IllegalStateException(s"Preparing transaction state transition to $updatedState " +
s"while it already a pending state ${pendingState.get}")
if (newProducerId < 0)
throw new IllegalArgumentException(s"Illegal new producer id $newProducerId")
if (updatedProducerId < 0)
throw new IllegalArgumentException(s"Illegal new producer id $updatedProducerId")
// The epoch is initialized to NO_PRODUCER_EPOCH when the TransactionMetadata
// is created for the first time and it could stay like this until transitioning
// to Dead.
if (newState != Dead && newEpoch < 0)
throw new IllegalArgumentException(s"Illegal new producer epoch $newEpoch")
if (updatedState != Dead && updatedEpoch < 0)
throw new IllegalArgumentException(s"Illegal new producer epoch $updatedEpoch")
// check that the new state transition is valid and update the pending state if necessary
if (newState.validPreviousStates.contains(state)) {
val transitMetadata = TxnTransitMetadata(newProducerId, producerId, newEpoch, newLastEpoch, newTxnTimeoutMs, newState,
newTopicPartitions, newTxnStartTimestamp, updateTimestamp)
if (updatedState.validPreviousStates.contains(state)) {
val transitMetadata = TxnTransitMetadata(updatedProducerId, producerId, nextProducerId, updatedEpoch, updatedLastEpoch, updatedTxnTimeoutMs, updatedState,
updatedTopicPartitions, updatedTxnStartTimestamp, updateTimestamp, clientTransactionVersion)
debug(s"TransactionalId $transactionalId prepare transition from $state to $transitMetadata")
pendingState = Some(newState)
pendingState = Some(updatedState)
transitMetadata
} else {
throw new IllegalStateException(s"Preparing transaction state transition to $newState failed since the target state" +
s" $newState is not a valid previous state of the current state $state")
throw new IllegalStateException(s"Preparing transaction state transition to $updatedState failed since the target state" +
s" $updatedState is not a valid previous state of the current state $state")
}
}
@ -436,7 +457,7 @@ private[transaction] class TransactionMetadata(val transactionalId: String,
producerEpoch = transitMetadata.producerEpoch
lastProducerEpoch = transitMetadata.lastProducerEpoch
producerId = transitMetadata.producerId
lastProducerId = transitMetadata.lastProducerId
previousProducerId = transitMetadata.prevProducerId
}
case Ongoing => // from addPartitions
@ -457,6 +478,10 @@ private[transaction] class TransactionMetadata(val transactionalId: String,
txnStartTimestamp != transitMetadata.txnStartTimestamp) {
throwStateTransitionFailure(transitMetadata)
} else if (transitMetadata.clientTransactionVersion.supportsEpochBump()) {
producerEpoch = transitMetadata.producerEpoch
lastProducerEpoch = transitMetadata.lastProducerEpoch
nextProducerId = transitMetadata.nextProducerId
}
case CompleteAbort | CompleteCommit => // from write markers
@ -468,6 +493,13 @@ private[transaction] class TransactionMetadata(val transactionalId: String,
} else {
txnStartTimestamp = transitMetadata.txnStartTimestamp
topicPartitions.clear()
if (transitMetadata.clientTransactionVersion.supportsEpochBump()) {
producerEpoch = transitMetadata.producerEpoch
lastProducerEpoch = transitMetadata.lastProducerEpoch
previousProducerId = transitMetadata.prevProducerId
producerId = transitMetadata.producerId
nextProducerId = transitMetadata.nextProducerId
}
}
case PrepareEpochFence =>
@ -487,6 +519,7 @@ private[transaction] class TransactionMetadata(val transactionalId: String,
}
debug(s"TransactionalId $transactionalId complete transition from $state to $transitMetadata")
clientTransactionVersion = transitMetadata.clientTransactionVersion
txnLastUpdateTimestamp = transitMetadata.txnLastUpdateTimestamp
pendingState = None
state = toState
@ -494,8 +527,14 @@ private[transaction] class TransactionMetadata(val transactionalId: String,
}
private def validProducerEpoch(transitMetadata: TxnTransitMetadata): Boolean = {
val transitEpoch = transitMetadata.producerEpoch
val transitProducerId = transitMetadata.producerId
val isAtLeastTransactionsV2 = transitMetadata.clientTransactionVersion.supportsEpochBump()
val isOverflowComplete = isAtLeastTransactionsV2 && (transitMetadata.txnState == CompleteCommit || transitMetadata.txnState == CompleteAbort) && transitMetadata.producerEpoch == 0
val transitEpoch =
if (isOverflowComplete || (isAtLeastTransactionsV2 && (transitMetadata.txnState == PrepareCommit || transitMetadata.txnState == PrepareAbort)))
transitMetadata.lastProducerEpoch
else
transitMetadata.producerEpoch
val transitProducerId = if (isOverflowComplete) transitMetadata.prevProducerId else transitMetadata.producerId
transitEpoch == producerEpoch && transitProducerId == producerId
}
@ -518,6 +557,8 @@ private[transaction] class TransactionMetadata(val transactionalId: String,
"TransactionMetadata(" +
s"transactionalId=$transactionalId, " +
s"producerId=$producerId, " +
s"previousProducerId=$previousProducerId, "
s"nextProducerId=$nextProducerId, "
s"producerEpoch=$producerEpoch, " +
s"txnTimeoutMs=$txnTimeoutMs, " +
s"state=$state, " +

View File

@ -101,8 +101,10 @@ class TransactionStateManager(brokerId: Int,
TransactionStateManagerConfig.METRICS_GROUP,
"The avg time it took to load the partitions in the last 30sec"), new Avg())
private[transaction] def usesFlexibleRecords(): Boolean = {
metadataCache.features().finalizedFeatures().getOrDefault(TransactionVersion.FEATURE_NAME, 0.toShort) > 0
private[transaction] def transactionVersionLevel(): TransactionVersion = {
val version = TransactionVersion.fromFeatureLevel(metadataCache.features().finalizedFeatures().getOrDefault(
TransactionVersion.FEATURE_NAME, 0.toShort))
version
}
// visible for testing only
@ -624,7 +626,7 @@ class TransactionStateManager(brokerId: Int,
// generate the message for this transaction metadata
val keyBytes = TransactionLog.keyToBytes(transactionalId)
val valueBytes = TransactionLog.valueToBytes(newMetadata, usesFlexibleRecords())
val valueBytes = TransactionLog.valueToBytes(newMetadata, transactionVersionLevel())
val timestamp = time.milliseconds()
val records = MemoryRecords.withRecords(TransactionLog.EnforcedCompression, new SimpleRecord(timestamp, keyBytes, valueBytes))

View File

@ -73,7 +73,7 @@ import org.apache.kafka.coordinator.group.{Group, GroupCoordinator}
import org.apache.kafka.coordinator.share.ShareCoordinator
import org.apache.kafka.server.ClientMetricsManager
import org.apache.kafka.server.authorizer._
import org.apache.kafka.server.common.{GroupVersion, MetadataVersion, RequestLocal}
import org.apache.kafka.server.common.{GroupVersion, MetadataVersion, RequestLocal, TransactionVersion}
import org.apache.kafka.server.common.MetadataVersion.{IBP_0_11_0_IV0, IBP_2_3_IV0}
import org.apache.kafka.server.record.BrokerCompressionType
import org.apache.kafka.server.share.context.ShareFetchContext
@ -2299,7 +2299,7 @@ class KafkaApis(val requestChannel: RequestChannel,
val transactionalId = endTxnRequest.data.transactionalId
if (authHelper.authorize(request.context, WRITE, TRANSACTIONAL_ID, transactionalId)) {
def sendResponseCallback(error: Errors): Unit = {
def sendResponseCallback(error: Errors, newProducerId: Long, newProducerEpoch: Short): Unit = {
def createResponse(requestThrottleMs: Int): AbstractResponse = {
val finalError =
if (endTxnRequest.version < 2 && error == Errors.PRODUCER_FENCED) {
@ -2311,6 +2311,8 @@ class KafkaApis(val requestChannel: RequestChannel,
}
val responseBody = new EndTxnResponse(new EndTxnResponseData()
.setErrorCode(finalError.code)
.setProducerId(newProducerId)
.setProducerEpoch(newProducerEpoch)
.setThrottleTimeMs(requestThrottleMs))
trace(s"Completed ${endTxnRequest.data.transactionalId}'s EndTxnRequest " +
s"with committed: ${endTxnRequest.data.committed}, " +
@ -2320,10 +2322,14 @@ class KafkaApis(val requestChannel: RequestChannel,
requestHelper.sendResponseMaybeThrottle(request, createResponse)
}
// If the request is version 4, we know the client supports transaction version 2.
val clientTransactionVersion = if (endTxnRequest.version() > 4) TransactionVersion.TV_2 else TransactionVersion.TV_0
txnCoordinator.handleEndTransaction(endTxnRequest.data.transactionalId,
endTxnRequest.data.producerId,
endTxnRequest.data.producerEpoch,
endTxnRequest.result(),
clientTransactionVersion,
sendResponseCallback,
requestLocal)
} else

View File

@ -464,10 +464,10 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren
addPartitionsOp.awaitAndVerify(txn)
val txnMetadata = transactionMetadata(txn).getOrElse(throw new IllegalStateException(s"Transaction not found $txn"))
txnRecords += new SimpleRecord(txn.txnMessageKeyBytes, TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), true))
txnRecords += new SimpleRecord(txn.txnMessageKeyBytes, TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), TransactionVersion.TV_2))
txnMetadata.state = PrepareCommit
txnRecords += new SimpleRecord(txn.txnMessageKeyBytes, TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), true))
txnRecords += new SimpleRecord(txn.txnMessageKeyBytes, TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), TransactionVersion.TV_2))
prepareTxnLog(partitionId)
}
@ -506,13 +506,15 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren
private def prepareExhaustedEpochTxnMetadata(txn: Transaction): TransactionMetadata = {
new TransactionMetadata(transactionalId = txn.transactionalId,
producerId = producerId,
lastProducerId = RecordBatch.NO_PRODUCER_ID,
previousProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = (Short.MaxValue - 1).toShort,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 60000,
state = Empty,
topicPartitions = collection.mutable.Set.empty[TopicPartition],
txnLastUpdateTimestamp = time.milliseconds())
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TransactionVersion.TV_0)
}
abstract class TxnOperation[R] extends Operation {
@ -562,7 +564,8 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren
txnMetadata.producerId,
txnMetadata.producerEpoch,
transactionResult(txn),
resultCallback,
TransactionVersion.TV_2,
(r, _, _) => resultCallback(r),
RequestLocal.withThreadConfinedCaching)
}
}

View File

@ -23,9 +23,13 @@ import org.apache.kafka.common.record.RecordBatch
import org.apache.kafka.common.requests.{AddPartitionsToTxnResponse, TransactionResult}
import org.apache.kafka.common.utils.{LogContext, MockTime, ProducerIdAndEpoch}
import org.apache.kafka.coordinator.transaction.TransactionStateManagerConfig
import org.apache.kafka.server.common.TransactionVersion
import org.apache.kafka.server.common.TransactionVersion.{TV_0, TV_2}
import org.apache.kafka.server.util.MockScheduler
import org.junit.jupiter.api.Assertions._
import org.junit.jupiter.api.Test
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.ValueSource
import org.mockito.{ArgumentCaptor, ArgumentMatchers}
import org.mockito.ArgumentMatchers.{any, anyInt}
import org.mockito.Mockito.{mock, times, verify, when}
@ -51,6 +55,7 @@ class TransactionCoordinatorTest {
private val producerId = 10L
private val producerEpoch: Short = 1
private val txnTimeoutMs = 1
private val producerId2 = 11L
private val partitions = mutable.Set[TopicPartition](new TopicPartition("topic1", 0))
private val scheduler = new MockScheduler(time)
@ -66,6 +71,8 @@ class TransactionCoordinatorTest {
val transactionStatePartitionCount = 1
var result: InitProducerIdResult = _
var error: Errors = Errors.NONE
var newProducerId: Long = RecordBatch.NO_PRODUCER_ID
var newEpoch: Short = RecordBatch.NO_PRODUCER_EPOCH
private def mockPidGenerator(): Unit = {
when(pidGenerator.generateProducerId()).thenAnswer(_ => {
@ -155,8 +162,8 @@ class TransactionCoordinatorTest {
def shouldGenerateNewProducerIdIfEpochsExhausted(): Unit = {
initPidGenericMocks(transactionalId)
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, (Short.MaxValue - 1).toShort,
(Short.MaxValue - 2).toShort, txnTimeoutMs, Empty, mutable.Set.empty, time.milliseconds(), time.milliseconds())
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, (Short.MaxValue - 1).toShort,
(Short.MaxValue - 2).toShort, txnTimeoutMs, Empty, mutable.Set.empty, time.milliseconds(), time.milliseconds(), TV_0)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
@ -245,7 +252,8 @@ class TransactionCoordinatorTest {
errors = AddPartitionsToTxnResponse.errorsForTransaction(result.topicResults()).asScala.toMap
}
// If producer ID is not the same, return INVALID_PRODUCER_ID_MAPPING
val wrongPidTxnMetadata = new TransactionMetadata(transactionalId, 1, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, PrepareCommit, partitions, 0, 0)
val wrongPidTxnMetadata = new TransactionMetadata(transactionalId, 1, 0, RecordBatch.NO_PRODUCER_ID,
0, RecordBatch.NO_PRODUCER_EPOCH, 0, PrepareCommit, partitions, 0, 0, TV_0)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(new CoordinatorEpochAndTxnMetadata(coordinatorEpoch, wrongPidTxnMetadata))))
@ -254,9 +262,9 @@ class TransactionCoordinatorTest {
assertEquals(Errors.INVALID_PRODUCER_ID_MAPPING, error)
}
// If producer epoch is not equal, return PRODUCER_FENCED
val oldEpochTxnMetadata = new TransactionMetadata(transactionalId, 0, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, PrepareCommit, partitions, 0, 0)
val oldEpochTxnMetadata = new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID,
0, RecordBatch.NO_PRODUCER_EPOCH, 0, PrepareCommit, partitions, 0, 0, TV_0)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(new CoordinatorEpochAndTxnMetadata(coordinatorEpoch, oldEpochTxnMetadata))))
@ -266,7 +274,8 @@ class TransactionCoordinatorTest {
}
// If the txn state is Prepare or AbortCommit, we return CONCURRENT_TRANSACTIONS
val emptyTxnMetadata = new TransactionMetadata(transactionalId, 0, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, PrepareCommit, partitions, 0, 0)
val emptyTxnMetadata = new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID,
0, RecordBatch.NO_PRODUCER_EPOCH, 0, PrepareCommit, partitions, 0, 0, TV_0)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(new CoordinatorEpochAndTxnMetadata(coordinatorEpoch, emptyTxnMetadata))))
@ -276,7 +285,8 @@ class TransactionCoordinatorTest {
}
// Pending state does not matter, we will just check if the partitions are in the txnMetadata.
val ongoingTxnMetadata = new TransactionMetadata(transactionalId, 0, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, Ongoing, mutable.Set.empty, 0, 0)
val ongoingTxnMetadata = new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID,
0, RecordBatch.NO_PRODUCER_EPOCH, 0, Ongoing, mutable.Set.empty, 0, 0, TV_0)
ongoingTxnMetadata.pendingState = Some(CompleteCommit)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(new CoordinatorEpochAndTxnMetadata(coordinatorEpoch, ongoingTxnMetadata))))
@ -298,9 +308,11 @@ class TransactionCoordinatorTest {
}
def validateConcurrentTransactions(state: TransactionState): Unit = {
// Since the clientTransactionVersion doesn't matter, use 2 since the states are PrepareCommit and PrepareAbort.
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch,
new TransactionMetadata(transactionalId, 0, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, state, mutable.Set.empty, 0, 0)))))
new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID,
0, RecordBatch.NO_PRODUCER_EPOCH, 0, state, mutable.Set.empty, 0, 0, TV_2)))))
coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 0, partitions, errorsCallback)
assertEquals(Errors.CONCURRENT_TRANSACTIONS, error)
@ -308,9 +320,11 @@ class TransactionCoordinatorTest {
@Test
def shouldRespondWithProducerFencedOnAddPartitionsWhenEpochsAreDifferent(): Unit = {
// Since the clientTransactionVersion doesn't matter, use 2 since the state is PrepareCommit.
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch,
new TransactionMetadata(transactionalId, 0, 0, 10, 9, 0, PrepareCommit, mutable.Set.empty, 0, 0)))))
new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID,
10, 9, 0, PrepareCommit, mutable.Set.empty, 0, 0, TV_2)))))
coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 0, partitions, errorsCallback)
assertEquals(Errors.PRODUCER_FENCED, error)
@ -318,27 +332,30 @@ class TransactionCoordinatorTest {
@Test
def shouldAppendNewMetadataToLogOnAddPartitionsWhenPartitionsAdded(): Unit = {
validateSuccessfulAddPartitions(Empty)
validateSuccessfulAddPartitions(Empty, 0)
}
@Test
def shouldRespondWithSuccessOnAddPartitionsWhenStateIsOngoing(): Unit = {
validateSuccessfulAddPartitions(Ongoing)
validateSuccessfulAddPartitions(Ongoing, 0)
}
@Test
def shouldRespondWithSuccessOnAddPartitionsWhenStateIsCompleteCommit(): Unit = {
validateSuccessfulAddPartitions(CompleteCommit)
@ParameterizedTest
@ValueSource(shorts = Array(0, 2))
def shouldRespondWithSuccessOnAddPartitionsWhenStateIsCompleteCommit(clientTransactionVersion: Short): Unit = {
validateSuccessfulAddPartitions(CompleteCommit, clientTransactionVersion)
}
@Test
def shouldRespondWithSuccessOnAddPartitionsWhenStateIsCompleteAbort(): Unit = {
validateSuccessfulAddPartitions(CompleteAbort)
@ParameterizedTest
@ValueSource(shorts = Array(0, 2))
def shouldRespondWithSuccessOnAddPartitionsWhenStateIsCompleteAbort(clientTransactionVersion: Short): Unit = {
validateSuccessfulAddPartitions(CompleteAbort, clientTransactionVersion)
}
def validateSuccessfulAddPartitions(previousState: TransactionState): Unit = {
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, (producerEpoch - 1).toShort,
txnTimeoutMs, previousState, mutable.Set.empty, time.milliseconds(), time.milliseconds())
def validateSuccessfulAddPartitions(previousState: TransactionState, transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
producerEpoch, (producerEpoch - 1).toShort, txnTimeoutMs, previousState, mutable.Set.empty, time.milliseconds(), time.milliseconds(), clientTransactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
@ -360,7 +377,8 @@ class TransactionCoordinatorTest {
def shouldRespondWithErrorsNoneOnAddPartitionWhenNoErrorsAndPartitionsTheSame(): Unit = {
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch,
new TransactionMetadata(transactionalId, 0, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, Empty, partitions, 0, 0)))))
new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID,
0, RecordBatch.NO_PRODUCER_EPOCH, 0, Empty, partitions, 0, 0, TV_0)))))
coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 0, partitions, errorsCallback)
assertEquals(Errors.NONE, error)
@ -376,7 +394,8 @@ class TransactionCoordinatorTest {
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch,
new TransactionMetadata(transactionalId, 0, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, Ongoing, partitions, 0, 0)))))
new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID,
0, RecordBatch.NO_PRODUCER_EPOCH, 0, Ongoing, partitions, 0, 0, TV_0)))))
coordinator.handleVerifyPartitionsInTransaction(transactionalId, 0L, 0, partitions, verifyPartitionsInTxnCallback)
errors.foreach { case (_, error) =>
@ -394,7 +413,8 @@ class TransactionCoordinatorTest {
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch,
new TransactionMetadata(transactionalId, 0, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, Empty, partitions, 0, 0)))))
new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID,
0, RecordBatch.NO_PRODUCER_EPOCH, 0, Empty, partitions, 0, 0, TV_0)))))
val extraPartitions = partitions ++ Set(new TopicPartition("topic2", 0))
@ -404,107 +424,227 @@ class TransactionCoordinatorTest {
verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId))
}
@Test
def shouldReplyWithInvalidPidMappingOnEndTxnWhenTxnIdDoesntExist(): Unit = {
@ParameterizedTest
@ValueSource(shorts = Array(0, 2))
def shouldReplyWithInvalidPidMappingOnEndTxnWhenTxnIdDoesntExist(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(None))
coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, errorsCallback)
coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback)
assertEquals(Errors.INVALID_PRODUCER_ID_MAPPING, error)
verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId))
}
@Test
def shouldReplyWithInvalidPidMappingOnEndTxnWhenPidDosentMatchMapped(): Unit = {
@ParameterizedTest
@ValueSource(shorts = Array(0, 2))
def shouldReplyWithInvalidPidMappingOnEndTxnWhenPidDosentMatchMapped(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch,
new TransactionMetadata(transactionalId, 10, 10, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, Ongoing, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds())))))
new TransactionMetadata(transactionalId, 10, 10, RecordBatch.NO_PRODUCER_ID,
0, RecordBatch.NO_PRODUCER_EPOCH, 0, Ongoing, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_0)))))
coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, errorsCallback)
coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback)
assertEquals(Errors.INVALID_PRODUCER_ID_MAPPING, error)
verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId))
}
@Test
def shouldReplyWithProducerFencedOnEndTxnWhenEpochIsNotSameAsTransaction(): Unit = {
@ParameterizedTest
@ValueSource(shorts = Array(0, 2))
def shouldReplyWithProducerFencedOnEndTxnWhenEpochIsNotSameAsTransaction(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch,
new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, (producerEpoch - 1).toShort, 1, Ongoing, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds())))))
new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, producerEpoch,
(producerEpoch - 1).toShort, 1, Ongoing, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_0)))))
coordinator.handleEndTransaction(transactionalId, producerId, 0, TransactionResult.COMMIT, errorsCallback)
coordinator.handleEndTransaction(transactionalId, producerId, 0, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback)
assertEquals(Errors.PRODUCER_FENCED, error)
verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId))
}
@Test
def shouldReturnOkOnEndTxnWhenStatusIsCompleteCommitAndResultIsCommit(): Unit ={
@ParameterizedTest
@ValueSource(shorts = Array(0, 2))
def shouldReturnOkOnEndTxnWhenStatusIsCompleteCommitAndResultIsCommit(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch,
new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, (producerEpoch - 1).toShort, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds())))))
new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, producerEpoch,
(producerEpoch - 1).toShort, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)))))
coordinator.handleEndTransaction(transactionalId, producerId, 1, TransactionResult.COMMIT, errorsCallback)
coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch(clientTransactionVersion), TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback)
assertEquals(Errors.NONE, error)
verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId))
}
@Test
def shouldReturnOkOnEndTxnWhenStatusIsCompleteAbortAndResultIsAbort(): Unit ={
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, (producerEpoch - 1).toShort, 1, CompleteAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds())
@ParameterizedTest
@ValueSource(shorts = Array(0, 2))
def shouldReturnOkOnEndTxnWhenStatusIsCompleteAbortAndResultIsAbort(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
producerEpoch, (producerEpoch - 1).toShort, 1, CompleteAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
coordinator.handleEndTransaction(transactionalId, producerId, 1, TransactionResult.ABORT, errorsCallback)
coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch(clientTransactionVersion), TransactionResult.ABORT, clientTransactionVersion, endTxnCallback)
assertEquals(Errors.NONE, error)
verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId))
}
@Test
def shouldReturnInvalidTxnRequestOnEndTxnRequestWhenStatusIsCompleteAbortAndResultIsNotAbort(): Unit = {
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, (producerEpoch - 1).toShort, 1, CompleteAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds())
@ParameterizedTest
@ValueSource(shorts = Array(0, 2))
def shouldReturnInvalidTxnRequestOnEndTxnRequestWhenStatusIsCompleteAbortAndResultIsNotAbort(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
producerEpoch, (producerEpoch - 1).toShort, 1, CompleteAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
coordinator.handleEndTransaction(transactionalId, producerId, 1, TransactionResult.COMMIT, errorsCallback)
coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch(clientTransactionVersion), TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback)
assertEquals(Errors.INVALID_TXN_STATE, error)
verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId))
}
@Test
def shouldReturnInvalidTxnRequestOnEndTxnRequestWhenStatusIsCompleteCommitAndResultIsNotCommit(): Unit = {
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, (producerEpoch - 1).toShort,1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds())
@ParameterizedTest
@ValueSource(shorts = Array(0, 2))
def shouldReturnInvalidTxnRequestOnEndTxnRequestWhenStatusIsCompleteCommitAndResultIsNotCommit(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
producerEpoch, (producerEpoch - 1).toShort,1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
coordinator.handleEndTransaction(transactionalId, producerId, 1, TransactionResult.ABORT, errorsCallback)
coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch(clientTransactionVersion), TransactionResult.ABORT, clientTransactionVersion, endTxnCallback)
assertEquals(Errors.INVALID_TXN_STATE, error)
verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId))
}
@Test
def shouldReturnConcurrentTxnRequestOnEndTxnRequestWhenStatusIsPrepareCommit(): Unit = {
@ParameterizedTest
@ValueSource(shorts = Array(0, 2))
def shouldReturnConcurrentTransactionsOnEndTxnRequestWhenStatusIsPrepareCommit(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, (producerEpoch - 1).toShort, 1, PrepareCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds())))))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_ID, producerEpoch, (producerEpoch - 1).toShort, 1, PrepareCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)))))
coordinator.handleEndTransaction(transactionalId, producerId, 1, TransactionResult.COMMIT, errorsCallback)
coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch(clientTransactionVersion), TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback)
assertEquals(Errors.CONCURRENT_TRANSACTIONS, error)
verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId))
}
@Test
def shouldReturnInvalidTxnRequestOnEndTxnRequestWhenStatusIsPrepareAbort(): Unit = {
@ParameterizedTest
@ValueSource(shorts = Array(0, 2))
def shouldReturnInvalidTxnRequestOnEndTxnRequestWhenStatusIsPrepareAbort(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, 1, RecordBatch.NO_PRODUCER_EPOCH, 1, PrepareAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds())))))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, PrepareAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)))))
coordinator.handleEndTransaction(transactionalId, producerId, 1, TransactionResult.COMMIT, errorsCallback)
coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch(clientTransactionVersion), TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback)
assertEquals(Errors.INVALID_TXN_STATE, error)
verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId))
}
@Test
def shouldAppendPrepareCommitToLogOnEndTxnWhenStatusIsOngoingAndResultIsCommit(): Unit = {
mockPrepare(PrepareCommit)
def shouldReturnWhenTransactionVersionDowngraded(): Unit = {
// State was written when transactions V2
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_ID, (producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH, 1, PrepareCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2)))))
coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.COMMIT, errorsCallback)
// Return CONCURRENT_TRANSACTIONS as the transaction is still completing
coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.COMMIT, TV_0, endTxnCallback)
assertEquals(Errors.CONCURRENT_TRANSACTIONS, error)
assertEquals(RecordBatch.NO_PRODUCER_ID, newProducerId)
assertEquals(RecordBatch.NO_PRODUCER_EPOCH, newEpoch)
verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId))
// Recognize the retry and return NONE
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_ID, (producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2)))))
coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.COMMIT, TV_0, endTxnCallback)
assertEquals(Errors.NONE, error)
assertEquals(producerId, newProducerId)
assertEquals((producerEpoch + 1).toShort, newEpoch) // epoch is bumped since we started as V2
verify(transactionManager, times(2)).getTransactionState(ArgumentMatchers.eq(transactionalId))
}
@Test
def shouldReturnCorrectlyWhenTransactionVersionUpgraded(): Unit = {
// State was written when transactions V0
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, PrepareCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_0)))))
// Transactions V0 throws the concurrent transactions error here.
coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.COMMIT, TV_2, endTxnCallback)
assertEquals(Errors.CONCURRENT_TRANSACTIONS, error)
verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId))
// When the transaction is completed, return and do not throw an error.
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_0)))))
coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.COMMIT, TV_2, endTxnCallback)
assertEquals(Errors.NONE, error)
assertEquals(producerId, newProducerId)
assertEquals(producerEpoch, newEpoch) // epoch is not bumped since this started as V1
verify(transactionManager, times(2)).getTransactionState(ArgumentMatchers.eq(transactionalId))
}
@Test
def shouldReturnInvalidTxnRequestOnEndTxnV2IfNotEndTxnV2Retry(): Unit = {
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, PrepareCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2)))))
// If producerEpoch is the same, this is not a retry of the EndTxnRequest, but the next EndTxnRequest. Return PRODUCER_FENCED.
coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.COMMIT, TV_2, endTxnCallback)
assertEquals(Errors.PRODUCER_FENCED, error)
verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId))
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2)))))
// If producerEpoch is the same, this is not a retry of the EndTxnRequest, but the next EndTxnRequest. Return PRODUCER_FENCED.
coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.COMMIT, TV_2, endTxnCallback)
assertEquals(Errors.PRODUCER_FENCED, error)
verify(transactionManager, times(2)).getTransactionState(ArgumentMatchers.eq(transactionalId))
}
@Test
def shouldReturnOkOnEndTxnV2IfEndTxnV2RetryEpochOverflow(): Unit = {
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId,
producerId2, Short.MaxValue, (Short.MaxValue - 1).toShort, 1, PrepareCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2)))))
// Return CONCURRENT_TRANSACTIONS while transaction is still completing
coordinator.handleEndTransaction(transactionalId, producerId, (Short.MaxValue - 1).toShort, TransactionResult.COMMIT, TV_2, endTxnCallback)
assertEquals(Errors.CONCURRENT_TRANSACTIONS, error)
verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId))
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId2, producerId,
RecordBatch.NO_PRODUCER_ID, 0, RecordBatch.NO_PRODUCER_EPOCH, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2)))))
coordinator.handleEndTransaction(transactionalId, producerId, (Short.MaxValue - 1).toShort, TransactionResult.COMMIT, TV_2, endTxnCallback)
assertEquals(Errors.NONE, error)
assertNotEquals(RecordBatch.NO_PRODUCER_ID, newProducerId)
assertNotEquals(producerId, newProducerId)
assertEquals(0, newEpoch)
verify(transactionManager, times(2)).getTransactionState(ArgumentMatchers.eq(transactionalId))
}
@ParameterizedTest
@ValueSource(shorts = Array(0, 2))
def shouldAppendPrepareCommitToLogOnEndTxnWhenStatusIsOngoingAndResultIsCommit(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
mockPrepare(PrepareCommit, clientTransactionVersion)
coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback)
verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId))
verify(transactionManager).appendTransactionToLog(
ArgumentMatchers.eq(transactionalId),
@ -515,11 +655,13 @@ class TransactionCoordinatorTest {
any())
}
@Test
def shouldAppendPrepareAbortToLogOnEndTxnWhenStatusIsOngoingAndResultIsAbort(): Unit = {
mockPrepare(PrepareAbort)
@ParameterizedTest
@ValueSource(shorts = Array(0, 2))
def shouldAppendPrepareAbortToLogOnEndTxnWhenStatusIsOngoingAndResultIsAbort(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
mockPrepare(PrepareAbort, clientTransactionVersion)
coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.ABORT, errorsCallback)
coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.ABORT, clientTransactionVersion, endTxnCallback)
verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId))
verify(transactionManager).appendTransactionToLog(
ArgumentMatchers.eq(transactionalId),
@ -530,74 +672,90 @@ class TransactionCoordinatorTest {
any())
}
@Test
def shouldRespondWithInvalidRequestOnEndTxnWhenTransactionalIdIsNull(): Unit = {
coordinator.handleEndTransaction(null, 0, 0, TransactionResult.COMMIT, errorsCallback)
@ParameterizedTest
@ValueSource(shorts = Array(0, 2))
def shouldRespondWithInvalidRequestOnEndTxnWhenTransactionalIdIsNull(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
coordinator.handleEndTransaction(null, 0, 0, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback)
assertEquals(Errors.INVALID_REQUEST, error)
}
@Test
def shouldRespondWithInvalidRequestOnEndTxnWhenTransactionalIdIsEmpty(): Unit = {
@ParameterizedTest
@ValueSource(shorts = Array(0, 2))
def shouldRespondWithInvalidRequestOnEndTxnWhenTransactionalIdIsEmpty(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Left(Errors.NOT_COORDINATOR))
coordinator.handleEndTransaction("", 0, 0, TransactionResult.COMMIT, errorsCallback)
coordinator.handleEndTransaction("", 0, 0, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback)
assertEquals(Errors.INVALID_REQUEST, error)
}
@Test
def shouldRespondWithNotCoordinatorOnEndTxnWhenIsNotCoordinatorForId(): Unit = {
@ParameterizedTest
@ValueSource(shorts = Array(0, 2))
def shouldRespondWithNotCoordinatorOnEndTxnWhenIsNotCoordinatorForId(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Left(Errors.NOT_COORDINATOR))
coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, errorsCallback)
coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback)
assertEquals(Errors.NOT_COORDINATOR, error)
}
@Test
def shouldRespondWithCoordinatorLoadInProgressOnEndTxnWhenCoordinatorIsLoading(): Unit = {
@ParameterizedTest
@ValueSource(shorts = Array(0, 2))
def shouldRespondWithCoordinatorLoadInProgressOnEndTxnWhenCoordinatorIsLoading(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Left(Errors.COORDINATOR_LOAD_IN_PROGRESS))
coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, errorsCallback)
coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback)
assertEquals(Errors.COORDINATOR_LOAD_IN_PROGRESS, error)
}
@Test
def shouldReturnInvalidEpochOnEndTxnWhenEpochIsLarger(): Unit = {
@ParameterizedTest
@ValueSource(shorts = Array(0, 2))
def shouldReturnInvalidEpochOnEndTxnWhenEpochIsLarger(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
val serverProducerEpoch = 1.toShort
verifyEndTxnEpoch(serverProducerEpoch, (serverProducerEpoch + 1).toShort)
verifyEndTxnEpoch(serverProducerEpoch, (serverProducerEpoch + 1).toShort, clientTransactionVersion)
}
@Test
def shouldReturnInvalidEpochOnEndTxnWhenEpochIsSmaller(): Unit = {
val serverProducerEpoch = 1.toShort
verifyEndTxnEpoch(serverProducerEpoch, (serverProducerEpoch - 1).toShort)
@ParameterizedTest
@ValueSource(shorts = Array(0, 2))
def shouldReturnInvalidEpochOnEndTxnWhenEpochIsSmaller(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
val serverProducerEpoch = 2.toShort
// Since we bump epoch in transactionV2 the request should be one producer ID older
verifyEndTxnEpoch(serverProducerEpoch, requestEpoch(clientTransactionVersion), clientTransactionVersion)
}
private def verifyEndTxnEpoch(metadataEpoch: Short, requestEpoch: Short): Unit = {
private def verifyEndTxnEpoch(metadataEpoch: Short, requestEpoch: Short, clientTransactionVersion: TransactionVersion): Unit = {
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch,
new TransactionMetadata(transactionalId, producerId, producerId, metadataEpoch, 0, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds())))))
new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, metadataEpoch, 1,
1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion)))))
coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch, TransactionResult.COMMIT, errorsCallback)
coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback)
assertEquals(Errors.PRODUCER_FENCED, error)
verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId))
}
@Test
def shouldIncrementEpochAndUpdateMetadataOnHandleInitPidWhenExistingEmptyTransaction(): Unit = {
validateIncrementEpochAndUpdateMetadata(Empty)
validateIncrementEpochAndUpdateMetadata(Empty, 0)
}
@Test
def shouldIncrementEpochAndUpdateMetadataOnHandleInitPidWhenExistingCompleteTransaction(): Unit = {
validateIncrementEpochAndUpdateMetadata(CompleteAbort)
@ParameterizedTest
@ValueSource(shorts = Array(0, 2))
def shouldIncrementEpochAndUpdateMetadataOnHandleInitPidWhenExistingCompleteTransaction(clientTransactionVersion: Short): Unit = {
validateIncrementEpochAndUpdateMetadata(CompleteAbort, clientTransactionVersion)
}
@Test
def shouldIncrementEpochAndUpdateMetadataOnHandleInitPidWhenExistingCompleteCommitTransaction(): Unit = {
validateIncrementEpochAndUpdateMetadata(CompleteCommit)
@ParameterizedTest
@ValueSource(shorts = Array(0, 2))
def shouldIncrementEpochAndUpdateMetadataOnHandleInitPidWhenExistingCompleteCommitTransaction(clientTransactionVersion: Short): Unit = {
validateIncrementEpochAndUpdateMetadata(CompleteCommit, clientTransactionVersion)
}
@Test
@ -612,8 +770,8 @@ class TransactionCoordinatorTest {
@Test
def shouldAbortTransactionOnHandleInitPidWhenExistingTransactionInOngoingState(): Unit = {
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch,
(producerEpoch - 1).toShort, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds())
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
producerEpoch, (producerEpoch - 1).toShort, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0)
when(transactionManager.validateTransactionTimeoutMs(anyInt()))
.thenReturn(true)
@ -621,8 +779,10 @@ class TransactionCoordinatorTest {
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
val originalMetadata = new TransactionMetadata(transactionalId, producerId, producerId, (producerEpoch + 1).toShort,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds())
when(transactionManager.transactionVersionLevel()).thenReturn(TV_0)
val originalMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
(producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0)
when(transactionManager.appendTransactionToLog(
ArgumentMatchers.eq(transactionalId),
ArgumentMatchers.eq(coordinatorEpoch),
@ -640,7 +800,7 @@ class TransactionCoordinatorTest {
verify(transactionManager).appendTransactionToLog(
ArgumentMatchers.eq(transactionalId),
ArgumentMatchers.eq(coordinatorEpoch),
ArgumentMatchers.eq(originalMetadata.prepareAbortOrCommit(PrepareAbort, time.milliseconds())),
ArgumentMatchers.eq(originalMetadata.prepareAbortOrCommit(PrepareAbort, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds())),
any(),
any(),
any())
@ -648,14 +808,14 @@ class TransactionCoordinatorTest {
@Test
def shouldFailToAbortTransactionOnHandleInitPidWhenProducerEpochIsSmaller(): Unit = {
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch,
(producerEpoch - 1).toShort, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds())
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
producerEpoch, (producerEpoch - 1).toShort, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0)
when(transactionManager.validateTransactionTimeoutMs(anyInt()))
.thenReturn(true)
val bumpedTxnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, (producerEpoch + 2).toShort,
(producerEpoch - 1).toShort, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds())
val bumpedTxnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
(producerEpoch + 2).toShort, (producerEpoch - 1).toShort, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
@ -671,8 +831,8 @@ class TransactionCoordinatorTest {
@Test
def shouldNotRepeatedlyBumpEpochDueToInitPidDuringOngoingTxnIfAppendToLogFails(): Unit = {
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds())
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0)
when(transactionManager.validateTransactionTimeoutMs(anyInt()))
.thenReturn(true)
@ -683,9 +843,11 @@ class TransactionCoordinatorTest {
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenAnswer(_ => Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
val originalMetadata = new TransactionMetadata(transactionalId, producerId, producerId, (producerEpoch + 1).toShort,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds())
val txnTransitMetadata = originalMetadata.prepareAbortOrCommit(PrepareAbort, time.milliseconds())
when(transactionManager.transactionVersionLevel()).thenReturn(TV_0)
val originalMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
(producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0)
val txnTransitMetadata = originalMetadata.prepareAbortOrCommit(PrepareAbort, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds())
when(transactionManager.appendTransactionToLog(
ArgumentMatchers.eq(transactionalId),
ArgumentMatchers.eq(coordinatorEpoch),
@ -740,33 +902,38 @@ class TransactionCoordinatorTest {
@Test
def shouldUseLastEpochToFenceWhenEpochsAreExhausted(): Unit = {
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, (Short.MaxValue - 1).toShort,
(Short.MaxValue - 2).toShort, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds())
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
(Short.MaxValue - 1).toShort, (Short.MaxValue - 2).toShort, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0)
assertTrue(txnMetadata.isProducerEpochExhausted)
when(transactionManager.validateTransactionTimeoutMs(anyInt()))
.thenReturn(true)
val postFenceTxnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, Short.MaxValue,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, PrepareAbort, partitions, time.milliseconds(), time.milliseconds())
val postFenceTxnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
Short.MaxValue, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, PrepareAbort, partitions, time.milliseconds(), time.milliseconds(), TV_0)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, postFenceTxnMetadata))))
when(transactionManager.transactionVersionLevel()).thenReturn(TV_0)
// InitProducerId uses FenceProducerEpoch so clientTransactionVersion is 0.
when(transactionManager.appendTransactionToLog(
ArgumentMatchers.eq(transactionalId),
ArgumentMatchers.eq(coordinatorEpoch),
ArgumentMatchers.eq(TxnTransitMetadata(
producerId = producerId,
lastProducerId = producerId,
prevProducerId = producerId,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = Short.MaxValue,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = txnTimeoutMs,
txnState = PrepareAbort,
topicPartitions = partitions.toSet,
txnStartTimestamp = time.milliseconds(),
txnLastUpdateTimestamp = time.milliseconds())),
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)),
capturedErrorsCallback.capture(),
any(),
any())
@ -783,14 +950,16 @@ class TransactionCoordinatorTest {
ArgumentMatchers.eq(coordinatorEpoch),
ArgumentMatchers.eq(TxnTransitMetadata(
producerId = producerId,
lastProducerId = producerId,
prevProducerId = producerId,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = Short.MaxValue,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = txnTimeoutMs,
txnState = PrepareAbort,
topicPartitions = partitions.toSet,
txnStartTimestamp = time.milliseconds(),
txnLastUpdateTimestamp = time.milliseconds())),
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)),
any(),
any(),
any())
@ -800,8 +969,8 @@ class TransactionCoordinatorTest {
def testInitProducerIdWithNoLastProducerData(): Unit = {
// If the metadata doesn't include the previous producer data (for example, if it was written to the log by a broker
// on an old version), the retry case should fail
val txnMetadata = new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, (producerEpoch + 1).toShort,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds)
val txnMetadata = new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID,
RecordBatch.NO_PRODUCER_EPOCH, (producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds, TV_0)
when(transactionManager.validateTransactionTimeoutMs(anyInt()))
.thenReturn(true)
@ -817,8 +986,8 @@ class TransactionCoordinatorTest {
@Test
def testFenceProducerWhenMappingExistsWithDifferentProducerId(): Unit = {
// Existing transaction ID maps to new producer ID
val txnMetadata = new TransactionMetadata(transactionalId, producerId + 1, producerId, producerEpoch,
(producerEpoch - 1).toShort, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds)
val txnMetadata = new TransactionMetadata(transactionalId, producerId + 1, producerId,
RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, (producerEpoch - 1).toShort, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds, TV_0)
when(transactionManager.validateTransactionTimeoutMs(anyInt()))
.thenReturn(true)
@ -835,8 +1004,8 @@ class TransactionCoordinatorTest {
def testInitProducerIdWithCurrentEpochProvided(): Unit = {
mockPidGenerator()
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, 10,
9, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds)
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_EPOCH, 10, 9, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds, TV_0)
when(transactionManager.validateTransactionTimeoutMs(anyInt()))
.thenReturn(true)
@ -870,8 +1039,8 @@ class TransactionCoordinatorTest {
def testInitProducerIdStaleCurrentEpochProvided(): Unit = {
mockPidGenerator()
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, 10,
9, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds)
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_EPOCH, 10, 9, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds, TV_0)
when(transactionManager.validateTransactionTimeoutMs(anyInt()))
.thenReturn(true)
@ -906,8 +1075,8 @@ class TransactionCoordinatorTest {
@Test
def testRetryInitProducerIdAfterProducerIdRotation(): Unit = {
// Existing transaction ID maps to new producer ID
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, (Short.MaxValue - 1).toShort,
(Short.MaxValue - 2).toShort, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds)
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_EPOCH, (Short.MaxValue - 1).toShort, (Short.MaxValue - 2).toShort, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds, TV_0)
when(pidGenerator.generateProducerId())
.thenReturn(Success(producerId + 1))
@ -928,7 +1097,7 @@ class TransactionCoordinatorTest {
capturedErrorsCallback.getValue.apply(Errors.NONE)
txnMetadata.pendingState = None
txnMetadata.producerId = capturedTxnTransitMetadata.getValue.producerId
txnMetadata.lastProducerId = capturedTxnTransitMetadata.getValue.lastProducerId
txnMetadata.previousProducerId = capturedTxnTransitMetadata.getValue.prevProducerId
txnMetadata.producerEpoch = capturedTxnTransitMetadata.getValue.producerEpoch
txnMetadata.lastProducerEpoch = capturedTxnTransitMetadata.getValue.lastProducerEpoch
})
@ -947,8 +1116,8 @@ class TransactionCoordinatorTest {
@Test
def testInitProducerIdWithInvalidEpochAfterProducerIdRotation(): Unit = {
// Existing transaction ID maps to new producer ID
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, (Short.MaxValue - 1).toShort,
(Short.MaxValue - 2).toShort, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds)
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_EPOCH, (Short.MaxValue - 1).toShort, (Short.MaxValue - 2).toShort, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds, TV_0)
when(pidGenerator.generateProducerId())
.thenReturn(Success(producerId + 1))
@ -969,7 +1138,7 @@ class TransactionCoordinatorTest {
capturedErrorsCallback.getValue.apply(Errors.NONE)
txnMetadata.pendingState = None
txnMetadata.producerId = capturedTxnTransitMetadata.getValue.producerId
txnMetadata.lastProducerId = capturedTxnTransitMetadata.getValue.lastProducerId
txnMetadata.previousProducerId = capturedTxnTransitMetadata.getValue.prevProducerId
txnMetadata.producerEpoch = capturedTxnTransitMetadata.getValue.producerEpoch
txnMetadata.lastProducerEpoch = capturedTxnTransitMetadata.getValue.lastProducerEpoch
})
@ -995,16 +1164,20 @@ class TransactionCoordinatorTest {
@Test
def shouldAbortExpiredTransactionsInOngoingStateAndBumpEpoch(): Unit = {
val now = time.milliseconds()
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now)
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now, TV_0)
when(transactionManager.timedOutTransactions())
.thenReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, producerId, producerEpoch)))
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
val expectedTransition = TxnTransitMetadata(producerId, producerId, (producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs, PrepareAbort, partitions.toSet, now, now + TransactionStateManagerConfig.TRANSACTIONS_ABORT_TIMED_OUT_TRANSACTION_CLEANUP_INTERVAL_MS_DEFAULT)
// Transaction timeouts use FenceProducerEpoch so clientTransactionVersion is 0.
val expectedTransition = TxnTransitMetadata(producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, (producerEpoch + 1).toShort,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, PrepareAbort, partitions.toSet, now,
now + TransactionStateManagerConfig.TRANSACTIONS_ABORT_TIMED_OUT_TRANSACTION_CLEANUP_INTERVAL_MS_DEFAULT, TV_0)
when(transactionManager.transactionVersionLevel()).thenReturn(TV_0)
when(transactionManager.appendTransactionToLog(ArgumentMatchers.eq(transactionalId),
ArgumentMatchers.eq(coordinatorEpoch),
@ -1030,20 +1203,22 @@ class TransactionCoordinatorTest {
@Test
def shouldNotAcceptSmallerEpochDuringTransactionExpiration(): Unit = {
val now = time.milliseconds()
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now)
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now, TV_0)
when(transactionManager.timedOutTransactions())
.thenReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, producerId, producerEpoch)))
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
val bumpedTxnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, (producerEpoch + 2).toShort,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now)
when(transactionManager.transactionVersionLevel()).thenReturn(TV_0)
val bumpedTxnMetadata = new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_EPOCH, (producerEpoch + 2).toShort, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now, TV_0)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, bumpedTxnMetadata))))
def checkOnEndTransactionComplete(txnIdAndPidEpoch: TransactionalIdAndProducerIdEpoch)(error: Errors): Unit = {
def checkOnEndTransactionComplete(txnIdAndPidEpoch: TransactionalIdAndProducerIdEpoch)(error: Errors, producerId: Long, producerEpoch: Short): Unit = {
assertEquals(Errors.PRODUCER_FENCED, error)
}
coordinator.abortTimedOutTransactions(checkOnEndTransactionComplete)
@ -1054,9 +1229,9 @@ class TransactionCoordinatorTest {
@Test
def shouldNotAbortExpiredTransactionsThatHaveAPendingStateTransition(): Unit = {
val metadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds())
metadata.prepareAbortOrCommit(PrepareCommit, time.milliseconds())
val metadata = new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0)
metadata.prepareAbortOrCommit(PrepareCommit, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds())
when(transactionManager.timedOutTransactions())
.thenReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, producerId, producerEpoch)))
@ -1073,22 +1248,26 @@ class TransactionCoordinatorTest {
@Test
def shouldNotBumpEpochWhenAbortingExpiredTransactionIfAppendToLogFails(): Unit = {
val now = time.milliseconds()
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now)
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now, TV_0)
when(transactionManager.timedOutTransactions())
.thenReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, producerId, producerEpoch)))
val txnMetadataAfterAppendFailure = new TransactionMetadata(transactionalId, producerId, producerId, (producerEpoch + 1).toShort,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now)
val txnMetadataAfterAppendFailure = new TransactionMetadata(transactionalId, producerId, producerId,
RecordBatch.NO_PRODUCER_EPOCH, (producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now, TV_0)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadataAfterAppendFailure))))
// Transaction timeouts use FenceProducerEpoch so clientTransactionVersion is 0.
val bumpedEpoch = (producerEpoch + 1).toShort
val expectedTransition = TxnTransitMetadata(producerId, producerId, bumpedEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs,
PrepareAbort, partitions.toSet, now, now + TransactionStateManagerConfig.TRANSACTIONS_ABORT_TIMED_OUT_TRANSACTION_CLEANUP_INTERVAL_MS_DEFAULT)
val expectedTransition = TxnTransitMetadata(producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, bumpedEpoch,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, PrepareAbort, partitions.toSet, now,
now + TransactionStateManagerConfig.TRANSACTIONS_ABORT_TIMED_OUT_TRANSACTION_CLEANUP_INTERVAL_MS_DEFAULT, TV_0)
when(transactionManager.transactionVersionLevel()).thenReturn(TV_0)
when(transactionManager.appendTransactionToLog(ArgumentMatchers.eq(transactionalId),
ArgumentMatchers.eq(coordinatorEpoch),
@ -1117,9 +1296,9 @@ class TransactionCoordinatorTest {
@Test
def shouldNotBumpEpochWithPendingTransaction(): Unit = {
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds())
txnMetadata.prepareAbortOrCommit(PrepareCommit, time.milliseconds())
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0)
txnMetadata.prepareAbortOrCommit(PrepareCommit, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds())
when(transactionManager.validateTransactionTimeoutMs(anyInt()))
.thenReturn(true)
@ -1146,9 +1325,9 @@ class TransactionCoordinatorTest {
def testDescribeTransactionsWithExpiringTransactionalId(): Unit = {
coordinator.startup(() => transactionStatePartitionCount, enableTransactionalIdExpiration = false)
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch,
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Dead, mutable.Set.empty, time.milliseconds(),
time.milliseconds())
time.milliseconds(), TV_0)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
@ -1172,8 +1351,8 @@ class TransactionCoordinatorTest {
@Test
def testDescribeTransactions(): Unit = {
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds())
val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds(), TV_0)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))))
@ -1199,7 +1378,9 @@ class TransactionCoordinatorTest {
when(transactionManager.validateTransactionTimeoutMs(anyInt()))
.thenReturn(true)
val metadata = new TransactionMetadata(transactionalId, 0, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, state, mutable.Set[TopicPartition](new TopicPartition("topic", 1)), 0, 0)
// Since the clientTransactionVersion doesn't matter, use 2 since the states are PrepareCommit and PrepareAbort.
val metadata = new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_EPOCH,
0, RecordBatch.NO_PRODUCER_EPOCH, 0, state, mutable.Set[TopicPartition](new TopicPartition("topic", 1)), 0, 0, TV_2)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, metadata))))
@ -1208,14 +1389,16 @@ class TransactionCoordinatorTest {
assertEquals(InitProducerIdResult(-1, -1, Errors.CONCURRENT_TRANSACTIONS), result)
}
private def validateIncrementEpochAndUpdateMetadata(state: TransactionState): Unit = {
private def validateIncrementEpochAndUpdateMetadata(state: TransactionState, transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
when(pidGenerator.generateProducerId())
.thenReturn(Success(producerId))
when(transactionManager.validateTransactionTimeoutMs(anyInt()))
.thenReturn(true)
val metadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, state, mutable.Set.empty[TopicPartition], time.milliseconds(), time.milliseconds())
val metadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH,
producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, state, mutable.Set.empty[TopicPartition], time.milliseconds(), time.milliseconds(), clientTransactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, metadata))))
@ -1242,13 +1425,13 @@ class TransactionCoordinatorTest {
assertEquals(producerId, metadata.producerId)
}
private def mockPrepare(transactionState: TransactionState, runCallback: Boolean = false): TransactionMetadata = {
private def mockPrepare(transactionState: TransactionState, clientTransactionVersion: TransactionVersion, runCallback: Boolean = false): TransactionMetadata = {
val now = time.milliseconds()
val originalMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs, Ongoing, partitions, now, now)
val originalMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH,
producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now, TV_0)
val transition = TxnTransitMetadata(producerId, producerId, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs,
transactionState, partitions.toSet, now, now)
val transition = TxnTransitMetadata(producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, transactionState, partitions.toSet, now, now, clientTransactionVersion)
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
.thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, originalMetadata))))
@ -1264,8 +1447,8 @@ class TransactionCoordinatorTest {
capturedErrorsCallback.getValue.apply(Errors.NONE)
})
new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs, transactionState, partitions, time.milliseconds(), time.milliseconds())
new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, transactionState, partitions, time.milliseconds(), time.milliseconds(), clientTransactionVersion)
}
def initProducerIdMockCallback(ret: InitProducerIdResult): Unit = {
@ -1275,4 +1458,17 @@ class TransactionCoordinatorTest {
def errorsCallback(ret: Errors): Unit = {
error = ret
}
def endTxnCallback(ret: Errors, producerId: Long, epoch: Short): Unit = {
error = ret
newProducerId = producerId
newEpoch = epoch
}
def requestEpoch(clientTransactionVersion: TransactionVersion): Short = {
if (clientTransactionVersion.supportsEpochBump())
(producerEpoch - 1).toShort
else
producerEpoch
}
}

View File

@ -23,8 +23,9 @@ import org.apache.kafka.common.compress.Compression
import org.apache.kafka.common.protocol.{ByteBufferAccessor, MessageUtil}
import org.apache.kafka.common.protocol.types.Field.TaggedFieldsSection
import org.apache.kafka.common.protocol.types.{CompactArrayOf, Field, Schema, Struct, Type}
import org.apache.kafka.common.record.{MemoryRecords, SimpleRecord}
import org.apache.kafka.common.record.{MemoryRecords, RecordBatch, SimpleRecord}
import org.apache.kafka.coordinator.transaction.generated.{TransactionLogKey, TransactionLogValue}
import org.apache.kafka.server.common.TransactionVersion.{TV_0, TV_2}
import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows, assertTrue}
import org.junit.jupiter.api.Test
@ -48,10 +49,11 @@ class TransactionLogTest {
val transactionalId = "transactionalId"
val producerId = 23423L
val txnMetadata = TransactionMetadata(transactionalId, producerId, producerEpoch, transactionTimeoutMs, 0)
val txnMetadata = new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_ID, producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH, transactionTimeoutMs, Empty, collection.mutable.Set.empty[TopicPartition], 0, 0, TV_0)
txnMetadata.addPartitions(topicPartitions)
assertThrows(classOf[IllegalStateException], () => TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), true))
assertThrows(classOf[IllegalStateException], () => TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), TV_2))
}
@Test
@ -72,14 +74,14 @@ class TransactionLogTest {
// generate transaction log messages
val txnRecords = pidMappings.map { case (transactionalId, producerId) =>
val txnMetadata = TransactionMetadata(transactionalId, producerId, producerEpoch, transactionTimeoutMs,
transactionStates(producerId), 0)
val txnMetadata = new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_ID, producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH, transactionTimeoutMs, transactionStates(producerId), collection.mutable.Set.empty[TopicPartition], 0, 0, TV_0)
if (!txnMetadata.state.equals(Empty))
txnMetadata.addPartitions(topicPartitions)
val keyBytes = TransactionLog.keyToBytes(transactionalId)
val valueBytes = TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), true)
val valueBytes = TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), TV_2)
new SimpleRecord(keyBytes, valueBytes)
}.toSeq
@ -114,12 +116,12 @@ class TransactionLogTest {
val producerId = 1334L
val topicPartition = new TopicPartition("topic", 0)
val txnMetadata = TransactionMetadata(transactionalId, producerId, producerEpoch,
transactionTimeoutMs, Ongoing, 0)
val txnMetadata = new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_ID, producerEpoch,
RecordBatch.NO_PRODUCER_EPOCH, transactionTimeoutMs, Ongoing, collection.mutable.Set.empty[TopicPartition], 0, 0, TV_0)
txnMetadata.addPartitions(Set(topicPartition))
val keyBytes = TransactionLog.keyToBytes(transactionalId)
val valueBytes = TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), true)
val valueBytes = TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), TV_2)
val transactionMetadataRecord = TestUtils.records(Seq(
new SimpleRecord(keyBytes, valueBytes)
)).records.asScala.head
@ -144,15 +146,15 @@ class TransactionLogTest {
@Test
def testSerializeTransactionLogValueToHighestNonFlexibleVersion(): Unit = {
val txnTransitMetadata = TxnTransitMetadata(1, 1, 1, 1, 1000, CompleteCommit, Set.empty, 500, 500)
val txnLogValueBuffer = ByteBuffer.wrap(TransactionLog.valueToBytes(txnTransitMetadata, false))
val txnTransitMetadata = TxnTransitMetadata(1, 1, 1, 1, 1, 1000, CompleteCommit, Set.empty, 500, 500, TV_0)
val txnLogValueBuffer = ByteBuffer.wrap(TransactionLog.valueToBytes(txnTransitMetadata, TV_0))
assertEquals(0, txnLogValueBuffer.getShort)
}
@Test
def testSerializeTransactionLogValueToFlexibleVersion(): Unit = {
val txnTransitMetadata = TxnTransitMetadata(1, 1, 1, 1, 1000, CompleteCommit, Set.empty, 500, 500)
val txnLogValueBuffer = ByteBuffer.wrap(TransactionLog.valueToBytes(txnTransitMetadata, true))
val txnTransitMetadata = TxnTransitMetadata(1, 1, 1, 1, 1, 1000, CompleteCommit, Set.empty, 500, 500, TV_2)
val txnLogValueBuffer = ByteBuffer.wrap(TransactionLog.valueToBytes(txnTransitMetadata, TV_2))
assertEquals(TransactionLogValue.HIGHEST_SUPPORTED_VERSION, txnLogValueBuffer.getShort)
}
@ -194,8 +196,8 @@ class TransactionLogTest {
new Field("topic", Type.COMPACT_STRING, ""),
new Field("partition_ids", new CompactArrayOf(Type.INT32), ""),
TaggedFieldsSection.of(
Int.box(0), new Field("partition_foo", Type.STRING, ""),
Int.box(1), new Field("partition_foo", Type.INT32, "")
Int.box(100), new Field("partition_foo", Type.STRING, ""),
Int.box(101), new Field("partition_foo", Type.INT32, "")
)
)
@ -204,8 +206,8 @@ class TransactionLogTest {
txnPartitions.set("topic", "topic")
txnPartitions.set("partition_ids", Array(Integer.valueOf(1)))
val txnPartitionsTaggedFields = new java.util.TreeMap[Integer, Any]()
txnPartitionsTaggedFields.put(0, "foo")
txnPartitionsTaggedFields.put(1, 4000)
txnPartitionsTaggedFields.put(100, "foo")
txnPartitionsTaggedFields.put(101, 4000)
txnPartitions.set("_tagged_fields", txnPartitionsTaggedFields)
// Copy of TransactionLogValue.SCHEMA_1 with a few
@ -219,8 +221,8 @@ class TransactionLogTest {
new Field("transaction_last_update_timestamp_ms", Type.INT64, ""),
new Field("transaction_start_timestamp_ms", Type.INT64, ""),
TaggedFieldsSection.of(
Int.box(0), new Field("txn_foo", Type.STRING, ""),
Int.box(1), new Field("txn_bar", Type.INT32, "")
Int.box(100), new Field("txn_foo", Type.STRING, ""),
Int.box(101), new Field("txn_bar", Type.INT32, "")
)
)
@ -234,8 +236,8 @@ class TransactionLogTest {
transactionLogValue.set("transaction_last_update_timestamp_ms", 2000L)
transactionLogValue.set("transaction_start_timestamp_ms", 3000L)
val txnLogValueTaggedFields = new java.util.TreeMap[Integer, Any]()
txnLogValueTaggedFields.put(0, "foo")
txnLogValueTaggedFields.put(1, 4000)
txnLogValueTaggedFields.put(100, "foo")
txnLogValueTaggedFields.put(101, 4000)
transactionLogValue.set("_tagged_fields", txnLogValueTaggedFields)
// Prepare the buffer.
@ -249,8 +251,8 @@ class TransactionLogTest {
// fields were read but ignored.
buffer.getShort() // Skip version.
val value = new TransactionLogValue(new ByteBufferAccessor(buffer), 1.toShort)
assertEquals(Seq(0, 1), value.unknownTaggedFields().asScala.map(_.tag))
assertEquals(Seq(0, 1), value.transactionPartitions().get(0).unknownTaggedFields().asScala.map(_.tag))
assertEquals(Seq(100, 101), value.unknownTaggedFields().asScala.map(_.tag))
assertEquals(Seq(100, 101), value.transactionPartitions().get(0).unknownTaggedFields().asScala.map(_.tag))
// Read the buffer with readTxnRecordValue.
buffer.rewind()

View File

@ -28,7 +28,7 @@ import org.apache.kafka.common.record.RecordBatch
import org.apache.kafka.common.requests.{RequestHeader, TransactionResult, WriteTxnMarkersRequest, WriteTxnMarkersResponse}
import org.apache.kafka.common.utils.MockTime
import org.apache.kafka.common.{Node, TopicPartition}
import org.apache.kafka.server.common.MetadataVersion
import org.apache.kafka.server.common.{MetadataVersion, TransactionVersion}
import org.apache.kafka.server.metrics.{KafkaMetricsGroup, KafkaYammerMetrics}
import org.apache.kafka.server.util.RequestAndCompletionHandler
import org.junit.jupiter.api.Assertions._
@ -63,10 +63,10 @@ class TransactionMarkerChannelManagerTest {
private val coordinatorEpoch2 = 1
private val txnTimeoutMs = 0
private val txnResult = TransactionResult.COMMIT
private val txnMetadata1 = new TransactionMetadata(transactionalId1, producerId1, producerId1, producerEpoch, lastProducerEpoch,
txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](partition1, partition2), 0L, 0L)
private val txnMetadata2 = new TransactionMetadata(transactionalId2, producerId2, producerId2, producerEpoch, lastProducerEpoch,
txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](partition1), 0L, 0L)
private val txnMetadata1 = new TransactionMetadata(transactionalId1, producerId1, producerId1, RecordBatch.NO_PRODUCER_ID,
producerEpoch, lastProducerEpoch, txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](partition1, partition2), 0L, 0L, TransactionVersion.TV_2)
private val txnMetadata2 = new TransactionMetadata(transactionalId2, producerId2, producerId2, RecordBatch.NO_PRODUCER_ID,
producerEpoch, lastProducerEpoch, txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](partition1), 0L, 0L, TransactionVersion.TV_2)
private val capturedErrorsCallback: ArgumentCaptor[Errors => Unit] = ArgumentCaptor.forClass(classOf[Errors => Unit])
private val time = new MockTime

View File

@ -23,6 +23,7 @@ import org.apache.kafka.common.TopicPartition
import org.apache.kafka.common.protocol.{ApiKeys, Errors}
import org.apache.kafka.common.record.RecordBatch
import org.apache.kafka.common.requests.{RequestHeader, TransactionResult, WriteTxnMarkersRequest, WriteTxnMarkersResponse}
import org.apache.kafka.server.common.TransactionVersion
import org.junit.jupiter.api.Assertions._
import org.junit.jupiter.api.Test
import org.mockito.ArgumentMatchers
@ -42,8 +43,8 @@ class TransactionMarkerRequestCompletionHandlerTest {
private val coordinatorEpoch = 0
private val txnResult = TransactionResult.COMMIT
private val topicPartition = new TopicPartition("topic1", 0)
private val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, lastProducerEpoch,
txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](topicPartition), 0L, 0L)
private val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID,
producerEpoch, lastProducerEpoch, txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](topicPartition), 0L, 0L, TransactionVersion.TV_2)
private val pendingCompleteTxnAndMarkers = asList(
PendingCompleteTxnAndMarkerEntry(
PendingCompleteTxn(transactionalId, coordinatorEpoch, txnMetadata, txnMetadata.prepareComplete(42)),

View File

@ -19,9 +19,13 @@ package kafka.coordinator.transaction
import org.apache.kafka.common.TopicPartition
import org.apache.kafka.common.protocol.Errors
import org.apache.kafka.common.record.RecordBatch
import org.apache.kafka.server.common.TransactionVersion
import org.apache.kafka.server.common.TransactionVersion.{TV_0, TV_2}
import org.apache.kafka.server.util.MockTime
import org.junit.jupiter.api.Assertions._
import org.junit.jupiter.api.Test
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.ValueSource
import scala.collection.mutable
@ -38,13 +42,15 @@ class TransactionMetadataTest {
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
lastProducerId = RecordBatch.NO_PRODUCER_ID,
previousProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = Empty,
topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds())
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, None)
txnMetadata.completeTransitionTo(transitMetadata)
@ -60,13 +66,15 @@ class TransactionMetadataTest {
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
lastProducerId = RecordBatch.NO_PRODUCER_ID,
previousProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = Empty,
topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds())
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, None)
txnMetadata.completeTransitionTo(transitMetadata)
@ -82,13 +90,15 @@ class TransactionMetadataTest {
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
lastProducerId = RecordBatch.NO_PRODUCER_ID,
previousProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = Empty,
topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds())
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
assertTrue(txnMetadata.isProducerEpochExhausted)
assertThrows(classOf[IllegalStateException], () => txnMetadata.prepareIncrementProducerEpoch(30000,
@ -101,14 +111,16 @@ class TransactionMetadataTest {
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
lastProducerId = RecordBatch.NO_PRODUCER_ID,
previousProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = Empty,
topicPartitions = mutable.Set.empty,
txnStartTimestamp = 1L,
txnLastUpdateTimestamp = time.milliseconds())
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
// let new time be smaller
val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, Option(producerEpoch),
@ -127,14 +139,16 @@ class TransactionMetadataTest {
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
lastProducerId = RecordBatch.NO_PRODUCER_ID,
previousProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = Empty,
topicPartitions = mutable.Set.empty,
txnStartTimestamp = 1L,
txnLastUpdateTimestamp = time.milliseconds())
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
// let new time be smaller
val transitMetadata = txnMetadata.prepareProducerIdRotation(producerId + 1, 30000, time.milliseconds() - 1, recordLastEpoch = true)
@ -152,14 +166,16 @@ class TransactionMetadataTest {
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
lastProducerId = RecordBatch.NO_PRODUCER_ID,
previousProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = Empty,
topicPartitions = mutable.Set.empty,
txnStartTimestamp = time.milliseconds(),
txnLastUpdateTimestamp = time.milliseconds())
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
// let new time be smaller; when transiting from Empty the start time would be updated to the update-time
var transitMetadata = txnMetadata.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic1", 0)), time.milliseconds() - 1)
@ -188,17 +204,19 @@ class TransactionMetadataTest {
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
lastProducerId = RecordBatch.NO_PRODUCER_ID,
previousProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = Ongoing,
topicPartitions = mutable.Set.empty,
txnStartTimestamp = 1L,
txnLastUpdateTimestamp = time.milliseconds())
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
// let new time be smaller
val transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareCommit, time.milliseconds() - 1)
val transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareCommit, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds() - 1)
txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(PrepareCommit, txnMetadata.state)
assertEquals(producerId, txnMetadata.producerId)
@ -214,17 +232,19 @@ class TransactionMetadataTest {
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
lastProducerId = RecordBatch.NO_PRODUCER_ID,
previousProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = Ongoing,
topicPartitions = mutable.Set.empty,
txnStartTimestamp = 1L,
txnLastUpdateTimestamp = time.milliseconds())
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
// let new time be smaller
val transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareAbort, time.milliseconds() - 1)
val transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareAbort, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds() - 1)
txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(PrepareAbort, txnMetadata.state)
assertEquals(producerId, txnMetadata.producerId)
@ -234,53 +254,65 @@ class TransactionMetadataTest {
assertEquals(time.milliseconds() - 1, txnMetadata.txnLastUpdateTimestamp)
}
@Test
def testTolerateTimeShiftDuringCompleteCommit(): Unit = {
@ParameterizedTest
@ValueSource(shorts = Array(0, 2))
def testTolerateTimeShiftDuringCompleteCommit(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
val producerEpoch: Short = 1
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
lastProducerId = RecordBatch.NO_PRODUCER_ID,
previousProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = PrepareCommit,
topicPartitions = mutable.Set.empty,
txnStartTimestamp = 1L,
txnLastUpdateTimestamp = time.milliseconds())
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = clientTransactionVersion)
// let new time be smaller
val transitMetadata = txnMetadata.prepareComplete(time.milliseconds() - 1)
txnMetadata.completeTransitionTo(transitMetadata)
val lastEpoch = if (clientTransactionVersion.supportsEpochBump()) producerEpoch else RecordBatch.NO_PRODUCER_EPOCH
assertEquals(CompleteCommit, txnMetadata.state)
assertEquals(producerId, txnMetadata.producerId)
assertEquals(RecordBatch.NO_PRODUCER_EPOCH, txnMetadata.lastProducerEpoch)
assertEquals(lastEpoch, txnMetadata.lastProducerEpoch)
assertEquals(producerEpoch, txnMetadata.producerEpoch)
assertEquals(1L, txnMetadata.txnStartTimestamp)
assertEquals(time.milliseconds() - 1, txnMetadata.txnLastUpdateTimestamp)
}
@Test
def testTolerateTimeShiftDuringCompleteAbort(): Unit = {
@ParameterizedTest
@ValueSource(shorts = Array(0, 2))
def testTolerateTimeShiftDuringCompleteAbort(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
val producerEpoch: Short = 1
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
lastProducerId = RecordBatch.NO_PRODUCER_ID,
previousProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = PrepareAbort,
topicPartitions = mutable.Set.empty,
txnStartTimestamp = 1L,
txnLastUpdateTimestamp = time.milliseconds())
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = clientTransactionVersion)
// let new time be smaller
val transitMetadata = txnMetadata.prepareComplete(time.milliseconds() - 1)
txnMetadata.completeTransitionTo(transitMetadata)
val lastEpoch = if (clientTransactionVersion.supportsEpochBump()) producerEpoch else RecordBatch.NO_PRODUCER_EPOCH
assertEquals(CompleteAbort, txnMetadata.state)
assertEquals(producerId, txnMetadata.producerId)
assertEquals(RecordBatch.NO_PRODUCER_EPOCH, txnMetadata.lastProducerEpoch)
assertEquals(lastEpoch, txnMetadata.lastProducerEpoch)
assertEquals(producerEpoch, txnMetadata.producerEpoch)
assertEquals(1L, txnMetadata.txnStartTimestamp)
assertEquals(time.milliseconds() - 1, txnMetadata.txnLastUpdateTimestamp)
@ -293,13 +325,15 @@ class TransactionMetadataTest {
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
lastProducerId = RecordBatch.NO_PRODUCER_ID,
previousProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = Ongoing,
topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds())
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
assertTrue(txnMetadata.isProducerEpochExhausted)
val fencingTransitMetadata = txnMetadata.prepareFenceProducerEpoch()
@ -310,7 +344,7 @@ class TransactionMetadataTest {
// We should reset the pending state to make way for the abort transition.
txnMetadata.pendingState = None
val transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareAbort, time.milliseconds())
val transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareAbort, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds())
txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(producerId, transitMetadata.producerId)
}
@ -322,13 +356,15 @@ class TransactionMetadataTest {
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
lastProducerId = RecordBatch.NO_PRODUCER_ID,
previousProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = Ongoing,
topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds())
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
assertTrue(txnMetadata.isProducerEpochExhausted)
assertThrows(classOf[IllegalStateException], () => txnMetadata.prepareFenceProducerEpoch())
}
@ -340,36 +376,108 @@ class TransactionMetadataTest {
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
lastProducerId = RecordBatch.NO_PRODUCER_ID,
previousProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = Empty,
topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds())
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
val newProducerId = 9893L
val transitMetadata = txnMetadata.prepareProducerIdRotation(newProducerId, 30000, time.milliseconds(), recordLastEpoch = true)
txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(newProducerId, txnMetadata.producerId)
assertEquals(producerId, txnMetadata.lastProducerId)
assertEquals(producerId, txnMetadata.previousProducerId)
assertEquals(0, txnMetadata.producerEpoch)
assertEquals(producerEpoch, txnMetadata.lastProducerEpoch)
}
@Test
def testEpochBumpOnEndTxn(): Unit = {
time.sleep(100)
val producerEpoch = 10.toShort
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
previousProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = Ongoing,
topicPartitions = mutable.Set.empty,
txnStartTimestamp = time.milliseconds(),
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
var transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareCommit, TV_2, RecordBatch.NO_PRODUCER_ID, time.milliseconds() - 1)
txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(producerId, txnMetadata.producerId)
assertEquals((producerEpoch + 1).toShort, txnMetadata.producerEpoch)
assertEquals(TV_2, txnMetadata.clientTransactionVersion)
transitMetadata = txnMetadata.prepareComplete(time.milliseconds())
txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(producerId, txnMetadata.producerId)
assertEquals((producerEpoch + 1).toShort, txnMetadata.producerEpoch)
assertEquals(TV_2, txnMetadata.clientTransactionVersion)
}
@Test
def testEpochBumpOnEndTxnOverflow(): Unit = {
time.sleep(100)
val producerEpoch = (Short.MaxValue - 1).toShort
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
previousProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = Ongoing,
topicPartitions = mutable.Set.empty,
txnStartTimestamp = time.milliseconds(),
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
assertTrue(txnMetadata.isProducerEpochExhausted)
val newProducerId = 9893L
var transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareCommit, TV_2, newProducerId, time.milliseconds() - 1)
txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(producerId, txnMetadata.producerId)
assertEquals(Short.MaxValue, txnMetadata.producerEpoch)
assertEquals(TV_2, txnMetadata.clientTransactionVersion)
transitMetadata = txnMetadata.prepareComplete(time.milliseconds())
txnMetadata.completeTransitionTo(transitMetadata)
assertEquals(newProducerId, txnMetadata.producerId)
assertEquals(0, txnMetadata.producerEpoch)
assertEquals(TV_2, txnMetadata.clientTransactionVersion)
}
@Test
def testRotateProducerIdInOngoingState(): Unit = {
assertThrows(classOf[IllegalStateException], () => testRotateProducerIdInOngoingState(Ongoing))
assertThrows(classOf[IllegalStateException], () => testRotateProducerIdInOngoingState(Ongoing, TV_0))
}
@Test
def testRotateProducerIdInPrepareAbortState(): Unit = {
assertThrows(classOf[IllegalStateException], () => testRotateProducerIdInOngoingState(PrepareAbort))
@ParameterizedTest
@ValueSource(shorts = Array(0, 2))
def testRotateProducerIdInPrepareAbortState(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
assertThrows(classOf[IllegalStateException], () => testRotateProducerIdInOngoingState(PrepareAbort, clientTransactionVersion))
}
@Test
def testRotateProducerIdInPrepareCommitState(): Unit = {
assertThrows(classOf[IllegalStateException], () => testRotateProducerIdInOngoingState(PrepareCommit))
@ParameterizedTest
@ValueSource(shorts = Array(0, 2))
def testRotateProducerIdInPrepareCommitState(transactionVersion: Short): Unit = {
val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion)
assertThrows(classOf[IllegalStateException], () => testRotateProducerIdInOngoingState(PrepareCommit, clientTransactionVersion))
}
@Test
@ -379,13 +487,15 @@ class TransactionMetadataTest {
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
lastProducerId = RecordBatch.NO_PRODUCER_ID,
previousProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = Empty,
topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds())
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, Some(producerEpoch))
txnMetadata.completeTransitionTo(transitMetadata)
@ -401,13 +511,15 @@ class TransactionMetadataTest {
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
lastProducerId = RecordBatch.NO_PRODUCER_ID,
previousProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = Empty,
topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds())
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, Some(producerEpoch))
txnMetadata.completeTransitionTo(transitMetadata)
@ -424,13 +536,15 @@ class TransactionMetadataTest {
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
lastProducerId = RecordBatch.NO_PRODUCER_ID,
previousProducerId = RecordBatch.NO_PRODUCER_ID,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = lastProducerEpoch,
txnTimeoutMs = 30000,
state = Empty,
topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds())
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, Some(lastProducerEpoch))
txnMetadata.completeTransitionTo(transitMetadata)
@ -447,13 +561,15 @@ class TransactionMetadataTest {
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
lastProducerId = producerId,
previousProducerId = producerId,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = lastProducerEpoch,
txnTimeoutMs = 30000,
state = Empty,
topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds())
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = TV_0)
val result = txnMetadata.prepareIncrementProducerEpoch(30000, Some((lastProducerEpoch - 1).toShort),
time.milliseconds())
@ -503,19 +619,21 @@ class TransactionMetadataTest {
assertEquals(Set.empty, unmatchedStates)
}
private def testRotateProducerIdInOngoingState(state: TransactionState): Unit = {
private def testRotateProducerIdInOngoingState(state: TransactionState, clientTransactionVersion: TransactionVersion): Unit = {
val producerEpoch = (Short.MaxValue - 1).toShort
val txnMetadata = new TransactionMetadata(
transactionalId = transactionalId,
producerId = producerId,
lastProducerId = producerId,
previousProducerId = producerId,
nextProducerId = RecordBatch.NO_PRODUCER_ID,
producerEpoch = producerEpoch,
lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = 30000,
state = state,
topicPartitions = mutable.Set.empty,
txnLastUpdateTimestamp = time.milliseconds())
txnLastUpdateTimestamp = time.milliseconds(),
clientTransactionVersion = clientTransactionVersion)
val newProducerId = 9893L
txnMetadata.prepareProducerIdRotation(newProducerId, 30000, time.milliseconds(), recordLastEpoch = false)
}

View File

@ -35,6 +35,7 @@ import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse
import org.apache.kafka.common.requests.TransactionResult
import org.apache.kafka.common.utils.MockTime
import org.apache.kafka.server.common.{FinalizedFeatures, MetadataVersion, RequestLocal, TransactionVersion}
import org.apache.kafka.server.common.TransactionVersion.{TV_0, TV_2}
import org.apache.kafka.coordinator.transaction.generated.TransactionLogKey
import org.apache.kafka.server.util.MockScheduler
import org.apache.kafka.storage.internals.log.{AppendOrigin, FetchDataInfo, FetchIsolation, LogConfig, LogOffsetMetadata}
@ -181,7 +182,7 @@ class TransactionStateManagerTest {
new TopicPartition("topic1", 0),
new TopicPartition("topic1", 1)))
val records = MemoryRecords.withRecords(startOffset, Compression.NONE,
new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), true)))
new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2)))
// We create a latch which is awaited while the log is loading. This ensures that the deletion
// is triggered before the loading returns
@ -225,19 +226,19 @@ class TransactionStateManagerTest {
txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0),
new TopicPartition("topic1", 1)))
txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), true))
txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2))
// pid1's transaction adds three more partitions
txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic2", 0),
new TopicPartition("topic2", 1),
new TopicPartition("topic2", 2)))
txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), true))
txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2))
// pid1's transaction is preparing to commit
txnMetadata1.state = PrepareCommit
txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), true))
txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2))
// pid2's transaction started with three partitions
txnMetadata2.state = Ongoing
@ -245,23 +246,23 @@ class TransactionStateManagerTest {
new TopicPartition("topic3", 1),
new TopicPartition("topic3", 2)))
txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), true))
txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), TV_2))
// pid2's transaction is preparing to abort
txnMetadata2.state = PrepareAbort
txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), true))
txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), TV_2))
// pid2's transaction has aborted
txnMetadata2.state = CompleteAbort
txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), true))
txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), TV_2))
// pid2's epoch has advanced, with no ongoing transaction yet
txnMetadata2.state = Empty
txnMetadata2.topicPartitions.clear()
txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), true))
txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), TV_2))
val startOffset = 15L // it should work for any start offset
val records = MemoryRecords.withRecords(startOffset, Compression.NONE, txnRecords.toArray: _*)
@ -796,14 +797,9 @@ class TransactionStateManagerTest {
// write the change. If the write fails (e.g. under min isr), the TransactionMetadata
// is left at it is. If the transactional id is never reused, the TransactionMetadata
// will be expired and it should succeed.
val txnMetadata = TransactionMetadata(
transactionalId = transactionalId,
producerId = 1,
producerEpoch = RecordBatch.NO_PRODUCER_EPOCH,
txnTimeoutMs = transactionTimeoutMs,
state = Empty,
timestamp = time.milliseconds()
)
val timestamp = time.milliseconds()
val txnMetadata = new TransactionMetadata(transactionalId, 1, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH,
RecordBatch.NO_PRODUCER_EPOCH, transactionTimeoutMs, Empty, collection.mutable.Set.empty[TopicPartition], timestamp, timestamp, TV_0)
transactionManager.putTransactionStateIfNotExists(txnMetadata)
time.sleep(txnConfig.transactionalIdExpirationMs + 1)
@ -890,7 +886,7 @@ class TransactionStateManagerTest {
txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0),
new TopicPartition("topic1", 1)))
txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), true))
txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2))
val startOffset = 0L
val records = MemoryRecords.withRecords(startOffset, Compression.NONE, txnRecords.toArray: _*)
@ -1053,7 +1049,7 @@ class TransactionStateManagerTest {
txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0),
new TopicPartition("topic1", 1)))
txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), true))
txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2))
val startOffset = 0L
val records = MemoryRecords.withRecords(startOffset, Compression.NONE, txnRecords.toArray: _*)
@ -1081,7 +1077,9 @@ class TransactionStateManagerTest {
producerId: Long,
state: TransactionState = Empty,
txnTimeout: Int = transactionTimeoutMs): TransactionMetadata = {
TransactionMetadata(transactionalId, producerId, 0.toShort, txnTimeout, state, time.milliseconds())
val timestamp = time.milliseconds()
new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_ID, 0.toShort,
RecordBatch.NO_PRODUCER_EPOCH, txnTimeout, state, collection.mutable.Set.empty[TopicPartition], timestamp, timestamp, TV_0)
}
private def prepareTxnLog(topicPartition: TopicPartition,
@ -1159,7 +1157,7 @@ class TransactionStateManagerTest {
txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 1),
new TopicPartition("topic1", 1)))
txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), true))
txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2))
val startOffset = 15L
val records = MemoryRecords.withRecords(startOffset, Compression.NONE, txnRecords.toArray: _*)
@ -1178,7 +1176,7 @@ class TransactionStateManagerTest {
txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0),
new TopicPartition("topic1", 1)))
txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), true))
txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2))
val startOffset = 0L
val unknownKey = new TransactionLogKey()
@ -1199,7 +1197,7 @@ class TransactionStateManagerTest {
val txnMetadata = txnMetadataPool.get(transactionalId1)
assertEquals(txnMetadata1.transactionalId, txnMetadata.transactionalId)
assertEquals(txnMetadata1.producerId, txnMetadata.producerId)
assertEquals(txnMetadata1.lastProducerId, txnMetadata.lastProducerId)
assertEquals(txnMetadata1.previousProducerId, txnMetadata.previousProducerId)
assertEquals(txnMetadata1.producerEpoch, txnMetadata.producerEpoch)
assertEquals(txnMetadata1.lastProducerEpoch, txnMetadata.lastProducerEpoch)
assertEquals(txnMetadata1.txnTimeoutMs, txnMetadata.txnTimeoutMs)
@ -1210,7 +1208,7 @@ class TransactionStateManagerTest {
@ParameterizedTest
@EnumSource(classOf[TransactionVersion])
def testUsesFlexibleRecords(transactionVersion: TransactionVersion): Unit = {
def testTransactionVersionInTransactionManager(transactionVersion: TransactionVersion): Unit = {
val metadataCache = mock(classOf[MetadataCache])
when(metadataCache.features()).thenReturn {
new FinalizedFeatures(
@ -1223,7 +1221,6 @@ class TransactionStateManagerTest {
val transactionManager = new TransactionStateManager(0, scheduler,
replicaManager, metadataCache, txnConfig, time, metrics)
val expectFlexibleRecords = transactionVersion.featureLevel > 0
assertEquals(expectFlexibleRecords, transactionManager.usesFlexibleRecords())
assertEquals(transactionVersion, transactionManager.transactionVersionLevel())
}
}

View File

@ -85,7 +85,7 @@ import org.apache.kafka.security.authorizer.AclEntry
import org.apache.kafka.server.ClientMetricsManager
import org.apache.kafka.server.authorizer.{Action, AuthorizationResult, Authorizer}
import org.apache.kafka.server.common.MetadataVersion.{IBP_0_10_2_IV0, IBP_2_2_IV1}
import org.apache.kafka.server.common.{FeatureVersion, FinalizedFeatures, GroupVersion, KRaftVersion, MetadataVersion, RequestLocal}
import org.apache.kafka.server.common.{FeatureVersion, FinalizedFeatures, GroupVersion, KRaftVersion, MetadataVersion, RequestLocal, TransactionVersion}
import org.apache.kafka.server.config.{ConfigType, KRaftConfigs, ReplicationConfigs, ServerConfigs, ServerLogConfigs, ShareGroupConfig}
import org.apache.kafka.server.metrics.ClientMetricsTestUtils
import org.apache.kafka.server.share.{CachedSharePartition, ErroneousAndValidPartitionData}
@ -2572,7 +2572,7 @@ class KafkaApisTest extends Logging {
reset(replicaManager, clientRequestQuotaManager, requestChannel, txnCoordinator)
val capturedResponse: ArgumentCaptor[EndTxnResponse] = ArgumentCaptor.forClass(classOf[EndTxnResponse])
val responseCallback: ArgumentCaptor[Errors => Unit] = ArgumentCaptor.forClass(classOf[Errors => Unit])
val responseCallback: ArgumentCaptor[(Errors, Long, Short) => Unit] = ArgumentCaptor.forClass(classOf[(Errors, Long, Short) => Unit])
val transactionalId = "txnId"
val producerId = 15L
@ -2587,15 +2587,18 @@ class KafkaApisTest extends Logging {
).build(version.toShort)
val request = buildRequest(endTxnRequest)
val clientTransactionVersion = if (version > 4) TransactionVersion.TV_2 else TransactionVersion.TV_0
val requestLocal = RequestLocal.withThreadConfinedCaching
when(txnCoordinator.handleEndTransaction(
ArgumentMatchers.eq(transactionalId),
ArgumentMatchers.eq(producerId),
ArgumentMatchers.eq(epoch),
ArgumentMatchers.eq(TransactionResult.COMMIT),
ArgumentMatchers.eq(clientTransactionVersion),
responseCallback.capture(),
ArgumentMatchers.eq(requestLocal)
)).thenAnswer(_ => responseCallback.getValue.apply(Errors.PRODUCER_FENCED))
)).thenAnswer(_ => responseCallback.getValue.apply(Errors.PRODUCER_FENCED, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH))
val kafkaApis = createKafkaApis()
try {
kafkaApis.handleEndTxnRequest(request, requestLocal)

View File

@ -49,6 +49,10 @@ public enum TransactionVersion implements FeatureVersion {
return featureLevel;
}
public static TransactionVersion fromFeatureLevel(short version) {
return (TransactionVersion) Features.TRANSACTION_VERSION.fromFeatureLevel(version, true);
}
@Override
public String featureName() {
return FEATURE_NAME;
@ -63,4 +67,14 @@ public enum TransactionVersion implements FeatureVersion {
public Map<String, Short> dependencies() {
return dependencies;
}
// Transactions V1 enables log version 0 (flexible fields)
public short transactionLogValueVersion() {
return (short) (featureLevel >= 1 ? 1 : 0);
}
// Transactions V2 enables epoch bump on commit/abort.
public boolean supportsEpochBump() {
return featureLevel >= 2;
}
}

View File

@ -24,6 +24,10 @@
"fields": [
{ "name": "ProducerId", "type": "int64", "versions": "0+",
"about": "Producer id in use by the transactional id"},
{ "name": "PreviousProducerId", "type": "int64", "taggedVersions": "1+", "tag": 0, "default": -1,
"about": "Producer id used by the last committed transaction"},
{ "name": "NextProducerId", "type": "int64", "taggedVersions": "1+", "tag": 1, "default": -1,
"about": "Latest producer ID sent to the producer for the given transactional ID"},
{ "name": "ProducerEpoch", "type": "int16", "versions": "0+",
"about": "Epoch associated with the producer id"},
{ "name": "TransactionTimeoutMs", "type": "int32", "versions": "0+",
@ -37,6 +41,8 @@
{ "name": "TransactionLastUpdateTimestampMs", "type": "int64", "versions": "0+",
"about": "Time the transaction was last updated"},
{ "name": "TransactionStartTimestampMs", "type": "int64", "versions": "0+",
"about": "Time the transaction was started"}
"about": "Time the transaction was started"},
{ "name": "ClientTransactionVersion", "type": "int16", "default": 0, "taggedVersions": "1+", "tag": 2,
"about": "The transaction version used by the client"}
]
}