KAFKA-14884: Include check transaction is still ongoing right before append (take 2) (#13787)

Introduced extra mapping to track verification state.

When verifying, there is a race condition that the add partitions verification response returns that the partition is in the ongoing transaction, but an abort marker is written before we get to append. Therefore, we track any given transaction we are verifying with an object unique to that transaction.

We check this unique state upon the first append to the log. After that, we can rely on currentTransactionFirstOffset. We remove the verification state on appending to the log with a transactional data record or marker.

We will also clean up lingering verification state entries via the producer state entry expiration mechanism. We do not update the the timestamp on retrying a verification for a transaction, so each entry must be verified before producer.id.expiration.ms.

There were a few other fixes:
- Moved the transaction manager handling for failed batch into the future completed exceptionally block to avoid processing it twice (this caused issues in unit tests)
- handle interrupted exceptions encountered when callback thread encountered them
- change handling to throw error if we try to set verification state and leaderLogIfLocal is None.

Reviewers: David Jacot <djacot@confluent.io>, Artem Livshits <alivshits@confluent.io>, Jason Gustafson <jason@confluent.io>
This commit is contained in:
Justine Olshan 2023-07-14 15:18:11 -07:00 committed by GitHub
parent d9253fed5c
commit ea0bb00126
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 722 additions and 157 deletions

View File

@ -793,19 +793,18 @@ public class Sender implements Runnable {
Function<Integer, RuntimeException> recordExceptions,
boolean adjustSequenceNumbers
) {
this.sensors.recordErrors(batch.topicPartition.topic(), batch.recordCount);
if (batch.completeExceptionally(topLevelException, recordExceptions)) {
if (transactionManager != null) {
try {
// This call can throw an exception in the rare case that there's an invalid state transition
// attempted. Catch these so as not to interfere with the rest of the logic.
transactionManager.handleFailedBatch(batch, topLevelException, adjustSequenceNumbers);
} catch (Exception e) {
log.debug("Encountered error when handling a failed batch", e);
log.debug("Encountered error when transaction manager was handling a failed batch", e);
}
}
this.sensors.recordErrors(batch.topicPartition.topic(), batch.recordCount);
if (batch.completeExceptionally(topLevelException, recordExceptions)) {
maybeRemoveAndDeallocateBatch(batch);
}
}

View File

@ -1030,6 +1030,11 @@ public class TransactionManager {
return isTransactional() && currentState == State.READY;
}
// visible for testing
synchronized boolean isInitializing() {
return isTransactional() && currentState == State.INITIALIZING;
}
void handleCoordinatorReady() {
NodeApiVersions nodeApiVersions = transactionCoordinator != null ?
apiVersions.get(transactionCoordinator.idString()) :

View File

@ -189,6 +189,10 @@ public class MockClient implements KafkaClient {
@Override
public void disconnect(String node) {
disconnect(node, false);
}
public void disconnect(String node, boolean allowLateResponses) {
long now = time.milliseconds();
Iterator<ClientRequest> iter = requests.iterator();
while (iter.hasNext()) {
@ -197,6 +201,7 @@ public class MockClient implements KafkaClient {
short version = request.requestBuilder().latestAllowedVersion();
responses.add(new ClientResponse(request.makeHeader(version), request.callback(), request.destination(),
request.createdTimeMs(), now, true, null, null, null));
if (!allowLateResponses)
iter.remove();
}
}

View File

@ -3057,6 +3057,56 @@ public class SenderTest {
assertEquals(RETRY_BACKOFF_MS, time.milliseconds() - request2);
}
@Test
public void testReceiveFailedBatchTwiceWithTransactions() throws Exception {
ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(123456L, (short) 0);
apiVersions.update("0", NodeApiVersions.create(ApiKeys.INIT_PRODUCER_ID.id, (short) 0, (short) 3));
TransactionManager txnManager = new TransactionManager(logContext, "testFailTwice", 60000, 100, apiVersions);
setupWithTransactionState(txnManager);
doInitTransactions(txnManager, producerIdAndEpoch);
txnManager.beginTransaction();
txnManager.maybeAddPartition(tp0);
client.prepareResponse(buildAddPartitionsToTxnResponseData(0, Collections.singletonMap(tp0, Errors.NONE)));
sender.runOnce();
// Send first ProduceRequest
Future<RecordMetadata> request1 = appendToAccumulator(tp0);
sender.runOnce(); // send request
Node node = metadata.fetch().nodes().get(0);
time.sleep(2000L);
client.disconnect(node.idString(), true);
client.backoff(node, 10);
sender.runOnce(); // now expire the batch.
assertFutureFailure(request1, TimeoutException.class);
time.sleep(20);
sendIdempotentProducerResponse(0, tp0, Errors.INVALID_RECORD, -1);
sender.runOnce(); // receive late response
// Loop once and confirm that the transaction manager does not enter a fatal error state
sender.runOnce();
assertTrue(txnManager.hasAbortableError());
TransactionalRequestResult result = txnManager.beginAbort();
sender.runOnce();
respondToEndTxn(Errors.NONE);
sender.runOnce();
assertTrue(txnManager::isInitializing);
prepareInitProducerResponse(Errors.NONE, producerIdAndEpoch.producerId, producerIdAndEpoch.epoch);
sender.runOnce();
assertTrue(txnManager::isReady);
assertTrue(result.isSuccessful());
result.await();
txnManager.beginTransaction();
}
private void verifyErrorMessage(ProduceResponse response, String expectedMessage) throws Exception {
Future<RecordMetadata> future = appendToAccumulator(tp0, 0L, "key", "value");
sender.runOnce(); // connect

View File

@ -576,8 +576,12 @@ class Partition(val topicPartition: TopicPartition,
}
}
def hasOngoingTransaction(producerId: Long): Boolean = {
leaderLogIfLocal.exists(leaderLog => leaderLog.hasOngoingTransaction(producerId))
// Returns a verification guard object if we need to verify. This starts or continues the verification process. Otherwise return null.
def maybeStartTransactionVerification(producerId: Long): Object = {
leaderLogIfLocal match {
case Some(log) => log.maybeStartTransactionVerification(producerId)
case None => throw new NotLeaderOrFollowerException();
}
}
// Return true if the future replica exists and it has caught up with the current replica for this partition
@ -1279,7 +1283,7 @@ class Partition(val topicPartition: TopicPartition,
}
def appendRecordsToLeader(records: MemoryRecords, origin: AppendOrigin, requiredAcks: Int,
requestLocal: RequestLocal): LogAppendInfo = {
requestLocal: RequestLocal, verificationGuard: Object = null): LogAppendInfo = {
val (info, leaderHWIncremented) = inReadLock(leaderIsrUpdateLock) {
leaderLogIfLocal match {
case Some(leaderLog) =>
@ -1293,7 +1297,7 @@ class Partition(val topicPartition: TopicPartition,
}
val info = leaderLog.appendAsLeader(records, leaderEpoch = this.leaderEpoch, origin,
interBrokerProtocolVersion, requestLocal)
interBrokerProtocolVersion, requestLocal, verificationGuard)
// we may need to increment high watermark since ISR could be down to 1
(info, maybeIncrementLeaderHW(leaderLog))

View File

@ -577,6 +577,28 @@ class UnifiedLog(@volatile var logStartOffset: Long,
result
}
/**
* Maybe create and return the verification guard object for the given producer ID if the transaction is not yet ongoing.
* Creation starts the verification process. Otherwise return null.
*/
def maybeStartTransactionVerification(producerId: Long): Object = lock synchronized {
if (hasOngoingTransaction(producerId))
null
else
getOrMaybeCreateVerificationGuard(producerId, true)
}
/**
* Maybe create the VerificationStateEntry for the given producer ID -- if an entry is present, return its verification guard, otherwise, return null.
*/
def getOrMaybeCreateVerificationGuard(producerId: Long, createIfAbsent: Boolean = false): Object = lock synchronized {
val entry = producerStateManager.verificationStateEntry(producerId, createIfAbsent)
if (entry != null) entry.verificationGuard else null
}
/**
* Return true if the given producer ID has a transaction ongoing.
*/
def hasOngoingTransaction(producerId: Long): Boolean = lock synchronized {
val entry = producerStateManager.activeProducers.get(producerId)
entry != null && entry.currentTxnFirstOffset.isPresent
@ -662,9 +684,10 @@ class UnifiedLog(@volatile var logStartOffset: Long,
leaderEpoch: Int,
origin: AppendOrigin = AppendOrigin.CLIENT,
interBrokerProtocolVersion: MetadataVersion = MetadataVersion.latest,
requestLocal: RequestLocal = RequestLocal.NoCaching): LogAppendInfo = {
requestLocal: RequestLocal = RequestLocal.NoCaching,
verificationGuard: Object = null): LogAppendInfo = {
val validateAndAssignOffsets = origin != AppendOrigin.RAFT_LEADER
append(records, origin, interBrokerProtocolVersion, validateAndAssignOffsets, leaderEpoch, Some(requestLocal), ignoreRecordSize = false)
append(records, origin, interBrokerProtocolVersion, validateAndAssignOffsets, leaderEpoch, Some(requestLocal), verificationGuard, ignoreRecordSize = false)
}
/**
@ -681,6 +704,7 @@ class UnifiedLog(@volatile var logStartOffset: Long,
validateAndAssignOffsets = false,
leaderEpoch = -1,
requestLocal = None,
verificationGuard = null,
// disable to check the validation of record size since the record is already accepted by leader.
ignoreRecordSize = true)
}
@ -709,6 +733,7 @@ class UnifiedLog(@volatile var logStartOffset: Long,
validateAndAssignOffsets: Boolean,
leaderEpoch: Int,
requestLocal: Option[RequestLocal],
verificationGuard: Object,
ignoreRecordSize: Boolean): LogAppendInfo = {
// We want to ensure the partition metadata file is written to the log dir before any log data is written to disk.
// This will ensure that any log data can be recovered with the correct topic ID in the case of failure.
@ -833,7 +858,7 @@ class UnifiedLog(@volatile var logStartOffset: Long,
// now that we have valid records, offsets assigned, and timestamps updated, we need to
// validate the idempotent/transactional state of the producers and collect some metadata
val (updatedProducers, completedTxns, maybeDuplicate) = analyzeAndValidateProducerState(
logOffsetMetadata, validRecords, origin)
logOffsetMetadata, validRecords, origin, verificationGuard)
maybeDuplicate match {
case Some(duplicate) =>
@ -961,7 +986,8 @@ class UnifiedLog(@volatile var logStartOffset: Long,
private def analyzeAndValidateProducerState(appendOffsetMetadata: LogOffsetMetadata,
records: MemoryRecords,
origin: AppendOrigin):
origin: AppendOrigin,
requestVerificationGuard: Object):
(mutable.Map[Long, ProducerAppendInfo], List[CompletedTxn], Option[BatchMetadata]) = {
val updatedProducers = mutable.Map.empty[Long, ProducerAppendInfo]
val completedTxns = ListBuffer.empty[CompletedTxn]
@ -978,6 +1004,25 @@ class UnifiedLog(@volatile var logStartOffset: Long,
if (duplicateBatch.isPresent) {
return (updatedProducers, completedTxns.toList, Some(duplicateBatch.get()))
}
// Verify that if the record is transactional & the append origin is client, that we either have an ongoing transaction or verified transaction state.
// This guarantees that transactional records are never written to the log outside of the transaction coordinator's knowledge of an open transaction on
// the partition. If we do not have an ongoing transaction or correct guard, return an error and do not append.
// There are two phases -- the first append to the log and subsequent appends.
//
// 1. First append: Verification starts with creating a verification guard object, sending a verification request to the transaction coordinator, and
// given a "verified" response, continuing the append path. (A non-verified response throws an error.) We create the unique verification guard for the transaction
// to ensure there is no race between the transaction coordinator response and an abort marker getting written to the log. We need a unique guard because we could
// have a sequence of events where we start a transaction verification, have the transaction coordinator send a verified response, write an abort marker,
// start a new transaction not aware of the partition, and receive the stale verification (ABA problem). With a unique verification guard object, this sequence would not
// result in appending to the log and would return an error. The guard is removed after the first append to the transaction and from then, we can rely on phase 2.
//
// 2. Subsequent appends: Once we write to the transaction, the in-memory state currentTxnFirstOffset is populated. This field remains until the
// transaction is completed or aborted. We can guarantee the transaction coordinator knows about the transaction given step 1 and that the transaction is still
// ongoing. If the transaction is expected to be ongoing, we will not set a verification guard. If the transaction is aborted, hasOngoingTransaction is false and
// requestVerificationGuard is null, so we will throw an error. A subsequent produce request (retry) should create verification state and return to phase 1.
if (batch.isTransactional && !hasOngoingTransaction(batch.producerId) && batchMissingRequiredVerification(batch, requestVerificationGuard))
throw new InvalidRecordException("Record was not part of an ongoing transaction")
}
// We cache offset metadata for the start of each transaction. This allows us to
@ -996,6 +1041,10 @@ class UnifiedLog(@volatile var logStartOffset: Long,
(updatedProducers, completedTxns.toList, None)
}
private def batchMissingRequiredVerification(batch: MutableRecordBatch, requestVerificationGuard: Object): Boolean = {
producerStateManager.producerStateManagerConfig().transactionVerificationEnabled() && (requestVerificationGuard != getOrMaybeCreateVerificationGuard(batch.producerId) || requestVerificationGuard == null)
}
/**
* Validate the following:
* <ol>
@ -1872,7 +1921,11 @@ object UnifiedLog extends Logging {
origin: AppendOrigin): Option[CompletedTxn] = {
val producerId = batch.producerId
val appendInfo = producers.getOrElseUpdate(producerId, producerStateManager.prepareUpdate(producerId, origin))
appendInfo.append(batch, firstOffsetMetadata.asJava).asScala
val completedTxn = appendInfo.append(batch, firstOffsetMetadata.asJava).asScala
// Whether we wrote a control marker or a data batch, we can remove verification guard since either the transaction is complete or we have a first offset.
if (batch.isTransactional)
producerStateManager.clearVerificationStateEntry(producerId)
completedTxn
}
/**

View File

@ -706,26 +706,16 @@ class ReplicaManager(val config: KafkaConfig,
if (isValidRequiredAcks(requiredAcks)) {
val sTime = time.milliseconds
val transactionalProducerIds = mutable.HashSet[Long]()
val (verifiedEntriesPerPartition, notYetVerifiedEntriesPerPartition) =
val verificationGuards: mutable.Map[TopicPartition, Object] = mutable.Map[TopicPartition, Object]()
val (verifiedEntriesPerPartition, notYetVerifiedEntriesPerPartition, errorsPerPartition) =
if (transactionStatePartition.isEmpty || !config.transactionPartitionVerificationEnable)
(entriesPerPartition, Map.empty)
(entriesPerPartition, Map.empty[TopicPartition, MemoryRecords], Map.empty[TopicPartition, Errors])
else {
entriesPerPartition.partition { case (topicPartition, records) =>
// Produce requests (only requests that require verification) should only have one batch per partition in "batches" but check all just to be safe.
val transactionalBatches = records.batches.asScala.filter(batch => batch.hasProducerId && batch.isTransactional)
transactionalBatches.foreach(batch => transactionalProducerIds.add(batch.producerId))
if (transactionalBatches.nonEmpty) {
getPartitionOrException(topicPartition).hasOngoingTransaction(transactionalBatches.head.producerId)
} else {
// If there is no producer ID in the batches, no need to verify.
true
}
}
}
// We should have exactly one producer ID for transactional records
if (transactionalProducerIds.size > 1) {
throw new InvalidPidMappingException("Transactional records contained more than one producer ID")
val verifiedEntries = mutable.Map[TopicPartition, MemoryRecords]()
val unverifiedEntries = mutable.Map[TopicPartition, MemoryRecords]()
val errorEntries = mutable.Map[TopicPartition, Errors]()
partitionEntriesForVerification(verificationGuards, entriesPerPartition, verifiedEntries, unverifiedEntries, errorEntries)
(verifiedEntries.toMap, unverifiedEntries.toMap, errorEntries.toMap)
}
def appendEntries(allEntries: Map[TopicPartition, MemoryRecords])(unverifiedEntries: Map[TopicPartition, Errors]): Unit = {
@ -738,7 +728,7 @@ class ReplicaManager(val config: KafkaConfig,
}
val localProduceResults = appendToLocalLog(internalTopicsAllowed = internalTopicsAllowed,
origin, verifiedEntries, requiredAcks, requestLocal)
origin, verifiedEntries, requiredAcks, requestLocal, verificationGuards.toMap)
debug("Produce to local log in %d ms".format(time.milliseconds - sTime))
val unverifiedResults = unverifiedEntries.map { case (topicPartition, error) =>
@ -750,7 +740,14 @@ class ReplicaManager(val config: KafkaConfig,
)
}
val allResults = localProduceResults ++ unverifiedResults
val errorResults = errorsPerPartition.map { case (topicPartition, error) =>
topicPartition -> LogAppendResult(
LogAppendInfo.UNKNOWN_LOG_APPEND_INFO,
Some(error.exception())
)
}
val allResults = localProduceResults ++ unverifiedResults ++ errorResults
val produceStatus = allResults.map { case (topicPartition, result) =>
topicPartition -> ProducePartitionStatus(
@ -851,6 +848,40 @@ class ReplicaManager(val config: KafkaConfig,
}
}
private def partitionEntriesForVerification(verificationGuards: mutable.Map[TopicPartition, Object],
entriesPerPartition: Map[TopicPartition, MemoryRecords],
verifiedEntries: mutable.Map[TopicPartition, MemoryRecords],
unverifiedEntries: mutable.Map[TopicPartition, MemoryRecords],
errorEntries: mutable.Map[TopicPartition, Errors]): Unit= {
val transactionalProducerIds = mutable.HashSet[Long]()
entriesPerPartition.foreach { case (topicPartition, records) =>
try {
// Produce requests (only requests that require verification) should only have one batch per partition in "batches" but check all just to be safe.
val transactionalBatches = records.batches.asScala.filter(batch => batch.hasProducerId && batch.isTransactional)
transactionalBatches.foreach(batch => transactionalProducerIds.add(batch.producerId))
if (transactionalBatches.nonEmpty) {
// We return verification guard if the partition needs to be verified. If no state is present, no need to verify.
val verificationGuard = getPartitionOrException(topicPartition).maybeStartTransactionVerification(records.firstBatch.producerId)
if (verificationGuard != null) {
verificationGuards.put(topicPartition, verificationGuard)
unverifiedEntries.put(topicPartition, records)
} else
verifiedEntries.put(topicPartition, records)
} else {
// If there is no producer ID or transactional records in the batches, no need to verify.
verifiedEntries.put(topicPartition, records)
}
} catch {
case e: Exception => errorEntries.put(topicPartition, Errors.forException(e))
}
}
// We should have exactly one producer ID for transactional records
if (transactionalProducerIds.size > 1) {
throw new InvalidPidMappingException("Transactional records contained more than one producer ID")
}
}
/**
* Delete records on leader replicas of the partition, and wait for delete records operation be propagated to other replicas;
* the callback function will be triggered either when timeout or logStartOffset of all live replicas have reached the specified offset
@ -1107,7 +1138,8 @@ class ReplicaManager(val config: KafkaConfig,
origin: AppendOrigin,
entriesPerPartition: Map[TopicPartition, MemoryRecords],
requiredAcks: Short,
requestLocal: RequestLocal): Map[TopicPartition, LogAppendResult] = {
requestLocal: RequestLocal,
verificationGuards: Map[TopicPartition, Object]): Map[TopicPartition, LogAppendResult] = {
val traceEnabled = isTraceEnabled
def processFailedRecord(topicPartition: TopicPartition, t: Throwable) = {
val logStartOffset = onlinePartition(topicPartition).map(_.logStartOffset).getOrElse(-1L)
@ -1133,7 +1165,7 @@ class ReplicaManager(val config: KafkaConfig,
} else {
try {
val partition = getPartitionOrException(topicPartition)
val info = partition.appendRecordsToLeader(records, origin, requiredAcks, requestLocal)
val info = partition.appendRecordsToLeader(records, origin, requiredAcks, requestLocal, verificationGuards.getOrElse(topicPartition, null))
val numAppendedMessages = info.numMessages
// update stats for successfully appended bytes and messages as bytesInRate and messageInRate

View File

@ -74,7 +74,7 @@ class AbstractPartitionTest {
logDir1 = TestUtils.randomPartitionLogDir(tmpDir)
logDir2 = TestUtils.randomPartitionLogDir(tmpDir)
logManager = TestUtils.createLogManager(Seq(logDir1, logDir2), logConfig, configRepository,
new CleanerConfig(false), time, interBrokerProtocolVersion)
new CleanerConfig(false), time, interBrokerProtocolVersion, transactionVerificationEnabled = true)
logManager.startup(Set.empty)
alterPartitionManager = TestUtils.createAlterIsrManager()

View File

@ -449,8 +449,8 @@ class PartitionLockTest extends Logging {
keepPartitionMetadataFile = true) {
override def appendAsLeader(records: MemoryRecords, leaderEpoch: Int, origin: AppendOrigin,
interBrokerProtocolVersion: MetadataVersion, requestLocal: RequestLocal): LogAppendInfo = {
val appendInfo = super.appendAsLeader(records, leaderEpoch, origin, interBrokerProtocolVersion, requestLocal)
interBrokerProtocolVersion: MetadataVersion, requestLocal: RequestLocal, verificationGuard: Object): LogAppendInfo = {
val appendInfo = super.appendAsLeader(records, leaderEpoch, origin, interBrokerProtocolVersion, requestLocal, verificationGuard)
appendSemaphore.acquire()
appendInfo
}

View File

@ -32,7 +32,7 @@ import org.apache.kafka.common.record.FileRecords.TimestampAndOffset
import org.apache.kafka.common.record._
import org.apache.kafka.common.requests.{AlterPartitionResponse, FetchRequest, ListOffsetsRequest, RequestHeader}
import org.apache.kafka.common.utils.SystemTime
import org.apache.kafka.common.{IsolationLevel, TopicPartition, Uuid}
import org.apache.kafka.common.{InvalidRecordException, IsolationLevel, TopicPartition, Uuid}
import org.apache.kafka.metadata.LeaderRecoveryState
import org.junit.jupiter.api.Assertions._
import org.junit.jupiter.api.Test
@ -996,8 +996,10 @@ class PartitionTest extends AbstractPartitionTest {
new SimpleRecord("k1".getBytes, "v1".getBytes),
new SimpleRecord("k2".getBytes, "v2".getBytes),
new SimpleRecord("k3".getBytes, "v3".getBytes)),
baseOffset = 0L)
partition.appendRecordsToLeader(records, origin = AppendOrigin.CLIENT, requiredAcks = 0, RequestLocal.withThreadConfinedCaching)
baseOffset = 0L,
producerId = 2L)
val verificationGuard = partition.maybeStartTransactionVerification(2L)
partition.appendRecordsToLeader(records, origin = AppendOrigin.CLIENT, requiredAcks = 0, RequestLocal.withThreadConfinedCaching, verificationGuard)
def fetchOffset(isolationLevel: Option[IsolationLevel], timestamp: Long): TimestampAndOffset = {
val res = partition.fetchOffsetForTimestamp(timestamp,
@ -3349,7 +3351,7 @@ class PartitionTest extends AbstractPartitionTest {
}
@Test
def testHasOngoingTransaction(): Unit = {
def testMaybeStartTransactionVerification(): Unit = {
val controllerEpoch = 0
val leaderEpoch = 5
val replicas = List[Integer](brokerId, brokerId + 1).asJava
@ -3367,7 +3369,6 @@ class PartitionTest extends AbstractPartitionTest {
.setReplicas(replicas)
.setIsNew(true), offsetCheckpoints, None), "Expected become leader transition to succeed")
assertEquals(leaderEpoch, partition.getLeaderEpoch)
assertFalse(partition.hasOngoingTransaction(producerId))
val idempotentRecords = createIdempotentRecords(List(
new SimpleRecord("k1".getBytes, "v1".getBytes),
@ -3376,17 +3377,35 @@ class PartitionTest extends AbstractPartitionTest {
baseOffset = 0L,
producerId = producerId)
partition.appendRecordsToLeader(idempotentRecords, origin = AppendOrigin.CLIENT, requiredAcks = 1, RequestLocal.withThreadConfinedCaching)
assertFalse(partition.hasOngoingTransaction(producerId))
val transactionRecords = createTransactionalRecords(List(
def transactionRecords() = createTransactionalRecords(List(
new SimpleRecord("k1".getBytes, "v1".getBytes),
new SimpleRecord("k2".getBytes, "v2".getBytes),
new SimpleRecord("k3".getBytes, "v3".getBytes)),
baseOffset = 0L,
baseSequence = 3,
producerId = producerId)
partition.appendRecordsToLeader(transactionRecords, origin = AppendOrigin.CLIENT, requiredAcks = 1, RequestLocal.withThreadConfinedCaching)
assertTrue(partition.hasOngoingTransaction(producerId))
// When verification guard is not there, we should not be able to append.
assertThrows(classOf[InvalidRecordException], () => partition.appendRecordsToLeader(transactionRecords(), origin = AppendOrigin.CLIENT, requiredAcks = 1, RequestLocal.withThreadConfinedCaching))
// Before appendRecordsToLeader is called, ReplicaManager will call maybeStartTransactionVerification. We should get a non-null verification object.
val verificationGuard = partition.maybeStartTransactionVerification(producerId)
assertNotNull(verificationGuard)
// With the wrong verification guard, append should fail.
assertThrows(classOf[InvalidRecordException], () => partition.appendRecordsToLeader(transactionRecords(),
origin = AppendOrigin.CLIENT, requiredAcks = 1, RequestLocal.withThreadConfinedCaching, Optional.of(new Object)))
// We should return the same verification object when we still need to verify. Append should proceed.
val verificationGuard2 = partition.maybeStartTransactionVerification(producerId)
assertEquals(verificationGuard, verificationGuard2)
partition.appendRecordsToLeader(transactionRecords(), origin = AppendOrigin.CLIENT, requiredAcks = 1, RequestLocal.withThreadConfinedCaching, verificationGuard)
// We should no longer need a verification object. Future appends without verification guard will also succeed.
val verificationGuard3 = partition.maybeStartTransactionVerification(producerId)
assertNull(verificationGuard3)
partition.appendRecordsToLeader(transactionRecords(), origin = AppendOrigin.CLIENT, requiredAcks = 1, RequestLocal.withThreadConfinedCaching)
}
private def makeLeader(

View File

@ -1692,7 +1692,7 @@ class GroupMetadataManagerTest {
when(partition.appendRecordsToLeader(any[MemoryRecords],
origin = ArgumentMatchers.eq(AppendOrigin.COORDINATOR), requiredAcks = anyInt(),
any())).thenReturn(LogAppendInfo.UNKNOWN_LOG_APPEND_INFO)
any(), any())).thenReturn(LogAppendInfo.UNKNOWN_LOG_APPEND_INFO)
groupMetadataManager.cleanupGroupMetadata()
assertEquals(Some(group), groupMetadataManager.getGroup(groupId))
@ -1740,7 +1740,7 @@ class GroupMetadataManagerTest {
mockGetPartition()
when(partition.appendRecordsToLeader(recordsCapture.capture(),
origin = ArgumentMatchers.eq(AppendOrigin.COORDINATOR), requiredAcks = anyInt(),
any())).thenReturn(LogAppendInfo.UNKNOWN_LOG_APPEND_INFO)
any(), any())).thenReturn(LogAppendInfo.UNKNOWN_LOG_APPEND_INFO)
groupMetadataManager.cleanupGroupMetadata()
val records = recordsCapture.getValue.records.asScala.toList
@ -1783,7 +1783,7 @@ class GroupMetadataManagerTest {
mockGetPartition()
when(partition.appendRecordsToLeader(recordsCapture.capture(),
origin = ArgumentMatchers.eq(AppendOrigin.COORDINATOR), requiredAcks = anyInt(),
any())).thenReturn(LogAppendInfo.UNKNOWN_LOG_APPEND_INFO)
any(), any())).thenReturn(LogAppendInfo.UNKNOWN_LOG_APPEND_INFO)
groupMetadataManager.cleanupGroupMetadata()
val records = recordsCapture.getValue.records.asScala.toList
@ -1851,7 +1851,7 @@ class GroupMetadataManagerTest {
when(partition.appendRecordsToLeader(recordsCapture.capture(),
origin = ArgumentMatchers.eq(AppendOrigin.COORDINATOR), requiredAcks = anyInt(),
any())).thenReturn(LogAppendInfo.UNKNOWN_LOG_APPEND_INFO)
any(), any())).thenReturn(LogAppendInfo.UNKNOWN_LOG_APPEND_INFO)
groupMetadataManager.cleanupGroupMetadata()
// verify the tombstones are correct and only for the expired offsets
@ -1959,7 +1959,7 @@ class GroupMetadataManagerTest {
// expect the offset tombstone
when(partition.appendRecordsToLeader(any[MemoryRecords],
origin = ArgumentMatchers.eq(AppendOrigin.COORDINATOR), requiredAcks = anyInt(),
any())).thenReturn(LogAppendInfo.UNKNOWN_LOG_APPEND_INFO)
any(), any())).thenReturn(LogAppendInfo.UNKNOWN_LOG_APPEND_INFO)
groupMetadataManager.cleanupGroupMetadata()
// group is empty now, only one offset should expire
@ -1984,7 +1984,7 @@ class GroupMetadataManagerTest {
// expect the offset tombstone
when(partition.appendRecordsToLeader(any[MemoryRecords],
origin = ArgumentMatchers.eq(AppendOrigin.COORDINATOR), requiredAcks = anyInt(),
any())).thenReturn(LogAppendInfo.UNKNOWN_LOG_APPEND_INFO)
any(), any())).thenReturn(LogAppendInfo.UNKNOWN_LOG_APPEND_INFO)
groupMetadataManager.cleanupGroupMetadata()
// one more offset should expire
@ -2041,7 +2041,7 @@ class GroupMetadataManagerTest {
// expect the offset tombstone
when(partition.appendRecordsToLeader(any[MemoryRecords],
origin = ArgumentMatchers.eq(AppendOrigin.COORDINATOR), requiredAcks = anyInt(),
any())).thenReturn(LogAppendInfo.UNKNOWN_LOG_APPEND_INFO)
any(), any())).thenReturn(LogAppendInfo.UNKNOWN_LOG_APPEND_INFO)
groupMetadataManager.cleanupGroupMetadata()
// group and all its offsets should be gone now
@ -2131,7 +2131,7 @@ class GroupMetadataManagerTest {
// expect the offset tombstone
when(partition.appendRecordsToLeader(any[MemoryRecords],
origin = ArgumentMatchers.eq(AppendOrigin.COORDINATOR), requiredAcks = anyInt(),
any())).thenReturn(LogAppendInfo.UNKNOWN_LOG_APPEND_INFO)
any(), any())).thenReturn(LogAppendInfo.UNKNOWN_LOG_APPEND_INFO)
groupMetadataManager.cleanupGroupMetadata()
// group and all its offsets should be gone now
@ -2283,13 +2283,13 @@ class GroupMetadataManagerTest {
// expect the offset tombstone
when(partition.appendRecordsToLeader(any[MemoryRecords],
origin = ArgumentMatchers.eq(AppendOrigin.COORDINATOR), requiredAcks = anyInt(),
any())).thenReturn(LogAppendInfo.UNKNOWN_LOG_APPEND_INFO)
any(), any())).thenReturn(LogAppendInfo.UNKNOWN_LOG_APPEND_INFO)
groupMetadataManager.cleanupGroupMetadata()
verify(partition).appendRecordsToLeader(any[MemoryRecords],
origin = ArgumentMatchers.eq(AppendOrigin.COORDINATOR), requiredAcks = anyInt(),
any())
any(), any())
verify(replicaManager, times(2)).onlinePartition(groupTopicPartition)
assertEquals(Some(group), groupMetadataManager.getGroup(groupId))

View File

@ -29,7 +29,7 @@ import org.apache.kafka.common.errors._
import org.apache.kafka.common.internals.Topic
import org.apache.kafka.common.record._
import org.apache.kafka.common.utils.{MockTime, Utils}
import org.apache.kafka.storage.internals.log.{AppendOrigin, CompletedTxn, LogFileUtils, LogOffsetMetadata, ProducerAppendInfo, ProducerStateEntry, ProducerStateManager, ProducerStateManagerConfig, TxnMetadata}
import org.apache.kafka.storage.internals.log.{AppendOrigin, CompletedTxn, LogFileUtils, LogOffsetMetadata, ProducerAppendInfo, ProducerStateEntry, ProducerStateManager, ProducerStateManagerConfig, TxnMetadata, VerificationStateEntry}
import org.junit.jupiter.api.Assertions._
import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
import org.mockito.Mockito.{mock, when}
@ -1087,6 +1087,41 @@ class ProducerStateManagerTest {
assertTrue(!manager.latestSnapshotOffset.isPresent)
}
@Test
def testEntryForVerification(): Unit = {
val originalEntry = stateManager.verificationStateEntry(producerId, true)
val originalEntryVerificationGuard = originalEntry.verificationGuard()
def verifyEntry(producerId: Long, newEntry: VerificationStateEntry): Unit = {
val entry = stateManager.verificationStateEntry(producerId, false)
assertEquals(originalEntryVerificationGuard, entry.verificationGuard)
assertEquals(entry.verificationGuard, newEntry.verificationGuard)
}
// If we already have an entry, reuse it.
val updatedEntry = stateManager.verificationStateEntry(producerId, true)
verifyEntry(producerId, updatedEntry)
// Add the transactional data and clear the entry.
append(stateManager, producerId, 0, 0, offset = 0, isTransactional = true)
stateManager.clearVerificationStateEntry(producerId)
assertNull(stateManager.verificationStateEntry(producerId, false))
}
@Test
def testVerificationStateEntryExpiration(): Unit = {
val originalEntry = stateManager.verificationStateEntry(producerId, true)
// Before timeout we do not remove. Note: Accessing the verification entry does not update the time.
time.sleep(producerStateManagerConfig.producerIdExpirationMs / 2)
stateManager.removeExpiredProducers(time.milliseconds())
assertEquals(originalEntry, stateManager.verificationStateEntry(producerId, false))
time.sleep((producerStateManagerConfig.producerIdExpirationMs / 2) + 1)
stateManager.removeExpiredProducers(time.milliseconds())
assertNull(stateManager.verificationStateEntry(producerId, false))
}
private def testLoadFromCorruptSnapshot(makeFileCorrupt: FileChannel => Unit): Unit = {
val epoch = 0.toShort
val producerId = 1L

View File

@ -3667,6 +3667,118 @@ class UnifiedLogTest {
listener.verify(expectedHighWatermark = 4)
}
@Test
def testTransactionIsOngoingAndVerificationGuard(): Unit = {
val producerStateManagerConfig = new ProducerStateManagerConfig(86400000, true)
val producerId = 23L
val producerEpoch = 1.toShort
val sequence = 3
val logConfig = LogTestUtils.createLogConfig(segmentBytes = 2048 * 5)
val log = createLog(logDir, logConfig, producerStateManagerConfig = producerStateManagerConfig)
assertFalse(log.hasOngoingTransaction(producerId))
assertNull(log.getOrMaybeCreateVerificationGuard(producerId))
val idempotentRecords = MemoryRecords.withIdempotentRecords(
CompressionType.NONE,
producerId,
producerEpoch,
sequence,
new SimpleRecord("1".getBytes),
new SimpleRecord("2".getBytes)
)
val verificationGuard = log.maybeStartTransactionVerification(producerId)
assertNotNull(verificationGuard)
log.appendAsLeader(idempotentRecords, leaderEpoch = 0)
assertFalse(log.hasOngoingTransaction(producerId))
// Since we wrote idempotent records, we keep verification guard.
assertEquals(verificationGuard, log.getOrMaybeCreateVerificationGuard(producerId))
val transactionalRecords = MemoryRecords.withTransactionalRecords(
CompressionType.NONE,
producerId,
producerEpoch,
sequence + 2,
new SimpleRecord("1".getBytes),
new SimpleRecord("2".getBytes)
)
log.appendAsLeader(transactionalRecords, leaderEpoch = 0, verificationGuard = verificationGuard)
assertTrue(log.hasOngoingTransaction(producerId))
// Verification guard should be cleared now.
assertNull(log.getOrMaybeCreateVerificationGuard(producerId))
// A subsequent maybeStartTransactionVerification will be empty since we are already verified.
assertNull(log.maybeStartTransactionVerification(producerId))
val endTransactionMarkerRecord = MemoryRecords.withEndTransactionMarker(
producerId,
producerEpoch,
new EndTransactionMarker(ControlRecordType.COMMIT, 0)
)
log.appendAsLeader(endTransactionMarkerRecord, origin = AppendOrigin.COORDINATOR, leaderEpoch = 0)
assertFalse(log.hasOngoingTransaction(producerId))
assertNull(log.getOrMaybeCreateVerificationGuard(producerId))
// A new maybeStartTransactionVerification will not be empty, as we need to verify the next transaction.
val newVerificationGuard = log.maybeStartTransactionVerification(producerId)
assertNotNull(newVerificationGuard)
assertNotEquals(verificationGuard, newVerificationGuard)
}
@Test
def testEmptyTransactionStillClearsVerificationGuard(): Unit = {
val producerStateManagerConfig = new ProducerStateManagerConfig(86400000, true)
val producerId = 23L
val producerEpoch = 1.toShort
val logConfig = LogTestUtils.createLogConfig(segmentBytes = 2048 * 5)
val log = createLog(logDir, logConfig, producerStateManagerConfig = producerStateManagerConfig)
val verificationGuard = log.maybeStartTransactionVerification(producerId)
assertNotNull(verificationGuard)
val endTransactionMarkerRecord = MemoryRecords.withEndTransactionMarker(
producerId,
producerEpoch,
new EndTransactionMarker(ControlRecordType.COMMIT, 0)
)
log.appendAsLeader(endTransactionMarkerRecord, origin = AppendOrigin.COORDINATOR, leaderEpoch = 0)
assertFalse(log.hasOngoingTransaction(producerId))
assertNull(log.getOrMaybeCreateVerificationGuard(producerId))
}
@Test
def testAllowNonZeroSequenceOnFirstAppendNonZeroEpoch(): Unit = {
val producerStateManagerConfig = new ProducerStateManagerConfig(86400000, true)
val producerId = 23L
val producerEpoch = 1.toShort
val sequence = 3
val logConfig = LogTestUtils.createLogConfig(segmentBytes = 2048 * 5)
val log = createLog(logDir, logConfig, producerStateManagerConfig = producerStateManagerConfig)
assertFalse(log.hasOngoingTransaction(producerId))
assertNull(log.getOrMaybeCreateVerificationGuard(producerId))
val transactionalRecords = MemoryRecords.withTransactionalRecords(
CompressionType.NONE,
producerId,
producerEpoch,
sequence,
new SimpleRecord("1".getBytes),
new SimpleRecord("2".getBytes)
)
val verificationGuard = log.maybeStartTransactionVerification(producerId)
// Append should not throw error.
log.appendAsLeader(transactionalRecords, leaderEpoch = 0, verificationGuard = verificationGuard)
}
private def appendTransactionalToBuffer(buffer: ByteBuffer,
producerId: Long,
producerEpoch: Short,

View File

@ -87,6 +87,7 @@ class ReplicaManagerTest {
private val topicId = Uuid.randomUuid()
private val topicIds = scala.Predef.Map("test-topic" -> topicId)
private val topicNames = scala.Predef.Map(topicId -> "test-topic")
private val transactionalId = "txn"
private val time = new MockTime
private val scheduler = new MockScheduler(time)
private val metrics = new Metrics
@ -94,6 +95,7 @@ class ReplicaManagerTest {
private var config: KafkaConfig = _
private var quotaManager: QuotaManagers = _
private var mockRemoteLogManager: RemoteLogManager = _
private var addPartitionsToTxnManager: AddPartitionsToTxnManager = _
// Constants defined for readability
private val zkVersion = 0
@ -108,6 +110,14 @@ class ReplicaManagerTest {
alterPartitionManager = mock(classOf[AlterPartitionManager])
quotaManager = QuotaFactory.instantiate(config, metrics, time, "")
mockRemoteLogManager = mock(classOf[RemoteLogManager])
addPartitionsToTxnManager = mock(classOf[AddPartitionsToTxnManager])
// Anytime we try to verify, just automatically run the callback as though the transaction was verified.
when(addPartitionsToTxnManager.addTxnData(any(), any(), any())).thenAnswer {
invocationOnMock =>
val callback = invocationOnMock.getArgument(2, classOf[AddPartitionsToTxnManager.AppendCallback])
callback(Map.empty[TopicPartition, Errors].toMap)
}
}
@AfterEach
@ -596,7 +606,7 @@ class ReplicaManagerTest {
// Simulate producer id expiration.
// We use -1 because the timestamp in this test is set to -1, so when
// the expiration check subtracts timestamp, we get max value.
partition0.removeExpiredProducers(Long.MaxValue - 1);
partition0.removeExpiredProducers(Long.MaxValue - 1)
assertEquals(1, replicaManagerMetricValue())
} finally {
replicaManager.shutdown(checkpointHW = false)
@ -643,7 +653,7 @@ class ReplicaManagerTest {
val sequence = 9
val records = MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, epoch, sequence,
new SimpleRecord(time.milliseconds(), s"message $sequence".getBytes))
appendRecords(replicaManager, new TopicPartition(topic, 0), records).onFire { response =>
appendRecords(replicaManager, new TopicPartition(topic, 0), records, transactionalId = transactionalId, transactionStatePartition = Some(0)).onFire { response =>
assertEquals(Errors.NONE, response.error)
}
assertLateTransactionCount(Some(0))
@ -707,7 +717,7 @@ class ReplicaManagerTest {
for (sequence <- 0 until numRecords) {
val records = MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, epoch, sequence,
new SimpleRecord(s"message $sequence".getBytes))
appendRecords(replicaManager, new TopicPartition(topic, 0), records).onFire { response =>
appendRecords(replicaManager, new TopicPartition(topic, 0), records, transactionalId = transactionalId, transactionStatePartition = Some(0)).onFire { response =>
assertEquals(Errors.NONE, response.error)
}
}
@ -828,7 +838,7 @@ class ReplicaManagerTest {
for (sequence <- 0 until numRecords) {
val records = MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, epoch, sequence,
new SimpleRecord(s"message $sequence".getBytes))
appendRecords(replicaManager, new TopicPartition(topic, 0), records).onFire { response =>
appendRecords(replicaManager, new TopicPartition(topic, 0), records, transactionalId = transactionalId, transactionStatePartition = Some(0)).onFire { response =>
assertEquals(Errors.NONE, response.error)
}
}
@ -2130,54 +2140,35 @@ class ReplicaManagerTest {
}
@Test
def testVerificationForTransactionalPartitions(): Unit = {
val tp = new TopicPartition(topic, 0)
val transactionalId = "txn1"
def testVerificationForTransactionalPartitionsOnly(): Unit = {
val tp0 = new TopicPartition(topic, 0)
val tp1 = new TopicPartition(topic, 1)
val producerId = 24L
val producerEpoch = 0.toShort
val sequence = 0
val mockLogMgr = TestUtils.createLogManager(config.logDirs.map(new File(_)))
val metadataCache = mock(classOf[MetadataCache])
val node = new Node(0, "host1", 0)
val addPartitionsToTxnManager = mock(classOf[AddPartitionsToTxnManager])
val replicaManager = new ReplicaManager(
metrics = metrics,
config = config,
time = time,
scheduler = new MockScheduler(time),
logManager = mockLogMgr,
quotaManagers = quotaManager,
metadataCache = metadataCache,
logDirFailureChannel = new LogDirFailureChannel(config.logDirs.size),
alterPartitionManager = alterPartitionManager,
addPartitionsToTxnManager = Some(addPartitionsToTxnManager))
val replicaManager = setUpReplicaManagerWithMockedAddPartitionsToTxnManager(addPartitionsToTxnManager, List(tp0, tp1), node)
try {
val becomeLeaderRequest = makeLeaderAndIsrRequest(topicIds(tp.topic), tp, Seq(0, 1), LeaderAndIsr(1, List(0, 1)))
replicaManager.becomeLeaderOrFollower(1, becomeLeaderRequest, (_, _) => ())
replicaManager.becomeLeaderOrFollower(1,
makeLeaderAndIsrRequest(topicIds(tp0.topic), tp0, Seq(0, 1), LeaderAndIsr(1, List(0, 1))),
(_, _) => ())
// We must set up the metadata cache to handle the append and verification.
val metadataResponseTopic = Seq(new MetadataResponseTopic()
.setName(Topic.TRANSACTION_STATE_TOPIC_NAME)
.setPartitions(Seq(
new MetadataResponsePartition()
.setPartitionIndex(0)
.setLeaderId(0)).asJava))
val node = new Node(0, "host1", 0)
replicaManager.becomeLeaderOrFollower(1,
makeLeaderAndIsrRequest(topicIds(tp1.topic), tp1, Seq(0, 1), LeaderAndIsr(1, List(0, 1))),
(_, _) => ())
when(metadataCache.contains(tp)).thenReturn(true)
when(metadataCache.getTopicMetadata(Set(Topic.TRANSACTION_STATE_TOPIC_NAME), config.interBrokerListenerName)).thenReturn(metadataResponseTopic)
when(metadataCache.getAliveBrokerNode(0, config.interBrokerListenerName)).thenReturn(Some(node))
when(metadataCache.getAliveBrokerNode(1, config.interBrokerListenerName)).thenReturn(None)
// If we supply no transactional ID and idempotent records, we do not verify.
val idempotentRecords = MemoryRecords.withIdempotentRecords(CompressionType.NONE, producerId, producerEpoch, sequence,
new SimpleRecord("message".getBytes))
appendRecords(replicaManager, tp0, idempotentRecords)
verify(addPartitionsToTxnManager, times(0)).addTxnData(any(), any(), any[AddPartitionsToTxnManager.AppendCallback]())
assertNull(getVerificationGuard(replicaManager, tp0, producerId))
// We will attempt to schedule to the request handler thread using a non request handler thread. Set this to avoid error.
KafkaRequestHandler.setBypassThreadCheck(true)
// Append some transactional records.
val transactionalRecords = MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, producerEpoch, sequence,
new SimpleRecord(s"message $sequence".getBytes))
val result = appendRecords(replicaManager, tp, transactionalRecords, transactionalId = transactionalId, transactionStatePartition = Some(0))
// If we supply a transactional ID and some transactional and some idempotent records, we should only verify the topic partition with transactional records.
val transactionalRecords = MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, producerEpoch, sequence + 1,
new SimpleRecord("message".getBytes))
val transactionToAdd = new AddPartitionsToTxnTransaction()
.setTransactionalId(transactionalId)
@ -2185,27 +2176,103 @@ class ReplicaManagerTest {
.setProducerEpoch(producerEpoch)
.setVerifyOnly(true)
.setTopics(new AddPartitionsToTxnTopicCollection(
Seq(new AddPartitionsToTxnTopic().setName(tp.topic).setPartitions(Collections.singletonList(tp.partition))).iterator.asJava
Seq(new AddPartitionsToTxnTopic().setName(tp0.topic).setPartitions(Collections.singletonList(tp0.partition))).iterator.asJava
))
val idempotentRecords2 = MemoryRecords.withIdempotentRecords(CompressionType.NONE, producerId, producerEpoch, sequence,
new SimpleRecord("message".getBytes))
appendRecordsToMultipleTopics(replicaManager, Map(tp0 -> transactionalRecords, tp1 -> idempotentRecords2), transactionalId, Some(0))
verify(addPartitionsToTxnManager, times(1)).addTxnData(ArgumentMatchers.eq(node), ArgumentMatchers.eq(transactionToAdd), any[AddPartitionsToTxnManager.AppendCallback]())
assertNotNull(getVerificationGuard(replicaManager, tp0, producerId))
assertNull(getVerificationGuard(replicaManager, tp1, producerId))
} finally {
replicaManager.shutdown(checkpointHW = false)
}
}
@Test
def testTransactionVerificationFlow(): Unit = {
val tp0 = new TopicPartition(topic, 0)
val producerId = 24L
val producerEpoch = 0.toShort
val sequence = 6
val node = new Node(0, "host1", 0)
val addPartitionsToTxnManager = mock(classOf[AddPartitionsToTxnManager])
val replicaManager = setUpReplicaManagerWithMockedAddPartitionsToTxnManager(addPartitionsToTxnManager, List(tp0), node)
try {
replicaManager.becomeLeaderOrFollower(1,
makeLeaderAndIsrRequest(topicIds(tp0.topic), tp0, Seq(0, 1), LeaderAndIsr(1, List(0, 1))),
(_, _) => ())
// Append some transactional records.
val transactionalRecords = MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, producerEpoch, sequence,
new SimpleRecord("message".getBytes))
val transactionToAdd = new AddPartitionsToTxnTransaction()
.setTransactionalId(transactionalId)
.setProducerId(producerId)
.setProducerEpoch(producerEpoch)
.setVerifyOnly(true)
.setTopics(new AddPartitionsToTxnTopicCollection(
Seq(new AddPartitionsToTxnTopic().setName(tp0.topic).setPartitions(Collections.singletonList(tp0.partition))).iterator.asJava
))
val appendCallback = ArgumentCaptor.forClass(classOf[AddPartitionsToTxnManager.AppendCallback])
// We should add these partitions to the manager to verify.
val result = appendRecords(replicaManager, tp0, transactionalRecords, transactionalId = transactionalId, transactionStatePartition = Some(0))
val appendCallback = ArgumentCaptor.forClass(classOf[AddPartitionsToTxnManager.AppendCallback])
verify(addPartitionsToTxnManager, times(1)).addTxnData(ArgumentMatchers.eq(node), ArgumentMatchers.eq(transactionToAdd), appendCallback.capture())
val verificationGuard = getVerificationGuard(replicaManager, tp0, producerId)
assertEquals(verificationGuard, getVerificationGuard(replicaManager, tp0, producerId))
// Confirm we did not write to the log and instead returned error.
val callback: AddPartitionsToTxnManager.AppendCallback = appendCallback.getValue()
callback(Map(tp -> Errors.INVALID_RECORD).toMap)
callback(Map(tp0 -> Errors.INVALID_RECORD).toMap)
assertEquals(Errors.INVALID_RECORD, result.assertFired.error)
assertEquals(verificationGuard, getVerificationGuard(replicaManager, tp0, producerId))
// If we supply no transactional ID and idempotent records, we do not verify, so counter stays the same.
val idempotentRecords2 = MemoryRecords.withIdempotentRecords(CompressionType.NONE, producerId, producerEpoch, sequence + 1,
// This time verification is successful.
appendRecords(replicaManager, tp0, transactionalRecords, transactionalId = transactionalId, transactionStatePartition = Some(0))
val appendCallback2 = ArgumentCaptor.forClass(classOf[AddPartitionsToTxnManager.AppendCallback])
verify(addPartitionsToTxnManager, times(2)).addTxnData(ArgumentMatchers.eq(node), ArgumentMatchers.eq(transactionToAdd), appendCallback2.capture())
assertEquals(verificationGuard, getVerificationGuard(replicaManager, tp0, producerId))
val callback2: AddPartitionsToTxnManager.AppendCallback = appendCallback2.getValue()
callback2(Map.empty[TopicPartition, Errors].toMap)
assertEquals(null, getVerificationGuard(replicaManager, tp0, producerId))
assertTrue(replicaManager.localLog(tp0).get.hasOngoingTransaction(producerId))
} finally {
replicaManager.shutdown(checkpointHW = false)
}
}
@Test
def testTransactionVerificationGuardOnMultiplePartitions(): Unit = {
val mockTimer = new MockTimer(time)
val tp0 = new TopicPartition(topic, 0)
val tp1 = new TopicPartition(topic, 1)
val producerId = 24L
val producerEpoch = 0.toShort
val sequence = 0
val replicaManager = setupReplicaManagerWithMockedPurgatories(mockTimer)
try {
replicaManager.becomeLeaderOrFollower(1,
makeLeaderAndIsrRequest(topicIds(tp0.topic), tp0, Seq(0, 1), LeaderAndIsr(0, List(0, 1))),
(_, _) => ())
replicaManager.becomeLeaderOrFollower(1,
makeLeaderAndIsrRequest(topicIds(tp1.topic), tp1, Seq(0, 1), LeaderAndIsr(0, List(0, 1))),
(_, _) => ())
val transactionalRecords = MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, producerEpoch, sequence,
new SimpleRecord(s"message $sequence".getBytes))
appendRecords(replicaManager, tp, idempotentRecords2)
verify(addPartitionsToTxnManager, times(1)).addTxnData(ArgumentMatchers.eq(node), ArgumentMatchers.eq(transactionToAdd), any[AddPartitionsToTxnManager.AppendCallback]())
// If we supply a transactional ID and some transactional and some idempotent records, we should only verify the topic partition with transactional records.
appendRecordsToMultipleTopics(replicaManager, Map(tp -> transactionalRecords, new TopicPartition(topic, 1) -> idempotentRecords2), transactionalId, Some(0))
verify(addPartitionsToTxnManager, times(2)).addTxnData(ArgumentMatchers.eq(node), ArgumentMatchers.eq(transactionToAdd), any[AddPartitionsToTxnManager.AppendCallback]())
appendRecordsToMultipleTopics(replicaManager, Map(tp0 -> transactionalRecords, tp1 -> transactionalRecords), transactionalId, Some(0)).onFire { responses =>
responses.foreach {
entry => assertEquals(Errors.NONE, entry._2)
}
}
} finally {
replicaManager.shutdown(checkpointHW = false)
}
@ -2220,21 +2287,10 @@ class ReplicaManagerTest {
val producerEpoch = 0.toShort
val sequence = 0
val mockLogMgr = TestUtils.createLogManager(config.logDirs.map(new File(_)))
val metadataCache = mock(classOf[MetadataCache])
val node = new Node(0, "host1", 0)
val addPartitionsToTxnManager = mock(classOf[AddPartitionsToTxnManager])
val replicaManager = new ReplicaManager(
metrics = metrics,
config = config,
time = time,
scheduler = new MockScheduler(time),
logManager = mockLogMgr,
quotaManagers = quotaManager,
metadataCache = metadataCache,
logDirFailureChannel = new LogDirFailureChannel(config.logDirs.size),
alterPartitionManager = alterPartitionManager,
addPartitionsToTxnManager = Some(addPartitionsToTxnManager))
val replicaManager = setUpReplicaManagerWithMockedAddPartitionsToTxnManager(addPartitionsToTxnManager, List(tp0, tp1), node)
try {
replicaManager.becomeLeaderOrFollower(1,
@ -2254,13 +2310,49 @@ class ReplicaManagerTest {
assertThrows(classOf[InvalidPidMappingException],
() => appendRecordsToMultipleTopics(replicaManager, transactionalRecords, transactionalId = transactionalId, transactionStatePartition = Some(0)))
// We should not add these partitions to the manager to verify.
verify(addPartitionsToTxnManager, times(0)).addTxnData(any(), any(), any())
} finally {
replicaManager.shutdown(checkpointHW = false)
}
}
@Test
def testDisabledVerification(): Unit = {
def testTransactionVerificationWhenNotLeader(): Unit = {
val tp0 = new TopicPartition(topic, 0)
val producerId = 24L
val producerEpoch = 0.toShort
val sequence = 6
val node = new Node(0, "host1", 0)
val addPartitionsToTxnManager = mock(classOf[AddPartitionsToTxnManager])
val replicaManager = setUpReplicaManagerWithMockedAddPartitionsToTxnManager(addPartitionsToTxnManager, List(tp0), node)
try {
// Append some transactional records.
val transactionalRecords = MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, producerEpoch, sequence,
new SimpleRecord("message".getBytes))
val transactionToAdd = new AddPartitionsToTxnTransaction()
.setTransactionalId(transactionalId)
.setProducerId(producerId)
.setProducerEpoch(producerEpoch)
.setVerifyOnly(true)
.setTopics(new AddPartitionsToTxnTopicCollection(
Seq(new AddPartitionsToTxnTopic().setName(tp0.topic).setPartitions(Collections.singletonList(tp0.partition))).iterator.asJava
))
// We should not add these partitions to the manager to verify, but instead throw an error.
appendRecords(replicaManager, tp0, transactionalRecords, transactionalId = transactionalId, transactionStatePartition = Some(0)).onFire { response =>
assertEquals(Errors.NOT_LEADER_OR_FOLLOWER, response.error)
}
verify(addPartitionsToTxnManager, times(0)).addTxnData(ArgumentMatchers.eq(node), ArgumentMatchers.eq(transactionToAdd), any[AddPartitionsToTxnManager.AppendCallback]())
} finally {
replicaManager.shutdown(checkpointHW = false)
}
}
@Test
def testDisabledTransactionVerification(): Unit = {
val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect)
props.put("transaction.partition.verification.enable", "false")
val config = KafkaConfig.fromProps(props)
@ -2271,36 +2363,21 @@ class ReplicaManagerTest {
val producerEpoch = 0.toShort
val sequence = 0
val mockLogMgr = TestUtils.createLogManager(config.logDirs.map(new File(_)))
val metadataCache = mock(classOf[MetadataCache])
val node = new Node(0, "host1", 0)
val addPartitionsToTxnManager = mock(classOf[AddPartitionsToTxnManager])
val replicaManager = new ReplicaManager(
metrics = metrics,
config = config,
time = time,
scheduler = new MockScheduler(time),
logManager = mockLogMgr,
quotaManagers = quotaManager,
metadataCache = metadataCache,
logDirFailureChannel = new LogDirFailureChannel(config.logDirs.size),
alterPartitionManager = alterPartitionManager,
addPartitionsToTxnManager = Some(addPartitionsToTxnManager))
val replicaManager = setUpReplicaManagerWithMockedAddPartitionsToTxnManager(addPartitionsToTxnManager, List(tp), node, config = config)
try {
val becomeLeaderRequest = makeLeaderAndIsrRequest(topicIds(tp.topic), tp, Seq(0, 1), LeaderAndIsr(0, List(0, 1)))
replicaManager.becomeLeaderOrFollower(1, becomeLeaderRequest, (_, _) => ())
when(metadataCache.contains(tp)).thenReturn(true)
val transactionalRecords = MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, producerEpoch, sequence,
new SimpleRecord(s"message $sequence".getBytes))
appendRecords(replicaManager, tp, transactionalRecords, transactionalId = transactionalId, transactionStatePartition = Some(0))
assertNull(getVerificationGuard(replicaManager, tp, producerId))
// We should not add these partitions to the manager to verify.
verify(metadataCache, times(0)).getTopicMetadata(any(), any(), any(), any())
verify(metadataCache, times(0)).getAliveBrokerNode(any(), any())
verify(metadataCache, times(0)).getAliveBrokerNode(any(), any())
verify(addPartitionsToTxnManager, times(0)).addTxnData(any(), any(), any())
} finally {
replicaManager.shutdown(checkpointHW = false)
@ -2631,9 +2708,11 @@ class ReplicaManagerTest {
transactionalId: String,
transactionStatePartition: Option[Int],
origin: AppendOrigin = AppendOrigin.CLIENT,
requiredAcks: Short = -1): Unit = {
requiredAcks: Short = -1): CallbackResult[Map[TopicPartition, PartitionResponse]] = {
val result = new CallbackResult[Map[TopicPartition, PartitionResponse]]()
def appendCallback(responses: Map[TopicPartition, PartitionResponse]): Unit = {
responses.foreach( response => responses.get(response._1).isDefined)
responses.foreach( response => assertTrue(responses.get(response._1).isDefined))
result.fire(responses)
}
replicaManager.appendRecords(
@ -2645,6 +2724,8 @@ class ReplicaManagerTest {
responseCallback = appendCallback,
transactionalId = transactionalId,
transactionStatePartition = transactionStatePartition)
result
}
private def fetchPartitionAsConsumer(
@ -2763,6 +2844,48 @@ class ReplicaManagerTest {
)
}
private def getVerificationGuard(replicaManager: ReplicaManager,
tp: TopicPartition,
producerId: Long): Object = {
replicaManager.getPartitionOrException(tp).log.get.getOrMaybeCreateVerificationGuard(producerId)
}
private def setUpReplicaManagerWithMockedAddPartitionsToTxnManager(addPartitionsToTxnManager: AddPartitionsToTxnManager,
transactionalTopicPartitions: List[TopicPartition],
node: Node,
config: KafkaConfig = config): ReplicaManager = {
val mockLogMgr = TestUtils.createLogManager(config.logDirs.map(new File(_)))
val metadataCache = mock(classOf[MetadataCache])
val replicaManager = new ReplicaManager(
metrics = metrics,
config = config,
time = time,
scheduler = new MockScheduler(time),
logManager = mockLogMgr,
quotaManagers = quotaManager,
metadataCache = metadataCache,
logDirFailureChannel = new LogDirFailureChannel(config.logDirs.size),
alterPartitionManager = alterPartitionManager,
addPartitionsToTxnManager = Some(addPartitionsToTxnManager))
val metadataResponseTopic = Seq(new MetadataResponseTopic()
.setName(Topic.TRANSACTION_STATE_TOPIC_NAME)
.setPartitions(Seq(
new MetadataResponsePartition()
.setPartitionIndex(0)
.setLeaderId(0)).asJava))
transactionalTopicPartitions.foreach(tp => when(metadataCache.contains(tp)).thenReturn(true))
when(metadataCache.getTopicMetadata(Set(Topic.TRANSACTION_STATE_TOPIC_NAME), config.interBrokerListenerName)).thenReturn(metadataResponseTopic)
when(metadataCache.getAliveBrokerNode(0, config.interBrokerListenerName)).thenReturn(Some(node))
when(metadataCache.getAliveBrokerNode(1, config.interBrokerListenerName)).thenReturn(None)
// We will attempt to schedule to the request handler thread using a non request handler thread. Set this to avoid error.
KafkaRequestHandler.setBypassThreadCheck(true)
replicaManager
}
private def setupReplicaManagerWithMockedPurgatories(
timer: MockTimer,
brokerId: Int = 0,
@ -2796,6 +2919,18 @@ class ReplicaManagerTest {
val mockDelayedElectLeaderPurgatory = new DelayedOperationPurgatory[DelayedElectLeader](
purgatoryName = "DelayedElectLeader", timer, reaperEnabled = false)
// Set up transactions
val metadataResponseTopic = Seq(new MetadataResponseTopic()
.setName(Topic.TRANSACTION_STATE_TOPIC_NAME)
.setPartitions(Seq(
new MetadataResponsePartition()
.setPartitionIndex(0)
.setLeaderId(0)).asJava))
when(metadataCache.contains(new TopicPartition(topic, 0))).thenReturn(true)
when(metadataCache.getTopicMetadata(Set(Topic.TRANSACTION_STATE_TOPIC_NAME), config.interBrokerListenerName)).thenReturn(metadataResponseTopic)
// Transactional appends attempt to schedule to the request handler thread using a non request handler thread. Set this to avoid error.
KafkaRequestHandler.setBypassThreadCheck(true)
new ReplicaManager(
metrics = metrics,
config = config,
@ -2812,7 +2947,8 @@ class ReplicaManagerTest {
delayedDeleteRecordsPurgatoryParam = Some(mockDeleteRecordsPurgatory),
delayedElectLeaderPurgatoryParam = Some(mockDelayedElectLeaderPurgatory),
threadNamePrefix = Option(this.getClass.getName),
remoteLogManager = if (enableRemoteStorage) Some(mockRemoteLogManager) else None) {
remoteLogManager = if (enableRemoteStorage) Some(mockRemoteLogManager) else None,
addPartitionsToTxnManager = Some(addPartitionsToTxnManager)) {
override protected def createReplicaFetcherManager(
metrics: Metrics,

View File

@ -1407,7 +1407,8 @@ object TestUtils extends Logging {
cleanerConfig: CleanerConfig = new CleanerConfig(false),
time: MockTime = new MockTime(),
interBrokerProtocolVersion: MetadataVersion = MetadataVersion.latest,
recoveryThreadsPerDataDir: Int = 4): LogManager = {
recoveryThreadsPerDataDir: Int = 4,
transactionVerificationEnabled: Boolean = false): LogManager = {
new LogManager(logDirs = logDirs.map(_.getAbsoluteFile),
initialOfflineDirs = Array.empty[File],
configRepository = configRepository,
@ -1419,7 +1420,7 @@ object TestUtils extends Logging {
flushStartOffsetCheckpointMs = 10000L,
retentionCheckMs = 1000L,
maxTransactionTimeoutMs = 5 * 60 * 1000,
producerStateManagerConfig = new ProducerStateManagerConfig(kafka.server.Defaults.ProducerIdExpirationMs, false),
producerStateManagerConfig = new ProducerStateManagerConfig(kafka.server.Defaults.ProducerIdExpirationMs, transactionVerificationEnabled),
producerIdExpirationCheckIntervalMs = kafka.server.Defaults.ProducerIdExpirationCheckIntervalMs,
scheduler = time.scheduler,
time = time,

View File

@ -108,7 +108,7 @@ public class CheckpointBench {
JavaConverters.seqAsJavaList(brokerProperties.logDirs()).stream().map(File::new).collect(Collectors.toList());
this.logManager = TestUtils.createLogManager(JavaConverters.asScalaBuffer(files),
new LogConfig(new Properties()), new MockConfigRepository(), new CleanerConfig(1, 4 * 1024 * 1024L, 0.9d,
1024 * 1024, 32 * 1024 * 1024, Double.MAX_VALUE, 15 * 1000, true), time, MetadataVersion.latest(), 4);
1024 * 1024, 32 * 1024 * 1024, Double.MAX_VALUE, 15 * 1000, true), time, MetadataVersion.latest(), 4, false);
scheduler.startup();
final BrokerTopicStats brokerTopicStats = new BrokerTopicStats();
final MetadataCache metadataCache =

View File

@ -118,6 +118,10 @@ public abstract class InterBrokerSendThread extends ShutdownableThread {
// DisconnectException is expected when NetworkClient#initiateClose is called
return;
}
if (t instanceof InterruptedException && !isRunning()) {
// InterruptedException is expected when shutting down. Throw the error to ShutdownableThread to handle
throw t;
}
log.error("unhandled exception caught in InterBrokerSendThread", t);
// rethrow any unhandled exceptions as FatalExitError so the JVM will be terminated
// as we will be in an unknown state with potentially some requests dropped and not

View File

@ -35,6 +35,7 @@ import java.util.Collections;
import java.util.Queue;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import org.apache.kafka.clients.ClientRequest;
import org.apache.kafka.clients.ClientResponse;
import org.apache.kafka.clients.KafkaClient;
@ -47,6 +48,8 @@ import org.apache.kafka.common.protocol.ApiKeys;
import org.apache.kafka.common.requests.AbstractRequest;
import org.apache.kafka.common.utils.Time;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentMatchers;
public class InterBrokerSendThreadTest {
@ -299,6 +302,40 @@ public class InterBrokerSendThreadTest {
assertTrue(completionHandler.executedWithDisconnectedResponse);
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
public void testInterruption(boolean isShuttingDown) throws InterruptedException, IOException {
Exception interrupted = new InterruptedException();
// InterBrokerSendThread#shutdown calls NetworkClient#initiateClose first so NetworkClient#poll
// can throw InterruptedException if a callback request that throws it is handled
when(networkClient.poll(anyLong(), anyLong())).thenAnswer(t -> {
throw interrupted;
});
AtomicReference<Throwable> exception = new AtomicReference<>();
final InterBrokerSendThread thread =
new TestInterBrokerSendThread(networkClient, t -> {
if (isShuttingDown)
assertTrue(t instanceof InterruptedException);
else
assertTrue(t instanceof FatalExitError);
exception.getAndSet(t);
});
if (isShuttingDown)
thread.shutdown();
thread.pollOnce(100);
verify(networkClient).poll(anyLong(), anyLong());
if (isShuttingDown) {
verify(networkClient).initiateClose();
verify(networkClient).close();
}
verifyNoMoreInteractions(networkClient);
assertNotNull(exception.get());
}
private static class StubRequestBuilder<T extends AbstractRequest>
extends AbstractRequest.Builder<T> {

View File

@ -113,6 +113,8 @@ public class ProducerStateManager {
private final Map<Long, ProducerStateEntry> producers = new HashMap<>();
private final Map<Long, VerificationStateEntry> verificationStates = new HashMap<>();
// ongoing transactions sorted by the first offset of the transaction
private final TreeMap<Long, TxnMetadata> ongoingTxns = new TreeMap<>();
@ -184,6 +186,26 @@ public class ProducerStateManager {
producerIdCount = 0;
}
/**
* Maybe create the VerificationStateEntry for a given producer ID. Return it if it exists, otherwise return null.
*/
public VerificationStateEntry verificationStateEntry(long producerId, boolean createIfAbsent) {
return verificationStates.computeIfAbsent(producerId, pid -> {
if (createIfAbsent)
return new VerificationStateEntry(time.milliseconds());
else {
return null;
}
});
}
/**
* Clear the verificationStateEntry for the given producer ID.
*/
public void clearVerificationStateEntry(long producerId) {
verificationStates.remove(producerId);
}
/**
* Load producer state snapshots by scanning the logDir.
*/
@ -338,6 +360,7 @@ public class ProducerStateManager {
/**
* Expire any producer ids which have been idle longer than the configured maximum expiration timeout.
* Also expire any verification state entries that are lingering as unverified.
*/
public void removeExpiredProducers(long currentTimeMs) {
List<Long> keys = producers.entrySet().stream()
@ -345,6 +368,12 @@ public class ProducerStateManager {
.map(Map.Entry::getKey)
.collect(Collectors.toList());
removeProducerIds(keys);
List<Long> verificationKeys = verificationStates.entrySet().stream()
.filter(entry -> currentTimeMs - entry.getValue().timestamp() >= producerStateManagerConfig.producerIdExpirationMs())
.map(Map.Entry::getKey)
.collect(Collectors.toList());
verificationKeys.forEach(verificationStates::remove);
}
/**

View File

@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.storage.internals.log;
/**
* This class represents the verification state of a specific producer id.
* It contains a verification guard object that is used to uniquely identify the transaction we want to verify.
* After verifying, we retain this object until we append to the log. This prevents any race conditions where the transaction
* may end via a control marker before we write to the log. This mechanism is used to prevent hanging transactions.
* We remove the verification guard object whenever we write data to the transaction or write an end marker for the transaction.
* Any lingering entries that are never verified are removed via the producer state entry cleanup mechanism.
*/
public class VerificationStateEntry {
final private long timestamp;
final private Object verificationGuard;
public VerificationStateEntry(long timestamp) {
this.timestamp = timestamp;
this.verificationGuard = new Object();
}
public long timestamp() {
return timestamp;
}
public Object verificationGuard() {
return verificationGuard;
}
}