From 990cb5c06c708170b4d0ff8cb9d4827b3a30f31d Mon Sep 17 00:00:00 2001 From: PoAn Yang Date: Sat, 16 Aug 2025 02:10:52 +0800 Subject: [PATCH] KAFKA-18884 Move TransactionMetadata to transaction-coordinator module (#19699) 1. Move TransactionMetadata to transaction-coordinator module. 2. Rewrite TransactionMetadata in Java. 3. The `topicPartitions` field uses `HashSet` instead of `Set`, because it's mutable field. 4. In Scala, when calling `prepare*` methods, they can use current value as default input in `prepareTransitionTo`. However, in Java, it doesn't support function default input value. To avoid a lot of duplicated code or assign value to wrong field, we add a private class `TransitionData`. It can get current `TransactionMetadata` value as default value and `prepare*` methods just need to assign updated value. Reviewers: Justine Olshan , Artem Livshits , Chia-Ping Tsai --- .../transaction/TransactionCoordinator.scala | 129 ++-- .../transaction/TransactionLog.scala | 39 +- .../TransactionMarkerChannelManager.scala | 14 +- ...actionMarkerRequestCompletionHandler.scala | 4 +- .../transaction/TransactionMetadata.scala | 492 ------------- .../transaction/TransactionStateManager.scala | 38 +- .../main/scala/kafka/server/KafkaApis.scala | 6 +- ...ransactionCoordinatorConcurrencyTest.scala | 35 +- .../TransactionCoordinatorTest.scala | 156 +++-- .../transaction/TransactionLogTest.scala | 16 +- .../TransactionMarkerChannelManagerTest.scala | 67 +- ...onMarkerRequestCompletionHandlerTest.scala | 8 +- .../transaction/TransactionMetadataTest.scala | 651 ++++++++--------- .../TransactionStateManagerTest.scala | 98 +-- .../unit/kafka/server/KafkaApisTest.scala | 8 +- gradle/spotbugs-exclude.xml | 6 +- .../transaction/TransactionMetadata.java | 662 ++++++++++++++++++ .../transaction/TxnTransitMetadata.java | 6 +- 18 files changed, 1316 insertions(+), 1119 deletions(-) delete mode 100644 core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala create mode 100644 transaction-coordinator/src/main/java/org/apache/kafka/coordinator/transaction/TransactionMetadata.java diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala index e57cf18d0a1..1e348b19b3e 100644 --- a/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala +++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala @@ -27,14 +27,15 @@ import org.apache.kafka.common.protocol.Errors import org.apache.kafka.common.record.RecordBatch import org.apache.kafka.common.requests.{AddPartitionsToTxnResponse, TransactionResult} import org.apache.kafka.common.utils.{LogContext, ProducerIdAndEpoch, Time} -import org.apache.kafka.coordinator.transaction.{ProducerIdManager, TransactionLogConfig, TransactionState, TransactionStateManagerConfig, TxnTransitMetadata} +import org.apache.kafka.coordinator.transaction.{ProducerIdManager, TransactionLogConfig, TransactionMetadata, TransactionState, TransactionStateManagerConfig, TxnTransitMetadata} import org.apache.kafka.metadata.MetadataCache import org.apache.kafka.server.common.{RequestLocal, TransactionVersion} import org.apache.kafka.server.util.Scheduler +import java.util import java.util.Properties import java.util.concurrent.atomic.AtomicBoolean -import scala.jdk.CollectionConverters._ +import scala.jdk.OptionConverters._ object TransactionCoordinator { @@ -147,17 +148,18 @@ class TransactionCoordinator(txnConfig: TransactionConfig, val coordinatorEpochAndMetadata = txnManager.getTransactionState(transactionalId).flatMap { case None => try { - val createdMetadata = new TransactionMetadata(transactionalId = transactionalId, - producerId = producerIdManager.generateProducerId(), - prevProducerId = RecordBatch.NO_PRODUCER_ID, - nextProducerId = RecordBatch.NO_PRODUCER_ID, - producerEpoch = RecordBatch.NO_PRODUCER_EPOCH, - lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, - txnTimeoutMs = resolvedTxnTimeoutMs, - state = TransactionState.EMPTY, - topicPartitions = collection.mutable.Set.empty[TopicPartition], - txnLastUpdateTimestamp = time.milliseconds(), - clientTransactionVersion = TransactionVersion.TV_0) + val createdMetadata = new TransactionMetadata(transactionalId, + producerIdManager.generateProducerId(), + RecordBatch.NO_PRODUCER_ID, + RecordBatch.NO_PRODUCER_ID, + RecordBatch.NO_PRODUCER_EPOCH, + RecordBatch.NO_PRODUCER_EPOCH, + resolvedTxnTimeoutMs, + TransactionState.EMPTY, + util.Set.of(), + -1, + time.milliseconds(), + TransactionVersion.TV_0) txnManager.putTransactionStateIfNotExists(createdMetadata) } catch { case e: Exception => Left(Errors.forException(e)) @@ -171,10 +173,10 @@ class TransactionCoordinator(txnConfig: TransactionConfig, val coordinatorEpoch = existingEpochAndMetadata.coordinatorEpoch val txnMetadata = existingEpochAndMetadata.transactionMetadata - txnMetadata.inLock { + txnMetadata.inLock(() => prepareInitProducerIdTransit(transactionalId, resolvedTxnTimeoutMs, coordinatorEpoch, txnMetadata, expectedProducerIdAndEpoch) - } + ) } result match { @@ -256,17 +258,16 @@ class TransactionCoordinator(txnConfig: TransactionConfig, case TransactionState.COMPLETE_ABORT | TransactionState.COMPLETE_COMMIT | TransactionState.EMPTY => val transitMetadataResult = // If the epoch is exhausted and the expected epoch (if provided) matches it, generate a new producer ID - if (txnMetadata.isProducerEpochExhausted && - expectedProducerIdAndEpoch.forall(_.epoch == txnMetadata.producerEpoch)) { - try { + try { + if (txnMetadata.isProducerEpochExhausted && + expectedProducerIdAndEpoch.forall(_.epoch == txnMetadata.producerEpoch)) Right(txnMetadata.prepareProducerIdRotation(producerIdManager.generateProducerId(), transactionTimeoutMs, time.milliseconds(), expectedProducerIdAndEpoch.isDefined)) - } catch { - case e: Exception => Left(Errors.forException(e)) - } - } else { - txnMetadata.prepareIncrementProducerEpoch(transactionTimeoutMs, expectedProducerIdAndEpoch.map(_.epoch), - time.milliseconds()) + else + Right(txnMetadata.prepareIncrementProducerEpoch(transactionTimeoutMs, expectedProducerIdAndEpoch.map(e => Short.box(e.epoch)).toJava, + time.milliseconds())) + } catch { + case e: Exception => Left(Errors.forException(e)) } transitMetadataResult match { @@ -326,12 +327,12 @@ class TransactionCoordinator(txnConfig: TransactionConfig, transactionState.setErrorCode(Errors.TRANSACTIONAL_ID_NOT_FOUND.code) case Right(Some(coordinatorEpochAndMetadata)) => val txnMetadata = coordinatorEpochAndMetadata.transactionMetadata - txnMetadata.inLock { + txnMetadata.inLock(() => { if (txnMetadata.state == TransactionState.DEAD) { // The transaction state is being expired, so ignore it transactionState.setErrorCode(Errors.TRANSACTIONAL_ID_NOT_FOUND.code) } else { - txnMetadata.topicPartitions.foreach { topicPartition => + txnMetadata.topicPartitions.forEach(topicPartition => { var topicData = transactionState.topics.find(topicPartition.topic) if (topicData == null) { topicData = new DescribeTransactionsResponseData.TopicData() @@ -339,7 +340,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, transactionState.topics.add(topicData) } topicData.partitions.add(topicPartition.partition) - } + }) transactionState .setErrorCode(Errors.NONE.code) @@ -349,7 +350,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, .setTransactionTimeoutMs(txnMetadata.txnTimeoutMs) .setTransactionStartTimeMs(txnMetadata.txnStartTimestamp) } - } + }) } } } @@ -357,13 +358,15 @@ class TransactionCoordinator(txnConfig: TransactionConfig, def handleVerifyPartitionsInTransaction(transactionalId: String, producerId: Long, producerEpoch: Short, - partitions: collection.Set[TopicPartition], + partitions: util.Set[TopicPartition], responseCallback: VerifyPartitionsCallback): Unit = { if (transactionalId == null || transactionalId.isEmpty) { debug(s"Returning ${Errors.INVALID_REQUEST} error code to client for $transactionalId's AddPartitions request for verification") - responseCallback(AddPartitionsToTxnResponse.resultForTransaction(transactionalId, partitions.map(_ -> Errors.INVALID_REQUEST).toMap.asJava)) + val errors = new util.HashMap[TopicPartition, Errors]() + partitions.forEach(partition => errors.put(partition, Errors.INVALID_REQUEST)) + responseCallback(AddPartitionsToTxnResponse.resultForTransaction(transactionalId, errors)) } else { - val result: ApiResult[Map[TopicPartition, Errors]] = + val result: ApiResult[util.Map[TopicPartition, Errors]] = txnManager.getTransactionState(transactionalId).flatMap { case None => Left(Errors.INVALID_PRODUCER_ID_MAPPING) @@ -373,7 +376,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, // Given the txnMetadata is valid, we check if the partitions are in the transaction. // Pending state is not checked since there is a final validation on the append to the log. // Partitions are added to metadata when the add partitions state is persisted, and removed when the end marker is persisted. - txnMetadata.inLock { + txnMetadata.inLock(() => { if (txnMetadata.producerId != producerId) { Left(Errors.INVALID_PRODUCER_ID_MAPPING) } else if (txnMetadata.producerEpoch != producerEpoch) { @@ -381,23 +384,27 @@ class TransactionCoordinator(txnConfig: TransactionConfig, } else if (txnMetadata.state == TransactionState.PREPARE_COMMIT || txnMetadata.state == TransactionState.PREPARE_ABORT) { Left(Errors.CONCURRENT_TRANSACTIONS) } else { - Right(partitions.map { part => + val errors = new util.HashMap[TopicPartition, Errors]() + partitions.forEach(part => { if (txnMetadata.topicPartitions.contains(part)) - (part, Errors.NONE) + errors.put(part, Errors.NONE) else - (part, Errors.TRANSACTION_ABORTABLE) - }.toMap) + errors.put(part, Errors.TRANSACTION_ABORTABLE) + }) + Right(errors) } - } + }) } result match { case Left(err) => debug(s"Returning $err error code to client for $transactionalId's AddPartitions request for verification") - responseCallback(AddPartitionsToTxnResponse.resultForTransaction(transactionalId, partitions.map(_ -> err).toMap.asJava)) + val errors = new util.HashMap[TopicPartition, Errors]() + partitions.forEach(partition => errors.put(partition, err)) + responseCallback(AddPartitionsToTxnResponse.resultForTransaction(transactionalId, errors)) case Right(errors) => - responseCallback(AddPartitionsToTxnResponse.resultForTransaction(transactionalId, errors.asJava)) + responseCallback(AddPartitionsToTxnResponse.resultForTransaction(transactionalId, errors)) } } @@ -406,7 +413,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, def handleAddPartitionsToTransaction(transactionalId: String, producerId: Long, producerEpoch: Short, - partitions: collection.Set[TopicPartition], + partitions: util.Set[TopicPartition], responseCallback: AddPartitionsCallback, clientTransactionVersion: TransactionVersion, requestLocal: RequestLocal = RequestLocal.noCaching): Unit = { @@ -424,7 +431,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, val txnMetadata = epochAndMetadata.transactionMetadata // generate the new transaction metadata with added partitions - txnMetadata.inLock { + txnMetadata.inLock(() => { if (txnMetadata.pendingTransitionInProgress) { // return a retriable exception to let the client backoff and retry // This check is performed first so that the pending transition can complete before subsequent checks. @@ -437,13 +444,13 @@ class TransactionCoordinator(txnConfig: TransactionConfig, Left(Errors.PRODUCER_FENCED) } else if (txnMetadata.state == TransactionState.PREPARE_COMMIT || txnMetadata.state == TransactionState.PREPARE_ABORT) { Left(Errors.CONCURRENT_TRANSACTIONS) - } else if (txnMetadata.state == TransactionState.ONGOING && partitions.subsetOf(txnMetadata.topicPartitions)) { + } else if (txnMetadata.state == TransactionState.ONGOING && txnMetadata.topicPartitions.containsAll(partitions)) { // this is an optimization: if the partitions are already in the metadata reply OK immediately Left(Errors.NONE) } else { - Right(coordinatorEpoch, txnMetadata.prepareAddPartitions(partitions.toSet, time.milliseconds(), clientTransactionVersion)) + Right(coordinatorEpoch, txnMetadata.prepareAddPartitions(partitions, time.milliseconds(), clientTransactionVersion)) } - } + }) } result match { @@ -549,7 +556,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, val txnMetadata = epochAndTxnMetadata.transactionMetadata val coordinatorEpoch = epochAndTxnMetadata.coordinatorEpoch - txnMetadata.inLock { + txnMetadata.inLock(() => { if (txnMetadata.producerId != producerId) Left(Errors.INVALID_PRODUCER_ID_MAPPING) // Strict equality is enforced on the client side requests, as they shouldn't bump the producer epoch. @@ -564,13 +571,13 @@ class TransactionCoordinator(txnConfig: TransactionConfig, else TransactionState.PREPARE_ABORT - if (nextState == TransactionState.PREPARE_ABORT && txnMetadata.pendingState.contains(TransactionState.PREPARE_EPOCH_FENCE)) { + if (nextState == TransactionState.PREPARE_ABORT && txnMetadata.pendingState.filter(s => s == TransactionState.PREPARE_EPOCH_FENCE).isPresent) { // We should clear the pending state to make way for the transition to PrepareAbort and also bump // the epoch in the transaction metadata we are about to append. isEpochFence = true - txnMetadata.pendingState = None - txnMetadata.producerEpoch = producerEpoch - txnMetadata.lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH + txnMetadata.pendingState(util.Optional.empty()) + txnMetadata.setProducerEpoch(producerEpoch) + txnMetadata.setLastProducerEpoch(RecordBatch.NO_PRODUCER_EPOCH) } Right(coordinatorEpoch, txnMetadata.prepareAbortOrCommit(nextState, TransactionVersion.fromFeatureLevel(0), RecordBatch.NO_PRODUCER_ID, time.milliseconds(), false)) @@ -602,7 +609,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, fatal(errorMsg) throw new IllegalStateException(errorMsg) } - } + }) } preAppendResult match { @@ -623,7 +630,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, case Some(epochAndMetadata) => if (epochAndMetadata.coordinatorEpoch == coordinatorEpoch) { val txnMetadata = epochAndMetadata.transactionMetadata - txnMetadata.inLock { + txnMetadata.inLock(() => { if (txnMetadata.producerId != producerId) Left(Errors.INVALID_PRODUCER_ID_MAPPING) else if (txnMetadata.producerEpoch != producerEpoch) @@ -649,7 +656,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, fatal(errorMsg) throw new IllegalStateException(errorMsg) } - } + }) } else { debug(s"The transaction coordinator epoch has changed to ${epochAndMetadata.coordinatorEpoch} after $txnMarkerResult was " + s"successfully appended to the log for $transactionalId with old epoch $coordinatorEpoch") @@ -682,7 +689,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, case Some(epochAndMetadata) => if (epochAndMetadata.coordinatorEpoch == coordinatorEpoch) { // This was attempted epoch fence that failed, so mark this state on the metadata - epochAndMetadata.transactionMetadata.hasFailedEpochFence = true + epochAndMetadata.transactionMetadata.hasFailedEpochFence(true) warn(s"The coordinator failed to write an epoch fence transition for producer $transactionalId to the transaction log " + s"with error $error. The epoch was increased to ${newMetadata.producerEpoch} but not returned to the client") } @@ -771,12 +778,12 @@ class TransactionCoordinator(txnConfig: TransactionConfig, val txnMetadata = epochAndTxnMetadata.transactionMetadata val coordinatorEpoch = epochAndTxnMetadata.coordinatorEpoch - txnMetadata.inLock { + txnMetadata.inLock(() => { producerIdCopy = txnMetadata.producerId producerEpochCopy = txnMetadata.producerEpoch // PrepareEpochFence has slightly different epoch bumping logic so don't include it here. // Note that, it can only happen when the current state is Ongoing. - isEpochFence = txnMetadata.pendingState.contains(TransactionState.PREPARE_EPOCH_FENCE) + isEpochFence = txnMetadata.pendingState.filter(s => s == TransactionState.PREPARE_EPOCH_FENCE).isPresent // True if the client retried a request that had overflowed the epoch, and a new producer ID is stored in the txnMetadata val retryOnOverflow = !isEpochFence && txnMetadata.prevProducerId == producerId && producerEpoch == Short.MaxValue - 1 && txnMetadata.producerEpoch == 0 @@ -820,7 +827,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, if (nextState == TransactionState.PREPARE_ABORT && isEpochFence) { // We should clear the pending state to make way for the transition to PrepareAbort - txnMetadata.pendingState = None + txnMetadata.pendingState(util.Optional.empty()) // For TV2+, don't manually set the epoch - let prepareAbortOrCommit handle it naturally. } @@ -893,7 +900,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, throw new IllegalStateException(errorMsg) } - } + }) } preAppendResult match { @@ -918,7 +925,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, case Some(epochAndMetadata) => if (epochAndMetadata.coordinatorEpoch == coordinatorEpoch) { val txnMetadata = epochAndMetadata.transactionMetadata - txnMetadata.inLock { + txnMetadata.inLock(() => { if (txnMetadata.producerId != producerId) Left(Errors.INVALID_PRODUCER_ID_MAPPING) else if (txnMetadata.producerEpoch != producerEpoch && txnMetadata.producerEpoch != producerEpoch + 1) @@ -945,7 +952,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, throw new IllegalStateException(errorMsg) } - } + }) } else { debug(s"The transaction coordinator epoch has changed to ${epochAndMetadata.coordinatorEpoch} after $txnMarkerResult was " + s"successfully appended to the log for $transactionalId with old epoch $coordinatorEpoch") @@ -1026,7 +1033,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, case Some(epochAndTxnMetadata) => val txnMetadata = epochAndTxnMetadata.transactionMetadata - val transitMetadataOpt = txnMetadata.inLock { + val transitMetadataOpt = txnMetadata.inLock(() => { if (txnMetadata.producerId != txnIdAndPidEpoch.producerId) { error(s"Found incorrect producerId when expiring transactionalId: ${txnIdAndPidEpoch.transactionalId}. " + s"Expected producerId: ${txnIdAndPidEpoch.producerId}. Found producerId: " + @@ -1039,7 +1046,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig, } else { Some(txnMetadata.prepareFenceProducerEpoch()) } - } + }) transitMetadataOpt.foreach { txnTransitMetadata => endTransaction(txnMetadata.transactionalId, diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala index a206f160d59..75baa98da15 100644 --- a/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala +++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala @@ -21,11 +21,12 @@ import org.apache.kafka.common.compress.Compression import org.apache.kafka.common.protocol.{ByteBufferAccessor, MessageUtil} import org.apache.kafka.common.record.RecordBatch import org.apache.kafka.common.TopicPartition -import org.apache.kafka.coordinator.transaction.{TransactionState, TxnTransitMetadata} +import org.apache.kafka.coordinator.transaction.{TransactionMetadata, TransactionState, TxnTransitMetadata} import org.apache.kafka.coordinator.transaction.generated.{CoordinatorRecordType, TransactionLogKey, TransactionLogValue} import org.apache.kafka.server.common.TransactionVersion -import scala.collection.mutable +import java.util + import scala.jdk.CollectionConverters._ /** @@ -115,26 +116,26 @@ object TransactionLog { if (version >= TransactionLogValue.LOWEST_SUPPORTED_VERSION && version <= TransactionLogValue.HIGHEST_SUPPORTED_VERSION) { val value = new TransactionLogValue(new ByteBufferAccessor(buffer), version) val transactionMetadata = new TransactionMetadata( - transactionalId = transactionalId, - producerId = value.producerId, - prevProducerId = value.previousProducerId, - nextProducerId = value.nextProducerId, - producerEpoch = value.producerEpoch, - lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, - txnTimeoutMs = value.transactionTimeoutMs, - state = TransactionState.fromId(value.transactionStatus), - topicPartitions = mutable.Set.empty[TopicPartition], - txnStartTimestamp = value.transactionStartTimestampMs, - txnLastUpdateTimestamp = value.transactionLastUpdateTimestampMs, - clientTransactionVersion = TransactionVersion.fromFeatureLevel(value.clientTransactionVersion)) + transactionalId, + value.producerId, + value.previousProducerId, + value.nextProducerId, + value.producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, + value.transactionTimeoutMs, + TransactionState.fromId(value.transactionStatus), + util.Set.of(), + value.transactionStartTimestampMs, + value.transactionLastUpdateTimestampMs, + TransactionVersion.fromFeatureLevel(value.clientTransactionVersion)) if (!transactionMetadata.state.equals(TransactionState.EMPTY)) - value.transactionPartitions.forEach(partitionsSchema => + value.transactionPartitions.forEach(partitionsSchema => { transactionMetadata.addPartitions(partitionsSchema.partitionIds - .asScala - .map(partitionId => new TopicPartition(partitionsSchema.topic, partitionId)) - .toSet) - ) + .stream + .map(partitionId => new TopicPartition(partitionsSchema.topic, partitionId.intValue())) + .toList) + }) Some(transactionMetadata) } else throw new IllegalStateException(s"Unknown version $version from the transaction log message value") } diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala index 1cc0550b462..6c395feb582 100644 --- a/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala +++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala @@ -32,7 +32,7 @@ import org.apache.kafka.common.requests.{TransactionResult, WriteTxnMarkersReque import org.apache.kafka.common.security.JaasContext import org.apache.kafka.common.utils.{LogContext, Time} import org.apache.kafka.common.{Node, Reconfigurable, TopicPartition} -import org.apache.kafka.coordinator.transaction.TxnTransitMetadata +import org.apache.kafka.coordinator.transaction.{TransactionMetadata, TxnTransitMetadata} import org.apache.kafka.metadata.MetadataCache import org.apache.kafka.server.common.RequestLocal import org.apache.kafka.server.metrics.KafkaMetricsGroup @@ -326,16 +326,16 @@ class TransactionMarkerChannelManager( info(s"Replaced an existing pending complete txn $prev with $pendingCompleteTxn while adding markers to send.") } addTxnMarkersToBrokerQueue(txnMetadata.producerId, - txnMetadata.producerEpoch, txnResult, pendingCompleteTxn, txnMetadata.topicPartitions.toSet) + txnMetadata.producerEpoch, txnResult, pendingCompleteTxn, txnMetadata.topicPartitions.asScala.toSet) maybeWriteTxnCompletion(transactionalId) } def numTxnsWithPendingMarkers: Int = transactionsWithPendingMarkers.size private def hasPendingMarkersToWrite(txnMetadata: TransactionMetadata): Boolean = { - txnMetadata.inLock { - txnMetadata.topicPartitions.nonEmpty - } + txnMetadata.inLock(() => + !txnMetadata.topicPartitions.isEmpty + ) } def maybeWriteTxnCompletion(transactionalId: String): Unit = { @@ -422,9 +422,9 @@ class TransactionMarkerChannelManager( val txnMetadata = epochAndMetadata.transactionMetadata - txnMetadata.inLock { + txnMetadata.inLock(() => topicPartitions.foreach(txnMetadata.removePartition) - } + ) maybeWriteTxnCompletion(transactionalId) } diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandler.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandler.scala index d95dabab6c3..63990fda985 100644 --- a/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandler.scala +++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandler.scala @@ -131,7 +131,7 @@ class TransactionMarkerRequestCompletionHandler(brokerId: Int, txnMarkerChannelManager.removeMarkersForTxn(pendingCompleteTxn) abortSending = true } else { - txnMetadata.inLock { + txnMetadata.inLock(() => { for ((topicPartition, error) <- errors.asScala) { error match { case Errors.NONE => @@ -178,7 +178,7 @@ class TransactionMarkerRequestCompletionHandler(brokerId: Int, throw new IllegalStateException(s"Unexpected error ${other.exceptionName} while sending txn marker for $transactionalId") } } - } + }) } if (!abortSending) { diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala deleted file mode 100644 index 73403a452e1..00000000000 --- a/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala +++ /dev/null @@ -1,492 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package kafka.coordinator.transaction - -import java.util.concurrent.locks.ReentrantLock -import kafka.utils.{CoreUtils, Logging, nonthreadsafe} -import org.apache.kafka.common.TopicPartition -import org.apache.kafka.common.protocol.Errors -import org.apache.kafka.common.record.RecordBatch -import org.apache.kafka.coordinator.transaction.{TransactionState, TxnTransitMetadata} -import org.apache.kafka.server.common.TransactionVersion - -import scala.collection.{immutable, mutable} -import scala.jdk.CollectionConverters._ - -private[transaction] object TransactionMetadata { - def isEpochExhausted(producerEpoch: Short): Boolean = producerEpoch >= Short.MaxValue - 1 -} - -/** - * - * @param producerId producer id - * @param prevProducerId producer id for the last committed transaction with this transactional ID - * @param nextProducerId Latest producer ID sent to the producer for the given transactional ID - * @param producerEpoch current epoch of the producer - * @param lastProducerEpoch last epoch of the producer - * @param txnTimeoutMs timeout to be used to abort long running transactions - * @param state current state of the transaction - * @param topicPartitions current set of partitions that are part of this transaction - * @param txnStartTimestamp time the transaction was started, i.e., when first partition is added - * @param txnLastUpdateTimestamp updated when any operation updates the TransactionMetadata. To be used for expiration - * @param clientTransactionVersion TransactionVersion used by the client when the state was transitioned - */ -@nonthreadsafe -private[transaction] class TransactionMetadata(val transactionalId: String, - var producerId: Long, - var prevProducerId: Long, - var nextProducerId: Long, - var producerEpoch: Short, - var lastProducerEpoch: Short, - var txnTimeoutMs: Int, - var state: TransactionState, - var topicPartitions: mutable.Set[TopicPartition], - @volatile var txnStartTimestamp: Long = -1, - @volatile var txnLastUpdateTimestamp: Long, - var clientTransactionVersion: TransactionVersion) extends Logging { - - // pending state is used to indicate the state that this transaction is going to - // transit to, and for blocking future attempts to transit it again if it is not legal; - // initialized as the same as the current state - var pendingState: Option[TransactionState] = None - - // Indicates that during a previous attempt to fence a producer, the bumped epoch may not have been - // successfully written to the log. If this is true, we will not bump the epoch again when fencing - var hasFailedEpochFence: Boolean = false - - private[transaction] val lock = new ReentrantLock - - def inLock[T](fun: => T): T = CoreUtils.inLock(lock)(fun) - - def addPartitions(partitions: collection.Set[TopicPartition]): Unit = { - topicPartitions ++= partitions - } - - def removePartition(topicPartition: TopicPartition): Unit = { - if (state != TransactionState.PREPARE_COMMIT && state != TransactionState.PREPARE_ABORT) - throw new IllegalStateException(s"Transaction metadata's current state is $state, and its pending state is $pendingState " + - s"while trying to remove partitions whose txn marker has been sent, this is not expected") - - topicPartitions -= topicPartition - } - - // this is visible for test only - def prepareNoTransit(): TxnTransitMetadata = { - // do not call transitTo as it will set the pending state, a follow-up call to abort the transaction will set its pending state - new TxnTransitMetadata(producerId, prevProducerId, nextProducerId, producerEpoch, lastProducerEpoch, txnTimeoutMs, state, topicPartitions.clone().asJava, - txnStartTimestamp, txnLastUpdateTimestamp, clientTransactionVersion) - } - - def prepareFenceProducerEpoch(): TxnTransitMetadata = { - if (producerEpoch == Short.MaxValue) - throw new IllegalStateException(s"Cannot fence producer with epoch equal to Short.MaxValue since this would overflow") - - // If we've already failed to fence an epoch (because the write to the log failed), we don't increase it again. - // This is safe because we never return the epoch to client if we fail to fence the epoch - val bumpedEpoch = if (hasFailedEpochFence) producerEpoch else (producerEpoch + 1).toShort - - prepareTransitionTo( - state = TransactionState.PREPARE_EPOCH_FENCE, - producerEpoch = bumpedEpoch, - lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH - ) - } - - def prepareIncrementProducerEpoch(newTxnTimeoutMs: Int, - expectedProducerEpoch: Option[Short], - updateTimestamp: Long): Either[Errors, TxnTransitMetadata] = { - if (isProducerEpochExhausted) - throw new IllegalStateException(s"Cannot allocate any more producer epochs for producerId $producerId") - - val bumpedEpoch = (producerEpoch + 1).toShort - val epochBumpResult: Either[Errors, (Short, Short)] = expectedProducerEpoch match { - case None => - // If no expected epoch was provided by the producer, bump the current epoch and set the last epoch to -1 - // In the case of a new producer, producerEpoch will be -1 and bumpedEpoch will be 0 - Right(bumpedEpoch, RecordBatch.NO_PRODUCER_EPOCH) - - case Some(expectedEpoch) => - if (producerEpoch == RecordBatch.NO_PRODUCER_EPOCH || expectedEpoch == producerEpoch) - // If the expected epoch matches the current epoch, or if there is no current epoch, the producer is attempting - // to continue after an error and no other producer has been initialized. Bump the current and last epochs. - // The no current epoch case means this is a new producer; producerEpoch will be -1 and bumpedEpoch will be 0 - Right(bumpedEpoch, producerEpoch) - else if (expectedEpoch == lastProducerEpoch) - // If the expected epoch matches the previous epoch, it is a retry of a successful call, so just return the - // current epoch without bumping. There is no danger of this producer being fenced, because a new producer - // calling InitProducerId would have caused the last epoch to be set to -1. - // Note that if the IBP is prior to 2.4.IV1, the lastProducerId and lastProducerEpoch will not be written to - // the transaction log, so a retry that spans a coordinator change will fail. We expect this to be a rare case. - Right(producerEpoch, lastProducerEpoch) - else { - // Otherwise, the producer has a fenced epoch and should receive an PRODUCER_FENCED error - info(s"Expected producer epoch $expectedEpoch does not match current " + - s"producer epoch $producerEpoch or previous producer epoch $lastProducerEpoch") - Left(Errors.PRODUCER_FENCED) - } - } - - epochBumpResult match { - case Right((nextEpoch, lastEpoch)) => Right(prepareTransitionTo( - state = TransactionState.EMPTY, - producerEpoch = nextEpoch, - lastProducerEpoch = lastEpoch, - txnTimeoutMs = newTxnTimeoutMs, - topicPartitions = mutable.Set.empty[TopicPartition], - txnStartTimestamp = -1, - txnLastUpdateTimestamp = updateTimestamp - )) - - case Left(err) => Left(err) - } - } - - def prepareProducerIdRotation(newProducerId: Long, - newTxnTimeoutMs: Int, - updateTimestamp: Long, - recordLastEpoch: Boolean): TxnTransitMetadata = { - if (hasPendingTransaction) - throw new IllegalStateException("Cannot rotate producer ids while a transaction is still pending") - - prepareTransitionTo( - state = TransactionState.EMPTY, - producerId = newProducerId, - producerEpoch = 0, - lastProducerEpoch = if (recordLastEpoch) producerEpoch else RecordBatch.NO_PRODUCER_EPOCH, - txnTimeoutMs = newTxnTimeoutMs, - topicPartitions = mutable.Set.empty[TopicPartition], - txnStartTimestamp = -1, - txnLastUpdateTimestamp = updateTimestamp - ) - } - - def prepareAddPartitions(addedTopicPartitions: immutable.Set[TopicPartition], updateTimestamp: Long, clientTransactionVersion: TransactionVersion): TxnTransitMetadata = { - val newTxnStartTimestamp = state match { - case TransactionState.EMPTY | TransactionState.COMPLETE_ABORT | TransactionState.COMPLETE_COMMIT => updateTimestamp - case _ => txnStartTimestamp - } - - prepareTransitionTo( - state = TransactionState.ONGOING, - topicPartitions = (topicPartitions ++ addedTopicPartitions), - txnStartTimestamp = newTxnStartTimestamp, - txnLastUpdateTimestamp = updateTimestamp, - clientTransactionVersion = clientTransactionVersion - ) - } - - def prepareAbortOrCommit(newState: TransactionState, clientTransactionVersion: TransactionVersion, nextProducerId: Long, updateTimestamp: Long, noPartitionAdded: Boolean): TxnTransitMetadata = { - val (updatedProducerEpoch, updatedLastProducerEpoch) = if (clientTransactionVersion.supportsEpochBump()) { - // We already ensured that we do not overflow here. MAX_SHORT is the highest possible value. - ((producerEpoch + 1).toShort, producerEpoch) - } else { - (producerEpoch, lastProducerEpoch) - } - - // With transaction V2, it is allowed to abort the transaction without adding any partitions. Then, the transaction - // start time is uncertain but it is still required. So we can use the update time as the transaction start time. - val newTxnStartTimestamp = if (noPartitionAdded) updateTimestamp else txnStartTimestamp - prepareTransitionTo( - state = newState, - nextProducerId = nextProducerId, - producerEpoch = updatedProducerEpoch, - lastProducerEpoch = updatedLastProducerEpoch, - txnStartTimestamp = newTxnStartTimestamp, - txnLastUpdateTimestamp = updateTimestamp, - clientTransactionVersion = clientTransactionVersion - ) - } - - def prepareComplete(updateTimestamp: Long): TxnTransitMetadata = { - val newState = if (state == TransactionState.PREPARE_COMMIT) TransactionState.COMPLETE_COMMIT else TransactionState.COMPLETE_ABORT - - // Since the state change was successfully written to the log, unset the flag for a failed epoch fence - hasFailedEpochFence = false - val (updatedProducerId, updatedProducerEpoch) = - // In the prepareComplete transition for the overflow case, the lastProducerEpoch is kept at MAX-1, - // which is the last epoch visible to the client. - // Internally, however, during the transition between prepareAbort/prepareCommit and prepareComplete, the producer epoch - // reaches MAX but the client only sees the transition as MAX-1 followed by 0. - // When an epoch overflow occurs, we set the producerId to nextProducerId and reset the epoch to 0, - // but lastProducerEpoch remains MAX-1 to maintain consistency with what the client last saw. - if (clientTransactionVersion.supportsEpochBump() && nextProducerId != RecordBatch.NO_PRODUCER_ID) { - (nextProducerId, 0.toShort) - } else { - (producerId, producerEpoch) - } - - prepareTransitionTo( - state = newState, - producerId = updatedProducerId, - nextProducerId = RecordBatch.NO_PRODUCER_ID, - producerEpoch = updatedProducerEpoch, - topicPartitions = mutable.Set.empty[TopicPartition], - txnLastUpdateTimestamp = updateTimestamp - ) - } - - def prepareDead(): TxnTransitMetadata = { - prepareTransitionTo( - state = TransactionState.DEAD, - topicPartitions = mutable.Set.empty[TopicPartition] - ) - } - - /** - * Check if the epochs have been exhausted for the current producerId. We do not allow the client to use an - * epoch equal to Short.MaxValue to ensure that the coordinator will always be able to fence an existing producer. - */ - def isProducerEpochExhausted: Boolean = TransactionMetadata.isEpochExhausted(producerEpoch) - - /** - * Check if this is a distributed two phase commit transaction. - * Such transactions have no timeout (identified by maximum value for timeout). - */ - def isDistributedTwoPhaseCommitTxn: Boolean = txnTimeoutMs == Int.MaxValue - - private def hasPendingTransaction: Boolean = { - state match { - case TransactionState.ONGOING | TransactionState.PREPARE_ABORT | TransactionState.PREPARE_COMMIT => true - case _ => false - } - } - - private def prepareTransitionTo(state: TransactionState, - producerId: Long = this.producerId, - nextProducerId: Long = this.nextProducerId, - producerEpoch: Short = this.producerEpoch, - lastProducerEpoch: Short = this.lastProducerEpoch, - txnTimeoutMs: Int = this.txnTimeoutMs, - topicPartitions: mutable.Set[TopicPartition] = this.topicPartitions, - txnStartTimestamp: Long = this.txnStartTimestamp, - txnLastUpdateTimestamp: Long = this.txnLastUpdateTimestamp, - clientTransactionVersion: TransactionVersion = this.clientTransactionVersion): TxnTransitMetadata = { - if (pendingState.isDefined) - throw new IllegalStateException(s"Preparing transaction state transition to $state " + - s"while it already a pending state ${pendingState.get}") - - if (producerId < 0) - throw new IllegalArgumentException(s"Illegal new producer id $producerId") - - // The epoch is initialized to NO_PRODUCER_EPOCH when the TransactionMetadata - // is created for the first time and it could stay like this until transitioning - // to Dead. - if (state != TransactionState.DEAD && producerEpoch < 0) - throw new IllegalArgumentException(s"Illegal new producer epoch $producerEpoch") - - // check that the new state transition is valid and update the pending state if necessary - if (state.validPreviousStates.contains(this.state)) { - val transitMetadata = new TxnTransitMetadata(producerId, this.producerId, nextProducerId, producerEpoch, lastProducerEpoch, txnTimeoutMs, state, - topicPartitions.asJava, txnStartTimestamp, txnLastUpdateTimestamp, clientTransactionVersion) - debug(s"TransactionalId ${this.transactionalId} prepare transition from ${this.state} to $transitMetadata") - pendingState = Some(state) - transitMetadata - } else { - throw new IllegalStateException(s"Preparing transaction state transition to $state failed since the target state" + - s" $state is not a valid previous state of the current state ${this.state}") - } - } - - def completeTransitionTo(transitMetadata: TxnTransitMetadata): Unit = { - // metadata transition is valid only if all the following conditions are met: - // - // 1. the new state is already indicated in the pending state. - // 2. the epoch should be either the same value, the old value + 1, or 0 if we have a new producerId. - // 3. the last update time is no smaller than the old value. - // 4. the old partitions set is a subset of the new partitions set. - // - // plus, we should only try to update the metadata after the corresponding log entry has been successfully - // written and replicated (see TransactionStateManager#appendTransactionToLog) - // - // if valid, transition is done via overwriting the whole object to ensure synchronization - - val toState = pendingState.getOrElse { - fatal(s"$this's transition to $transitMetadata failed since pendingState is not defined: this should not happen") - - throw new IllegalStateException(s"TransactionalId $transactionalId " + - "completing transaction state transition while it does not have a pending state") - } - - if (toState != transitMetadata.txnState) { - throwStateTransitionFailure(transitMetadata) - } else { - toState match { - case TransactionState.EMPTY => // from initPid - if ((producerEpoch != transitMetadata.producerEpoch && !validProducerEpochBump(transitMetadata)) || - !transitMetadata.topicPartitions.isEmpty || - transitMetadata.txnStartTimestamp != -1) { - - throwStateTransitionFailure(transitMetadata) - } - - case TransactionState.ONGOING => // from addPartitions - if (!validProducerEpoch(transitMetadata) || - !topicPartitions.subsetOf(transitMetadata.topicPartitions.asScala) || - txnTimeoutMs != transitMetadata.txnTimeoutMs) { - - throwStateTransitionFailure(transitMetadata) - } - - case TransactionState.PREPARE_ABORT | TransactionState.PREPARE_COMMIT => // from endTxn - // In V2, we allow state transits from Empty, CompleteCommit and CompleteAbort to PrepareAbort. It is possible - // their updated start time is not equal to the current start time. - val allowedEmptyAbort = toState == TransactionState.PREPARE_ABORT && transitMetadata.clientTransactionVersion.supportsEpochBump() && - (state == TransactionState.EMPTY || state == TransactionState.COMPLETE_COMMIT || state == TransactionState.COMPLETE_ABORT) - val validTimestamp = txnStartTimestamp == transitMetadata.txnStartTimestamp || allowedEmptyAbort - if (!validProducerEpoch(transitMetadata) || - !topicPartitions.equals(transitMetadata.topicPartitions.asScala) || - txnTimeoutMs != transitMetadata.txnTimeoutMs || !validTimestamp) { - - throwStateTransitionFailure(transitMetadata) - } - - case TransactionState.COMPLETE_ABORT | TransactionState.COMPLETE_COMMIT => // from write markers - if (!validProducerEpoch(transitMetadata) || - txnTimeoutMs != transitMetadata.txnTimeoutMs || - transitMetadata.txnStartTimestamp == -1) { - throwStateTransitionFailure(transitMetadata) - } - - case TransactionState.PREPARE_EPOCH_FENCE => - // We should never get here, since once we prepare to fence the epoch, we immediately set the pending state - // to PrepareAbort, and then consequently to CompleteAbort after the markers are written.. So we should never - // ever try to complete a transition to PrepareEpochFence, as it is not a valid previous state for any other state, and hence - // can never be transitioned out of. - throwStateTransitionFailure(transitMetadata) - - - case TransactionState.DEAD => - // The transactionalId was being expired. The completion of the operation should result in removal of the - // the metadata from the cache, so we should never realistically transition to the dead state. - throw new IllegalStateException(s"TransactionalId $transactionalId is trying to complete a transition to " + - s"$toState. This means that the transactionalId was being expired, and the only acceptable completion of " + - s"this operation is to remove the transaction metadata from the cache, not to persist the $toState in the log.") - } - - debug(s"TransactionalId $transactionalId complete transition from $state to $transitMetadata") - producerId = transitMetadata.producerId - prevProducerId = transitMetadata.prevProducerId - nextProducerId = transitMetadata.nextProducerId - producerEpoch = transitMetadata.producerEpoch - lastProducerEpoch = transitMetadata.lastProducerEpoch - txnTimeoutMs = transitMetadata.txnTimeoutMs - topicPartitions = transitMetadata.topicPartitions.asScala - txnStartTimestamp = transitMetadata.txnStartTimestamp - txnLastUpdateTimestamp = transitMetadata.txnLastUpdateTimestamp - clientTransactionVersion = transitMetadata.clientTransactionVersion - - pendingState = None - state = toState - } - } - - /** - * Validates the producer epoch and ID based on transaction state and version. - * - * Logic: - * * 1. **Overflow Case in Transactions V2:** - * * - During overflow (epoch reset to 0), we compare both `lastProducerEpoch` values since it - * * does not change during completion. - * * - For PrepareComplete, the producer ID has been updated. We ensure that the `prevProducerID` - * * in the transit metadata matches the current producer ID, confirming the change. - * * - * * 2. **Epoch Bump Case in Transactions V2:** - * * - For PrepareCommit or PrepareAbort, the producer epoch has been bumped. We ensure the `lastProducerEpoch` - * * in transit metadata matches the current producer epoch, confirming the bump. - * * - We also verify that the producer ID remains the same. - * * - * * 3. **Other Cases:** - * * - For other states and versions, check if the producer epoch and ID match the current values. - * - * @param transitMetadata The transaction transition metadata containing state, producer epoch, and ID. - * @return true if the producer epoch and ID are valid; false otherwise. - */ - private def validProducerEpoch(transitMetadata: TxnTransitMetadata): Boolean = { - val isAtLeastTransactionsV2 = transitMetadata.clientTransactionVersion.supportsEpochBump() - val txnState = transitMetadata.txnState - val transitProducerEpoch = transitMetadata.producerEpoch - val transitProducerId = transitMetadata.producerId - val transitLastProducerEpoch = transitMetadata.lastProducerEpoch - - (isAtLeastTransactionsV2, txnState, transitProducerEpoch) match { - case (true, TransactionState.COMPLETE_COMMIT | TransactionState.COMPLETE_ABORT, epoch) if epoch == 0.toShort => - transitLastProducerEpoch == lastProducerEpoch && - transitMetadata.prevProducerId == producerId - - case (true, TransactionState.PREPARE_COMMIT | TransactionState.PREPARE_ABORT, _) => - transitLastProducerEpoch == producerEpoch && - transitProducerId == producerId - - case _ => - transitProducerEpoch == producerEpoch && - transitProducerId == producerId - } - } - - private def validProducerEpochBump(transitMetadata: TxnTransitMetadata): Boolean = { - val transitEpoch = transitMetadata.producerEpoch - val transitProducerId = transitMetadata.producerId - transitEpoch == producerEpoch + 1 || (transitEpoch == 0 && transitProducerId != producerId) - } - - private def throwStateTransitionFailure(txnTransitMetadata: TxnTransitMetadata): Unit = { - fatal(s"${this.toString}'s transition to $txnTransitMetadata failed: this should not happen") - - throw new IllegalStateException(s"TransactionalId $transactionalId failed transition to state $txnTransitMetadata " + - "due to unexpected metadata") - } - - def pendingTransitionInProgress: Boolean = pendingState.isDefined - - override def toString: String = { - "TransactionMetadata(" + - s"transactionalId=$transactionalId, " + - s"producerId=$producerId, " + - s"prevProducerId=$prevProducerId, " + - s"nextProducerId=$nextProducerId, " + - s"producerEpoch=$producerEpoch, " + - s"lastProducerEpoch=$lastProducerEpoch, " + - s"txnTimeoutMs=$txnTimeoutMs, " + - s"state=$state, " + - s"pendingState=$pendingState, " + - s"topicPartitions=$topicPartitions, " + - s"txnStartTimestamp=$txnStartTimestamp, " + - s"txnLastUpdateTimestamp=$txnLastUpdateTimestamp, " + - s"clientTransactionVersion=$clientTransactionVersion)" - } - - override def equals(that: Any): Boolean = that match { - case other: TransactionMetadata => - transactionalId == other.transactionalId && - producerId == other.producerId && - producerEpoch == other.producerEpoch && - lastProducerEpoch == other.lastProducerEpoch && - txnTimeoutMs == other.txnTimeoutMs && - state.equals(other.state) && - topicPartitions.equals(other.topicPartitions) && - txnStartTimestamp == other.txnStartTimestamp && - txnLastUpdateTimestamp == other.txnLastUpdateTimestamp && - clientTransactionVersion == other.clientTransactionVersion - case _ => false - } - - override def hashCode(): Int = { - val fields = Seq(transactionalId, producerId, producerEpoch, txnTimeoutMs, state, topicPartitions, - txnStartTimestamp, txnLastUpdateTimestamp, clientTransactionVersion) - fields.map(_.hashCode()).foldLeft(0)((a, b) => 31 * a + b) - } -} diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala index 153c331b3f9..82b960c5ba7 100644 --- a/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala +++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala @@ -35,7 +35,7 @@ import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse import org.apache.kafka.common.requests.TransactionResult import org.apache.kafka.common.utils.{Time, Utils} import org.apache.kafka.common.{KafkaException, TopicIdPartition, TopicPartition} -import org.apache.kafka.coordinator.transaction.{TransactionLogConfig, TransactionState, TransactionStateManagerConfig, TxnTransitMetadata} +import org.apache.kafka.coordinator.transaction.{TransactionLogConfig, TransactionMetadata, TransactionState, TransactionStateManagerConfig, TxnTransitMetadata} import org.apache.kafka.metadata.MetadataCache import org.apache.kafka.server.common.{RequestLocal, TransactionVersion} import org.apache.kafka.server.config.ServerConfigs @@ -46,6 +46,8 @@ import org.apache.kafka.storage.internals.log.AppendOrigin import com.google.re2j.{Pattern, PatternSyntaxException} import org.apache.kafka.common.errors.InvalidRegularExpression +import java.util.Optional + import scala.jdk.CollectionConverters._ import scala.collection.mutable @@ -176,7 +178,7 @@ class TransactionStateManager(brokerId: Int, val transactionalId = txnMetadata.transactionalId var fullBatch = false - txnMetadata.inLock { + txnMetadata.inLock(() => { if (txnMetadata.pendingState.isEmpty && shouldExpire(txnMetadata, currentTimeMs)) { if (recordsBuilder == null) { recordsBuilder = MemoryRecords.builder( @@ -199,7 +201,7 @@ class TransactionStateManager(brokerId: Int, fullBatch = true } } - } + }) if (fullBatch) { flushRecordsBuilder() @@ -263,9 +265,9 @@ class TransactionStateManager(brokerId: Int, expiredForPartition.foreach { idCoordinatorEpochAndMetadata => val transactionalId = idCoordinatorEpochAndMetadata.transactionalId val txnMetadata = txnMetadataCacheEntry.metadataPerTransactionalId.get(transactionalId) - txnMetadata.inLock { + txnMetadata.inLock(() => { if (txnMetadataCacheEntry.coordinatorEpoch == idCoordinatorEpochAndMetadata.coordinatorEpoch - && txnMetadata.pendingState.contains(TransactionState.DEAD) + && txnMetadata.pendingState.filter(s => s == TransactionState.DEAD).isPresent && txnMetadata.producerEpoch == idCoordinatorEpochAndMetadata.transitMetadata.producerEpoch && response.error == Errors.NONE) { txnMetadataCacheEntry.metadataPerTransactionalId.remove(transactionalId) @@ -276,9 +278,9 @@ class TransactionStateManager(brokerId: Int, s" expected producerEpoch: ${idCoordinatorEpochAndMetadata.transitMetadata.producerEpoch}," + s" coordinatorEpoch: ${txnMetadataCacheEntry.coordinatorEpoch}, expected coordinatorEpoch: " + s"${idCoordinatorEpochAndMetadata.coordinatorEpoch}") - txnMetadata.pendingState = None + txnMetadata.pendingState(Optional.empty()) } - } + }) } } } @@ -366,7 +368,7 @@ class TransactionStateManager(brokerId: Int, } else null transactionMetadataCache.foreachEntry { (_, cache) => cache.metadataPerTransactionalId.forEach { (_, txnMetadata) => - txnMetadata.inLock { + txnMetadata.inLock(() => { if (shouldInclude(txnMetadata, pattern)) { states.add(new ListTransactionsResponseData.TransactionState() .setTransactionalId(txnMetadata.transactionalId) @@ -374,7 +376,7 @@ class TransactionStateManager(brokerId: Int, .setTransactionState(txnMetadata.state.stateName) ) } - } + }) } } response.setErrorCode(Errors.NONE.code) @@ -565,7 +567,7 @@ class TransactionStateManager(brokerId: Int, val transactionsPendingForCompletion = new mutable.ListBuffer[TransactionalIdCoordinatorEpochAndTransitMetadata] loadedTransactions.forEach((transactionalId, txnMetadata) => { - txnMetadata.inLock { + txnMetadata.inLock(() => { // if state is PrepareCommit or PrepareAbort we need to complete the transaction txnMetadata.state match { case TransactionState.PREPARE_ABORT => @@ -577,7 +579,7 @@ class TransactionStateManager(brokerId: Int, case _ => // nothing needs to be done } - } + }) }) // we first remove the partition from loading partition then send out the markers for those pending to be @@ -713,7 +715,7 @@ class TransactionStateManager(brokerId: Int, case Right(Some(epochAndMetadata)) => val metadata = epochAndMetadata.transactionMetadata - metadata.inLock { + metadata.inLock(() => { if (epochAndMetadata.coordinatorEpoch != coordinatorEpoch) { // the cache may have been changed due to txn topic partition emigration and immigration, // in this case directly return NOT_COORDINATOR to client and let it to re-discover the transaction coordinator @@ -725,7 +727,7 @@ class TransactionStateManager(brokerId: Int, metadata.completeTransitionTo(newMetadata) debug(s"Updating $transactionalId's transaction state to $newMetadata with coordinator epoch $coordinatorEpoch for $transactionalId succeeded") } - } + }) case Right(None) => // this transactional id no longer exists, maybe the corresponding partition has already been migrated out. @@ -740,7 +742,7 @@ class TransactionStateManager(brokerId: Int, getTransactionState(transactionalId) match { case Right(Some(epochAndTxnMetadata)) => val metadata = epochAndTxnMetadata.transactionMetadata - metadata.inLock { + metadata.inLock(() => { if (epochAndTxnMetadata.coordinatorEpoch == coordinatorEpoch) { if (retryOnError(responseError)) { info(s"TransactionalId ${metadata.transactionalId} append transaction log for $newMetadata transition failed due to $responseError, " + @@ -749,13 +751,13 @@ class TransactionStateManager(brokerId: Int, info(s"TransactionalId ${metadata.transactionalId} append transaction log for $newMetadata transition failed due to $responseError, " + s"resetting pending state from ${metadata.pendingState}, aborting state transition and returning $responseError in the callback") - metadata.pendingState = None + metadata.pendingState(Optional.empty()) } } else { info(s"TransactionalId ${metadata.transactionalId} append transaction log for $newMetadata transition failed due to $responseError, " + s"aborting state transition and returning the error in the callback since the coordinator epoch has changed from ${epochAndTxnMetadata.coordinatorEpoch} to $coordinatorEpoch") } - } + }) case Right(None) => // Do nothing here, since we want to return the original append error to the user. @@ -790,7 +792,7 @@ class TransactionStateManager(brokerId: Int, case Right(Some(epochAndMetadata)) => val metadata = epochAndMetadata.transactionMetadata - val append: Boolean = metadata.inLock { + val append: Boolean = metadata.inLock(() => { if (epochAndMetadata.coordinatorEpoch != coordinatorEpoch) { // the coordinator epoch has changed, reply to client immediately with NOT_COORDINATOR responseCallback(Errors.NOT_COORDINATOR) @@ -800,7 +802,7 @@ class TransactionStateManager(brokerId: Int, // under the same coordinator epoch, so directly append to txn log now true } - } + }) if (append) { replicaManager.appendRecords( timeout = newMetadata.txnTimeoutMs.toLong, diff --git a/core/src/main/scala/kafka/server/KafkaApis.scala b/core/src/main/scala/kafka/server/KafkaApis.scala index 45bd8b7e7fb..2862d090362 100644 --- a/core/src/main/scala/kafka/server/KafkaApis.scala +++ b/core/src/main/scala/kafka/server/KafkaApis.scala @@ -1970,7 +1970,7 @@ class KafkaApis(val requestChannel: RequestChannel, } else { val unauthorizedTopicErrors = mutable.Map[TopicPartition, Errors]() val nonExistingTopicErrors = mutable.Map[TopicPartition, Errors]() - val authorizedPartitions = mutable.Set[TopicPartition]() + val authorizedPartitions = new util.HashSet[TopicPartition]() // Only request versions less than 4 need write authorization since they come from clients. val authorizedTopics = @@ -1992,7 +1992,7 @@ class KafkaApis(val requestChannel: RequestChannel, // partitions which failed, and an 'OPERATION_NOT_ATTEMPTED' error code for the partitions which succeeded // the authorization check to indicate that they were not added to the transaction. val partitionErrors = unauthorizedTopicErrors ++ nonExistingTopicErrors ++ - authorizedPartitions.map(_ -> Errors.OPERATION_NOT_ATTEMPTED) + authorizedPartitions.asScala.map(_ -> Errors.OPERATION_NOT_ATTEMPTED) addResultAndMaybeSendResponse(AddPartitionsToTxnResponse.resultForTransaction(transactionalId, partitionErrors.asJava)) } else { def sendResponseCallback(error: Errors): Unit = { @@ -2071,7 +2071,7 @@ class KafkaApis(val requestChannel: RequestChannel, txnCoordinator.handleAddPartitionsToTransaction(transactionalId, addOffsetsToTxnRequest.data.producerId, addOffsetsToTxnRequest.data.producerEpoch, - Set(offsetTopicPartition), + util.Set.of(offsetTopicPartition), sendResponseCallback, TransactionVersion.TV_0, // This request will always come from the client not using TV 2. requestLocal) diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala index 030295975a4..79957b01fb7 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala @@ -37,7 +37,7 @@ import org.apache.kafka.common.record.{FileRecords, MemoryRecords, RecordBatch, import org.apache.kafka.common.requests._ import org.apache.kafka.common.utils.{LogContext, MockTime, ProducerIdAndEpoch} import org.apache.kafka.common.{Node, TopicPartition, Uuid} -import org.apache.kafka.coordinator.transaction.{ProducerIdManager, TransactionState} +import org.apache.kafka.coordinator.transaction.{ProducerIdManager, TransactionMetadata, TransactionState} import org.apache.kafka.metadata.MetadataCache import org.apache.kafka.server.common.{FinalizedFeatures, MetadataVersion, RequestLocal, TransactionVersion} import org.apache.kafka.server.storage.log.FetchIsolation @@ -63,7 +63,7 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren private val allOperations = Seq( new InitProducerIdOperation, - new AddPartitionsToTxnOperation(Set(new TopicPartition("topic", 0))), + new AddPartitionsToTxnOperation(util.Set.of(new TopicPartition("topic", 0))), new EndTxnOperation) private val allTransactions = mutable.Set[Transaction]() @@ -459,7 +459,7 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren val partitionId = txnStateManager.partitionFor(txn.transactionalId) val txnRecords = txnRecordsByPartition(partitionId) val initPidOp = new InitProducerIdOperation() - val addPartitionsOp = new AddPartitionsToTxnOperation(Set(new TopicPartition("topic", 0))) + val addPartitionsOp = new AddPartitionsToTxnOperation(util.Set.of(new TopicPartition("topic", 0))) initPidOp.run(txn) initPidOp.awaitAndVerify(txn) addPartitionsOp.run(txn) @@ -468,7 +468,7 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren val txnMetadata = transactionMetadata(txn).getOrElse(throw new IllegalStateException(s"Transaction not found $txn")) txnRecords += new SimpleRecord(txn.txnMessageKeyBytes, TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), TransactionVersion.TV_2)) - txnMetadata.state = TransactionState.PREPARE_COMMIT + txnMetadata.state(TransactionState.PREPARE_COMMIT) txnRecords += new SimpleRecord(txn.txnMessageKeyBytes, TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), TransactionVersion.TV_2)) prepareTxnLog(partitionId) @@ -506,17 +506,18 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren } private def prepareExhaustedEpochTxnMetadata(txn: Transaction): TransactionMetadata = { - new TransactionMetadata(transactionalId = txn.transactionalId, - producerId = producerId, - prevProducerId = RecordBatch.NO_PRODUCER_ID, - nextProducerId = RecordBatch.NO_PRODUCER_ID, - producerEpoch = (Short.MaxValue - 1).toShort, - lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, - txnTimeoutMs = 60000, - state = TransactionState.EMPTY, - topicPartitions = collection.mutable.Set.empty[TopicPartition], - txnLastUpdateTimestamp = time.milliseconds(), - clientTransactionVersion = TransactionVersion.TV_0) + new TransactionMetadata(txn.transactionalId, + producerId, + RecordBatch.NO_PRODUCER_ID, + RecordBatch.NO_PRODUCER_ID, + (Short.MaxValue - 1).toShort, + RecordBatch.NO_PRODUCER_EPOCH, + 60000, + TransactionState.EMPTY, + new util.HashSet[TopicPartition](), + -1, + time.milliseconds(), + TransactionVersion.TV_0) } abstract class TxnOperation[R] extends Operation { @@ -548,7 +549,7 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren } } - class AddPartitionsToTxnOperation(partitions: Set[TopicPartition]) extends TxnOperation[Errors] { + class AddPartitionsToTxnOperation(partitions: util.Set[TopicPartition]) extends TxnOperation[Errors] { override def run(txn: Transaction): Unit = { transactionMetadata(txn).foreach { txnMetadata => transactionCoordinator.handleAddPartitionsToTransaction(txn.transactionalId, @@ -629,7 +630,7 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren override def run(): Unit = { transactions.foreach { txn => transactionMetadata(txn).foreach { txnMetadata => - txnMetadata.txnLastUpdateTimestamp = time.milliseconds() - txnConfig.transactionalIdExpirationMs + txnMetadata.txnLastUpdateTimestamp(time.milliseconds() - txnConfig.transactionalIdExpirationMs) } } txnStateManager.enableTransactionalIdExpiration() diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala index 26675fca747..d9f6e115fbb 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala @@ -22,7 +22,7 @@ import org.apache.kafka.common.protocol.Errors import org.apache.kafka.common.record.RecordBatch import org.apache.kafka.common.requests.{AddPartitionsToTxnResponse, TransactionResult} import org.apache.kafka.common.utils.{LogContext, MockTime, ProducerIdAndEpoch} -import org.apache.kafka.coordinator.transaction.{ProducerIdManager, TransactionState, TransactionStateManagerConfig, TxnTransitMetadata} +import org.apache.kafka.coordinator.transaction.{ProducerIdManager, TransactionMetadata, TransactionState, TransactionStateManagerConfig, TxnTransitMetadata} import org.apache.kafka.server.common.TransactionVersion import org.apache.kafka.server.common.TransactionVersion.{TV_0, TV_2} import org.apache.kafka.server.util.MockScheduler @@ -33,9 +33,9 @@ import org.junit.jupiter.params.provider.{CsvSource, ValueSource} import org.mockito.ArgumentMatchers.{any, anyBoolean, anyInt} import org.mockito.Mockito._ import org.mockito.{ArgumentCaptor, ArgumentMatchers} -import org.mockito.Mockito.doAnswer -import scala.collection.mutable +import java.util + import scala.jdk.CollectionConverters._ class TransactionCoordinatorTest { @@ -57,7 +57,8 @@ class TransactionCoordinatorTest { private val txnTimeoutMs = 1 private val producerId2 = 11L - private val partitions = mutable.Set[TopicPartition](new TopicPartition("topic1", 0)) + private val partitions = new util.HashSet[TopicPartition]() + partitions.add(new TopicPartition("topic1", 0)) private val scheduler = new MockScheduler(time) val coordinator = new TransactionCoordinator( @@ -198,7 +199,7 @@ class TransactionCoordinatorTest { initPidGenericMocks(transactionalId) val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, (Short.MaxValue - 1).toShort, - (Short.MaxValue - 2).toShort, txnTimeoutMs, TransactionState.EMPTY, mutable.Set.empty, time.milliseconds(), time.milliseconds(), TV_0) + (Short.MaxValue - 2).toShort, txnTimeoutMs, TransactionState.EMPTY, util.Set.of, time.milliseconds(), time.milliseconds(), TV_0) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) @@ -230,10 +231,10 @@ class TransactionCoordinatorTest { initPidGenericMocks(transactionalId) val txnMetadata1 = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, (Short.MaxValue - 1).toShort, - (Short.MaxValue - 2).toShort, txnTimeoutMs, TransactionState.ONGOING, mutable.Set.empty, time.milliseconds(), time.milliseconds(), TV_2) + (Short.MaxValue - 2).toShort, txnTimeoutMs, TransactionState.ONGOING, util.Set.of, time.milliseconds(), time.milliseconds(), TV_2) // We start with txnMetadata1 so we can transform the metadata to TransactionState.PREPARE_COMMIT. val txnMetadata2 = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, (Short.MaxValue - 1).toShort, - (Short.MaxValue - 2).toShort, txnTimeoutMs, TransactionState.ONGOING, mutable.Set.empty, time.milliseconds(), time.milliseconds(), TV_2) + (Short.MaxValue - 2).toShort, txnTimeoutMs, TransactionState.ONGOING, util.Set.of, time.milliseconds(), time.milliseconds(), TV_2) val transitMetadata = txnMetadata2.prepareAbortOrCommit(TransactionState.PREPARE_COMMIT, TV_2, producerId2, time.milliseconds(), false) txnMetadata2.completeTransitionTo(transitMetadata) @@ -376,8 +377,8 @@ class TransactionCoordinatorTest { // Pending state does not matter, we will just check if the partitions are in the txnMetadata. val ongoingTxnMetadata = new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID, - 0, RecordBatch.NO_PRODUCER_EPOCH, 0, TransactionState.ONGOING, mutable.Set.empty, 0, 0, TV_0) - ongoingTxnMetadata.pendingState = Some(TransactionState.COMPLETE_COMMIT) + 0, RecordBatch.NO_PRODUCER_EPOCH, 0, TransactionState.ONGOING, util.Set.of, 0, 0, TV_0) + ongoingTxnMetadata.pendingState(util.Optional.of(TransactionState.COMPLETE_COMMIT)) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(new CoordinatorEpochAndTxnMetadata(coordinatorEpoch, ongoingTxnMetadata)))) @@ -402,7 +403,7 @@ class TransactionCoordinatorTest { when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID, - 0, RecordBatch.NO_PRODUCER_EPOCH, 0, state, mutable.Set.empty, 0, 0, TV_2))))) + 0, RecordBatch.NO_PRODUCER_EPOCH, 0, state, util.Set.of, 0, 0, TV_2))))) coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 0, partitions, errorsCallback, TV_2) assertEquals(Errors.CONCURRENT_TRANSACTIONS, error) @@ -414,7 +415,7 @@ class TransactionCoordinatorTest { when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID, - 10, 9, 0, TransactionState.PREPARE_COMMIT, mutable.Set.empty, 0, 0, TV_2))))) + 10, 9, 0, TransactionState.PREPARE_COMMIT, util.Set.of, 0, 0, TV_2))))) coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 0, partitions, errorsCallback, TV_2) assertEquals(Errors.PRODUCER_FENCED, error) @@ -445,7 +446,7 @@ class TransactionCoordinatorTest { def validateSuccessfulAddPartitions(previousState: TransactionState, transactionVersion: Short): Unit = { val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, - producerEpoch, (producerEpoch - 1).toShort, txnTimeoutMs, previousState, mutable.Set.empty, time.milliseconds(), time.milliseconds(), clientTransactionVersion) + producerEpoch, (producerEpoch - 1).toShort, txnTimeoutMs, previousState, util.Set.of, time.milliseconds(), time.milliseconds(), clientTransactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) @@ -505,8 +506,9 @@ class TransactionCoordinatorTest { .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_ID, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, TransactionState.EMPTY, partitions, 0, 0, TV_0))))) - - val extraPartitions = partitions ++ Set(new TopicPartition("topic2", 0)) + + val extraPartitions = new util.HashSet[TopicPartition](partitions) + extraPartitions.add(new TopicPartition("topic2", 0)) coordinator.handleVerifyPartitionsInTransaction(transactionalId, 0L, 0, extraPartitions, verifyPartitionsInTxnCallback) assertEquals(Errors.TRANSACTION_ABORTABLE, errors(new TopicPartition("topic2", 0))) @@ -533,7 +535,7 @@ class TransactionCoordinatorTest { when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, 10, 10, RecordBatch.NO_PRODUCER_ID, - 0, RecordBatch.NO_PRODUCER_EPOCH, 0, TransactionState.ONGOING, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_0))))) + 0, RecordBatch.NO_PRODUCER_EPOCH, 0, TransactionState.ONGOING, util.Set.of, 0, time.milliseconds(), TV_0))))) coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback) assertEquals(Errors.INVALID_PRODUCER_ID_MAPPING, error) @@ -547,7 +549,7 @@ class TransactionCoordinatorTest { when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, producerEpoch, - (producerEpoch - 1).toShort, 1, TransactionState.ONGOING, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_0))))) + (producerEpoch - 1).toShort, 1, TransactionState.ONGOING, util.Set.of, 0, time.milliseconds(), TV_0))))) coordinator.handleEndTransaction(transactionalId, producerId, 0, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback) assertEquals(Errors.PRODUCER_FENCED, error) @@ -561,7 +563,7 @@ class TransactionCoordinatorTest { when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, producerEpoch, - (producerEpoch - 1).toShort, 1, TransactionState.COMPLETE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion))))) + (producerEpoch - 1).toShort, 1, TransactionState.COMPLETE_COMMIT, util.Set.of, 0, time.milliseconds(), clientTransactionVersion))))) val epoch = if (isRetry) producerEpoch - 1 else producerEpoch coordinator.handleEndTransaction(transactionalId, producerId, epoch.toShort, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback) @@ -588,7 +590,7 @@ class TransactionCoordinatorTest { when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, producerEpoch, - (producerEpoch - 1).toShort, 1, TransactionState.COMPLETE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion))))) + (producerEpoch - 1).toShort, 1, TransactionState.COMPLETE_COMMIT, util.Set.of, 0, time.milliseconds(), clientTransactionVersion))))) val epoch = if (isRetry) producerEpoch - 1 else producerEpoch coordinator.handleEndTransaction(transactionalId, producerId, epoch.toShort, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback) @@ -605,7 +607,7 @@ class TransactionCoordinatorTest { def testEndTxnWhenStatusIsCompleteAbortAndResultIsAbortInV1(isRetry: Boolean): Unit = { val clientTransactionVersion = TransactionVersion.fromFeatureLevel(0) val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, - producerEpoch, (producerEpoch - 1).toShort, 1, TransactionState.COMPLETE_ABORT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion) + producerEpoch, (producerEpoch - 1).toShort, 1, TransactionState.COMPLETE_ABORT, util.Set.of, 0, time.milliseconds(), clientTransactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) @@ -624,7 +626,7 @@ class TransactionCoordinatorTest { def shouldReturnOkOnEndTxnWhenStatusIsCompleteAbortAndResultIsAbortInV2(isRetry: Boolean): Unit = { val clientTransactionVersion = TransactionVersion.fromFeatureLevel(2) val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, - producerEpoch, (producerEpoch - 1).toShort, 1, TransactionState.COMPLETE_ABORT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion) + producerEpoch, (producerEpoch - 1).toShort, 1, TransactionState.COMPLETE_ABORT, util.Set.of, 0, time.milliseconds(), clientTransactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) @@ -661,7 +663,7 @@ class TransactionCoordinatorTest { def shouldReturnInvalidTxnRequestOnEndTxnRequestWhenStatusIsCompleteAbortAndResultIsNotAbort(transactionVersion: Short): Unit = { val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, - producerEpoch, (producerEpoch - 1).toShort, 1, TransactionState.COMPLETE_ABORT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion) + producerEpoch, (producerEpoch - 1).toShort, 1, TransactionState.COMPLETE_ABORT, util.Set.of, 0, time.milliseconds(), clientTransactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) @@ -674,7 +676,7 @@ class TransactionCoordinatorTest { def shouldReturnInvalidTxnRequestOnEndTxnRequestWhenStatusIsCompleteCommitAndResultIsNotCommit(): Unit = { val clientTransactionVersion = TransactionVersion.fromFeatureLevel(0) val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, - producerEpoch, (producerEpoch - 1).toShort,1, TransactionState.COMPLETE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion) + producerEpoch, (producerEpoch - 1).toShort,1, TransactionState.COMPLETE_COMMIT, util.Set.of, 0, time.milliseconds(), clientTransactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) @@ -688,7 +690,7 @@ class TransactionCoordinatorTest { def testEndTxnRequestWhenStatusIsCompleteCommitAndResultIsAbortInV1(isRetry: Boolean): Unit = { val clientTransactionVersion = TransactionVersion.fromFeatureLevel(0) val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, - producerEpoch, (producerEpoch - 1).toShort, 1, TransactionState.COMPLETE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion) + producerEpoch, (producerEpoch - 1).toShort, 1, TransactionState.COMPLETE_COMMIT, util.Set.of, 0, time.milliseconds(), clientTransactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) @@ -707,7 +709,7 @@ class TransactionCoordinatorTest { def testEndTxnRequestWhenStatusIsCompleteCommitAndResultIsAbortInV2(isRetry: Boolean): Unit = { val clientTransactionVersion = TransactionVersion.fromFeatureLevel(2) val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, - producerEpoch, (producerEpoch - 1).toShort, 1, TransactionState.COMPLETE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion) + producerEpoch, (producerEpoch - 1).toShort, 1, TransactionState.COMPLETE_COMMIT, util.Set.of, 0, time.milliseconds(), clientTransactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) @@ -738,7 +740,7 @@ class TransactionCoordinatorTest { val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, - RecordBatch.NO_PRODUCER_ID, producerEpoch, (producerEpoch - 1).toShort, 1, TransactionState.PREPARE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion))))) + RecordBatch.NO_PRODUCER_ID, producerEpoch, (producerEpoch - 1).toShort, 1, TransactionState.PREPARE_COMMIT, util.Set.of, 0, time.milliseconds(), clientTransactionVersion))))) coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch(clientTransactionVersion), TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback) assertEquals(Errors.CONCURRENT_TRANSACTIONS, error) @@ -751,7 +753,7 @@ class TransactionCoordinatorTest { val clientTransactionVersion = TransactionVersion.fromFeatureLevel(transactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, - RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, TransactionState.PREPARE_ABORT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion))))) + RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, TransactionState.PREPARE_ABORT, util.Set.of, 0, time.milliseconds(), clientTransactionVersion))))) coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch(clientTransactionVersion), TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback) assertEquals(Errors.INVALID_TXN_STATE, error) @@ -763,7 +765,7 @@ class TransactionCoordinatorTest { val clientTransactionVersion = TransactionVersion.fromFeatureLevel(0) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, - RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, TransactionState.EMPTY, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion))))) + RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, TransactionState.EMPTY, util.Set.of, 0, time.milliseconds(), clientTransactionVersion))))) coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.ABORT, clientTransactionVersion, endTxnCallback) assertEquals(Errors.INVALID_TXN_STATE, error) @@ -776,7 +778,7 @@ class TransactionCoordinatorTest { val clientTransactionVersion = TransactionVersion.fromFeatureLevel(2) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, - RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, TransactionState.EMPTY, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion))))) + RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, TransactionState.EMPTY, util.Set.of, 0, time.milliseconds(), clientTransactionVersion))))) val epoch = if (isRetry) producerEpoch - 1 else producerEpoch coordinator.handleEndTransaction(transactionalId, producerId, epoch.toShort, TransactionResult.ABORT, clientTransactionVersion, endTxnCallback) @@ -805,7 +807,7 @@ class TransactionCoordinatorTest { val clientTransactionVersion = TransactionVersion.fromFeatureLevel(2) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, - RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, TransactionState.EMPTY, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion))))) + RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, TransactionState.EMPTY, util.Set.of, 0, time.milliseconds(), clientTransactionVersion))))) val epoch = if (isRetry) producerEpoch - 1 else producerEpoch coordinator.handleEndTransaction(transactionalId, producerId, epoch.toShort, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback) @@ -821,7 +823,7 @@ class TransactionCoordinatorTest { def shouldReturnInvalidTxnRequestOnEndTxnV2IfNotEndTxnV2Retry(): Unit = { when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, - RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, TransactionState.PREPARE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2))))) + RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, TransactionState.PREPARE_COMMIT, util.Set.of, 0, time.milliseconds(), TV_2))))) // If producerEpoch is the same, this is not a retry of the EndTxnRequest, but the next EndTxnRequest. Return PRODUCER_FENCED. coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.COMMIT, TV_2, endTxnCallback) @@ -830,7 +832,7 @@ class TransactionCoordinatorTest { when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, - RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, TransactionState.COMPLETE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2))))) + RecordBatch.NO_PRODUCER_ID, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, 1, TransactionState.COMPLETE_COMMIT, util.Set.of, 0, time.milliseconds(), TV_2))))) // If producerEpoch is the same, this is not a retry of the EndTxnRequest, but the next EndTxnRequest. Return INVALID_TXN_STATE. coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.COMMIT, TV_2, endTxnCallback) @@ -842,7 +844,7 @@ class TransactionCoordinatorTest { def shouldReturnOkOnEndTxnV2IfEndTxnV2RetryEpochOverflow(): Unit = { when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, - producerId2, Short.MaxValue, (Short.MaxValue - 1).toShort, 1, TransactionState.PREPARE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2))))) + producerId2, Short.MaxValue, (Short.MaxValue - 1).toShort, 1, TransactionState.PREPARE_COMMIT, util.Set.of, 0, time.milliseconds(), TV_2))))) // Return CONCURRENT_TRANSACTIONS while transaction is still completing coordinator.handleEndTransaction(transactionalId, producerId, (Short.MaxValue - 1).toShort, TransactionResult.COMMIT, TV_2, endTxnCallback) @@ -851,7 +853,7 @@ class TransactionCoordinatorTest { when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId2, producerId, - RecordBatch.NO_PRODUCER_ID, 0, RecordBatch.NO_PRODUCER_EPOCH, 1, TransactionState.COMPLETE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2))))) + RecordBatch.NO_PRODUCER_ID, 0, RecordBatch.NO_PRODUCER_EPOCH, 1, TransactionState.COMPLETE_COMMIT, util.Set.of, 0, time.milliseconds(), TV_2))))) coordinator.handleEndTransaction(transactionalId, producerId, (Short.MaxValue - 1).toShort, TransactionResult.COMMIT, TV_2, endTxnCallback) assertEquals(Errors.NONE, error) @@ -864,7 +866,7 @@ class TransactionCoordinatorTest { @Test def shouldReturnConcurrentTxnOnAddPartitionsIfEndTxnV2EpochOverflowAndNotComplete(): Unit = { val prepareWithPending = new TransactionMetadata(transactionalId, producerId, producerId, - producerId2, Short.MaxValue, (Short.MaxValue - 1).toShort, 1, TransactionState.PREPARE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), TV_2) + producerId2, Short.MaxValue, (Short.MaxValue - 1).toShort, 1, TransactionState.PREPARE_COMMIT, util.Set.of, 0, time.milliseconds(), TV_2) val txnTransitMetadata = prepareWithPending.prepareComplete(time.milliseconds()) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) @@ -990,7 +992,7 @@ class TransactionCoordinatorTest { when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, metadataEpoch, 1, - 1, TransactionState.COMPLETE_COMMIT, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds(), clientTransactionVersion))))) + 1, TransactionState.COMPLETE_COMMIT, util.Set.of, 0, time.milliseconds(), clientTransactionVersion))))) coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch, TransactionResult.COMMIT, clientTransactionVersion, endTxnCallback) assertEquals(Errors.PRODUCER_FENCED, error) @@ -1132,10 +1134,10 @@ class TransactionCoordinatorTest { any()) ).thenAnswer(_ => { capturedErrorsCallback.getValue.apply(Errors.NOT_ENOUGH_REPLICAS) - txnMetadata.pendingState = None + txnMetadata.pendingState(util.Optional.empty()) }).thenAnswer(_ => { capturedErrorsCallback.getValue.apply(Errors.NOT_ENOUGH_REPLICAS) - txnMetadata.pendingState = None + txnMetadata.pendingState(util.Optional.empty()) }).thenAnswer(_ => { capturedErrorsCallback.getValue.apply(Errors.NONE) @@ -1226,7 +1228,7 @@ class TransactionCoordinatorTest { RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.PREPARE_ABORT, - partitions.clone.asJava, + partitions, time.milliseconds(), time.milliseconds(), TV_0)), @@ -1259,7 +1261,7 @@ class TransactionCoordinatorTest { RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.PREPARE_ABORT, - partitions.clone.asJava, + partitions, time.milliseconds(), time.milliseconds(), TV_0)), @@ -1334,18 +1336,18 @@ class TransactionCoordinatorTest { // Create transaction metadata at the epoch boundary that would cause overflow IFF double-incremented val txnMetadata = new TransactionMetadata( - transactionalId = transactionalId, - producerId = producerId, - prevProducerId = RecordBatch.NO_PRODUCER_ID, - nextProducerId = RecordBatch.NO_PRODUCER_ID, - producerEpoch = epochAtMaxBoundary, - lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, - txnTimeoutMs = txnTimeoutMs, - state = TransactionState.ONGOING, - topicPartitions = partitions, - txnStartTimestamp = now, - txnLastUpdateTimestamp = now, - clientTransactionVersion = TV_2 + transactionalId, + producerId, + RecordBatch.NO_PRODUCER_ID, + RecordBatch.NO_PRODUCER_ID, + epochAtMaxBoundary, + RecordBatch.NO_PRODUCER_EPOCH, + txnTimeoutMs, + TransactionState.ONGOING, + partitions, + now, + now, + TV_2 ) assertTrue(txnMetadata.isProducerEpochExhausted) @@ -1472,7 +1474,7 @@ class TransactionCoordinatorTest { any()) ).thenAnswer(_ => { capturedErrorsCallback.getValue.apply(Errors.NONE) - txnMetadata.pendingState = None + txnMetadata.pendingState(util.Optional.empty()) }) // Re-initialization should succeed and bump the producer epoch @@ -1520,9 +1522,9 @@ class TransactionCoordinatorTest { any()) ).thenAnswer(_ => { capturedErrorsCallback.getValue.apply(Errors.NONE) - txnMetadata.pendingState = None - txnMetadata.producerEpoch = capturedTxnTransitMetadata.getValue.producerEpoch - txnMetadata.lastProducerEpoch = capturedTxnTransitMetadata.getValue.lastProducerEpoch + txnMetadata.pendingState(util.Optional.empty()) + txnMetadata.setProducerEpoch(capturedTxnTransitMetadata.getValue.producerEpoch) + txnMetadata.setLastProducerEpoch(capturedTxnTransitMetadata.getValue.lastProducerEpoch) }) // With producer epoch at 10, new producer calls InitProducerId and should get epoch 11 @@ -1571,11 +1573,11 @@ class TransactionCoordinatorTest { any()) ).thenAnswer(_ => { capturedErrorsCallback.getValue.apply(Errors.NONE) - txnMetadata.pendingState = None - txnMetadata.producerId = capturedTxnTransitMetadata.getValue.producerId - txnMetadata.prevProducerId = capturedTxnTransitMetadata.getValue.prevProducerId - txnMetadata.producerEpoch = capturedTxnTransitMetadata.getValue.producerEpoch - txnMetadata.lastProducerEpoch = capturedTxnTransitMetadata.getValue.lastProducerEpoch + txnMetadata.pendingState(util.Optional.empty()) + txnMetadata.setProducerId(capturedTxnTransitMetadata.getValue.producerId) + txnMetadata.setPrevProducerId(capturedTxnTransitMetadata.getValue.prevProducerId) + txnMetadata.setProducerEpoch(capturedTxnTransitMetadata.getValue.producerEpoch) + txnMetadata.setLastProducerEpoch(capturedTxnTransitMetadata.getValue.lastProducerEpoch) }) // Bump epoch and cause producer ID to be rotated @@ -1624,11 +1626,11 @@ class TransactionCoordinatorTest { any()) ).thenAnswer(_ => { capturedErrorsCallback.getValue.apply(Errors.NONE) - txnMetadata.pendingState = None - txnMetadata.producerId = capturedTxnTransitMetadata.getValue.producerId - txnMetadata.prevProducerId = capturedTxnTransitMetadata.getValue.prevProducerId - txnMetadata.producerEpoch = capturedTxnTransitMetadata.getValue.producerEpoch - txnMetadata.lastProducerEpoch = capturedTxnTransitMetadata.getValue.lastProducerEpoch + txnMetadata.pendingState(util.Optional.empty()) + txnMetadata.setProducerId(capturedTxnTransitMetadata.getValue.producerId) + txnMetadata.setPrevProducerId(capturedTxnTransitMetadata.getValue.prevProducerId) + txnMetadata.setProducerEpoch(capturedTxnTransitMetadata.getValue.producerEpoch) + txnMetadata.setLastProducerEpoch(capturedTxnTransitMetadata.getValue.lastProducerEpoch) }) // Bump epoch and cause producer ID to be rotated @@ -1674,7 +1676,7 @@ class TransactionCoordinatorTest { // Transaction timeouts use FenceProducerEpoch so clientTransactionVersion is 0. val expectedTransition = new TxnTransitMetadata(producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, (producerEpoch + 1).toShort, - RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.PREPARE_ABORT, partitions.clone.asJava, now, + RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.PREPARE_ABORT, partitions, now, now + TransactionStateManagerConfig.TRANSACTIONS_ABORT_TIMED_OUT_TRANSACTION_CLEANUP_INTERVAL_MS_DEFAULT, TV_0) when(transactionManager.transactionVersionLevel()).thenReturn(TV_0) @@ -1764,7 +1766,7 @@ class TransactionCoordinatorTest { // Transaction timeouts use FenceProducerEpoch so clientTransactionVersion is 0. val bumpedEpoch = (producerEpoch + 1).toShort val expectedTransition = new TxnTransitMetadata(producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, bumpedEpoch, - RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.PREPARE_ABORT, partitions.clone.asJava, now, + RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.PREPARE_ABORT, partitions, now, now + TransactionStateManagerConfig.TRANSACTIONS_ABORT_TIMED_OUT_TRANSACTION_CLEANUP_INTERVAL_MS_DEFAULT, TV_0) when(transactionManager.transactionVersionLevel()).thenReturn(TV_0) @@ -1832,7 +1834,7 @@ class TransactionCoordinatorTest { coordinator.startup(() => transactionStatePartitionCount, enableTransactionalIdExpiration = false) val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, - RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.DEAD, mutable.Set.empty, time.milliseconds(), + RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.DEAD, util.Set.of, time.milliseconds(), time.milliseconds(), TV_0) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) @@ -1872,9 +1874,11 @@ class TransactionCoordinatorTest { assertEquals(txnTimeoutMs, result.transactionTimeoutMs) assertEquals(time.milliseconds(), result.transactionStartTimeMs) - val addedPartitions = result.topics.asScala.flatMap { topicData => - topicData.partitions.asScala.map(partition => new TopicPartition(topicData.topic, partition)) - }.toSet + val addedPartitions = result.topics.stream.flatMap(topicData => + topicData.partitions.stream + .map(partition => new TopicPartition(topicData.topic, partition)) + ) + .collect(util.stream.Collectors.toSet()); assertEquals(partitions, addedPartitions) verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) @@ -1886,7 +1890,7 @@ class TransactionCoordinatorTest { // Since the clientTransactionVersion doesn't matter, use 2 since the states are TransactionState.PREPARE_COMMIT and TransactionState.PREPARE_ABORT. val metadata = new TransactionMetadata(transactionalId, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, - 0, RecordBatch.NO_PRODUCER_EPOCH, 0, state, mutable.Set[TopicPartition](new TopicPartition("topic", 1)), 0, 0, TV_2) + 0, RecordBatch.NO_PRODUCER_EPOCH, 0, state, util.Set.of[TopicPartition](new TopicPartition("topic", 1)), 0, 0, TV_2) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, metadata)))) @@ -1905,7 +1909,7 @@ class TransactionCoordinatorTest { .thenReturn(true) val metadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, - producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, state, mutable.Set.empty[TopicPartition], time.milliseconds(), time.milliseconds(), clientTransactionVersion) + producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, state, util.Set.of, time.milliseconds(), time.milliseconds(), clientTransactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, metadata)))) @@ -1939,7 +1943,7 @@ class TransactionCoordinatorTest { producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, TransactionState.ONGOING, partitions, now, now, TV_0) val transition = new TxnTransitMetadata(producerId, producerId, RecordBatch.NO_PRODUCER_EPOCH, producerEpoch, - RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, transactionState, partitions.clone.asJava, now, now, clientTransactionVersion) + RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, transactionState, partitions, now, now, clientTransactionVersion) when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, originalMetadata)))) @@ -2007,7 +2011,7 @@ class TransactionCoordinatorTest { // Simulate the real TransactionStateManager behavior: reset pendingState on failure // since handleInitProducerId doesn't provide a custom retryOnError function - txnMetadata.pendingState = None + txnMetadata.pendingState(util.Optional.empty()) // For TV2, hasFailedEpochFence is NOT set to true, allowing epoch bumps on retry // The epoch remains at its original value (1) since completeTransitionTo was never called @@ -2062,7 +2066,7 @@ class TransactionCoordinatorTest { // Simulate the completion of transaction markers and the second write // This would normally happen asynchronously after markers are sent txnMetadata.completeTransitionTo(newMetadata) // This transitions to COMPLETE_ABORT - txnMetadata.pendingState = None + txnMetadata.pendingState(util.Optional.empty()) null }).when(transactionMarkerChannelManager).addTxnMarkersToSend( diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala index dd378757130..46116be6d7c 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala @@ -22,7 +22,7 @@ import org.apache.kafka.common.protocol.{ByteBufferAccessor, MessageUtil} import org.apache.kafka.common.protocol.types.Field.TaggedFieldsSection import org.apache.kafka.common.protocol.types.{CompactArrayOf, Field, Schema, Struct, Type} import org.apache.kafka.common.record.{MemoryRecords, RecordBatch, SimpleRecord} -import org.apache.kafka.coordinator.transaction.{TransactionState, TxnTransitMetadata} +import org.apache.kafka.coordinator.transaction.{TransactionMetadata, TransactionState, TxnTransitMetadata} import org.apache.kafka.coordinator.transaction.generated.{TransactionLogKey, TransactionLogValue} import org.apache.kafka.server.common.TransactionVersion.{TV_0, TV_2} import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows, assertTrue, fail} @@ -38,7 +38,7 @@ class TransactionLogTest { val producerEpoch: Short = 0 val transactionTimeoutMs: Int = 1000 - val topicPartitions: Set[TopicPartition] = Set[TopicPartition](new TopicPartition("topic1", 0), + val topicPartitions = util.Set.of(new TopicPartition("topic1", 0), new TopicPartition("topic1", 1), new TopicPartition("topic2", 0), new TopicPartition("topic2", 1), @@ -50,7 +50,7 @@ class TransactionLogTest { val producerId = 23423L val txnMetadata = new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_ID, producerEpoch, - RecordBatch.NO_PRODUCER_EPOCH, transactionTimeoutMs, TransactionState.EMPTY, collection.mutable.Set.empty[TopicPartition], 0, 0, TV_0) + RecordBatch.NO_PRODUCER_EPOCH, transactionTimeoutMs, TransactionState.EMPTY, util.Set.of, 0, 0, TV_0) txnMetadata.addPartitions(topicPartitions) assertThrows(classOf[IllegalStateException], () => TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), TV_2)) @@ -75,7 +75,7 @@ class TransactionLogTest { // generate transaction log messages val txnRecords = pidMappings.map { case (transactionalId, producerId) => val txnMetadata = new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_ID, producerEpoch, - RecordBatch.NO_PRODUCER_EPOCH, transactionTimeoutMs, transactionStates(producerId), collection.mutable.Set.empty[TopicPartition], 0, 0, TV_0) + RecordBatch.NO_PRODUCER_EPOCH, transactionTimeoutMs, transactionStates(producerId), util.Set.of, 0, 0, TV_0) if (!txnMetadata.state.equals(TransactionState.EMPTY)) txnMetadata.addPartitions(topicPartitions) @@ -101,7 +101,7 @@ class TransactionLogTest { assertEquals(transactionStates(txnMetadata.producerId), txnMetadata.state) if (txnMetadata.state.equals(TransactionState.EMPTY)) - assertEquals(Set.empty[TopicPartition], txnMetadata.topicPartitions) + assertEquals(util.Set.of, txnMetadata.topicPartitions) else assertEquals(topicPartitions, txnMetadata.topicPartitions) @@ -114,14 +114,14 @@ class TransactionLogTest { @Test def testSerializeTransactionLogValueToHighestNonFlexibleVersion(): Unit = { - val txnTransitMetadata = new TxnTransitMetadata(1, 1, 1, 1, 1, 1000, TransactionState.COMPLETE_COMMIT, util.Set.of, 500, 500, TV_0) + val txnTransitMetadata = new TxnTransitMetadata(1, 1, 1, 1, 1, 1000, TransactionState.COMPLETE_COMMIT, new util.HashSet(), 500, 500, TV_0) val txnLogValueBuffer = ByteBuffer.wrap(TransactionLog.valueToBytes(txnTransitMetadata, TV_0)) assertEquals(0, txnLogValueBuffer.getShort) } @Test def testSerializeTransactionLogValueToFlexibleVersion(): Unit = { - val txnTransitMetadata = new TxnTransitMetadata(1, 1, 1, 1, 1, 1000, TransactionState.COMPLETE_COMMIT, util.Set.of, 500, 500, TV_2) + val txnTransitMetadata = new TxnTransitMetadata(1, 1, 1, 1, 1, 1000, TransactionState.COMPLETE_COMMIT, new util.HashSet(), 500, 500, TV_2) val txnLogValueBuffer = ByteBuffer.wrap(TransactionLog.valueToBytes(txnTransitMetadata, TV_2)) assertEquals(TransactionLogValue.HIGHEST_SUPPORTED_VERSION, txnLogValueBuffer.getShort) } @@ -229,7 +229,7 @@ class TransactionLogTest { assertEquals(100, txnMetadata.producerEpoch) assertEquals(1000L, txnMetadata.txnTimeoutMs) assertEquals(TransactionState.COMPLETE_COMMIT, txnMetadata.state) - assertEquals(Set(new TopicPartition("topic", 1)), txnMetadata.topicPartitions) + assertEquals(util.Set.of(new TopicPartition("topic", 1)), txnMetadata.topicPartitions) assertEquals(2000L, txnMetadata.txnLastUpdateTimestamp) assertEquals(3000L, txnMetadata.txnStartTimestamp) } diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala index ec12afe6489..7699d643a3e 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala @@ -27,7 +27,7 @@ import org.apache.kafka.common.record.RecordBatch import org.apache.kafka.common.requests.{RequestHeader, TransactionResult, WriteTxnMarkersRequest, WriteTxnMarkersResponse} import org.apache.kafka.common.utils.MockTime import org.apache.kafka.common.{Node, TopicPartition} -import org.apache.kafka.coordinator.transaction.TransactionState +import org.apache.kafka.coordinator.transaction.{TransactionMetadata, TransactionState} import org.apache.kafka.metadata.MetadataCache import org.apache.kafka.server.common.{MetadataVersion, TransactionVersion} import org.apache.kafka.server.metrics.{KafkaMetricsGroup, KafkaYammerMetrics} @@ -41,7 +41,6 @@ import org.mockito.{ArgumentCaptor, ArgumentMatchers} import org.mockito.Mockito.{clearInvocations, mock, mockConstruction, times, verify, verifyNoMoreInteractions, when} import scala.jdk.CollectionConverters._ -import scala.collection.mutable import scala.util.Try class TransactionMarkerChannelManagerTest { @@ -67,9 +66,9 @@ class TransactionMarkerChannelManagerTest { private val txnTimeoutMs = 0 private val txnResult = TransactionResult.COMMIT private val txnMetadata1 = new TransactionMetadata(transactionalId1, producerId1, producerId1, RecordBatch.NO_PRODUCER_ID, - producerEpoch, lastProducerEpoch, txnTimeoutMs, TransactionState.PREPARE_COMMIT, mutable.Set[TopicPartition](partition1, partition2), 0L, 0L, TransactionVersion.TV_2) + producerEpoch, lastProducerEpoch, txnTimeoutMs, TransactionState.PREPARE_COMMIT, util.Set.of(partition1, partition2), 0L, 0L, TransactionVersion.TV_2) private val txnMetadata2 = new TransactionMetadata(transactionalId2, producerId2, producerId2, RecordBatch.NO_PRODUCER_ID, - producerEpoch, lastProducerEpoch, txnTimeoutMs, TransactionState.PREPARE_COMMIT, mutable.Set[TopicPartition](partition1), 0L, 0L, TransactionVersion.TV_2) + producerEpoch, lastProducerEpoch, txnTimeoutMs, TransactionState.PREPARE_COMMIT, util.Set.of(partition1), 0L, 0L, TransactionVersion.TV_2) private val capturedErrorsCallback: ArgumentCaptor[Errors => Unit] = ArgumentCaptor.forClass(classOf[Errors => Unit]) private val time = new MockTime @@ -145,33 +144,33 @@ class TransactionMarkerChannelManagerTest { var addMarkerFuture: Future[Try[Unit]] = null val executor = Executors.newFixedThreadPool(1) - txnMetadata2.lock.lock() try { - addMarkerFuture = executor.submit((() => { - Try(channelManager.addTxnMarkersToSend(coordinatorEpoch, txnResult, + txnMetadata2.inLock(() => { + addMarkerFuture = executor.submit((() => { + Try(channelManager.addTxnMarkersToSend(coordinatorEpoch, txnResult, txnMetadata2, expectedTransition)) - }): Callable[Try[Unit]]) + }): Callable[Try[Unit]]) - val header = new RequestHeader(ApiKeys.WRITE_TXN_MARKERS, 0, "client", 1) - val response = new WriteTxnMarkersResponse( - util.Map.of(producerId2: java.lang.Long, util.Map.of(partition1, Errors.NONE))) - val clientResponse = new ClientResponse(header, null, null, - time.milliseconds(), time.milliseconds(), false, null, null, - response) + val header = new RequestHeader(ApiKeys.WRITE_TXN_MARKERS, 0, "client", 1) + val response = new WriteTxnMarkersResponse( + util.Map.of(producerId2: java.lang.Long, util.Map.of(partition1, Errors.NONE))) + val clientResponse = new ClientResponse(header, null, null, + time.milliseconds(), time.milliseconds(), false, null, null, + response) - TestUtils.waitUntilTrue(() => { - val requests = channelManager.generateRequests().asScala - if (requests.nonEmpty) { - assertEquals(1, requests.size) - val request = requests.head - request.handler.onComplete(clientResponse) - true - } else { - false - } - }, "Timed out waiting for expected WriteTxnMarkers request") + TestUtils.waitUntilTrue(() => { + val requests = channelManager.generateRequests().asScala + if (requests.nonEmpty) { + assertEquals(1, requests.size) + val request = requests.head + request.handler.onComplete(clientResponse) + true + } else { + false + } + }, "Timed out waiting for expected WriteTxnMarkers request") + }) } finally { - txnMetadata2.lock.unlock() executor.shutdown() } @@ -478,7 +477,7 @@ class TransactionMarkerChannelManagerTest { assertEquals(0, channelManager.numTxnsWithPendingMarkers) assertEquals(0, channelManager.queueForBroker(broker1.id).get.totalNumMarkers) - assertEquals(None, txnMetadata2.pendingState) + assertEquals(Optional.empty(), txnMetadata2.pendingState) assertEquals(TransactionState.COMPLETE_COMMIT, txnMetadata2.state) } @@ -507,7 +506,7 @@ class TransactionMarkerChannelManagerTest { any(), any())) .thenAnswer(_ => { - txnMetadata2.pendingState = None + txnMetadata2.pendingState(util.Optional.empty()) capturedErrorsCallback.getValue.apply(Errors.NOT_COORDINATOR) }) @@ -531,7 +530,7 @@ class TransactionMarkerChannelManagerTest { assertEquals(0, channelManager.numTxnsWithPendingMarkers) assertEquals(0, channelManager.queueForBroker(broker1.id).get.totalNumMarkers) - assertEquals(None, txnMetadata2.pendingState) + assertEquals(Optional.empty(), txnMetadata2.pendingState) assertEquals(TransactionState.PREPARE_COMMIT, txnMetadata2.state) } @@ -592,7 +591,7 @@ class TransactionMarkerChannelManagerTest { assertEquals(0, channelManager.numTxnsWithPendingMarkers) assertEquals(0, channelManager.queueForBroker(broker1.id).get.totalNumMarkers) - assertEquals(None, txnMetadata2.pendingState) + assertEquals(Optional.empty(), txnMetadata2.pendingState) assertEquals(TransactionState.COMPLETE_COMMIT, txnMetadata2.state) } @@ -632,11 +631,11 @@ class TransactionMarkerChannelManagerTest { txnMetadata: TransactionMetadata ): Unit = { if (isTransactionV2Enabled) { - txnMetadata.clientTransactionVersion = TransactionVersion.TV_2 - txnMetadata.producerEpoch = (producerEpoch + 1).toShort - txnMetadata.lastProducerEpoch = producerEpoch + txnMetadata.clientTransactionVersion(TransactionVersion.TV_2) + txnMetadata.setProducerEpoch((producerEpoch + 1).toShort) + txnMetadata.setLastProducerEpoch(producerEpoch) } else { - txnMetadata.clientTransactionVersion = TransactionVersion.TV_1 + txnMetadata.clientTransactionVersion(TransactionVersion.TV_1) } } } diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala index 56dc0ec266c..e955a9009ce 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala @@ -22,15 +22,13 @@ import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.protocol.{ApiKeys, Errors} import org.apache.kafka.common.record.RecordBatch import org.apache.kafka.common.requests.{RequestHeader, TransactionResult, WriteTxnMarkersRequest, WriteTxnMarkersResponse} -import org.apache.kafka.coordinator.transaction.TransactionState +import org.apache.kafka.coordinator.transaction.{TransactionMetadata, TransactionState} import org.apache.kafka.server.common.TransactionVersion import org.junit.jupiter.api.Assertions._ import org.junit.jupiter.api.Test import org.mockito.ArgumentMatchers import org.mockito.Mockito.{mock, verify, when} -import scala.collection.mutable - class TransactionMarkerRequestCompletionHandlerTest { private val brokerId = 0 @@ -44,7 +42,7 @@ class TransactionMarkerRequestCompletionHandlerTest { private val txnResult = TransactionResult.COMMIT private val topicPartition = new TopicPartition("topic1", 0) private val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, RecordBatch.NO_PRODUCER_ID, - producerEpoch, lastProducerEpoch, txnTimeoutMs, TransactionState.PREPARE_COMMIT, mutable.Set[TopicPartition](topicPartition), 0L, 0L, TransactionVersion.TV_2) + producerEpoch, lastProducerEpoch, txnTimeoutMs, TransactionState.PREPARE_COMMIT, util.Set.of(topicPartition), 0L, 0L, TransactionVersion.TV_2) private val pendingCompleteTxnAndMarkers = util.List.of( PendingCompleteTxnAndMarkerEntry( PendingCompleteTxn(transactionalId, coordinatorEpoch, txnMetadata, txnMetadata.prepareComplete(42)), @@ -194,7 +192,7 @@ class TransactionMarkerRequestCompletionHandlerTest { handler.onComplete(new ClientResponse(new RequestHeader(ApiKeys.PRODUCE, 0, "client", 1), null, null, 0, 0, false, null, null, response)) - assertEquals(txnMetadata.topicPartitions, mutable.Set[TopicPartition](topicPartition)) + assertEquals(txnMetadata.topicPartitions, util.Set.of(topicPartition)) verify(markerChannelManager).addTxnMarkersToBrokerQueue(producerId, producerEpoch, txnResult, pendingCompleteTxnAndMarkers.get(0).pendingCompleteTxn, Set[TopicPartition](topicPartition)) diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMetadataTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMetadataTest.scala index 453bc13f0bd..87a18b18dc0 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMetadataTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMetadataTest.scala @@ -19,7 +19,7 @@ package kafka.coordinator.transaction import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.protocol.Errors import org.apache.kafka.common.record.RecordBatch -import org.apache.kafka.coordinator.transaction.{TransactionState, TxnTransitMetadata} +import org.apache.kafka.coordinator.transaction.{TransactionMetadata, TransactionState, TxnTransitMetadata} import org.apache.kafka.server.common.TransactionVersion import org.apache.kafka.server.common.TransactionVersion.{TV_0, TV_2} import org.apache.kafka.server.util.MockTime @@ -28,6 +28,7 @@ import org.junit.jupiter.api.Test import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.ValueSource +import java.util import java.util.Optional import scala.collection.mutable @@ -44,19 +45,20 @@ class TransactionMetadataTest { val producerEpoch = RecordBatch.NO_PRODUCER_EPOCH val txnMetadata = new TransactionMetadata( - transactionalId = transactionalId, - producerId = producerId, - prevProducerId = RecordBatch.NO_PRODUCER_ID, - nextProducerId = RecordBatch.NO_PRODUCER_ID, - producerEpoch = producerEpoch, - lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, - txnTimeoutMs = 30000, - state = TransactionState.EMPTY, - topicPartitions = mutable.Set.empty, - txnLastUpdateTimestamp = time.milliseconds(), - clientTransactionVersion = TV_0) + transactionalId, + producerId, + RecordBatch.NO_PRODUCER_ID, + RecordBatch.NO_PRODUCER_ID, + producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, + 30000, + TransactionState.EMPTY, + util.Set.of, + -1, + time.milliseconds(), + TV_0) - val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, None) + val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, Optional.empty()) txnMetadata.completeTransitionTo(transitMetadata) assertEquals(producerId, txnMetadata.producerId) assertEquals(0, txnMetadata.producerEpoch) @@ -68,19 +70,20 @@ class TransactionMetadataTest { val producerEpoch = 735.toShort val txnMetadata = new TransactionMetadata( - transactionalId = transactionalId, - producerId = producerId, - prevProducerId = RecordBatch.NO_PRODUCER_ID, - nextProducerId = RecordBatch.NO_PRODUCER_ID, - producerEpoch = producerEpoch, - lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, - txnTimeoutMs = 30000, - state = TransactionState.EMPTY, - topicPartitions = mutable.Set.empty, - txnLastUpdateTimestamp = time.milliseconds(), - clientTransactionVersion = TV_0) + transactionalId, + producerId, + RecordBatch.NO_PRODUCER_ID, + RecordBatch.NO_PRODUCER_ID, + producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, + 30000, + TransactionState.EMPTY, + util.Set.of, + -1, + time.milliseconds(), + TV_0) - val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, None) + val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, Optional.empty()) txnMetadata.completeTransitionTo(transitMetadata) assertEquals(producerId, txnMetadata.producerId) assertEquals(producerEpoch + 1, txnMetadata.producerEpoch) @@ -92,21 +95,22 @@ class TransactionMetadataTest { val producerEpoch = (Short.MaxValue - 1).toShort val txnMetadata = new TransactionMetadata( - transactionalId = transactionalId, - producerId = producerId, - prevProducerId = RecordBatch.NO_PRODUCER_ID, - nextProducerId = RecordBatch.NO_PRODUCER_ID, - producerEpoch = producerEpoch, - lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, - txnTimeoutMs = 30000, - state = TransactionState.EMPTY, - topicPartitions = mutable.Set.empty, - txnLastUpdateTimestamp = time.milliseconds(), - clientTransactionVersion = TV_0) + transactionalId, + producerId, + RecordBatch.NO_PRODUCER_ID, + RecordBatch.NO_PRODUCER_ID, + producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, + 30000, + TransactionState.EMPTY, + util.Set.of, + -1, + time.milliseconds(), + TV_0) assertTrue(txnMetadata.isProducerEpochExhausted) assertThrows(classOf[IllegalStateException], () => txnMetadata.prepareIncrementProducerEpoch(30000, - None, time.milliseconds())) + Optional.empty, time.milliseconds())) } @Test @@ -114,20 +118,20 @@ class TransactionMetadataTest { val producerEpoch = 735.toShort val txnMetadata = new TransactionMetadata( - transactionalId = transactionalId, - producerId = producerId, - prevProducerId = RecordBatch.NO_PRODUCER_ID, - nextProducerId = RecordBatch.NO_PRODUCER_ID, - producerEpoch = producerEpoch, - lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, - txnTimeoutMs = 30000, - state = TransactionState.EMPTY, - topicPartitions = mutable.Set.empty, - txnStartTimestamp = -1, - txnLastUpdateTimestamp = time.milliseconds(), - clientTransactionVersion = TV_2) + transactionalId, + producerId, + RecordBatch.NO_PRODUCER_ID, + RecordBatch.NO_PRODUCER_ID, + producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, + 30000, + TransactionState.EMPTY, + util.Set.of, + -1, + time.milliseconds(), + TV_2) - val transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_ABORT, TV_2, RecordBatch.NO_PRODUCER_ID, time.milliseconds() + 1, noPartitionAdded = true) + val transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_ABORT, TV_2, RecordBatch.NO_PRODUCER_ID, time.milliseconds() + 1, true) txnMetadata.completeTransitionTo(transitMetadata) assertEquals(producerId, txnMetadata.producerId) assertEquals(producerEpoch + 1, txnMetadata.producerEpoch) @@ -139,20 +143,20 @@ class TransactionMetadataTest { val producerEpoch = 735.toShort val txnMetadata = new TransactionMetadata( - transactionalId = transactionalId, - producerId = producerId, - prevProducerId = RecordBatch.NO_PRODUCER_ID, - nextProducerId = RecordBatch.NO_PRODUCER_ID, - producerEpoch = producerEpoch, - lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, - txnTimeoutMs = 30000, - state = TransactionState.COMPLETE_ABORT, - topicPartitions = mutable.Set.empty, - txnStartTimestamp = time.milliseconds() - 1, - txnLastUpdateTimestamp = time.milliseconds(), - clientTransactionVersion = TV_2) + transactionalId, + producerId, + RecordBatch.NO_PRODUCER_ID, + RecordBatch.NO_PRODUCER_ID, + producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, + 30000, + TransactionState.COMPLETE_ABORT, + util.Set.of, + time.milliseconds() - 1, + time.milliseconds(), + TV_2) - val transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_ABORT, TV_2, RecordBatch.NO_PRODUCER_ID, time.milliseconds() + 1, noPartitionAdded = true) + val transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_ABORT, TV_2, RecordBatch.NO_PRODUCER_ID, time.milliseconds() + 1, true) txnMetadata.completeTransitionTo(transitMetadata) assertEquals(producerId, txnMetadata.producerId) assertEquals(producerEpoch + 1, txnMetadata.producerEpoch) @@ -164,20 +168,20 @@ class TransactionMetadataTest { val producerEpoch = 735.toShort val txnMetadata = new TransactionMetadata( - transactionalId = transactionalId, - producerId = producerId, - prevProducerId = RecordBatch.NO_PRODUCER_ID, - nextProducerId = RecordBatch.NO_PRODUCER_ID, - producerEpoch = producerEpoch, - lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, - txnTimeoutMs = 30000, - state = TransactionState.COMPLETE_COMMIT, - topicPartitions = mutable.Set.empty, - txnStartTimestamp = time.milliseconds() - 1, - txnLastUpdateTimestamp = time.milliseconds(), - clientTransactionVersion = TV_2) + transactionalId, + producerId, + RecordBatch.NO_PRODUCER_ID, + RecordBatch.NO_PRODUCER_ID, + producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, + 30000, + TransactionState.COMPLETE_COMMIT, + util.Set.of, + time.milliseconds() - 1, + time.milliseconds(), + TV_2) - val transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_ABORT, TV_2, RecordBatch.NO_PRODUCER_ID, time.milliseconds() + 1, noPartitionAdded = true) + val transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_ABORT, TV_2, RecordBatch.NO_PRODUCER_ID, time.milliseconds() + 1, true) txnMetadata.completeTransitionTo(transitMetadata) assertEquals(producerId, txnMetadata.producerId) assertEquals(producerEpoch + 1, txnMetadata.producerEpoch) @@ -188,21 +192,21 @@ class TransactionMetadataTest { def testTolerateUpdateTimeShiftDuringEpochBump(): Unit = { val producerEpoch: Short = 1 val txnMetadata = new TransactionMetadata( - transactionalId = transactionalId, - producerId = producerId, - prevProducerId = RecordBatch.NO_PRODUCER_ID, - nextProducerId = RecordBatch.NO_PRODUCER_ID, - producerEpoch = producerEpoch, - lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, - txnTimeoutMs = 30000, - state = TransactionState.EMPTY, - topicPartitions = mutable.Set.empty, - txnStartTimestamp = 1L, - txnLastUpdateTimestamp = time.milliseconds(), - clientTransactionVersion = TV_0) + transactionalId, + producerId, + RecordBatch.NO_PRODUCER_ID, + RecordBatch.NO_PRODUCER_ID, + producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, + 30000, + TransactionState.EMPTY, + util.Set.of, + 1L, + time.milliseconds(), + TV_0) // let new time be smaller - val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, Option(producerEpoch), + val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, Optional.of(producerEpoch), Some(time.milliseconds() - 1)) txnMetadata.completeTransitionTo(transitMetadata) assertEquals(producerId, txnMetadata.producerId) @@ -216,21 +220,21 @@ class TransactionMetadataTest { def testTolerateUpdateTimeResetDuringProducerIdRotation(): Unit = { val producerEpoch: Short = 1 val txnMetadata = new TransactionMetadata( - transactionalId = transactionalId, - producerId = producerId, - prevProducerId = RecordBatch.NO_PRODUCER_ID, - nextProducerId = RecordBatch.NO_PRODUCER_ID, - producerEpoch = producerEpoch, - lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, - txnTimeoutMs = 30000, - state = TransactionState.EMPTY, - topicPartitions = mutable.Set.empty, - txnStartTimestamp = 1L, - txnLastUpdateTimestamp = time.milliseconds(), - clientTransactionVersion = TV_0) + transactionalId, + producerId, + RecordBatch.NO_PRODUCER_ID, + RecordBatch.NO_PRODUCER_ID, + producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, + 30000, + TransactionState.EMPTY, + util.Set.of, + 1L, + time.milliseconds(), + TV_0) // let new time be smaller - val transitMetadata = txnMetadata.prepareProducerIdRotation(producerId + 1, 30000, time.milliseconds() - 1, recordLastEpoch = true) + val transitMetadata = txnMetadata.prepareProducerIdRotation(producerId + 1, 30000, time.milliseconds() - 1, true) txnMetadata.completeTransitionTo(transitMetadata) assertEquals(producerId + 1, txnMetadata.producerId) assertEquals(producerEpoch, txnMetadata.lastProducerEpoch) @@ -243,23 +247,23 @@ class TransactionMetadataTest { def testTolerateTimeShiftDuringAddPartitions(): Unit = { val producerEpoch: Short = 1 val txnMetadata = new TransactionMetadata( - transactionalId = transactionalId, - producerId = producerId, - prevProducerId = RecordBatch.NO_PRODUCER_ID, - nextProducerId = RecordBatch.NO_PRODUCER_ID, - producerEpoch = producerEpoch, - lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, - txnTimeoutMs = 30000, - state = TransactionState.EMPTY, - topicPartitions = mutable.Set.empty, - txnStartTimestamp = time.milliseconds(), - txnLastUpdateTimestamp = time.milliseconds(), - clientTransactionVersion = TV_0) + transactionalId, + producerId, + RecordBatch.NO_PRODUCER_ID, + RecordBatch.NO_PRODUCER_ID, + producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, + 30000, + TransactionState.EMPTY, + util.Set.of, + time.milliseconds(), + time.milliseconds(), + TV_0) // let new time be smaller; when transiting from TransactionState.EMPTY the start time would be updated to the update-time - var transitMetadata = txnMetadata.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic1", 0)), time.milliseconds() - 1, TV_0) + var transitMetadata = txnMetadata.prepareAddPartitions(util.Set.of(new TopicPartition("topic1", 0)), time.milliseconds() - 1, TV_0) txnMetadata.completeTransitionTo(transitMetadata) - assertEquals(Set[TopicPartition](new TopicPartition("topic1", 0)), txnMetadata.topicPartitions) + assertEquals(util.Set.of(new TopicPartition("topic1", 0)), txnMetadata.topicPartitions) assertEquals(producerId, txnMetadata.producerId) assertEquals(RecordBatch.NO_PRODUCER_EPOCH, txnMetadata.lastProducerEpoch) assertEquals(producerEpoch, txnMetadata.producerEpoch) @@ -267,9 +271,9 @@ class TransactionMetadataTest { assertEquals(time.milliseconds() - 1, txnMetadata.txnLastUpdateTimestamp) // add another partition, check that in TransactionState.ONGOING state the start timestamp would not change to update time - transitMetadata = txnMetadata.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds() - 2, TV_0) + transitMetadata = txnMetadata.prepareAddPartitions(util.Set.of(new TopicPartition("topic2", 0)), time.milliseconds() - 2, TV_0) txnMetadata.completeTransitionTo(transitMetadata) - assertEquals(Set[TopicPartition](new TopicPartition("topic1", 0), new TopicPartition("topic2", 0)), txnMetadata.topicPartitions) + assertEquals(util.Set.of(new TopicPartition("topic1", 0), new TopicPartition("topic2", 0)), txnMetadata.topicPartitions) assertEquals(producerId, txnMetadata.producerId) assertEquals(RecordBatch.NO_PRODUCER_EPOCH, txnMetadata.lastProducerEpoch) assertEquals(producerEpoch, txnMetadata.producerEpoch) @@ -281,21 +285,21 @@ class TransactionMetadataTest { def testTolerateTimeShiftDuringPrepareCommit(): Unit = { val producerEpoch: Short = 1 val txnMetadata = new TransactionMetadata( - transactionalId = transactionalId, - producerId = producerId, - prevProducerId = RecordBatch.NO_PRODUCER_ID, - nextProducerId = RecordBatch.NO_PRODUCER_ID, - producerEpoch = producerEpoch, - lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, - txnTimeoutMs = 30000, - state = TransactionState.ONGOING, - topicPartitions = mutable.Set.empty, - txnStartTimestamp = 1L, - txnLastUpdateTimestamp = time.milliseconds(), - clientTransactionVersion = TV_0) + transactionalId, + producerId, + RecordBatch.NO_PRODUCER_ID, + RecordBatch.NO_PRODUCER_ID, + producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, + 30000, + TransactionState.ONGOING, + util.Set.of, + 1L, + time.milliseconds(), + TV_0) // let new time be smaller - val transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_COMMIT, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds() - 1, noPartitionAdded = false) + val transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_COMMIT, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds() - 1, false) txnMetadata.completeTransitionTo(transitMetadata) assertEquals(TransactionState.PREPARE_COMMIT, txnMetadata.state) assertEquals(producerId, txnMetadata.producerId) @@ -309,21 +313,21 @@ class TransactionMetadataTest { def testTolerateTimeShiftDuringPrepareAbort(): Unit = { val producerEpoch: Short = 1 val txnMetadata = new TransactionMetadata( - transactionalId = transactionalId, - producerId = producerId, - prevProducerId = RecordBatch.NO_PRODUCER_ID, - nextProducerId = RecordBatch.NO_PRODUCER_ID, - producerEpoch = producerEpoch, - lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, - txnTimeoutMs = 30000, - state = TransactionState.ONGOING, - topicPartitions = mutable.Set.empty, - txnStartTimestamp = 1L, - txnLastUpdateTimestamp = time.milliseconds(), - clientTransactionVersion = TV_0) + transactionalId, + producerId, + RecordBatch.NO_PRODUCER_ID, + RecordBatch.NO_PRODUCER_ID, + producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, + 30000, + TransactionState.ONGOING, + util.Set.of, + 1L, + time.milliseconds(), + TV_0) // let new time be smaller - val transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_ABORT, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds() - 1, noPartitionAdded = false) + val transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_ABORT, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds() - 1, false) txnMetadata.completeTransitionTo(transitMetadata) assertEquals(TransactionState.PREPARE_ABORT, txnMetadata.state) assertEquals(producerId, txnMetadata.producerId) @@ -340,18 +344,18 @@ class TransactionMetadataTest { val producerEpoch: Short = 1 val lastProducerEpoch: Short = 0 val txnMetadata = new TransactionMetadata( - transactionalId = transactionalId, - producerId = producerId, - prevProducerId = RecordBatch.NO_PRODUCER_ID, - nextProducerId = RecordBatch.NO_PRODUCER_ID, - producerEpoch = producerEpoch, - lastProducerEpoch = lastProducerEpoch, - txnTimeoutMs = 30000, - state = TransactionState.PREPARE_COMMIT, - topicPartitions = mutable.Set.empty, - txnStartTimestamp = 1L, - txnLastUpdateTimestamp = time.milliseconds(), - clientTransactionVersion = clientTransactionVersion + transactionalId, + producerId, + RecordBatch.NO_PRODUCER_ID, + RecordBatch.NO_PRODUCER_ID, + producerEpoch, + lastProducerEpoch, + 30000, + TransactionState.PREPARE_COMMIT, + util.Set.of(), + 1L, + time.milliseconds(), + clientTransactionVersion ) // let new time be smaller @@ -373,18 +377,18 @@ class TransactionMetadataTest { val producerEpoch: Short = 1 val lastProducerEpoch: Short = 0 val txnMetadata = new TransactionMetadata( - transactionalId = transactionalId, - producerId = producerId, - prevProducerId = RecordBatch.NO_PRODUCER_ID, - nextProducerId = RecordBatch.NO_PRODUCER_ID, - producerEpoch = producerEpoch, - lastProducerEpoch = lastProducerEpoch, - txnTimeoutMs = 30000, - state = TransactionState.PREPARE_ABORT, - topicPartitions = mutable.Set.empty, - txnStartTimestamp = 1L, - txnLastUpdateTimestamp = time.milliseconds(), - clientTransactionVersion = clientTransactionVersion + transactionalId, + producerId, + RecordBatch.NO_PRODUCER_ID, + RecordBatch.NO_PRODUCER_ID, + producerEpoch, + lastProducerEpoch, + 30000, + TransactionState.PREPARE_ABORT, + util.Set.of, + 1L, + time.milliseconds(), + clientTransactionVersion ) // let new time be smaller @@ -404,28 +408,29 @@ class TransactionMetadataTest { val producerEpoch = (Short.MaxValue - 1).toShort val txnMetadata = new TransactionMetadata( - transactionalId = transactionalId, - producerId = producerId, - prevProducerId = RecordBatch.NO_PRODUCER_ID, - nextProducerId = RecordBatch.NO_PRODUCER_ID, - producerEpoch = producerEpoch, - lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, - txnTimeoutMs = 30000, - state = TransactionState.ONGOING, - topicPartitions = mutable.Set.empty, - txnLastUpdateTimestamp = time.milliseconds(), - clientTransactionVersion = TV_0) + transactionalId, + producerId, + RecordBatch.NO_PRODUCER_ID, + RecordBatch.NO_PRODUCER_ID, + producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, + 30000, + TransactionState.ONGOING, + util.Set.of, + -1, + time.milliseconds(), + TV_0) assertTrue(txnMetadata.isProducerEpochExhausted) val fencingTransitMetadata = txnMetadata.prepareFenceProducerEpoch() assertEquals(Short.MaxValue, fencingTransitMetadata.producerEpoch) assertEquals(RecordBatch.NO_PRODUCER_EPOCH, fencingTransitMetadata.lastProducerEpoch) - assertEquals(Some(TransactionState.PREPARE_EPOCH_FENCE), txnMetadata.pendingState) + assertEquals(Optional.of(TransactionState.PREPARE_EPOCH_FENCE), txnMetadata.pendingState) // We should reset the pending state to make way for the abort transition. - txnMetadata.pendingState = None + txnMetadata.pendingState(Optional.empty()) - val transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_ABORT, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds(), noPartitionAdded = false) + val transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_ABORT, TV_0, RecordBatch.NO_PRODUCER_ID, time.milliseconds(), false) txnMetadata.completeTransitionTo(transitMetadata) assertEquals(producerId, transitMetadata.producerId) } @@ -435,17 +440,18 @@ class TransactionMetadataTest { val producerEpoch = (Short.MaxValue - 1).toShort val txnMetadata = new TransactionMetadata( - transactionalId = transactionalId, - producerId = producerId, - prevProducerId = RecordBatch.NO_PRODUCER_ID, - nextProducerId = RecordBatch.NO_PRODUCER_ID, - producerEpoch = producerEpoch, - lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, - txnTimeoutMs = 30000, - state = TransactionState.COMPLETE_COMMIT, - topicPartitions = mutable.Set.empty, - txnLastUpdateTimestamp = time.milliseconds(), - clientTransactionVersion = TV_0) + transactionalId, + producerId, + RecordBatch.NO_PRODUCER_ID, + RecordBatch.NO_PRODUCER_ID, + producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, + 30000, + TransactionState.COMPLETE_COMMIT, + util.Set.of, + -1, + time.milliseconds(), + TV_0) assertTrue(txnMetadata.isProducerEpochExhausted) assertThrows(classOf[IllegalStateException], () => txnMetadata.prepareFenceProducerEpoch()) @@ -456,17 +462,18 @@ class TransactionMetadataTest { val producerEpoch = (Short.MaxValue - 1).toShort val txnMetadata = new TransactionMetadata( - transactionalId = transactionalId, - producerId = producerId, - prevProducerId = RecordBatch.NO_PRODUCER_ID, - nextProducerId = RecordBatch.NO_PRODUCER_ID, - producerEpoch = producerEpoch, - lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, - txnTimeoutMs = 30000, - state = TransactionState.COMPLETE_ABORT, - topicPartitions = mutable.Set.empty, - txnLastUpdateTimestamp = time.milliseconds(), - clientTransactionVersion = TV_0) + transactionalId, + producerId, + RecordBatch.NO_PRODUCER_ID, + RecordBatch.NO_PRODUCER_ID, + producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, + 30000, + TransactionState.COMPLETE_ABORT, + util.Set.of, + -1, + time.milliseconds(), + TV_0) assertTrue(txnMetadata.isProducerEpochExhausted) assertThrows(classOf[IllegalStateException], () => txnMetadata.prepareFenceProducerEpoch()) @@ -477,17 +484,18 @@ class TransactionMetadataTest { val producerEpoch = Short.MaxValue val txnMetadata = new TransactionMetadata( - transactionalId = transactionalId, - producerId = producerId, - prevProducerId = RecordBatch.NO_PRODUCER_ID, - nextProducerId = RecordBatch.NO_PRODUCER_ID, - producerEpoch = producerEpoch, - lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, - txnTimeoutMs = 30000, - state = TransactionState.ONGOING, - topicPartitions = mutable.Set.empty, - txnLastUpdateTimestamp = time.milliseconds(), - clientTransactionVersion = TV_0) + transactionalId, + producerId, + RecordBatch.NO_PRODUCER_ID, + RecordBatch.NO_PRODUCER_ID, + producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, + 30000, + TransactionState.ONGOING, + util.Set.of, + -1, + time.milliseconds(), + TV_0) assertTrue(txnMetadata.isProducerEpochExhausted) assertThrows(classOf[IllegalStateException], () => txnMetadata.prepareFenceProducerEpoch()) } @@ -497,20 +505,21 @@ class TransactionMetadataTest { val producerEpoch = (Short.MaxValue - 1).toShort val txnMetadata = new TransactionMetadata( - transactionalId = transactionalId, - producerId = producerId, - prevProducerId = RecordBatch.NO_PRODUCER_ID, - nextProducerId = RecordBatch.NO_PRODUCER_ID, - producerEpoch = producerEpoch, - lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, - txnTimeoutMs = 30000, - state = TransactionState.EMPTY, - topicPartitions = mutable.Set.empty, - txnLastUpdateTimestamp = time.milliseconds(), - clientTransactionVersion = TV_0) + transactionalId, + producerId, + RecordBatch.NO_PRODUCER_ID, + RecordBatch.NO_PRODUCER_ID, + producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, + 30000, + TransactionState.EMPTY, + util.Set.of, + -1, + time.milliseconds(), + TV_0) val newProducerId = 9893L - val transitMetadata = txnMetadata.prepareProducerIdRotation(newProducerId, 30000, time.milliseconds(), recordLastEpoch = true) + val transitMetadata = txnMetadata.prepareProducerIdRotation(newProducerId, 30000, time.milliseconds(), true) txnMetadata.completeTransitionTo(transitMetadata) assertEquals(newProducerId, txnMetadata.producerId) assertEquals(producerId, txnMetadata.prevProducerId) @@ -524,20 +533,20 @@ class TransactionMetadataTest { val producerEpoch = 10.toShort val txnMetadata = new TransactionMetadata( - transactionalId = transactionalId, - producerId = producerId, - prevProducerId = RecordBatch.NO_PRODUCER_ID, - nextProducerId = RecordBatch.NO_PRODUCER_ID, - producerEpoch = producerEpoch, - lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, - txnTimeoutMs = 30000, - state = TransactionState.ONGOING, - topicPartitions = mutable.Set.empty, - txnStartTimestamp = time.milliseconds(), - txnLastUpdateTimestamp = time.milliseconds(), - clientTransactionVersion = TV_2) + transactionalId, + producerId, + RecordBatch.NO_PRODUCER_ID, + RecordBatch.NO_PRODUCER_ID, + producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, + 30000, + TransactionState.ONGOING, + util.Set.of, + time.milliseconds(), + time.milliseconds(), + TV_2) - var transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_COMMIT, TV_2, RecordBatch.NO_PRODUCER_ID, time.milliseconds() - 1, noPartitionAdded = false) + var transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_COMMIT, TV_2, RecordBatch.NO_PRODUCER_ID, time.milliseconds() - 1, false) txnMetadata.completeTransitionTo(transitMetadata) assertEquals(producerId, txnMetadata.producerId) assertEquals((producerEpoch + 1).toShort, txnMetadata.producerEpoch) @@ -556,22 +565,22 @@ class TransactionMetadataTest { val producerEpoch = (Short.MaxValue - 1).toShort val txnMetadata = new TransactionMetadata( - transactionalId = transactionalId, - producerId = producerId, - prevProducerId = RecordBatch.NO_PRODUCER_ID, - nextProducerId = RecordBatch.NO_PRODUCER_ID, - producerEpoch = producerEpoch, - lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, - txnTimeoutMs = 30000, - state = TransactionState.ONGOING, - topicPartitions = mutable.Set.empty, - txnStartTimestamp = time.milliseconds(), - txnLastUpdateTimestamp = time.milliseconds(), - clientTransactionVersion = TV_2) + transactionalId, + producerId, + RecordBatch.NO_PRODUCER_ID, + RecordBatch.NO_PRODUCER_ID, + producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, + 30000, + TransactionState.ONGOING, + util.Set.of, + time.milliseconds(), + time.milliseconds(), + TV_2) assertTrue(txnMetadata.isProducerEpochExhausted) val newProducerId = 9893L - var transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_COMMIT, TV_2, newProducerId, time.milliseconds() - 1, noPartitionAdded = false) + var transitMetadata = txnMetadata.prepareAbortOrCommit(TransactionState.PREPARE_COMMIT, TV_2, newProducerId, time.milliseconds() - 1, false) txnMetadata.completeTransitionTo(transitMetadata) assertEquals(producerId, txnMetadata.producerId) assertEquals(Short.MaxValue, txnMetadata.producerEpoch) @@ -610,19 +619,20 @@ class TransactionMetadataTest { val producerEpoch = 735.toShort val txnMetadata = new TransactionMetadata( - transactionalId = transactionalId, - producerId = producerId, - prevProducerId = RecordBatch.NO_PRODUCER_ID, - nextProducerId = RecordBatch.NO_PRODUCER_ID, - producerEpoch = RecordBatch.NO_PRODUCER_EPOCH, - lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, - txnTimeoutMs = 30000, - state = TransactionState.EMPTY, - topicPartitions = mutable.Set.empty, - txnLastUpdateTimestamp = time.milliseconds(), - clientTransactionVersion = TV_0) + transactionalId, + producerId, + RecordBatch.NO_PRODUCER_ID, + RecordBatch.NO_PRODUCER_ID, + RecordBatch.NO_PRODUCER_EPOCH, + RecordBatch.NO_PRODUCER_EPOCH, + 30000, + TransactionState.EMPTY, + util.Set.of, + -1, + time.milliseconds(), + TV_0) - val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, Some(producerEpoch)) + val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, Optional.of(producerEpoch)) txnMetadata.completeTransitionTo(transitMetadata) assertEquals(producerId, txnMetadata.producerId) assertEquals(0, txnMetadata.producerEpoch) @@ -634,19 +644,20 @@ class TransactionMetadataTest { val producerEpoch = 735.toShort val txnMetadata = new TransactionMetadata( - transactionalId = transactionalId, - producerId = producerId, - prevProducerId = RecordBatch.NO_PRODUCER_ID, - nextProducerId = RecordBatch.NO_PRODUCER_ID, - producerEpoch = producerEpoch, - lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, - txnTimeoutMs = 30000, - state = TransactionState.EMPTY, - topicPartitions = mutable.Set.empty, - txnLastUpdateTimestamp = time.milliseconds(), - clientTransactionVersion = TV_0) + transactionalId, + producerId, + RecordBatch.NO_PRODUCER_ID, + RecordBatch.NO_PRODUCER_ID, + producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, + 30000, + TransactionState.EMPTY, + util.Set.of, + -1, + time.milliseconds(), + TV_0) - val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, Some(producerEpoch)) + val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, Optional.of(producerEpoch)) txnMetadata.completeTransitionTo(transitMetadata) assertEquals(producerId, txnMetadata.producerId) assertEquals(producerEpoch + 1, txnMetadata.producerEpoch) @@ -659,19 +670,20 @@ class TransactionMetadataTest { val lastProducerEpoch = (producerEpoch - 1).toShort val txnMetadata = new TransactionMetadata( - transactionalId = transactionalId, - producerId = producerId, - prevProducerId = RecordBatch.NO_PRODUCER_ID, - nextProducerId = RecordBatch.NO_PRODUCER_ID, - producerEpoch = producerEpoch, - lastProducerEpoch = lastProducerEpoch, - txnTimeoutMs = 30000, - state = TransactionState.EMPTY, - topicPartitions = mutable.Set.empty, - txnLastUpdateTimestamp = time.milliseconds(), - clientTransactionVersion = TV_0) + transactionalId, + producerId, + RecordBatch.NO_PRODUCER_ID, + RecordBatch.NO_PRODUCER_ID, + producerEpoch, + lastProducerEpoch, + 30000, + TransactionState.EMPTY, + util.Set.of, + -1, + time.milliseconds(), + TV_0) - val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, Some(lastProducerEpoch)) + val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, Optional.of(lastProducerEpoch)) txnMetadata.completeTransitionTo(transitMetadata) assertEquals(producerId, txnMetadata.producerId) assertEquals(producerEpoch, txnMetadata.producerEpoch) @@ -684,21 +696,23 @@ class TransactionMetadataTest { val lastProducerEpoch = (producerEpoch - 1).toShort val txnMetadata = new TransactionMetadata( - transactionalId = transactionalId, - producerId = producerId, - prevProducerId = producerId, - nextProducerId = RecordBatch.NO_PRODUCER_ID, - producerEpoch = producerEpoch, - lastProducerEpoch = lastProducerEpoch, - txnTimeoutMs = 30000, - state = TransactionState.EMPTY, - topicPartitions = mutable.Set.empty, - txnLastUpdateTimestamp = time.milliseconds(), - clientTransactionVersion = TV_0) + transactionalId, + producerId, + producerId, + RecordBatch.NO_PRODUCER_ID, + producerEpoch, + lastProducerEpoch, + 30000, + TransactionState.EMPTY, + util.Set.of, + -1, + time.milliseconds(), + TV_0) - val result = txnMetadata.prepareIncrementProducerEpoch(30000, Some((lastProducerEpoch - 1).toShort), - time.milliseconds()) - assertEquals(Left(Errors.PRODUCER_FENCED), result) + assertThrows(Errors.PRODUCER_FENCED.exception().getClass, () => + txnMetadata.prepareIncrementProducerEpoch(30000, Optional.of((lastProducerEpoch - 1).toShort), + time.milliseconds()) + ) } @Test @@ -748,27 +762,26 @@ class TransactionMetadataTest { val producerEpoch = (Short.MaxValue - 1).toShort val txnMetadata = new TransactionMetadata( - transactionalId = transactionalId, - producerId = producerId, - prevProducerId = producerId, - nextProducerId = RecordBatch.NO_PRODUCER_ID, - producerEpoch = producerEpoch, - lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, - txnTimeoutMs = 30000, - state = state, - topicPartitions = mutable.Set.empty, - txnLastUpdateTimestamp = time.milliseconds(), - clientTransactionVersion = clientTransactionVersion) + transactionalId, + producerId, + producerId, + RecordBatch.NO_PRODUCER_ID, + producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, + 30000, + state, + util.Set.of, + -1, + time.milliseconds(), + clientTransactionVersion) val newProducerId = 9893L - txnMetadata.prepareProducerIdRotation(newProducerId, 30000, time.milliseconds(), recordLastEpoch = false) + txnMetadata.prepareProducerIdRotation(newProducerId, 30000, time.milliseconds(), false) } private def prepareSuccessfulIncrementProducerEpoch(txnMetadata: TransactionMetadata, - expectedProducerEpoch: Option[Short], + expectedProducerEpoch: Optional[java.lang.Short], now: Option[Long] = None): TxnTransitMetadata = { - val result = txnMetadata.prepareIncrementProducerEpoch(30000, expectedProducerEpoch, - now.getOrElse(time.milliseconds())) - result.getOrElse(throw new AssertionError(s"prepareIncrementProducerEpoch failed with $result")) + txnMetadata.prepareIncrementProducerEpoch(30000, expectedProducerEpoch, now.getOrElse(time.milliseconds())) } } diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala index 16b3224495a..41ee3f7f4cc 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala @@ -18,7 +18,6 @@ package kafka.coordinator.transaction import java.lang.management.ManagementFactory import java.nio.ByteBuffer -import java.util import java.util.concurrent.{ConcurrentHashMap, CountDownLatch} import javax.management.ObjectName import kafka.server.ReplicaManager @@ -33,7 +32,7 @@ import org.apache.kafka.common.record._ import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse import org.apache.kafka.common.requests.TransactionResult import org.apache.kafka.common.utils.MockTime -import org.apache.kafka.coordinator.transaction.{TransactionState, TxnTransitMetadata} +import org.apache.kafka.coordinator.transaction.{TransactionMetadata, TransactionState, TxnTransitMetadata} import org.apache.kafka.metadata.MetadataCache import org.apache.kafka.server.common.{FinalizedFeatures, MetadataVersion, RequestLocal, TransactionVersion} import org.apache.kafka.server.common.TransactionVersion.{TV_0, TV_2} @@ -50,6 +49,7 @@ import org.mockito.{ArgumentCaptor, ArgumentMatchers} import org.mockito.ArgumentMatchers.{any, anyInt, anyLong, anyShort} import org.mockito.Mockito.{atLeastOnce, mock, reset, times, verify, when} +import java.util import scala.collection.{Map, mutable} import scala.jdk.CollectionConverters._ @@ -181,8 +181,8 @@ class TransactionStateManagerTest { ).thenReturn(new FetchDataInfo(new LogOffsetMetadata(startOffset), fileRecordsMock)) when(replicaManager.getLogEndOffset(topicPartition)).thenReturn(Some(endOffset)) - txnMetadata1.state = TransactionState.PREPARE_COMMIT - txnMetadata1.addPartitions(Set[TopicPartition]( + txnMetadata1.state(TransactionState.PREPARE_COMMIT) + txnMetadata1.addPartitions(util.Set.of( new TopicPartition("topic1", 0), new TopicPartition("topic1", 1))) val records = MemoryRecords.withRecords(startOffset, Compression.NONE, @@ -240,8 +240,8 @@ class TransactionStateManagerTest { ).thenReturn(new FetchDataInfo(new LogOffsetMetadata(startOffset), fileRecordsMock)) when(replicaManager.getLogEndOffset(topicPartition)).thenReturn(Some(endOffset)) - txnMetadata1.state = TransactionState.PREPARE_COMMIT - txnMetadata1.addPartitions(Set[TopicPartition]( + txnMetadata1.state(TransactionState.PREPARE_COMMIT) + txnMetadata1.addPartitions(util.Set.of( new TopicPartition("topic1", 0), new TopicPartition("topic1", 1))) val records = MemoryRecords.withRecords(startOffset, Compression.NONE, @@ -285,44 +285,44 @@ class TransactionStateManagerTest { // generate transaction log messages for two pids traces: // pid1's transaction started with two partitions - txnMetadata1.state = TransactionState.ONGOING - txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0), + txnMetadata1.state(TransactionState.ONGOING) + txnMetadata1.addPartitions(util.Set.of(new TopicPartition("topic1", 0), new TopicPartition("topic1", 1))) txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2)) // pid1's transaction adds three more partitions - txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic2", 0), + txnMetadata1.addPartitions(util.Set.of(new TopicPartition("topic2", 0), new TopicPartition("topic2", 1), new TopicPartition("topic2", 2))) txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2)) // pid1's transaction is preparing to commit - txnMetadata1.state = TransactionState.PREPARE_COMMIT + txnMetadata1.state(TransactionState.PREPARE_COMMIT) txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2)) // pid2's transaction started with three partitions - txnMetadata2.state = TransactionState.ONGOING - txnMetadata2.addPartitions(Set[TopicPartition](new TopicPartition("topic3", 0), + txnMetadata2.state(TransactionState.ONGOING) + txnMetadata2.addPartitions(util.Set.of(new TopicPartition("topic3", 0), new TopicPartition("topic3", 1), new TopicPartition("topic3", 2))) txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), TV_2)) // pid2's transaction is preparing to abort - txnMetadata2.state = TransactionState.PREPARE_ABORT + txnMetadata2.state(TransactionState.PREPARE_ABORT) txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), TV_2)) // pid2's transaction has aborted - txnMetadata2.state = TransactionState.COMPLETE_ABORT + txnMetadata2.state(TransactionState.COMPLETE_ABORT) txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), TV_2)) // pid2's epoch has advanced, with no ongoing transaction yet - txnMetadata2.state = TransactionState.EMPTY + txnMetadata2.state(TransactionState.EMPTY) txnMetadata2.topicPartitions.clear() txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit(), TV_2)) @@ -391,7 +391,7 @@ class TransactionStateManagerTest { expectedError = Errors.NONE // update the metadata to ongoing with two partitions - val newMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic1", 0), + val newMetadata = txnMetadata1.prepareAddPartitions(util.Set.of(new TopicPartition("topic1", 0), new TopicPartition("topic1", 1)), time.milliseconds(), TV_0) // append the new metadata into log @@ -407,7 +407,7 @@ class TransactionStateManagerTest { transactionManager.putTransactionStateIfNotExists(txnMetadata1) expectedError = Errors.COORDINATOR_NOT_AVAILABLE - var failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds(), TV_0) + var failedMetadata = txnMetadata1.prepareAddPartitions(util.Set.of(new TopicPartition("topic2", 0)), time.milliseconds(), TV_0) prepareForTxnMessageAppend(Errors.UNKNOWN_TOPIC_OR_PARTITION) val requestLocal = RequestLocal.withThreadConfinedCaching @@ -415,19 +415,19 @@ class TransactionStateManagerTest { assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))), transactionManager.getTransactionState(transactionalId1)) assertTrue(txnMetadata1.pendingState.isEmpty) - failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds(), TV_0) + failedMetadata = txnMetadata1.prepareAddPartitions(util.Set.of(new TopicPartition("topic2", 0)), time.milliseconds(), TV_0) prepareForTxnMessageAppend(Errors.NOT_ENOUGH_REPLICAS) transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback, requestLocal = requestLocal) assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))), transactionManager.getTransactionState(transactionalId1)) assertTrue(txnMetadata1.pendingState.isEmpty) - failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds(), TV_0) + failedMetadata = txnMetadata1.prepareAddPartitions(util.Set.of(new TopicPartition("topic2", 0)), time.milliseconds(), TV_0) prepareForTxnMessageAppend(Errors.NOT_ENOUGH_REPLICAS_AFTER_APPEND) transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback, requestLocal = requestLocal) assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))), transactionManager.getTransactionState(transactionalId1)) assertTrue(txnMetadata1.pendingState.isEmpty) - failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds(), TV_0) + failedMetadata = txnMetadata1.prepareAddPartitions(util.Set.of(new TopicPartition("topic2", 0)), time.milliseconds(), TV_0) prepareForTxnMessageAppend(Errors.REQUEST_TIMED_OUT) transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback, requestLocal = requestLocal) assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))), transactionManager.getTransactionState(transactionalId1)) @@ -440,7 +440,7 @@ class TransactionStateManagerTest { transactionManager.putTransactionStateIfNotExists(txnMetadata1) expectedError = Errors.NOT_COORDINATOR - var failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds(), TV_0) + var failedMetadata = txnMetadata1.prepareAddPartitions(util.Set.of(new TopicPartition("topic2", 0)), time.milliseconds(), TV_0) prepareForTxnMessageAppend(Errors.NOT_LEADER_OR_FOLLOWER) val requestLocal = RequestLocal.withThreadConfinedCaching @@ -448,7 +448,7 @@ class TransactionStateManagerTest { assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))), transactionManager.getTransactionState(transactionalId1)) assertTrue(txnMetadata1.pendingState.isEmpty) - failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds(), TV_0) + failedMetadata = txnMetadata1.prepareAddPartitions(util.Set.of(new TopicPartition("topic2", 0)), time.milliseconds(), TV_0) prepareForTxnMessageAppend(Errors.NONE) transactionManager.removeTransactionsForTxnTopicPartition(partitionId, coordinatorEpoch) transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback, requestLocal = requestLocal) @@ -471,7 +471,7 @@ class TransactionStateManagerTest { transactionManager.putTransactionStateIfNotExists(txnMetadata1) expectedError = Errors.COORDINATOR_LOAD_IN_PROGRESS - val failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds(), TV_0) + val failedMetadata = txnMetadata1.prepareAddPartitions(util.Set.of(new TopicPartition("topic2", 0)), time.milliseconds(), TV_0) prepareForTxnMessageAppend(Errors.NONE) transactionManager.removeTransactionsForTxnTopicPartition(partitionId, coordinatorEpoch) @@ -485,7 +485,7 @@ class TransactionStateManagerTest { transactionManager.putTransactionStateIfNotExists(txnMetadata1) expectedError = Errors.UNKNOWN_SERVER_ERROR - var failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds(), TV_0) + var failedMetadata = txnMetadata1.prepareAddPartitions(util.Set.of(new TopicPartition("topic2", 0)), time.milliseconds(), TV_0) prepareForTxnMessageAppend(Errors.MESSAGE_TOO_LARGE) val requestLocal = RequestLocal.withThreadConfinedCaching @@ -493,7 +493,7 @@ class TransactionStateManagerTest { assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))), transactionManager.getTransactionState(transactionalId1)) assertTrue(txnMetadata1.pendingState.isEmpty) - failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds(), TV_0) + failedMetadata = txnMetadata1.prepareAddPartitions(util.Set.of(new TopicPartition("topic2", 0)), time.milliseconds(), TV_0) prepareForTxnMessageAppend(Errors.RECORD_LIST_TOO_LARGE) transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback, requestLocal = requestLocal) assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))), transactionManager.getTransactionState(transactionalId1)) @@ -506,12 +506,12 @@ class TransactionStateManagerTest { transactionManager.putTransactionStateIfNotExists(txnMetadata1) expectedError = Errors.COORDINATOR_NOT_AVAILABLE - val failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds(), TV_0) + val failedMetadata = txnMetadata1.prepareAddPartitions(util.Set.of(new TopicPartition("topic2", 0)), time.milliseconds(), TV_0) prepareForTxnMessageAppend(Errors.UNKNOWN_TOPIC_OR_PARTITION) transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback, _ => true, RequestLocal.withThreadConfinedCaching) assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))), transactionManager.getTransactionState(transactionalId1)) - assertEquals(Some(TransactionState.ONGOING), txnMetadata1.pendingState) + assertEquals(util.Optional.of(TransactionState.ONGOING), txnMetadata1.pendingState) } @Test @@ -524,11 +524,11 @@ class TransactionStateManagerTest { prepareForTxnMessageAppend(Errors.NONE) expectedError = Errors.NOT_COORDINATOR - val newMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic1", 0), + val newMetadata = txnMetadata1.prepareAddPartitions(util.Set.of(new TopicPartition("topic1", 0), new TopicPartition("topic1", 1)), time.milliseconds(), TV_0) // modify the cache while trying to append the new metadata - txnMetadata1.producerEpoch = (txnMetadata1.producerEpoch + 1).toShort + txnMetadata1.setProducerEpoch((txnMetadata1.producerEpoch + 1).toShort) // append the new metadata into log transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, newMetadata, assertCallback, requestLocal = RequestLocal.withThreadConfinedCaching) @@ -543,11 +543,11 @@ class TransactionStateManagerTest { prepareForTxnMessageAppend(Errors.NONE) expectedError = Errors.INVALID_PRODUCER_EPOCH - val newMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic1", 0), + val newMetadata = txnMetadata1.prepareAddPartitions(util.Set.of(new TopicPartition("topic1", 0), new TopicPartition("topic1", 1)), time.milliseconds(), TV_0) // modify the cache while trying to append the new metadata - txnMetadata1.pendingState = None + txnMetadata1.pendingState(util.Optional.empty()) // append the new metadata into log assertThrows(classOf[IllegalStateException], () => transactionManager.appendTransactionToLog(transactionalId1, @@ -876,7 +876,7 @@ class TransactionStateManagerTest { // will be expired and it should succeed. val timestamp = time.milliseconds() val txnMetadata = new TransactionMetadata(transactionalId, 1, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, - RecordBatch.NO_PRODUCER_EPOCH, transactionTimeoutMs, TransactionState.EMPTY, collection.mutable.Set.empty[TopicPartition], timestamp, timestamp, TV_0) + RecordBatch.NO_PRODUCER_EPOCH, transactionTimeoutMs, TransactionState.EMPTY, util.Set.of, timestamp, timestamp, TV_0) transactionManager.putTransactionStateIfNotExists(txnMetadata) time.sleep(txnConfig.transactionalIdExpirationMs + 1) @@ -934,7 +934,7 @@ class TransactionStateManagerTest { val txnlId = s"id_$i" val producerId = i val txnMetadata = transactionMetadata(txnlId, producerId) - txnMetadata.txnLastUpdateTimestamp = time.milliseconds() - txnConfig.transactionalIdExpirationMs + txnMetadata.txnLastUpdateTimestamp(time.milliseconds() - txnConfig.transactionalIdExpirationMs) transactionManager.putTransactionStateIfNotExists(txnMetadata) allTransactionalIds += txnlId } @@ -962,8 +962,8 @@ class TransactionStateManagerTest { @Test def testSuccessfulReimmigration(): Unit = { - txnMetadata1.state = TransactionState.PREPARE_COMMIT - txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0), + txnMetadata1.state(TransactionState.PREPARE_COMMIT) + txnMetadata1.addPartitions(util.Set.of(new TopicPartition("topic1", 0), new TopicPartition("topic1", 1))) txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2)) @@ -1029,10 +1029,10 @@ class TransactionStateManagerTest { @Test def testLoadTransactionMetadataContainingSegmentEndingWithEmptyBatch(): Unit = { // Simulate a case where a log contains two segments and the first segment ending with an empty batch. - txnMetadata1.state = TransactionState.PREPARE_COMMIT - txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0))) - txnMetadata2.state = TransactionState.ONGOING - txnMetadata2.addPartitions(Set[TopicPartition](new TopicPartition("topic2", 0))) + txnMetadata1.state(TransactionState.PREPARE_COMMIT) + txnMetadata1.addPartitions(util.Set.of(new TopicPartition("topic1", 0))) + txnMetadata2.state(TransactionState.ONGOING) + txnMetadata2.addPartitions(util.Set.of(new TopicPartition("topic2", 0))) // Create the first segment which contains two batches. // The first batch has one transactional record @@ -1158,11 +1158,11 @@ class TransactionStateManagerTest { loadTransactionsForPartitions(partitionIds) expectLogConfig(partitionIds, ServerLogConfigs.MAX_MESSAGE_BYTES_DEFAULT) - txnMetadata1.txnLastUpdateTimestamp = time.milliseconds() - txnConfig.transactionalIdExpirationMs - txnMetadata1.state = txnState + txnMetadata1.txnLastUpdateTimestamp(time.milliseconds() - txnConfig.transactionalIdExpirationMs) + txnMetadata1.state(txnState) transactionManager.putTransactionStateIfNotExists(txnMetadata1) - txnMetadata2.txnLastUpdateTimestamp = time.milliseconds() + txnMetadata2.txnLastUpdateTimestamp(time.milliseconds()) transactionManager.putTransactionStateIfNotExists(txnMetadata2) val appendedRecords = mutable.Map.empty[TopicIdPartition, mutable.Buffer[MemoryRecords]] @@ -1188,8 +1188,8 @@ class TransactionStateManagerTest { } private def verifyWritesTxnMarkersInPrepareState(state: TransactionState): Unit = { - txnMetadata1.state = state - txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0), + txnMetadata1.state(state) + txnMetadata1.addPartitions(util.Set.of(new TopicPartition("topic1", 0), new TopicPartition("topic1", 1))) txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2)) @@ -1222,7 +1222,7 @@ class TransactionStateManagerTest { txnTimeout: Int = transactionTimeoutMs): TransactionMetadata = { val timestamp = time.milliseconds() new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_ID, 0.toShort, - RecordBatch.NO_PRODUCER_EPOCH, txnTimeout, state, collection.mutable.Set.empty[TopicPartition], timestamp, timestamp, TV_0) + RecordBatch.NO_PRODUCER_EPOCH, txnTimeout, state, util.Set.of, timestamp, timestamp, TV_0) } private def prepareTxnLog(topicPartition: TopicPartition, @@ -1294,8 +1294,8 @@ class TransactionStateManagerTest { assertEquals(Double.NaN, partitionLoadTime("partition-load-time-avg"), 0) assertTrue(reporter.containsMbean(mBeanName)) - txnMetadata1.state = TransactionState.ONGOING - txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 1), + txnMetadata1.state(TransactionState.ONGOING) + txnMetadata1.addPartitions(util.List.of(new TopicPartition("topic1", 1), new TopicPartition("topic1", 1))) txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2)) @@ -1313,8 +1313,8 @@ class TransactionStateManagerTest { @Test def testIgnoreUnknownRecordType(): Unit = { - txnMetadata1.state = TransactionState.PREPARE_COMMIT - txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0), + txnMetadata1.state(TransactionState.PREPARE_COMMIT) + txnMetadata1.addPartitions(util.Set.of(new TopicPartition("topic1", 0), new TopicPartition("topic1", 1))) txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2)) diff --git a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala index 5676e454887..a7a949df55d 100644 --- a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala +++ b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala @@ -1892,7 +1892,7 @@ class KafkaApisTest extends Logging { ArgumentMatchers.eq(transactionalId), ArgumentMatchers.eq(producerId), ArgumentMatchers.eq(epoch), - ArgumentMatchers.eq(Set(new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, partition))), + ArgumentMatchers.eq(util.Set.of(new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, partition))), responseCallback.capture(), ArgumentMatchers.eq(TransactionVersion.TV_0), ArgumentMatchers.eq(requestLocal) @@ -1951,7 +1951,7 @@ class KafkaApisTest extends Logging { ArgumentMatchers.eq(transactionalId), ArgumentMatchers.eq(producerId), ArgumentMatchers.eq(epoch), - ArgumentMatchers.eq(Set(topicPartition)), + ArgumentMatchers.eq(util.Set.of(topicPartition)), responseCallback.capture(), ArgumentMatchers.eq(TransactionVersion.TV_0), ArgumentMatchers.eq(requestLocal) @@ -2157,7 +2157,7 @@ class KafkaApisTest extends Logging { ArgumentMatchers.eq(transactionalId1), ArgumentMatchers.eq(producerId), ArgumentMatchers.eq(epoch), - ArgumentMatchers.eq(Set(tp0)), + ArgumentMatchers.eq(util.Set.of(tp0)), responseCallback.capture(), any[TransactionVersion], ArgumentMatchers.eq(requestLocal) @@ -2167,7 +2167,7 @@ class KafkaApisTest extends Logging { ArgumentMatchers.eq(transactionalId2), ArgumentMatchers.eq(producerId), ArgumentMatchers.eq(epoch), - ArgumentMatchers.eq(Set(tp1)), + ArgumentMatchers.eq(util.Set.of(tp1)), verifyPartitionsCallback.capture(), )).thenAnswer(_ => verifyPartitionsCallback.getValue.apply(AddPartitionsToTxnResponse.resultForTransaction(transactionalId2, util.Map.of(tp1, Errors.PRODUCER_FENCED)))) kafkaApis = createKafkaApis() diff --git a/gradle/spotbugs-exclude.xml b/gradle/spotbugs-exclude.xml index 22a08ebd051..48fc1f3722c 100644 --- a/gradle/spotbugs-exclude.xml +++ b/gradle/spotbugs-exclude.xml @@ -573,7 +573,7 @@ For a detailed description of spotbugs bug categories, see https://spotbugs.read - + @@ -606,7 +606,7 @@ For a detailed description of spotbugs bug categories, see https://spotbugs.read - + @@ -670,7 +670,7 @@ For a detailed description of spotbugs bug categories, see https://spotbugs.read - + diff --git a/transaction-coordinator/src/main/java/org/apache/kafka/coordinator/transaction/TransactionMetadata.java b/transaction-coordinator/src/main/java/org/apache/kafka/coordinator/transaction/TransactionMetadata.java new file mode 100644 index 00000000000..96a92dd01c9 --- /dev/null +++ b/transaction-coordinator/src/main/java/org/apache/kafka/coordinator/transaction/TransactionMetadata.java @@ -0,0 +1,662 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.coordinator.transaction; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.config.LogLevelConfig; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.server.common.TransactionVersion; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.slf4j.MarkerFactory; + +import java.util.Collection; +import java.util.HashSet; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Supplier; + +public class TransactionMetadata { + private static final Logger LOGGER = LoggerFactory.getLogger(TransactionMetadata.class); + private final String transactionalId; + private long producerId; + private long prevProducerId; + private long nextProducerId; + private short producerEpoch; + private short lastProducerEpoch; + private int txnTimeoutMs; + private TransactionState state; + // The topicPartitions is mutable, so using HashSet, instead of Set. + private HashSet topicPartitions; + private volatile long txnStartTimestamp; + private volatile long txnLastUpdateTimestamp; + private TransactionVersion clientTransactionVersion; + + // pending state is used to indicate the state that this transaction is going to + // transit to, and for blocking future attempts to transit it again if it is not legal; + // initialized as the same as the current state + private Optional pendingState; + + // Indicates that during a previous attempt to fence a producer, the bumped epoch may not have been + // successfully written to the log. If this is true, we will not bump the epoch again when fencing + private boolean hasFailedEpochFence; + + private final ReentrantLock lock; + + public static boolean isEpochExhausted(short producerEpoch) { + return producerEpoch >= Short.MAX_VALUE - 1; + } + + /** + * @param transactionalId transactional id + * @param producerId producer id + * @param prevProducerId producer id for the last committed transaction with this transactional ID + * @param nextProducerId Latest producer ID sent to the producer for the given transactional ID + * @param producerEpoch current epoch of the producer + * @param lastProducerEpoch last epoch of the producer + * @param txnTimeoutMs timeout to be used to abort long running transactions + * @param state current state of the transaction + * @param topicPartitions current set of partitions that are part of this transaction + * @param txnStartTimestamp time the transaction was started, i.e., when first partition is added + * @param txnLastUpdateTimestamp updated when any operation updates the TransactionMetadata. To be used for expiration + * @param clientTransactionVersion TransactionVersion used by the client when the state was transitioned + */ + public TransactionMetadata(String transactionalId, + long producerId, + long prevProducerId, + long nextProducerId, + short producerEpoch, + short lastProducerEpoch, + int txnTimeoutMs, + TransactionState state, + Set topicPartitions, + long txnStartTimestamp, + long txnLastUpdateTimestamp, + TransactionVersion clientTransactionVersion) { + this.transactionalId = transactionalId; + this.producerId = producerId; + this.prevProducerId = prevProducerId; + this.nextProducerId = nextProducerId; + this.producerEpoch = producerEpoch; + this.lastProducerEpoch = lastProducerEpoch; + this.txnTimeoutMs = txnTimeoutMs; + this.state = state; + this.topicPartitions = new HashSet<>(topicPartitions); + this.txnStartTimestamp = txnStartTimestamp; + this.txnLastUpdateTimestamp = txnLastUpdateTimestamp; + this.clientTransactionVersion = clientTransactionVersion; + this.pendingState = Optional.empty(); + this.hasFailedEpochFence = false; + this.lock = new ReentrantLock(); + } + + public T inLock(Supplier function) { + lock.lock(); + try { + return function.get(); + } finally { + lock.unlock(); + } + } + + public void addPartitions(Collection partitions) { + topicPartitions.addAll(partitions); + } + + public void removePartition(TopicPartition topicPartition) { + if (state != TransactionState.PREPARE_COMMIT && state != TransactionState.PREPARE_ABORT) + throw new IllegalStateException("Transaction metadata's current state is " + state + ", and its pending state is " + + pendingState + " while trying to remove partitions whose txn marker has been sent, this is not expected"); + + topicPartitions.remove(topicPartition); + } + + // this is visible for test only + public TxnTransitMetadata prepareNoTransit() { + // do not call transitTo as it will set the pending state, a follow-up call to abort the transaction will set its pending state + return new TxnTransitMetadata(producerId, prevProducerId, nextProducerId, producerEpoch, lastProducerEpoch, txnTimeoutMs, + state, new HashSet<>(topicPartitions), txnStartTimestamp, txnLastUpdateTimestamp, clientTransactionVersion); + } + + public TxnTransitMetadata prepareFenceProducerEpoch() { + if (producerEpoch == Short.MAX_VALUE) + throw new IllegalStateException("Cannot fence producer with epoch equal to Short.MaxValue since this would overflow"); + + // If we've already failed to fence an epoch (because the write to the log failed), we don't increase it again. + // This is safe because we never return the epoch to client if we fail to fence the epoch + short bumpedEpoch = hasFailedEpochFence ? producerEpoch : (short) (producerEpoch + 1); + + TransitionData data = new TransitionData(TransactionState.PREPARE_EPOCH_FENCE); + data.producerEpoch = bumpedEpoch; + return prepareTransitionTo(data); + } + + public TxnTransitMetadata prepareIncrementProducerEpoch( + int newTxnTimeoutMs, + Optional expectedProducerEpoch, + long updateTimestamp) { + if (isProducerEpochExhausted()) + throw new IllegalStateException("Cannot allocate any more producer epochs for producerId " + producerId); + + TransitionData data = new TransitionData(TransactionState.EMPTY); + short bumpedEpoch = (short) (producerEpoch + 1); + if (expectedProducerEpoch.isEmpty()) { + // If no expected epoch was provided by the producer, bump the current epoch and set the last epoch to -1 + // In the case of a new producer, producerEpoch will be -1 and bumpedEpoch will be 0 + data.producerEpoch = bumpedEpoch; + data.lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH; + } else if (producerEpoch == RecordBatch.NO_PRODUCER_EPOCH || expectedProducerEpoch.get() == producerEpoch) { + // If the expected epoch matches the current epoch, or if there is no current epoch, the producer is attempting + // to continue after an error and no other producer has been initialized. Bump the current and last epochs. + // The no current epoch case means this is a new producer; producerEpoch will be -1 and bumpedEpoch will be 0 + data.producerEpoch = bumpedEpoch; + data.lastProducerEpoch = producerEpoch; + } else if (expectedProducerEpoch.get() == lastProducerEpoch) { + // If the expected epoch matches the previous epoch, it is a retry of a successful call, so just return the + // current epoch without bumping. There is no danger of this producer being fenced, because a new producer + // calling InitProducerId would have caused the last epoch to be set to -1. + // Note that if the IBP is prior to 2.4.IV1, the lastProducerId and lastProducerEpoch will not be written to + // the transaction log, so a retry that spans a coordinator change will fail. We expect this to be a rare case. + data.producerEpoch = producerEpoch; + data.lastProducerEpoch = lastProducerEpoch; + } else { + // Otherwise, the producer has a fenced epoch and should receive an PRODUCER_FENCED error + LOGGER.info("Expected producer epoch {} does not match current producer epoch {} or previous producer epoch {}", + expectedProducerEpoch.get(), producerEpoch, lastProducerEpoch); + throw Errors.PRODUCER_FENCED.exception(); + } + + data.txnTimeoutMs = newTxnTimeoutMs; + data.topicPartitions = new HashSet<>(); + data.txnStartTimestamp = -1L; + data.txnLastUpdateTimestamp = updateTimestamp; + return prepareTransitionTo(data); + } + + public TxnTransitMetadata prepareProducerIdRotation(long newProducerId, + int newTxnTimeoutMs, + long updateTimestamp, + boolean recordLastEpoch) { + if (hasPendingTransaction()) + throw new IllegalStateException("Cannot rotate producer ids while a transaction is still pending"); + + TransitionData data = new TransitionData(TransactionState.EMPTY); + data.producerId = newProducerId; + data.producerEpoch = 0; + data.lastProducerEpoch = recordLastEpoch ? producerEpoch : RecordBatch.NO_PRODUCER_EPOCH; + data.txnTimeoutMs = newTxnTimeoutMs; + data.topicPartitions = new HashSet<>(); + data.txnStartTimestamp = -1L; + data.txnLastUpdateTimestamp = updateTimestamp; + return prepareTransitionTo(data); + } + + public TxnTransitMetadata prepareAddPartitions(Set addedTopicPartitions, + long updateTimestamp, + TransactionVersion clientTransactionVersion) { + long newTxnStartTimestamp; + if (state == TransactionState.EMPTY || state == TransactionState.COMPLETE_ABORT || state == TransactionState.COMPLETE_COMMIT) { + newTxnStartTimestamp = updateTimestamp; + } else { + newTxnStartTimestamp = txnStartTimestamp; + } + + HashSet newTopicPartitions = new HashSet<>(topicPartitions); + newTopicPartitions.addAll(addedTopicPartitions); + + TransitionData data = new TransitionData(TransactionState.ONGOING); + data.topicPartitions = newTopicPartitions; + data.txnStartTimestamp = newTxnStartTimestamp; + data.txnLastUpdateTimestamp = updateTimestamp; + data.clientTransactionVersion = clientTransactionVersion; + return prepareTransitionTo(data); + } + + public TxnTransitMetadata prepareAbortOrCommit(TransactionState newState, + TransactionVersion clientTransactionVersion, + long nextProducerId, + long updateTimestamp, + boolean noPartitionAdded) { + TransitionData data = new TransitionData(newState); + if (clientTransactionVersion.supportsEpochBump()) { + // We already ensured that we do not overflow here. MAX_SHORT is the highest possible value. + data.producerEpoch = (short) (producerEpoch + 1); + data.lastProducerEpoch = producerEpoch; + } else { + data.producerEpoch = producerEpoch; + data.lastProducerEpoch = lastProducerEpoch; + } + + // With transaction V2, it is allowed to abort the transaction without adding any partitions. Then, the transaction + // start time is uncertain but it is still required. So we can use the update time as the transaction start time. + data.txnStartTimestamp = noPartitionAdded ? updateTimestamp : txnStartTimestamp; + data.nextProducerId = nextProducerId; + data.txnLastUpdateTimestamp = updateTimestamp; + data.clientTransactionVersion = clientTransactionVersion; + return prepareTransitionTo(data); + } + + public TxnTransitMetadata prepareComplete(long updateTimestamp) { + // Since the state change was successfully written to the log, unset the flag for a failed epoch fence + hasFailedEpochFence = false; + + TransitionData data = new TransitionData(state == TransactionState.PREPARE_COMMIT ? + TransactionState.COMPLETE_COMMIT : TransactionState.COMPLETE_ABORT); + // In the prepareComplete transition for the overflow case, the lastProducerEpoch is kept at MAX-1, + // which is the last epoch visible to the client. + // Internally, however, during the transition between prepareAbort/prepareCommit and prepareComplete, the producer epoch + // reaches MAX but the client only sees the transition as MAX-1 followed by 0. + // When an epoch overflow occurs, we set the producerId to nextProducerId and reset the epoch to 0, + // but lastProducerEpoch remains MAX-1 to maintain consistency with what the client last saw. + if (clientTransactionVersion.supportsEpochBump() && nextProducerId != RecordBatch.NO_PRODUCER_ID) { + data.producerId = nextProducerId; + data.producerEpoch = 0; + } else { + data.producerId = producerId; + data.producerEpoch = producerEpoch; + } + data.nextProducerId = RecordBatch.NO_PRODUCER_ID; + data.topicPartitions = new HashSet<>(); + data.txnLastUpdateTimestamp = updateTimestamp; + return prepareTransitionTo(data); + } + + public TxnTransitMetadata prepareDead() { + TransitionData data = new TransitionData(TransactionState.DEAD); + data.topicPartitions = new HashSet<>(); + return prepareTransitionTo(data); + } + + /** + * Check if the epochs have been exhausted for the current producerId. We do not allow the client to use an + * epoch equal to Short.MaxValue to ensure that the coordinator will always be able to fence an existing producer. + */ + public boolean isProducerEpochExhausted() { + return isEpochExhausted(producerEpoch); + } + + /** + * Check if this is a distributed two phase commit transaction. + * Such transactions have no timeout (identified by maximum value for timeout). + */ + public boolean isDistributedTwoPhaseCommitTxn() { + return txnTimeoutMs == Integer.MAX_VALUE; + } + + private boolean hasPendingTransaction() { + return state == TransactionState.ONGOING || + state == TransactionState.PREPARE_ABORT || + state == TransactionState.PREPARE_COMMIT; + } + + private TxnTransitMetadata prepareTransitionTo(TransitionData data) { + if (pendingState.isPresent()) + throw new IllegalStateException("Preparing transaction state transition to " + state + + " while it already a pending state " + pendingState.get()); + + if (data.producerId < 0) + throw new IllegalArgumentException("Illegal new producer id " + data.producerId); + + // The epoch is initialized to NO_PRODUCER_EPOCH when the TransactionMetadata + // is created for the first time and it could stay like this until transitioning + // to Dead. + if (data.state != TransactionState.DEAD && data.producerEpoch < 0) + throw new IllegalArgumentException("Illegal new producer epoch " + data.producerEpoch); + + // check that the new state transition is valid and update the pending state if necessary + if (data.state.validPreviousStates().contains(this.state)) { + TxnTransitMetadata transitMetadata = new TxnTransitMetadata( + data.producerId, this.producerId, data.nextProducerId, data.producerEpoch, data.lastProducerEpoch, + data.txnTimeoutMs, data.state, data.topicPartitions, + data.txnStartTimestamp, data.txnLastUpdateTimestamp, data.clientTransactionVersion + ); + + LOGGER.debug("TransactionalId {} prepare transition from {} to {}", transactionalId, this.state, data.state); + pendingState = Optional.of(data.state); + return transitMetadata; + } + throw new IllegalStateException("Preparing transaction state transition to " + data.state + " failed since the target state " + + data.state + " is not a valid previous state of the current state " + this.state); + } + + @SuppressWarnings("CyclomaticComplexity") + public void completeTransitionTo(TxnTransitMetadata transitMetadata) { + // metadata transition is valid only if all the following conditions are met: + // + // 1. the new state is already indicated in the pending state. + // 2. the epoch should be either the same value, the old value + 1, or 0 if we have a new producerId. + // 3. the last update time is no smaller than the old value. + // 4. the old partitions set is a subset of the new partitions set. + // + // plus, we should only try to update the metadata after the corresponding log entry has been successfully + // written and replicated (see TransactionStateManager#appendTransactionToLog) + // + // if valid, transition is done via overwriting the whole object to ensure synchronization + + TransactionState toState = pendingState.orElseThrow(() -> { + LOGGER.error(MarkerFactory.getMarker(LogLevelConfig.FATAL_LOG_LEVEL), + "{}'s transition to {} failed since pendingState is not defined: this should not happen", this, transitMetadata); + return new IllegalStateException("TransactionalId " + transactionalId + + " completing transaction state transition while it does not have a pending state"); + }); + + if (!toState.equals(transitMetadata.txnState())) throwStateTransitionFailure(transitMetadata); + + switch (toState) { + case EMPTY: // from initPid + if ((producerEpoch != transitMetadata.producerEpoch() && !validProducerEpochBump(transitMetadata)) || + !transitMetadata.topicPartitions().isEmpty() || + transitMetadata.txnStartTimestamp() != -1) { + throwStateTransitionFailure(transitMetadata); + } + break; + + case ONGOING: // from addPartitions + if (!validProducerEpoch(transitMetadata) || + !transitMetadata.topicPartitions().containsAll(topicPartitions) || + txnTimeoutMs != transitMetadata.txnTimeoutMs()) { + throwStateTransitionFailure(transitMetadata); + } + break; + + case PREPARE_ABORT: // from endTxn + case PREPARE_COMMIT: + // In V2, we allow state transits from Empty, CompleteCommit and CompleteAbort to PrepareAbort. It is possible + // their updated start time is not equal to the current start time. + boolean allowedEmptyAbort = toState == TransactionState.PREPARE_ABORT && transitMetadata.clientTransactionVersion().supportsEpochBump() && + (state == TransactionState.EMPTY || state == TransactionState.COMPLETE_COMMIT || state == TransactionState.COMPLETE_ABORT); + boolean validTimestamp = txnStartTimestamp == transitMetadata.txnStartTimestamp() || allowedEmptyAbort; + + if (!validProducerEpoch(transitMetadata) || + !topicPartitions.equals(transitMetadata.topicPartitions()) || + txnTimeoutMs != transitMetadata.txnTimeoutMs() || + !validTimestamp) { + throwStateTransitionFailure(transitMetadata); + } + break; + + case COMPLETE_ABORT: // from write markers + case COMPLETE_COMMIT: + if (!validProducerEpoch(transitMetadata) || + txnTimeoutMs != transitMetadata.txnTimeoutMs() || + transitMetadata.txnStartTimestamp() == -1) { + throwStateTransitionFailure(transitMetadata); + } + break; + + case PREPARE_EPOCH_FENCE: + // We should never get here, since once we prepare to fence the epoch, we immediately set the pending state + // to PrepareAbort, and then consequently to CompleteAbort after the markers are written.. So we should never + // ever try to complete a transition to PrepareEpochFence, as it is not a valid previous state for any other state, and hence + // can never be transitioned out of. + throwStateTransitionFailure(transitMetadata); + break; + + case DEAD: + // The transactionalId was being expired. The completion of the operation should result in removal of the + // the metadata from the cache, so we should never realistically transition to the dead state. + throw new IllegalStateException("TransactionalId " + transactionalId + " is trying to complete a transition to " + + toState + ". This means that the transactionalId was being expired, and the only acceptable completion of " + + "this operation is to remove the transaction metadata from the cache, not to persist the " + toState + " in the log."); + + default: + break; + } + + LOGGER.debug("TransactionalId {} complete transition from {} to {}", transactionalId, state, transitMetadata); + producerId = transitMetadata.producerId(); + prevProducerId = transitMetadata.prevProducerId(); + nextProducerId = transitMetadata.nextProducerId(); + producerEpoch = transitMetadata.producerEpoch(); + lastProducerEpoch = transitMetadata.lastProducerEpoch(); + txnTimeoutMs = transitMetadata.txnTimeoutMs(); + topicPartitions = transitMetadata.topicPartitions(); + txnStartTimestamp = transitMetadata.txnStartTimestamp(); + txnLastUpdateTimestamp = transitMetadata.txnLastUpdateTimestamp(); + clientTransactionVersion = transitMetadata.clientTransactionVersion(); + + pendingState = Optional.empty(); + state = toState; + } + + /** + * Validates the producer epoch and ID based on transaction state and version. + *

+ * Logic: + * * 1. **Overflow Case in Transactions V2:** + * * - During overflow (epoch reset to 0), we compare both `lastProducerEpoch` values since it + * * does not change during completion. + * * - For PrepareComplete, the producer ID has been updated. We ensure that the `prevProducerID` + * * in the transit metadata matches the current producer ID, confirming the change. + * * + * * 2. **Epoch Bump Case in Transactions V2:** + * * - For PrepareCommit or PrepareAbort, the producer epoch has been bumped. We ensure the `lastProducerEpoch` + * * in transit metadata matches the current producer epoch, confirming the bump. + * * - We also verify that the producer ID remains the same. + * * + * * 3. **Other Cases:** + * * - For other states and versions, check if the producer epoch and ID match the current values. + * + * @param transitMetadata The transaction transition metadata containing state, producer epoch, and ID. + * @return true if the producer epoch and ID are valid; false otherwise. + */ + private boolean validProducerEpoch(TxnTransitMetadata transitMetadata) { + boolean isAtLeastTransactionsV2 = transitMetadata.clientTransactionVersion().supportsEpochBump(); + TransactionState txnState = transitMetadata.txnState(); + short transitProducerEpoch = transitMetadata.producerEpoch(); + long transitProducerId = transitMetadata.producerId(); + short transitLastProducerEpoch = transitMetadata.lastProducerEpoch(); + + if (isAtLeastTransactionsV2 && + (txnState == TransactionState.COMPLETE_COMMIT || txnState == TransactionState.COMPLETE_ABORT) && + transitProducerEpoch == 0) { + return transitLastProducerEpoch == lastProducerEpoch && transitMetadata.prevProducerId() == producerId; + } + + if (isAtLeastTransactionsV2 && + (txnState == TransactionState.PREPARE_COMMIT || txnState == TransactionState.PREPARE_ABORT)) { + return transitLastProducerEpoch == producerEpoch && transitProducerId == producerId; + } + return transitProducerEpoch == producerEpoch && transitProducerId == producerId; + } + + private boolean validProducerEpochBump(TxnTransitMetadata transitMetadata) { + short transitEpoch = transitMetadata.producerEpoch(); + long transitProducerId = transitMetadata.producerId(); + return transitEpoch == (short) (producerEpoch + 1) || (transitEpoch == 0 && transitProducerId != producerId); + } + + private void throwStateTransitionFailure(TxnTransitMetadata txnTransitMetadata) { + LOGGER.error(MarkerFactory.getMarker(LogLevelConfig.FATAL_LOG_LEVEL), + "{}'s transition to {} failed: this should not happen", this, txnTransitMetadata); + + throw new IllegalStateException("TransactionalId " + transactionalId + " failed transition to state " + txnTransitMetadata + + " due to unexpected metadata"); + } + + public boolean pendingTransitionInProgress() { + return pendingState.isPresent(); + } + + public String transactionalId() { + return transactionalId; + } + + public void setProducerId(long producerId) { + this.producerId = producerId; + } + public long producerId() { + return producerId; + } + + public void setPrevProducerId(long prevProducerId) { + this.prevProducerId = prevProducerId; + } + public long prevProducerId() { + return prevProducerId; + } + + public void setProducerEpoch(short producerEpoch) { + this.producerEpoch = producerEpoch; + } + + public short producerEpoch() { + return producerEpoch; + } + + public void setLastProducerEpoch(short lastProducerEpoch) { + this.lastProducerEpoch = lastProducerEpoch; + } + + public short lastProducerEpoch() { + return lastProducerEpoch; + } + + public int txnTimeoutMs() { + return txnTimeoutMs; + } + + public void state(TransactionState state) { + this.state = state; + } + + public TransactionState state() { + return state; + } + + public Set topicPartitions() { + return topicPartitions; + } + + public long txnStartTimestamp() { + return txnStartTimestamp; + } + + public void txnLastUpdateTimestamp(long txnLastUpdateTimestamp) { + this.txnLastUpdateTimestamp = txnLastUpdateTimestamp; + } + + public long txnLastUpdateTimestamp() { + return txnLastUpdateTimestamp; + } + + public void clientTransactionVersion(TransactionVersion clientTransactionVersion) { + this.clientTransactionVersion = clientTransactionVersion; + } + + public TransactionVersion clientTransactionVersion() { + return clientTransactionVersion; + } + + public void pendingState(Optional pendingState) { + this.pendingState = pendingState; + } + + public Optional pendingState() { + return pendingState; + } + + public void hasFailedEpochFence(boolean hasFailedEpochFence) { + this.hasFailedEpochFence = hasFailedEpochFence; + } + + public boolean hasFailedEpochFence() { + return hasFailedEpochFence; + } + + @Override + public String toString() { + return "TransactionMetadata(" + + "transactionalId=" + transactionalId + + ", producerId=" + producerId + + ", prevProducerId=" + prevProducerId + + ", nextProducerId=" + nextProducerId + + ", producerEpoch=" + producerEpoch + + ", lastProducerEpoch=" + lastProducerEpoch + + ", txnTimeoutMs=" + txnTimeoutMs + + ", state=" + state + + ", pendingState=" + pendingState + + ", topicPartitions=" + topicPartitions + + ", txnStartTimestamp=" + txnStartTimestamp + + ", txnLastUpdateTimestamp=" + txnLastUpdateTimestamp + + ", clientTransactionVersion=" + clientTransactionVersion + + ")"; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (obj == null || getClass() != obj.getClass()) return false; + + TransactionMetadata other = (TransactionMetadata) obj; + return transactionalId.equals(other.transactionalId) && + producerId == other.producerId && + prevProducerId == other.prevProducerId && + nextProducerId == other.nextProducerId && + producerEpoch == other.producerEpoch && + lastProducerEpoch == other.lastProducerEpoch && + txnTimeoutMs == other.txnTimeoutMs && + state.equals(other.state) && + topicPartitions.equals(other.topicPartitions) && + txnStartTimestamp == other.txnStartTimestamp && + txnLastUpdateTimestamp == other.txnLastUpdateTimestamp && + clientTransactionVersion.equals(other.clientTransactionVersion); + } + + @Override + public int hashCode() { + return Objects.hash( + transactionalId, + producerId, + prevProducerId, + nextProducerId, + producerEpoch, + lastProducerEpoch, + txnTimeoutMs, + state, + topicPartitions, + txnStartTimestamp, + txnLastUpdateTimestamp, + clientTransactionVersion + ); + } + + /** + * This class is used to hold the data that is needed to transition the transaction metadata to a new state. + * The data is copied from the current transaction metadata to avoid a lot of duplicated code in the prepare methods. + */ + private class TransitionData { + final TransactionState state; + long producerId = TransactionMetadata.this.producerId; + long nextProducerId = TransactionMetadata.this.nextProducerId; + short producerEpoch = TransactionMetadata.this.producerEpoch; + short lastProducerEpoch = TransactionMetadata.this.lastProducerEpoch; + int txnTimeoutMs = TransactionMetadata.this.txnTimeoutMs; + HashSet topicPartitions = TransactionMetadata.this.topicPartitions; + long txnStartTimestamp = TransactionMetadata.this.txnStartTimestamp; + long txnLastUpdateTimestamp = TransactionMetadata.this.txnLastUpdateTimestamp; + TransactionVersion clientTransactionVersion = TransactionMetadata.this.clientTransactionVersion; + + private TransitionData(TransactionState state) { + this.state = state; + } + } +} \ No newline at end of file diff --git a/transaction-coordinator/src/main/java/org/apache/kafka/coordinator/transaction/TxnTransitMetadata.java b/transaction-coordinator/src/main/java/org/apache/kafka/coordinator/transaction/TxnTransitMetadata.java index 59a05227d7d..452c168687e 100644 --- a/transaction-coordinator/src/main/java/org/apache/kafka/coordinator/transaction/TxnTransitMetadata.java +++ b/transaction-coordinator/src/main/java/org/apache/kafka/coordinator/transaction/TxnTransitMetadata.java @@ -19,7 +19,7 @@ package org.apache.kafka.coordinator.transaction; import org.apache.kafka.common.TopicPartition; import org.apache.kafka.server.common.TransactionVersion; -import java.util.Set; +import java.util.HashSet; /** * Immutable object representing the target transition of the transaction metadata @@ -32,7 +32,9 @@ public record TxnTransitMetadata( short lastProducerEpoch, int txnTimeoutMs, TransactionState txnState, - Set topicPartitions, + // The TransactionMetadata#topicPartitions field is mutable. + // To avoid deep copy when assigning value from TxnTransitMetadata to TransactionMetadata, use HashSet here. + HashSet topicPartitions, long txnStartTimestamp, long txnLastUpdateTimestamp, TransactionVersion clientTransactionVersion