From b0ca05b751f0bb5ccbe80b108f90bd7186cf2bd7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Armando=20Garc=C3=ADa=20Sancio?= Date: Tue, 25 Feb 2025 20:09:19 -0500 Subject: [PATCH] KAFKA-18723; Better handle invalid records during replication (#18852) For the KRaft implementation there is a race between the network thread, which read bytes in the log segments, and the KRaft driver thread, which truncates the log and appends records to the log. This race can cause the network thread to send corrupted records or inconsistent records. The corrupted records case is handle by catching and logging the CorruptRecordException. The inconsistent records case is handle by only appending record batches who's partition leader epoch is less than or equal to the fetching replica's epoch and the epoch didn't change between the request and response. For the ISR implementation there is also a race between the network thread and the replica fetcher thread, which truncates the log and appends records to the log. This race can cause the network thread send corrupted records or inconsistent records. The replica fetcher thread already handles the corrupted record case. The inconsistent records case is handle by only appending record batches who's partition leader epoch is less than or equal to the leader epoch in the FETCH request. Reviewers: Jun Rao , Alyssa Huang , Chia-Ping Tsai --- build.gradle | 8 + .../common/record/DefaultRecordBatch.java | 3 +- .../kafka/common/record/MemoryRecords.java | 2 +- .../common/record/ArbitraryMemoryRecords.java | 39 ++ .../record/InvalidMemoryRecordsProvider.java | 132 +++++++ .../main/scala/kafka/cluster/Partition.scala | 20 +- .../src/main/scala/kafka/log/UnifiedLog.scala | 137 ++++--- .../scala/kafka/raft/KafkaMetadataLog.scala | 18 +- .../kafka/server/AbstractFetcherThread.scala | 21 +- .../server/ReplicaAlterLogDirsThread.scala | 11 +- .../kafka/server/ReplicaFetcherThread.scala | 11 +- .../kafka/raft/KafkaMetadataLogTest.scala | 95 ++++- .../unit/kafka/cluster/PartitionTest.scala | 108 +++-- .../scala/unit/kafka/log/LogCleanerTest.scala | 22 +- .../unit/kafka/log/LogConcurrencyTest.scala | 11 +- .../scala/unit/kafka/log/LogLoaderTest.scala | 32 +- .../scala/unit/kafka/log/UnifiedLogTest.scala | 369 +++++++++++++----- .../server/AbstractFetcherManagerTest.scala | 9 +- .../server/AbstractFetcherThreadTest.scala | 85 +++- .../unit/kafka/server/MockFetcherThread.scala | 35 +- .../server/ReplicaFetcherThreadTest.scala | 33 +- .../kafka/server/ReplicaManagerTest.scala | 9 +- .../ReplicaFetcherThreadBenchmark.java | 8 +- .../PartitionMakeFollowerBenchmark.java | 2 +- .../apache/kafka/raft/KafkaRaftClient.java | 52 ++- .../org/apache/kafka/raft/ReplicatedLog.java | 9 +- .../kafka/raft/KafkaRaftClientFetchTest.java | 151 +++++++ .../kafka/raft/KafkaRaftClientTest.java | 2 +- .../java/org/apache/kafka/raft/MockLog.java | 46 ++- .../org/apache/kafka/raft/MockLogTest.java | 125 +++++- 30 files changed, 1305 insertions(+), 300 deletions(-) create mode 100644 clients/src/test/java/org/apache/kafka/common/record/ArbitraryMemoryRecords.java create mode 100644 clients/src/test/java/org/apache/kafka/common/record/InvalidMemoryRecordsProvider.java create mode 100644 raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientFetchTest.java diff --git a/build.gradle b/build.gradle index ae6718ab676..6b4725d5429 100644 --- a/build.gradle +++ b/build.gradle @@ -975,6 +975,7 @@ project(':core') { testImplementation project(':server').sourceSets.test.output testImplementation libs.bcpkix testImplementation libs.mockitoCore + testImplementation libs.jqwik testImplementation(libs.apacheda) { exclude group: 'xml-apis', module: 'xml-apis' // `mina-core` is a transitive dependency for `apacheds` and `apacheda`. @@ -1184,6 +1185,12 @@ project(':core') { ) } + test { + useJUnitPlatform { + includeEngines 'jqwik', 'junit-jupiter' + } + } + tasks.create(name: "copyDependantTestLibs", type: Copy) { from (configurations.testRuntimeClasspath) { include('*.jar') @@ -1550,6 +1557,7 @@ project(':clients') { testImplementation libs.jose4j testImplementation libs.junitJupiter testImplementation libs.reload4j + testImplementation libs.jqwik testImplementation libs.mockitoCore testImplementation libs.mockitoJunitJupiter // supports MockitoExtension diff --git a/clients/src/main/java/org/apache/kafka/common/record/DefaultRecordBatch.java b/clients/src/main/java/org/apache/kafka/common/record/DefaultRecordBatch.java index 27629330d7a..6c77aada120 100644 --- a/clients/src/main/java/org/apache/kafka/common/record/DefaultRecordBatch.java +++ b/clients/src/main/java/org/apache/kafka/common/record/DefaultRecordBatch.java @@ -158,7 +158,7 @@ public class DefaultRecordBatch extends AbstractRecordBatch implements MutableRe /** * Gets the base timestamp of the batch which is used to calculate the record timestamps from the deltas. - * + * * @return The base timestamp */ public long baseTimestamp() { @@ -501,6 +501,7 @@ public class DefaultRecordBatch extends AbstractRecordBatch implements MutableRe public String toString() { return "RecordBatch(magic=" + magic() + ", offsets=[" + baseOffset() + ", " + lastOffset() + "], " + "sequence=[" + baseSequence() + ", " + lastSequence() + "], " + + "partitionLeaderEpoch=" + partitionLeaderEpoch() + ", " + "isTransactional=" + isTransactional() + ", isControlBatch=" + isControlBatch() + ", " + "compression=" + compressionType() + ", timestampType=" + timestampType() + ", crc=" + checksum() + ")"; } diff --git a/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java b/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java index c01bca2496e..610b6c8bcc8 100644 --- a/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java +++ b/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java @@ -617,7 +617,7 @@ public class MemoryRecords extends AbstractRecords { return withRecords(magic, initialOffset, compression, TimestampType.CREATE_TIME, records); } - public static MemoryRecords withRecords(long initialOffset, Compression compression, Integer partitionLeaderEpoch, SimpleRecord... records) { + public static MemoryRecords withRecords(long initialOffset, Compression compression, int partitionLeaderEpoch, SimpleRecord... records) { return withRecords(RecordBatch.CURRENT_MAGIC_VALUE, initialOffset, compression, TimestampType.CREATE_TIME, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, RecordBatch.NO_SEQUENCE, partitionLeaderEpoch, false, records); } diff --git a/clients/src/test/java/org/apache/kafka/common/record/ArbitraryMemoryRecords.java b/clients/src/test/java/org/apache/kafka/common/record/ArbitraryMemoryRecords.java new file mode 100644 index 00000000000..30eec866a6c --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/record/ArbitraryMemoryRecords.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import net.jqwik.api.Arbitraries; +import net.jqwik.api.Arbitrary; +import net.jqwik.api.ArbitrarySupplier; + +import java.nio.ByteBuffer; +import java.util.Random; + +public final class ArbitraryMemoryRecords implements ArbitrarySupplier { + @Override + public Arbitrary get() { + return Arbitraries.randomValue(ArbitraryMemoryRecords::buildRandomRecords); + } + + private static MemoryRecords buildRandomRecords(Random random) { + int size = random.nextInt(128) + 1; + byte[] bytes = new byte[size]; + random.nextBytes(bytes); + + return MemoryRecords.readableRecords(ByteBuffer.wrap(bytes)); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/record/InvalidMemoryRecordsProvider.java b/clients/src/test/java/org/apache/kafka/common/record/InvalidMemoryRecordsProvider.java new file mode 100644 index 00000000000..3bf7e822427 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/record/InvalidMemoryRecordsProvider.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.errors.CorruptRecordException; + +import org.junit.jupiter.api.extension.ExtensionContext; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.ArgumentsProvider; + +import java.nio.ByteBuffer; +import java.util.Optional; +import java.util.stream.Stream; + +public final class InvalidMemoryRecordsProvider implements ArgumentsProvider { + // Use a baseOffset that's not zero so that it is less likely to match the LEO + private static final long BASE_OFFSET = 1234; + private static final int EPOCH = 4321; + + /** + * Returns a stream of arguments for invalid memory records and the expected exception. + * + * The first object in the {@code Arguments} is a {@code MemoryRecords}. + * + * The second object in the {@code Arguments} is an {@code Optional>} which is + * the expected exception from the log layer. + */ + @Override + public Stream provideArguments(ExtensionContext context) { + return Stream.of( + Arguments.of(MemoryRecords.readableRecords(notEnoughBytes()), Optional.empty()), + Arguments.of(MemoryRecords.readableRecords(recordsSizeTooSmall()), Optional.of(CorruptRecordException.class)), + Arguments.of(MemoryRecords.readableRecords(notEnoughBytesToMagic()), Optional.empty()), + Arguments.of(MemoryRecords.readableRecords(negativeMagic()), Optional.of(CorruptRecordException.class)), + Arguments.of(MemoryRecords.readableRecords(largeMagic()), Optional.of(CorruptRecordException.class)), + Arguments.of(MemoryRecords.readableRecords(lessBytesThanRecordSize()), Optional.empty()) + ); + } + + private static ByteBuffer notEnoughBytes() { + ByteBuffer buffer = ByteBuffer.allocate(Records.LOG_OVERHEAD - 1); + buffer.limit(buffer.capacity()); + + return buffer; + } + + private static ByteBuffer recordsSizeTooSmall() { + ByteBuffer buffer = ByteBuffer.allocate(256); + // Write the base offset + buffer.putLong(BASE_OFFSET); + // Write record size + buffer.putInt(LegacyRecord.RECORD_OVERHEAD_V0 - 1); + buffer.position(0); + buffer.limit(buffer.capacity()); + + return buffer; + } + + private static ByteBuffer notEnoughBytesToMagic() { + ByteBuffer buffer = ByteBuffer.allocate(256); + // Write the base offset + buffer.putLong(BASE_OFFSET); + // Write record size + buffer.putInt(buffer.capacity() - Records.LOG_OVERHEAD); + buffer.position(0); + buffer.limit(Records.HEADER_SIZE_UP_TO_MAGIC - 1); + + return buffer; + } + + private static ByteBuffer negativeMagic() { + ByteBuffer buffer = ByteBuffer.allocate(256); + // Write the base offset + buffer.putLong(BASE_OFFSET); + // Write record size + buffer.putInt(buffer.capacity() - Records.LOG_OVERHEAD); + // Write the epoch + buffer.putInt(EPOCH); + // Write magic + buffer.put((byte) -1); + buffer.position(0); + buffer.limit(buffer.capacity()); + + return buffer; + } + + private static ByteBuffer largeMagic() { + ByteBuffer buffer = ByteBuffer.allocate(256); + // Write the base offset + buffer.putLong(BASE_OFFSET); + // Write record size + buffer.putInt(buffer.capacity() - Records.LOG_OVERHEAD); + // Write the epoch + buffer.putInt(EPOCH); + // Write magic + buffer.put((byte) (RecordBatch.CURRENT_MAGIC_VALUE + 1)); + buffer.position(0); + buffer.limit(buffer.capacity()); + + return buffer; + } + + private static ByteBuffer lessBytesThanRecordSize() { + ByteBuffer buffer = ByteBuffer.allocate(256); + // Write the base offset + buffer.putLong(BASE_OFFSET); + // Write record size + buffer.putInt(buffer.capacity() - Records.LOG_OVERHEAD); + // Write the epoch + buffer.putInt(EPOCH); + // Write magic + buffer.put(RecordBatch.CURRENT_MAGIC_VALUE); + buffer.position(0); + buffer.limit(buffer.capacity() - Records.LOG_OVERHEAD - 1); + + return buffer; + } +} diff --git a/core/src/main/scala/kafka/cluster/Partition.scala b/core/src/main/scala/kafka/cluster/Partition.scala index 454eec0d948..e063a2f2d61 100755 --- a/core/src/main/scala/kafka/cluster/Partition.scala +++ b/core/src/main/scala/kafka/cluster/Partition.scala @@ -1318,27 +1318,35 @@ class Partition(val topicPartition: TopicPartition, } } - private def doAppendRecordsToFollowerOrFutureReplica(records: MemoryRecords, isFuture: Boolean): Option[LogAppendInfo] = { + private def doAppendRecordsToFollowerOrFutureReplica( + records: MemoryRecords, + isFuture: Boolean, + partitionLeaderEpoch: Int + ): Option[LogAppendInfo] = { if (isFuture) { // The read lock is needed to handle race condition if request handler thread tries to // remove future replica after receiving AlterReplicaLogDirsRequest. inReadLock(leaderIsrUpdateLock) { // Note the replica may be undefined if it is removed by a non-ReplicaAlterLogDirsThread before // this method is called - futureLog.map { _.appendAsFollower(records) } + futureLog.map { _.appendAsFollower(records, partitionLeaderEpoch) } } } else { // The lock is needed to prevent the follower replica from being updated while ReplicaAlterDirThread // is executing maybeReplaceCurrentWithFutureReplica() to replace follower replica with the future replica. futureLogLock.synchronized { - Some(localLogOrException.appendAsFollower(records)) + Some(localLogOrException.appendAsFollower(records, partitionLeaderEpoch)) } } } - def appendRecordsToFollowerOrFutureReplica(records: MemoryRecords, isFuture: Boolean): Option[LogAppendInfo] = { + def appendRecordsToFollowerOrFutureReplica( + records: MemoryRecords, + isFuture: Boolean, + partitionLeaderEpoch: Int + ): Option[LogAppendInfo] = { try { - doAppendRecordsToFollowerOrFutureReplica(records, isFuture) + doAppendRecordsToFollowerOrFutureReplica(records, isFuture, partitionLeaderEpoch) } catch { case e: UnexpectedAppendOffsetException => val log = if (isFuture) futureLocalLogOrException else localLogOrException @@ -1356,7 +1364,7 @@ class Partition(val topicPartition: TopicPartition, info(s"Unexpected offset in append to $topicPartition. First offset ${e.firstOffset} is less than log start offset ${log.logStartOffset}." + s" Since this is the first record to be appended to the $replicaName's log, will start the log from offset ${e.firstOffset}.") truncateFullyAndStartAt(e.firstOffset, isFuture) - doAppendRecordsToFollowerOrFutureReplica(records, isFuture) + doAppendRecordsToFollowerOrFutureReplica(records, isFuture, partitionLeaderEpoch) } else throw e } diff --git a/core/src/main/scala/kafka/log/UnifiedLog.scala b/core/src/main/scala/kafka/log/UnifiedLog.scala index f435bca52ee..d648fff6194 100644 --- a/core/src/main/scala/kafka/log/UnifiedLog.scala +++ b/core/src/main/scala/kafka/log/UnifiedLog.scala @@ -710,6 +710,7 @@ class UnifiedLog(@volatile var logStartOffset: Long, * Append this message set to the active segment of the local log, assigning offsets and Partition Leader Epochs * * @param records The records to append + * @param leaderEpoch the epoch of the replica appending * @param origin Declares the origin of the append which affects required validations * @param interBrokerProtocolVersion Inter-broker message protocol version * @param requestLocal request local instance @@ -730,15 +731,16 @@ class UnifiedLog(@volatile var logStartOffset: Long, * Append this message set to the active segment of the local log without assigning offsets or Partition Leader Epochs * * @param records The records to append + * @param leaderEpoch the epoch of the replica appending * @throws KafkaStorageException If the append fails due to an I/O error. * @return Information about the appended messages including the first and last offset. */ - def appendAsFollower(records: MemoryRecords): LogAppendInfo = { + def appendAsFollower(records: MemoryRecords, leaderEpoch: Int): LogAppendInfo = { append(records, origin = AppendOrigin.REPLICATION, interBrokerProtocolVersion = MetadataVersion.latestProduction, validateAndAssignOffsets = false, - leaderEpoch = -1, + leaderEpoch = leaderEpoch, requestLocal = None, verificationGuard = VerificationGuard.SENTINEL, // disable to check the validation of record size since the record is already accepted by leader. @@ -1124,63 +1126,85 @@ class UnifiedLog(@volatile var logStartOffset: Long, var shallowOffsetOfMaxTimestamp = -1L var readFirstMessage = false var lastOffsetOfFirstBatch = -1L + var skipRemainingBatches = false records.batches.forEach { batch => if (origin == AppendOrigin.RAFT_LEADER && batch.partitionLeaderEpoch != leaderEpoch) { - throw new InvalidRecordException("Append from Raft leader did not set the batch epoch correctly") + throw new InvalidRecordException( + s"Append from Raft leader did not set the batch epoch correctly, expected $leaderEpoch " + + s"but the batch has ${batch.partitionLeaderEpoch}" + ) } // we only validate V2 and higher to avoid potential compatibility issues with older clients - if (batch.magic >= RecordBatch.MAGIC_VALUE_V2 && origin == AppendOrigin.CLIENT && batch.baseOffset != 0) + if (batch.magic >= RecordBatch.MAGIC_VALUE_V2 && origin == AppendOrigin.CLIENT && batch.baseOffset != 0) { throw new InvalidRecordException(s"The baseOffset of the record batch in the append to $topicPartition should " + s"be 0, but it is ${batch.baseOffset}") - - // update the first offset if on the first message. For magic versions older than 2, we use the last offset - // to avoid the need to decompress the data (the last offset can be obtained directly from the wrapper message). - // For magic version 2, we can get the first offset directly from the batch header. - // When appending to the leader, we will update LogAppendInfo.baseOffset with the correct value. In the follower - // case, validation will be more lenient. - // Also indicate whether we have the accurate first offset or not - if (!readFirstMessage) { - if (batch.magic >= RecordBatch.MAGIC_VALUE_V2) - firstOffset = batch.baseOffset - lastOffsetOfFirstBatch = batch.lastOffset - readFirstMessage = true } - // check that offsets are monotonically increasing - if (lastOffset >= batch.lastOffset) - monotonic = false + /* During replication of uncommitted data it is possible for the remote replica to send record batches after it lost + * leadership. This can happen if sending FETCH responses is slow. There is a race between sending the FETCH + * response and the replica truncating and appending to the log. The replicating replica resolves this issue by only + * persisting up to the current leader epoch used in the fetch request. See KAFKA-18723 for more details. + */ + skipRemainingBatches = skipRemainingBatches || hasHigherPartitionLeaderEpoch(batch, origin, leaderEpoch) + if (skipRemainingBatches) { + info( + s"Skipping batch $batch from an origin of $origin because its partition leader epoch " + + s"${batch.partitionLeaderEpoch} is higher than the replica's current leader epoch " + + s"$leaderEpoch" + ) + } else { + // update the first offset if on the first message. For magic versions older than 2, we use the last offset + // to avoid the need to decompress the data (the last offset can be obtained directly from the wrapper message). + // For magic version 2, we can get the first offset directly from the batch header. + // When appending to the leader, we will update LogAppendInfo.baseOffset with the correct value. In the follower + // case, validation will be more lenient. + // Also indicate whether we have the accurate first offset or not + if (!readFirstMessage) { + if (batch.magic >= RecordBatch.MAGIC_VALUE_V2) { + firstOffset = batch.baseOffset + } + lastOffsetOfFirstBatch = batch.lastOffset + readFirstMessage = true + } - // update the last offset seen - lastOffset = batch.lastOffset - lastLeaderEpoch = batch.partitionLeaderEpoch + // check that offsets are monotonically increasing + if (lastOffset >= batch.lastOffset) { + monotonic = false + } - // Check if the message sizes are valid. - val batchSize = batch.sizeInBytes - if (!ignoreRecordSize && batchSize > config.maxMessageSize) { - brokerTopicStats.topicStats(topicPartition.topic).bytesRejectedRate.mark(records.sizeInBytes) - brokerTopicStats.allTopicsStats.bytesRejectedRate.mark(records.sizeInBytes) - throw new RecordTooLargeException(s"The record batch size in the append to $topicPartition is $batchSize bytes " + - s"which exceeds the maximum configured value of ${config.maxMessageSize}.") + // update the last offset seen + lastOffset = batch.lastOffset + lastLeaderEpoch = batch.partitionLeaderEpoch + + // Check if the message sizes are valid. + val batchSize = batch.sizeInBytes + if (!ignoreRecordSize && batchSize > config.maxMessageSize) { + brokerTopicStats.topicStats(topicPartition.topic).bytesRejectedRate.mark(records.sizeInBytes) + brokerTopicStats.allTopicsStats.bytesRejectedRate.mark(records.sizeInBytes) + throw new RecordTooLargeException(s"The record batch size in the append to $topicPartition is $batchSize bytes " + + s"which exceeds the maximum configured value of ${config.maxMessageSize}.") + } + + // check the validity of the message by checking CRC + if (!batch.isValid) { + brokerTopicStats.allTopicsStats.invalidMessageCrcRecordsPerSec.mark() + throw new CorruptRecordException(s"Record is corrupt (stored crc = ${batch.checksum()}) in topic partition $topicPartition.") + } + + if (batch.maxTimestamp > maxTimestamp) { + maxTimestamp = batch.maxTimestamp + shallowOffsetOfMaxTimestamp = lastOffset + } + + validBytesCount += batchSize + + val batchCompression = CompressionType.forId(batch.compressionType.id) + // sourceCompression is only used on the leader path, which only contains one batch if version is v2 or messages are compressed + if (batchCompression != CompressionType.NONE) { + sourceCompression = batchCompression + } } - - // check the validity of the message by checking CRC - if (!batch.isValid) { - brokerTopicStats.allTopicsStats.invalidMessageCrcRecordsPerSec.mark() - throw new CorruptRecordException(s"Record is corrupt (stored crc = ${batch.checksum()}) in topic partition $topicPartition.") - } - - if (batch.maxTimestamp > maxTimestamp) { - maxTimestamp = batch.maxTimestamp - shallowOffsetOfMaxTimestamp = lastOffset - } - - validBytesCount += batchSize - - val batchCompression = CompressionType.forId(batch.compressionType.id) - // sourceCompression is only used on the leader path, which only contains one batch if version is v2 or messages are compressed - if (batchCompression != CompressionType.NONE) - sourceCompression = batchCompression } if (requireOffsetsMonotonic && !monotonic) @@ -1197,6 +1221,25 @@ class UnifiedLog(@volatile var logStartOffset: Long, validBytesCount, lastOffsetOfFirstBatch, Collections.emptyList[RecordError], LeaderHwChange.NONE) } + /** + * Return true if the record batch has a higher leader epoch than the specified leader epoch + * + * @param batch the batch to validate + * @param origin the reason for appending the record batch + * @param leaderEpoch the epoch to compare + * @return true if the append reason is replication and the batch's partition leader epoch is + * greater than the specified leaderEpoch, otherwise false + */ + private def hasHigherPartitionLeaderEpoch( + batch: RecordBatch, + origin: AppendOrigin, + leaderEpoch: Int + ): Boolean = { + origin == AppendOrigin.REPLICATION && + batch.partitionLeaderEpoch() != RecordBatch.NO_PARTITION_LEADER_EPOCH && + batch.partitionLeaderEpoch() > leaderEpoch + } + /** * Trim any invalid bytes from the end of this message set (if there are any) * diff --git a/core/src/main/scala/kafka/raft/KafkaMetadataLog.scala b/core/src/main/scala/kafka/raft/KafkaMetadataLog.scala index 5e107aa1487..e4ffbb69b14 100644 --- a/core/src/main/scala/kafka/raft/KafkaMetadataLog.scala +++ b/core/src/main/scala/kafka/raft/KafkaMetadataLog.scala @@ -26,6 +26,7 @@ import kafka.server.{BrokerTopicStats, RequestLocal} import kafka.utils.{CoreUtils, Logging} import org.apache.kafka.common.config.TopicConfig import org.apache.kafka.common.errors.InvalidConfigurationException +import org.apache.kafka.common.errors.CorruptRecordException import org.apache.kafka.common.record.{MemoryRecords, Records} import org.apache.kafka.common.utils.Time import org.apache.kafka.common.{KafkaException, TopicPartition, Uuid} @@ -87,8 +88,9 @@ final class KafkaMetadataLog private ( } override def appendAsLeader(records: Records, epoch: Int): LogAppendInfo = { - if (records.sizeInBytes == 0) + if (records.sizeInBytes == 0) { throw new IllegalArgumentException("Attempt to append an empty record set") + } handleAndConvertLogAppendInfo( log.appendAsLeader(records.asInstanceOf[MemoryRecords], @@ -99,18 +101,20 @@ final class KafkaMetadataLog private ( ) } - override def appendAsFollower(records: Records): LogAppendInfo = { - if (records.sizeInBytes == 0) + override def appendAsFollower(records: Records, epoch: Int): LogAppendInfo = { + if (records.sizeInBytes == 0) { throw new IllegalArgumentException("Attempt to append an empty record set") + } - handleAndConvertLogAppendInfo(log.appendAsFollower(records.asInstanceOf[MemoryRecords])) + handleAndConvertLogAppendInfo(log.appendAsFollower(records.asInstanceOf[MemoryRecords], epoch)) } private def handleAndConvertLogAppendInfo(appendInfo: internals.log.LogAppendInfo): LogAppendInfo = { - if (appendInfo.firstOffset != UnifiedLog.UnknownOffset) + if (appendInfo.firstOffset == UnifiedLog.UnknownOffset) { + throw new CorruptRecordException(s"Append failed unexpectedly $appendInfo") + } else { new LogAppendInfo(appendInfo.firstOffset, appendInfo.lastOffset) - else - throw new KafkaException(s"Append failed unexpectedly") + } } override def lastFetchedEpoch: Int = { diff --git a/core/src/main/scala/kafka/server/AbstractFetcherThread.scala b/core/src/main/scala/kafka/server/AbstractFetcherThread.scala index 989030f4439..cf0e0a582f5 100755 --- a/core/src/main/scala/kafka/server/AbstractFetcherThread.scala +++ b/core/src/main/scala/kafka/server/AbstractFetcherThread.scala @@ -78,9 +78,12 @@ abstract class AbstractFetcherThread(name: String, /* callbacks to be defined in subclass */ // process fetched data - protected def processPartitionData(topicPartition: TopicPartition, - fetchOffset: Long, - partitionData: FetchData): Option[LogAppendInfo] + protected def processPartitionData( + topicPartition: TopicPartition, + fetchOffset: Long, + partitionLeaderEpoch: Int, + partitionData: FetchData + ): Option[LogAppendInfo] protected def truncate(topicPartition: TopicPartition, truncationState: OffsetTruncationState): Unit @@ -335,7 +338,9 @@ abstract class AbstractFetcherThread(name: String, // In this case, we only want to process the fetch response if the partition state is ready for fetch and // the current offset is the same as the offset requested. val fetchPartitionData = sessionPartitions.get(topicPartition) - if (fetchPartitionData != null && fetchPartitionData.fetchOffset == currentFetchState.fetchOffset && currentFetchState.isReadyForFetch) { + if (fetchPartitionData != null && + fetchPartitionData.fetchOffset == currentFetchState.fetchOffset && + currentFetchState.isReadyForFetch) { Errors.forCode(partitionData.errorCode) match { case Errors.NONE => try { @@ -350,10 +355,16 @@ abstract class AbstractFetcherThread(name: String, .setLeaderEpoch(partitionData.divergingEpoch.epoch) .setEndOffset(partitionData.divergingEpoch.endOffset) } else { - // Once we hand off the partition data to the subclass, we can't mess with it any more in this thread + /* Once we hand off the partition data to the subclass, we can't mess with it any more in this thread + * + * When appending batches to the log only append record batches up to the leader epoch when the FETCH + * request was handled. This is done to make sure that logs are not inconsistent because of log + * truncation and append after the FETCH request was handled. See KAFKA-18723 for more details. + */ val logAppendInfoOpt = processPartitionData( topicPartition, currentFetchState.fetchOffset, + fetchPartitionData.currentLeaderEpoch.orElse(currentFetchState.currentLeaderEpoch), partitionData ) diff --git a/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala b/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala index 95c7a5ac3d4..4f838bae315 100644 --- a/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala +++ b/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala @@ -65,9 +65,12 @@ class ReplicaAlterLogDirsThread(name: String, } // process fetched data - override def processPartitionData(topicPartition: TopicPartition, - fetchOffset: Long, - partitionData: FetchData): Option[LogAppendInfo] = { + override def processPartitionData( + topicPartition: TopicPartition, + fetchOffset: Long, + partitionLeaderEpoch: Int, + partitionData: FetchData + ): Option[LogAppendInfo] = { val partition = replicaMgr.getPartitionOrException(topicPartition) val futureLog = partition.futureLocalLogOrException val records = toMemoryRecords(FetchResponse.recordsOrFail(partitionData)) @@ -77,7 +80,7 @@ class ReplicaAlterLogDirsThread(name: String, topicPartition, fetchOffset, futureLog.logEndOffset)) val logAppendInfo = if (records.sizeInBytes() > 0) - partition.appendRecordsToFollowerOrFutureReplica(records, isFuture = true) + partition.appendRecordsToFollowerOrFutureReplica(records, isFuture = true, partitionLeaderEpoch) else None diff --git a/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala b/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala index bb073682bdf..da9e37f7d91 100644 --- a/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala +++ b/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala @@ -100,9 +100,12 @@ class ReplicaFetcherThread(name: String, } // process fetched data - override def processPartitionData(topicPartition: TopicPartition, - fetchOffset: Long, - partitionData: FetchData): Option[LogAppendInfo] = { + override def processPartitionData( + topicPartition: TopicPartition, + fetchOffset: Long, + partitionLeaderEpoch: Int, + partitionData: FetchData + ): Option[LogAppendInfo] = { val logTrace = isTraceEnabled val partition = replicaMgr.getPartitionOrException(topicPartition) val log = partition.localLogOrException @@ -119,7 +122,7 @@ class ReplicaFetcherThread(name: String, .format(log.logEndOffset, topicPartition, records.sizeInBytes, partitionData.highWatermark)) // Append the leader's messages to the log - val logAppendInfo = partition.appendRecordsToFollowerOrFutureReplica(records, isFuture = false) + val logAppendInfo = partition.appendRecordsToFollowerOrFutureReplica(records, isFuture = false, partitionLeaderEpoch) if (logTrace) trace("Follower has replica log end offset %d after appending %d bytes of messages for partition %s" diff --git a/core/src/test/scala/kafka/raft/KafkaMetadataLogTest.scala b/core/src/test/scala/kafka/raft/KafkaMetadataLogTest.scala index 1c65fd5073c..8271aa6870b 100644 --- a/core/src/test/scala/kafka/raft/KafkaMetadataLogTest.scala +++ b/core/src/test/scala/kafka/raft/KafkaMetadataLogTest.scala @@ -20,9 +20,12 @@ import kafka.log.UnifiedLog import kafka.server.{KafkaConfig, KafkaRaftServer} import kafka.utils.TestUtils import org.apache.kafka.common.compress.Compression +import org.apache.kafka.common.errors.CorruptRecordException import org.apache.kafka.common.errors.{InvalidConfigurationException, RecordTooLargeException} import org.apache.kafka.common.protocol import org.apache.kafka.common.protocol.{ObjectSerializationCache, Writable} +import org.apache.kafka.common.record.ArbitraryMemoryRecords +import org.apache.kafka.common.record.InvalidMemoryRecordsProvider import org.apache.kafka.common.record.{MemoryRecords, SimpleRecord} import org.apache.kafka.common.utils.Utils import org.apache.kafka.raft._ @@ -34,7 +37,14 @@ import org.apache.kafka.snapshot.{FileRawSnapshotWriter, RawSnapshotReader, RawS import org.apache.kafka.storage.internals.log.{LogConfig, LogStartOffsetIncrementReason} import org.apache.kafka.test.TestUtils.assertOptional import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.function.Executable import org.junit.jupiter.api.{AfterEach, BeforeEach, Test} +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ArgumentsSource + +import net.jqwik.api.AfterFailureMode +import net.jqwik.api.ForAll +import net.jqwik.api.Property import java.io.File import java.nio.ByteBuffer @@ -109,12 +119,93 @@ final class KafkaMetadataLogTest { classOf[RuntimeException], () => { log.appendAsFollower( - MemoryRecords.withRecords(initialOffset, Compression.NONE, currentEpoch, recordFoo) + MemoryRecords.withRecords(initialOffset, Compression.NONE, currentEpoch, recordFoo), + currentEpoch ) } ) } + @Test + def testEmptyAppendNotAllowed(): Unit = { + val log = buildMetadataLog(tempDir, mockTime) + + assertThrows(classOf[IllegalArgumentException], () => log.appendAsFollower(MemoryRecords.EMPTY, 1)); + assertThrows(classOf[IllegalArgumentException], () => log.appendAsLeader(MemoryRecords.EMPTY, 1)); + } + + @ParameterizedTest + @ArgumentsSource(classOf[InvalidMemoryRecordsProvider]) + def testInvalidMemoryRecords(records: MemoryRecords, expectedException: Optional[Class[Exception]]): Unit = { + val log = buildMetadataLog(tempDir, mockTime) + val previousEndOffset = log.endOffset().offset() + + val action: Executable = () => log.appendAsFollower(records, Int.MaxValue) + if (expectedException.isPresent()) { + assertThrows(expectedException.get, action) + } else { + assertThrows(classOf[CorruptRecordException], action) + } + + assertEquals(previousEndOffset, log.endOffset().offset()) + } + + @Property(tries = 100, afterFailure = AfterFailureMode.SAMPLE_ONLY) + def testRandomRecords( + @ForAll(supplier = classOf[ArbitraryMemoryRecords]) records: MemoryRecords + ): Unit = { + val tempDir = TestUtils.tempDir() + try { + val log = buildMetadataLog(tempDir, mockTime) + val previousEndOffset = log.endOffset().offset() + + assertThrows( + classOf[CorruptRecordException], + () => log.appendAsFollower(records, Int.MaxValue) + ) + + assertEquals(previousEndOffset, log.endOffset().offset()) + } finally { + Utils.delete(tempDir) + } + } + + @Test + def testInvalidLeaderEpoch(): Unit = { + val log = buildMetadataLog(tempDir, mockTime) + val previousEndOffset = log.endOffset().offset() + val epoch = log.lastFetchedEpoch() + 1 + val numberOfRecords = 10 + + val batchWithValidEpoch = MemoryRecords.withRecords( + previousEndOffset, + Compression.NONE, + epoch, + (0 until numberOfRecords).map(number => new SimpleRecord(number.toString.getBytes)): _* + ) + + val batchWithInvalidEpoch = MemoryRecords.withRecords( + previousEndOffset + numberOfRecords, + Compression.NONE, + epoch + 1, + (0 until numberOfRecords).map(number => new SimpleRecord(number.toString.getBytes)): _* + ) + + val buffer = ByteBuffer.allocate(batchWithValidEpoch.sizeInBytes() + batchWithInvalidEpoch.sizeInBytes()) + buffer.put(batchWithValidEpoch.buffer()) + buffer.put(batchWithInvalidEpoch.buffer()) + buffer.flip() + + val records = MemoryRecords.readableRecords(buffer) + + log.appendAsFollower(records, epoch) + + // Check that only the first batch was appended + assertEquals(previousEndOffset + numberOfRecords, log.endOffset().offset()) + // Check that the last fetched epoch matches the first batch + assertEquals(epoch, log.lastFetchedEpoch()) + } + @Test def testCreateSnapshot(): Unit = { val numberOfRecords = 10 @@ -1051,4 +1142,4 @@ object KafkaMetadataLogTest { } dir } -} \ No newline at end of file +} diff --git a/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala b/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala index d0a630b4431..c8b4ab8de09 100644 --- a/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala +++ b/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala @@ -416,6 +416,7 @@ class PartitionTest extends AbstractPartitionTest { def testMakeFollowerWithWithFollowerAppendRecords(): Unit = { val appendSemaphore = new Semaphore(0) val mockTime = new MockTime() + val prevLeaderEpoch = 0 partition = new Partition( topicPartition, @@ -467,24 +468,38 @@ class PartitionTest extends AbstractPartitionTest { } partition.createLogIfNotExists(isNew = true, isFutureReplica = false, offsetCheckpoints, None) + var partitionState = new LeaderAndIsrPartitionState() + .setControllerEpoch(0) + .setLeader(2) + .setLeaderEpoch(prevLeaderEpoch) + .setIsr(List[Integer](0, 1, 2, brokerId).asJava) + .setPartitionEpoch(1) + .setReplicas(List[Integer](0, 1, 2, brokerId).asJava) + .setIsNew(false) + assertTrue(partition.makeFollower(partitionState, offsetCheckpoints, None)) val appendThread = new Thread { override def run(): Unit = { - val records = createRecords(List(new SimpleRecord("k1".getBytes, "v1".getBytes), - new SimpleRecord("k2".getBytes, "v2".getBytes)), - baseOffset = 0) - partition.appendRecordsToFollowerOrFutureReplica(records, isFuture = false) + val records = createRecords( + List( + new SimpleRecord("k1".getBytes, "v1".getBytes), + new SimpleRecord("k2".getBytes, "v2".getBytes) + ), + baseOffset = 0, + partitionLeaderEpoch = prevLeaderEpoch + ) + partition.appendRecordsToFollowerOrFutureReplica(records, isFuture = false, prevLeaderEpoch) } } appendThread.start() TestUtils.waitUntilTrue(() => appendSemaphore.hasQueuedThreads, "follower log append is not called.") - val partitionState = new LeaderAndIsrPartitionState() + partitionState = new LeaderAndIsrPartitionState() .setControllerEpoch(0) .setLeader(2) - .setLeaderEpoch(1) + .setLeaderEpoch(prevLeaderEpoch + 1) .setIsr(List[Integer](0, 1, 2, brokerId).asJava) - .setPartitionEpoch(1) + .setPartitionEpoch(2) .setReplicas(List[Integer](0, 1, 2, brokerId).asJava) .setIsNew(false) assertTrue(partition.makeFollower(partitionState, offsetCheckpoints, None)) @@ -524,15 +539,22 @@ class PartitionTest extends AbstractPartitionTest { // Write to the future replica as if the log had been compacted, and do not roll the segment val buffer = ByteBuffer.allocate(1024) - val builder = MemoryRecords.builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, Compression.NONE, - TimestampType.CREATE_TIME, 0L, RecordBatch.NO_TIMESTAMP, 0) + val builder = MemoryRecords.builder( + buffer, + RecordBatch.CURRENT_MAGIC_VALUE, + Compression.NONE, + TimestampType.CREATE_TIME, + 0L, // baseOffset + RecordBatch.NO_TIMESTAMP, + 0 // partitionLeaderEpoch + ) builder.appendWithOffset(2L, new SimpleRecord("k1".getBytes, "v3".getBytes)) builder.appendWithOffset(5L, new SimpleRecord("k2".getBytes, "v6".getBytes)) builder.appendWithOffset(6L, new SimpleRecord("k3".getBytes, "v7".getBytes)) builder.appendWithOffset(7L, new SimpleRecord("k4".getBytes, "v8".getBytes)) val futureLog = partition.futureLocalLogOrException - futureLog.appendAsFollower(builder.build()) + futureLog.appendAsFollower(builder.build(), 0) assertTrue(partition.maybeReplaceCurrentWithFutureReplica()) } @@ -934,6 +956,18 @@ class PartitionTest extends AbstractPartitionTest { def testAppendRecordsAsFollowerBelowLogStartOffset(): Unit = { partition.createLogIfNotExists(isNew = false, isFutureReplica = false, offsetCheckpoints, None) val log = partition.localLogOrException + val epoch = 1 + + // Start off as follower + val partitionState = new LeaderAndIsrPartitionState() + .setControllerEpoch(0) + .setLeader(1) + .setLeaderEpoch(epoch) + .setIsr(List[Integer](0, 1, 2, brokerId).asJava) + .setPartitionEpoch(1) + .setReplicas(List[Integer](0, 1, 2, brokerId).asJava) + .setIsNew(false) + partition.makeFollower(partitionState, offsetCheckpoints, None) val initialLogStartOffset = 5L partition.truncateFullyAndStartAt(initialLogStartOffset, isFuture = false) @@ -943,9 +977,14 @@ class PartitionTest extends AbstractPartitionTest { s"Log start offset after truncate fully and start at $initialLogStartOffset:") // verify that we cannot append records that do not contain log start offset even if the log is empty - assertThrows(classOf[UnexpectedAppendOffsetException], () => + assertThrows( + classOf[UnexpectedAppendOffsetException], // append one record with offset = 3 - partition.appendRecordsToFollowerOrFutureReplica(createRecords(List(new SimpleRecord("k1".getBytes, "v1".getBytes)), baseOffset = 3L), isFuture = false) + () => partition.appendRecordsToFollowerOrFutureReplica( + createRecords(List(new SimpleRecord("k1".getBytes, "v1".getBytes)), baseOffset = 3L), + isFuture = false, + partitionLeaderEpoch = epoch + ) ) assertEquals(initialLogStartOffset, log.logEndOffset, s"Log end offset should not change after failure to append") @@ -957,12 +996,16 @@ class PartitionTest extends AbstractPartitionTest { new SimpleRecord("k2".getBytes, "v2".getBytes), new SimpleRecord("k3".getBytes, "v3".getBytes)), baseOffset = newLogStartOffset) - partition.appendRecordsToFollowerOrFutureReplica(records, isFuture = false) + partition.appendRecordsToFollowerOrFutureReplica(records, isFuture = false, partitionLeaderEpoch = epoch) assertEquals(7L, log.logEndOffset, s"Log end offset after append of 3 records with base offset $newLogStartOffset:") assertEquals(newLogStartOffset, log.logStartOffset, s"Log start offset after append of 3 records with base offset $newLogStartOffset:") // and we can append more records after that - partition.appendRecordsToFollowerOrFutureReplica(createRecords(List(new SimpleRecord("k1".getBytes, "v1".getBytes)), baseOffset = 7L), isFuture = false) + partition.appendRecordsToFollowerOrFutureReplica( + createRecords(List(new SimpleRecord("k1".getBytes, "v1".getBytes)), baseOffset = 7L), + isFuture = false, + partitionLeaderEpoch = epoch + ) assertEquals(8L, log.logEndOffset, s"Log end offset after append of 1 record at offset 7:") assertEquals(newLogStartOffset, log.logStartOffset, s"Log start offset not expected to change:") @@ -970,11 +1013,18 @@ class PartitionTest extends AbstractPartitionTest { val records2 = createRecords(List(new SimpleRecord("k1".getBytes, "v1".getBytes), new SimpleRecord("k2".getBytes, "v2".getBytes)), baseOffset = 3L) - assertThrows(classOf[UnexpectedAppendOffsetException], () => partition.appendRecordsToFollowerOrFutureReplica(records2, isFuture = false)) + assertThrows( + classOf[UnexpectedAppendOffsetException], + () => partition.appendRecordsToFollowerOrFutureReplica(records2, isFuture = false, partitionLeaderEpoch = epoch) + ) assertEquals(8L, log.logEndOffset, s"Log end offset should not change after failure to append") // we still can append to next offset - partition.appendRecordsToFollowerOrFutureReplica(createRecords(List(new SimpleRecord("k1".getBytes, "v1".getBytes)), baseOffset = 8L), isFuture = false) + partition.appendRecordsToFollowerOrFutureReplica( + createRecords(List(new SimpleRecord("k1".getBytes, "v1".getBytes)), baseOffset = 8L), + isFuture = false, + partitionLeaderEpoch = epoch + ) assertEquals(9L, log.logEndOffset, s"Log end offset after append of 1 record at offset 8:") assertEquals(newLogStartOffset, log.logStartOffset, s"Log start offset not expected to change:") } @@ -1057,9 +1107,13 @@ class PartitionTest extends AbstractPartitionTest { @Test def testAppendRecordsToFollowerWithNoReplicaThrowsException(): Unit = { - assertThrows(classOf[NotLeaderOrFollowerException], () => - partition.appendRecordsToFollowerOrFutureReplica( - createRecords(List(new SimpleRecord("k1".getBytes, "v1".getBytes)), baseOffset = 0L), isFuture = false) + assertThrows( + classOf[NotLeaderOrFollowerException], + () => partition.appendRecordsToFollowerOrFutureReplica( + createRecords(List(new SimpleRecord("k1".getBytes, "v1".getBytes)), baseOffset = 0L), + isFuture = false, + partitionLeaderEpoch = 0 + ) ) } @@ -3514,13 +3568,16 @@ class PartitionTest extends AbstractPartitionTest { partition.createLogIfNotExists(isNew = true, isFutureReplica = false, offsetCheckpoints, topicId = None) assertTrue(partition.log.isDefined) + val replicas = Seq(brokerId, brokerId + 1) + val isr = replicas + val epoch = 0 partition.makeLeader( new LeaderAndIsrPartitionState() .setControllerEpoch(0) .setLeader(brokerId) - .setLeaderEpoch(0) - .setIsr(List(brokerId, brokerId + 1).map(Int.box).asJava) - .setReplicas(List(brokerId, brokerId + 1).map(Int.box).asJava) + .setLeaderEpoch(epoch) + .setIsr(isr.map(Int.box).asJava) + .setReplicas(replicas.map(Int.box).asJava) .setPartitionEpoch(1) .setIsNew(true), offsetCheckpoints, @@ -3551,7 +3608,8 @@ class PartitionTest extends AbstractPartitionTest { partition.appendRecordsToFollowerOrFutureReplica( records = records, - isFuture = true + isFuture = true, + partitionLeaderEpoch = epoch ) listener.verify() @@ -3696,9 +3754,9 @@ class PartitionTest extends AbstractPartitionTest { _topicId = None, keepPartitionMetadataFile = true) { - override def appendAsFollower(records: MemoryRecords): LogAppendInfo = { + override def appendAsFollower(records: MemoryRecords, epoch: Int): LogAppendInfo = { appendSemaphore.acquire() - val appendInfo = super.appendAsFollower(records) + val appendInfo = super.appendAsFollower(records, epoch) appendInfo } } diff --git a/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala b/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala index cb22931bb88..734845c440e 100644 --- a/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala +++ b/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala @@ -1427,7 +1427,7 @@ class LogCleanerTest extends Logging { log.appendAsLeader(TestUtils.singletonRecords(value = v, key = k), leaderEpoch = 0) //0 to Int.MaxValue is Int.MaxValue+1 message, -1 will be the last message of i-th segment val records = messageWithOffset(k, v, (i + 1L) * (Int.MaxValue + 1L) -1 ) - log.appendAsFollower(records) + log.appendAsFollower(records, Int.MaxValue) assertEquals(i + 1, log.numberOfSegments) } @@ -1481,7 +1481,7 @@ class LogCleanerTest extends Logging { // forward offset and append message to next segment at offset Int.MaxValue val records = messageWithOffset("hello".getBytes, "hello".getBytes, Int.MaxValue - 1) - log.appendAsFollower(records) + log.appendAsFollower(records, Int.MaxValue) log.appendAsLeader(TestUtils.singletonRecords(value = "hello".getBytes, key = "hello".getBytes), leaderEpoch = 0) assertEquals(Int.MaxValue, log.activeSegment.offsetIndex.lastOffset) @@ -1530,14 +1530,14 @@ class LogCleanerTest extends Logging { val log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps)) val record1 = messageWithOffset("hello".getBytes, "hello".getBytes, 0) - log.appendAsFollower(record1) + log.appendAsFollower(record1, Int.MaxValue) val record2 = messageWithOffset("hello".getBytes, "hello".getBytes, 1) - log.appendAsFollower(record2) + log.appendAsFollower(record2, Int.MaxValue) log.roll(Some(Int.MaxValue/2)) // starting a new log segment at offset Int.MaxValue/2 val record3 = messageWithOffset("hello".getBytes, "hello".getBytes, Int.MaxValue/2) - log.appendAsFollower(record3) + log.appendAsFollower(record3, Int.MaxValue) val record4 = messageWithOffset("hello".getBytes, "hello".getBytes, Int.MaxValue.toLong + 1) - log.appendAsFollower(record4) + log.appendAsFollower(record4, Int.MaxValue) assertTrue(log.logEndOffset - 1 - log.logStartOffset > Int.MaxValue, "Actual offset range should be > Int.MaxValue") assertTrue(log.logSegments.asScala.last.offsetIndex.lastOffset - log.logStartOffset <= Int.MaxValue, @@ -1848,8 +1848,8 @@ class LogCleanerTest extends Logging { val noDupSetOffset = 50 val noDupSet = noDupSetKeys zip (noDupSetOffset until noDupSetOffset + noDupSetKeys.size) - log.appendAsFollower(invalidCleanedMessage(dupSetOffset, dupSet, codec)) - log.appendAsFollower(invalidCleanedMessage(noDupSetOffset, noDupSet, codec)) + log.appendAsFollower(invalidCleanedMessage(dupSetOffset, dupSet, codec), Int.MaxValue) + log.appendAsFollower(invalidCleanedMessage(noDupSetOffset, noDupSet, codec), Int.MaxValue) log.roll() @@ -1932,7 +1932,7 @@ class LogCleanerTest extends Logging { log.roll(Some(11L)) // active segment record - log.appendAsFollower(messageWithOffset(1015, 1015, 11L)) + log.appendAsFollower(messageWithOffset(1015, 1015, 11L), Int.MaxValue) val (nextDirtyOffset, _) = cleaner.clean(LogToClean(log.topicPartition, log, 0L, log.activeSegment.baseOffset, needCompactionNow = true)) assertEquals(log.activeSegment.baseOffset, nextDirtyOffset, @@ -1951,7 +1951,7 @@ class LogCleanerTest extends Logging { log.roll(Some(30L)) // active segment record - log.appendAsFollower(messageWithOffset(1015, 1015, 30L)) + log.appendAsFollower(messageWithOffset(1015, 1015, 30L), Int.MaxValue) val (nextDirtyOffset, _) = cleaner.clean(LogToClean(log.topicPartition, log, 0L, log.activeSegment.baseOffset, needCompactionNow = true)) assertEquals(log.activeSegment.baseOffset, nextDirtyOffset, @@ -2168,7 +2168,7 @@ class LogCleanerTest extends Logging { private def writeToLog(log: UnifiedLog, keysAndValues: Iterable[(Int, Int)], offsetSeq: Iterable[Long]): Iterable[Long] = { for (((key, value), offset) <- keysAndValues.zip(offsetSeq)) - yield log.appendAsFollower(messageWithOffset(key, value, offset)).lastOffset + yield log.appendAsFollower(messageWithOffset(key, value, offset), Int.MaxValue).lastOffset } private def invalidCleanedMessage(initialOffset: Long, diff --git a/core/src/test/scala/unit/kafka/log/LogConcurrencyTest.scala b/core/src/test/scala/unit/kafka/log/LogConcurrencyTest.scala index 65187b707ad..f3455cdb713 100644 --- a/core/src/test/scala/unit/kafka/log/LogConcurrencyTest.scala +++ b/core/src/test/scala/unit/kafka/log/LogConcurrencyTest.scala @@ -125,9 +125,14 @@ class LogConcurrencyTest { log.appendAsLeader(TestUtils.records(records), leaderEpoch) log.maybeIncrementHighWatermark(logEndOffsetMetadata) } else { - log.appendAsFollower(TestUtils.records(records, - baseOffset = logEndOffset, - partitionLeaderEpoch = leaderEpoch)) + log.appendAsFollower( + TestUtils.records( + records, + baseOffset = logEndOffset, + partitionLeaderEpoch = leaderEpoch + ), + Int.MaxValue + ) log.updateHighWatermark(logEndOffset) } diff --git a/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala b/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala index 5287ef67e86..2e98182f29f 100644 --- a/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala +++ b/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala @@ -1075,17 +1075,17 @@ class LogLoaderTest { val set3 = MemoryRecords.withRecords(Integer.MAX_VALUE.toLong + 3, Compression.NONE, 0, new SimpleRecord("v4".getBytes(), "k4".getBytes())) val set4 = MemoryRecords.withRecords(Integer.MAX_VALUE.toLong + 4, Compression.NONE, 0, new SimpleRecord("v5".getBytes(), "k5".getBytes())) //Writes into an empty log with baseOffset 0 - log.appendAsFollower(set1) + log.appendAsFollower(set1, Int.MaxValue) assertEquals(0L, log.activeSegment.baseOffset) //This write will roll the segment, yielding a new segment with base offset = max(1, Integer.MAX_VALUE+2) = Integer.MAX_VALUE+2 - log.appendAsFollower(set2) + log.appendAsFollower(set2, Int.MaxValue) assertEquals(Integer.MAX_VALUE.toLong + 2, log.activeSegment.baseOffset) assertTrue(LogFileUtils.producerSnapshotFile(logDir, Integer.MAX_VALUE.toLong + 2).exists) //This will go into the existing log - log.appendAsFollower(set3) + log.appendAsFollower(set3, Int.MaxValue) assertEquals(Integer.MAX_VALUE.toLong + 2, log.activeSegment.baseOffset) //This will go into the existing log - log.appendAsFollower(set4) + log.appendAsFollower(set4, Int.MaxValue) assertEquals(Integer.MAX_VALUE.toLong + 2, log.activeSegment.baseOffset) log.close() val indexFiles = logDir.listFiles.filter(file => file.getName.contains(".index")) @@ -1138,17 +1138,17 @@ class LogLoaderTest { new SimpleRecord("v7".getBytes(), "k7".getBytes()), new SimpleRecord("v8".getBytes(), "k8".getBytes())) //Writes into an empty log with baseOffset 0 - log.appendAsFollower(set1) + log.appendAsFollower(set1, Int.MaxValue) assertEquals(0L, log.activeSegment.baseOffset) //This write will roll the segment, yielding a new segment with base offset = max(1, Integer.MAX_VALUE+2) = Integer.MAX_VALUE+2 - log.appendAsFollower(set2) + log.appendAsFollower(set2, Int.MaxValue) assertEquals(Integer.MAX_VALUE.toLong + 2, log.activeSegment.baseOffset) assertTrue(LogFileUtils.producerSnapshotFile(logDir, Integer.MAX_VALUE.toLong + 2).exists) //This will go into the existing log - log.appendAsFollower(set3) + log.appendAsFollower(set3, Int.MaxValue) assertEquals(Integer.MAX_VALUE.toLong + 2, log.activeSegment.baseOffset) //This will go into the existing log - log.appendAsFollower(set4) + log.appendAsFollower(set4, Int.MaxValue) assertEquals(Integer.MAX_VALUE.toLong + 2, log.activeSegment.baseOffset) log.close() val indexFiles = logDir.listFiles.filter(file => file.getName.contains(".index")) @@ -1178,18 +1178,18 @@ class LogLoaderTest { new SimpleRecord("v7".getBytes(), "k7".getBytes()), new SimpleRecord("v8".getBytes(), "k8".getBytes())) //Writes into an empty log with baseOffset 0 - log.appendAsFollower(set1) + log.appendAsFollower(set1, Int.MaxValue) assertEquals(0L, log.activeSegment.baseOffset) //This write will roll the segment, yielding a new segment with base offset = max(1, 3) = 3 - log.appendAsFollower(set2) + log.appendAsFollower(set2, Int.MaxValue) assertEquals(3, log.activeSegment.baseOffset) assertTrue(LogFileUtils.producerSnapshotFile(logDir, 3).exists) //This will also roll the segment, yielding a new segment with base offset = max(5, Integer.MAX_VALUE+4) = Integer.MAX_VALUE+4 - log.appendAsFollower(set3) + log.appendAsFollower(set3, Int.MaxValue) assertEquals(Integer.MAX_VALUE.toLong + 4, log.activeSegment.baseOffset) assertTrue(LogFileUtils.producerSnapshotFile(logDir, Integer.MAX_VALUE.toLong + 4).exists) //This will go into the existing log - log.appendAsFollower(set4) + log.appendAsFollower(set4, Int.MaxValue) assertEquals(Integer.MAX_VALUE.toLong + 4, log.activeSegment.baseOffset) log.close() val indexFiles = logDir.listFiles.filter(file => file.getName.contains(".index")) @@ -1379,16 +1379,16 @@ class LogLoaderTest { val log = createLog(logDir, new LogConfig(new Properties)) val leaderEpochCache = log.leaderEpochCache.get val firstBatch = singletonRecordsWithLeaderEpoch(value = "random".getBytes, leaderEpoch = 1, offset = 0) - log.appendAsFollower(records = firstBatch) + log.appendAsFollower(records = firstBatch, Int.MaxValue) val secondBatch = singletonRecordsWithLeaderEpoch(value = "random".getBytes, leaderEpoch = 2, offset = 1) - log.appendAsFollower(records = secondBatch) + log.appendAsFollower(records = secondBatch, Int.MaxValue) val thirdBatch = singletonRecordsWithLeaderEpoch(value = "random".getBytes, leaderEpoch = 2, offset = 2) - log.appendAsFollower(records = thirdBatch) + log.appendAsFollower(records = thirdBatch, Int.MaxValue) val fourthBatch = singletonRecordsWithLeaderEpoch(value = "random".getBytes, leaderEpoch = 3, offset = 3) - log.appendAsFollower(records = fourthBatch) + log.appendAsFollower(records = fourthBatch, Int.MaxValue) assertEquals(java.util.Arrays.asList(new EpochEntry(1, 0), new EpochEntry(2, 1), new EpochEntry(3, 3)), leaderEpochCache.epochEntries) diff --git a/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala b/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala index 13102074a88..5bc8c44ec5c 100755 --- a/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala +++ b/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala @@ -44,11 +44,16 @@ import org.apache.kafka.storage.log.metrics.BrokerTopicMetrics import org.junit.jupiter.api.Assertions._ import org.junit.jupiter.api.{AfterEach, BeforeEach, Test} import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ArgumentsSource import org.junit.jupiter.params.provider.{EnumSource, ValueSource} import org.mockito.ArgumentMatchers import org.mockito.ArgumentMatchers.{any, anyLong} import org.mockito.Mockito.{doThrow, mock, spy, when} +import net.jqwik.api.AfterFailureMode +import net.jqwik.api.ForAll +import net.jqwik.api.Property + import java.io._ import java.nio.ByteBuffer import java.nio.file.Files @@ -301,7 +306,7 @@ class UnifiedLogTest { assertHighWatermark(3L) // Update high watermark as follower - log.appendAsFollower(records(3L)) + log.appendAsFollower(records(3L), leaderEpoch) log.updateHighWatermark(6L) assertHighWatermark(6L) @@ -577,6 +582,7 @@ class UnifiedLogTest { @Test def testRollSegmentThatAlreadyExists(): Unit = { val logConfig = LogTestUtils.createLogConfig(segmentMs = 1 * 60 * 60L) + val partitionLeaderEpoch = 0 // create a log val log = createLog(logDir, logConfig) @@ -589,16 +595,16 @@ class UnifiedLogTest { // should be able to append records to active segment val records = TestUtils.records( List(new SimpleRecord(mockTime.milliseconds, "k1".getBytes, "v1".getBytes)), - baseOffset = 0L, partitionLeaderEpoch = 0) - log.appendAsFollower(records) + baseOffset = 0L, partitionLeaderEpoch = partitionLeaderEpoch) + log.appendAsFollower(records, partitionLeaderEpoch) assertEquals(1, log.numberOfSegments, "Expect one segment.") assertEquals(0L, log.activeSegment.baseOffset) // make sure we can append more records val records2 = TestUtils.records( List(new SimpleRecord(mockTime.milliseconds + 10, "k2".getBytes, "v2".getBytes)), - baseOffset = 1L, partitionLeaderEpoch = 0) - log.appendAsFollower(records2) + baseOffset = 1L, partitionLeaderEpoch = partitionLeaderEpoch) + log.appendAsFollower(records2, partitionLeaderEpoch) assertEquals(2, log.logEndOffset, "Expect two records in the log") assertEquals(0, LogTestUtils.readLog(log, 0, 1).records.batches.iterator.next().lastOffset) @@ -613,8 +619,8 @@ class UnifiedLogTest { log.activeSegment.offsetIndex.resize(0) val records3 = TestUtils.records( List(new SimpleRecord(mockTime.milliseconds + 12, "k3".getBytes, "v3".getBytes)), - baseOffset = 2L, partitionLeaderEpoch = 0) - log.appendAsFollower(records3) + baseOffset = 2L, partitionLeaderEpoch = partitionLeaderEpoch) + log.appendAsFollower(records3, partitionLeaderEpoch) assertTrue(log.activeSegment.offsetIndex.maxEntries > 1) assertEquals(2, LogTestUtils.readLog(log, 2, 1).records.batches.iterator.next().lastOffset) assertEquals(2, log.numberOfSegments, "Expect two segments.") @@ -791,17 +797,25 @@ class UnifiedLogTest { val logConfig = LogTestUtils.createLogConfig(segmentBytes = 2048 * 5) val log = createLog(logDir, logConfig) val pid = 1L - val epoch = 0.toShort + val producerEpoch = 0.toShort + val partitionLeaderEpoch = 0 val seq = 0 val baseOffset = 23L // create a batch with a couple gaps to simulate compaction - val records = TestUtils.records(producerId = pid, producerEpoch = epoch, sequence = seq, baseOffset = baseOffset, records = List( - new SimpleRecord(mockTime.milliseconds(), "a".getBytes), - new SimpleRecord(mockTime.milliseconds(), "key".getBytes, "b".getBytes), - new SimpleRecord(mockTime.milliseconds(), "c".getBytes), - new SimpleRecord(mockTime.milliseconds(), "key".getBytes, "d".getBytes))) - records.batches.forEach(_.setPartitionLeaderEpoch(0)) + val records = TestUtils.records( + producerId = pid, + producerEpoch = producerEpoch, + sequence = seq, + baseOffset = baseOffset, + records = List( + new SimpleRecord(mockTime.milliseconds(), "a".getBytes), + new SimpleRecord(mockTime.milliseconds(), "key".getBytes, "b".getBytes), + new SimpleRecord(mockTime.milliseconds(), "c".getBytes), + new SimpleRecord(mockTime.milliseconds(), "key".getBytes, "d".getBytes) + ) + ) + records.batches.forEach(_.setPartitionLeaderEpoch(partitionLeaderEpoch)) val filtered = ByteBuffer.allocate(2048) records.filterTo(new TopicPartition("foo", 0), new RecordFilter(0, 0) { @@ -812,14 +826,18 @@ class UnifiedLogTest { filtered.flip() val filteredRecords = MemoryRecords.readableRecords(filtered) - log.appendAsFollower(filteredRecords) + log.appendAsFollower(filteredRecords, partitionLeaderEpoch) // append some more data and then truncate to force rebuilding of the PID map - val moreRecords = TestUtils.records(baseOffset = baseOffset + 4, records = List( - new SimpleRecord(mockTime.milliseconds(), "e".getBytes), - new SimpleRecord(mockTime.milliseconds(), "f".getBytes))) - moreRecords.batches.forEach(_.setPartitionLeaderEpoch(0)) - log.appendAsFollower(moreRecords) + val moreRecords = TestUtils.records( + baseOffset = baseOffset + 4, + records = List( + new SimpleRecord(mockTime.milliseconds(), "e".getBytes), + new SimpleRecord(mockTime.milliseconds(), "f".getBytes) + ) + ) + moreRecords.batches.forEach(_.setPartitionLeaderEpoch(partitionLeaderEpoch)) + log.appendAsFollower(moreRecords, partitionLeaderEpoch) log.truncateTo(baseOffset + 4) @@ -835,15 +853,23 @@ class UnifiedLogTest { val logConfig = LogTestUtils.createLogConfig(segmentBytes = 2048 * 5) val log = createLog(logDir, logConfig) val pid = 1L - val epoch = 0.toShort + val producerEpoch = 0.toShort + val partitionLeaderEpoch = 0 val seq = 0 val baseOffset = 23L // create an empty batch - val records = TestUtils.records(producerId = pid, producerEpoch = epoch, sequence = seq, baseOffset = baseOffset, records = List( - new SimpleRecord(mockTime.milliseconds(), "key".getBytes, "a".getBytes), - new SimpleRecord(mockTime.milliseconds(), "key".getBytes, "b".getBytes))) - records.batches.forEach(_.setPartitionLeaderEpoch(0)) + val records = TestUtils.records( + producerId = pid, + producerEpoch = producerEpoch, + sequence = seq, + baseOffset = baseOffset, + records = List( + new SimpleRecord(mockTime.milliseconds(), "key".getBytes, "a".getBytes), + new SimpleRecord(mockTime.milliseconds(), "key".getBytes, "b".getBytes) + ) + ) + records.batches.forEach(_.setPartitionLeaderEpoch(partitionLeaderEpoch)) val filtered = ByteBuffer.allocate(2048) records.filterTo(new TopicPartition("foo", 0), new RecordFilter(0, 0) { @@ -854,14 +880,18 @@ class UnifiedLogTest { filtered.flip() val filteredRecords = MemoryRecords.readableRecords(filtered) - log.appendAsFollower(filteredRecords) + log.appendAsFollower(filteredRecords, partitionLeaderEpoch) // append some more data and then truncate to force rebuilding of the PID map - val moreRecords = TestUtils.records(baseOffset = baseOffset + 2, records = List( - new SimpleRecord(mockTime.milliseconds(), "e".getBytes), - new SimpleRecord(mockTime.milliseconds(), "f".getBytes))) - moreRecords.batches.forEach(_.setPartitionLeaderEpoch(0)) - log.appendAsFollower(moreRecords) + val moreRecords = TestUtils.records( + baseOffset = baseOffset + 2, + records = List( + new SimpleRecord(mockTime.milliseconds(), "e".getBytes), + new SimpleRecord(mockTime.milliseconds(), "f".getBytes) + ) + ) + moreRecords.batches.forEach(_.setPartitionLeaderEpoch(partitionLeaderEpoch)) + log.appendAsFollower(moreRecords, partitionLeaderEpoch) log.truncateTo(baseOffset + 2) @@ -877,17 +907,25 @@ class UnifiedLogTest { val logConfig = LogTestUtils.createLogConfig(segmentBytes = 2048 * 5) val log = createLog(logDir, logConfig) val pid = 1L - val epoch = 0.toShort + val producerEpoch = 0.toShort + val partitionLeaderEpoch = 0 val seq = 0 val baseOffset = 23L // create a batch with a couple gaps to simulate compaction - val records = TestUtils.records(producerId = pid, producerEpoch = epoch, sequence = seq, baseOffset = baseOffset, records = List( - new SimpleRecord(mockTime.milliseconds(), "a".getBytes), - new SimpleRecord(mockTime.milliseconds(), "key".getBytes, "b".getBytes), - new SimpleRecord(mockTime.milliseconds(), "c".getBytes), - new SimpleRecord(mockTime.milliseconds(), "key".getBytes, "d".getBytes))) - records.batches.forEach(_.setPartitionLeaderEpoch(0)) + val records = TestUtils.records( + producerId = pid, + producerEpoch = producerEpoch, + sequence = seq, + baseOffset = baseOffset, + records = List( + new SimpleRecord(mockTime.milliseconds(), "a".getBytes), + new SimpleRecord(mockTime.milliseconds(), "key".getBytes, "b".getBytes), + new SimpleRecord(mockTime.milliseconds(), "c".getBytes), + new SimpleRecord(mockTime.milliseconds(), "key".getBytes, "d".getBytes) + ) + ) + records.batches.forEach(_.setPartitionLeaderEpoch(partitionLeaderEpoch)) val filtered = ByteBuffer.allocate(2048) records.filterTo(new TopicPartition("foo", 0), new RecordFilter(0, 0) { @@ -898,7 +936,7 @@ class UnifiedLogTest { filtered.flip() val filteredRecords = MemoryRecords.readableRecords(filtered) - log.appendAsFollower(filteredRecords) + log.appendAsFollower(filteredRecords, partitionLeaderEpoch) val activeProducers = log.activeProducersWithLastSequence assertTrue(activeProducers.contains(pid)) @@ -1328,33 +1366,44 @@ class UnifiedLogTest { // create a log val log = createLog(logDir, new LogConfig(new Properties)) - val epoch: Short = 0 + val producerEpoch: Short = 0 + val partitionLeaderEpoch = 0 val buffer = ByteBuffer.allocate(512) - var builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, Compression.NONE, - TimestampType.LOG_APPEND_TIME, 0L, mockTime.milliseconds(), 1L, epoch, 0, false, 0) + var builder = MemoryRecords.builder( + buffer, RecordBatch.MAGIC_VALUE_V2, Compression.NONE, + TimestampType.LOG_APPEND_TIME, 0L, mockTime.milliseconds(), 1L, producerEpoch, 0, false, + partitionLeaderEpoch + ) builder.append(new SimpleRecord("key".getBytes, "value".getBytes)) builder.close() builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, Compression.NONE, - TimestampType.LOG_APPEND_TIME, 1L, mockTime.milliseconds(), 2L, epoch, 0, false, 0) + TimestampType.LOG_APPEND_TIME, 1L, mockTime.milliseconds(), 2L, producerEpoch, 0, false, + partitionLeaderEpoch) builder.append(new SimpleRecord("key".getBytes, "value".getBytes)) builder.close() - builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, Compression.NONE, - TimestampType.LOG_APPEND_TIME, 2L, mockTime.milliseconds(), 3L, epoch, 0, false, 0) + builder = MemoryRecords.builder( + buffer, RecordBatch.MAGIC_VALUE_V2, Compression.NONE, + TimestampType.LOG_APPEND_TIME, 2L, mockTime.milliseconds(), 3L, producerEpoch, 0, false, + partitionLeaderEpoch + ) builder.append(new SimpleRecord("key".getBytes, "value".getBytes)) builder.close() - builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, Compression.NONE, - TimestampType.LOG_APPEND_TIME, 3L, mockTime.milliseconds(), 4L, epoch, 0, false, 0) + builder = MemoryRecords.builder( + buffer, RecordBatch.MAGIC_VALUE_V2, Compression.NONE, + TimestampType.LOG_APPEND_TIME, 3L, mockTime.milliseconds(), 4L, producerEpoch, 0, false, + partitionLeaderEpoch + ) builder.append(new SimpleRecord("key".getBytes, "value".getBytes)) builder.close() buffer.flip() val memoryRecords = MemoryRecords.readableRecords(buffer) - log.appendAsFollower(memoryRecords) + log.appendAsFollower(memoryRecords, partitionLeaderEpoch) log.flush(false) val fetchedData = LogTestUtils.readLog(log, 0, Int.MaxValue) @@ -1373,7 +1422,7 @@ class UnifiedLogTest { def testDuplicateAppendToFollower(): Unit = { val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1024 * 1024 * 5) val log = createLog(logDir, logConfig) - val epoch: Short = 0 + val producerEpoch: Short = 0 val pid = 1L val baseSequence = 0 val partitionLeaderEpoch = 0 @@ -1381,10 +1430,32 @@ class UnifiedLogTest { // this is a bit contrived. to trigger the duplicate case for a follower append, we have to append // a batch with matching sequence numbers, but valid increasing offsets assertEquals(0L, log.logEndOffset) - log.appendAsFollower(MemoryRecords.withIdempotentRecords(0L, Compression.NONE, pid, epoch, baseSequence, - partitionLeaderEpoch, new SimpleRecord("a".getBytes), new SimpleRecord("b".getBytes))) - log.appendAsFollower(MemoryRecords.withIdempotentRecords(2L, Compression.NONE, pid, epoch, baseSequence, - partitionLeaderEpoch, new SimpleRecord("a".getBytes), new SimpleRecord("b".getBytes))) + log.appendAsFollower( + MemoryRecords.withIdempotentRecords( + 0L, + Compression.NONE, + pid, + producerEpoch, + baseSequence, + partitionLeaderEpoch, + new SimpleRecord("a".getBytes), + new SimpleRecord("b".getBytes) + ), + partitionLeaderEpoch + ) + log.appendAsFollower( + MemoryRecords.withIdempotentRecords( + 2L, + Compression.NONE, + pid, + producerEpoch, + baseSequence, + partitionLeaderEpoch, + new SimpleRecord("a".getBytes), + new SimpleRecord("b".getBytes) + ), + partitionLeaderEpoch + ) // Ensure that even the duplicate sequences are accepted on the follower. assertEquals(4L, log.logEndOffset) @@ -1397,48 +1468,49 @@ class UnifiedLogTest { val pid1 = 1L val pid2 = 2L - val epoch: Short = 0 + val producerEpoch: Short = 0 val buffer = ByteBuffer.allocate(512) // pid1 seq = 0 var builder = MemoryRecords.builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, Compression.NONE, - TimestampType.LOG_APPEND_TIME, 0L, mockTime.milliseconds(), pid1, epoch, 0) + TimestampType.LOG_APPEND_TIME, 0L, mockTime.milliseconds(), pid1, producerEpoch, 0) builder.append(new SimpleRecord("key".getBytes, "value".getBytes)) builder.close() // pid2 seq = 0 builder = MemoryRecords.builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, Compression.NONE, - TimestampType.LOG_APPEND_TIME, 1L, mockTime.milliseconds(), pid2, epoch, 0) + TimestampType.LOG_APPEND_TIME, 1L, mockTime.milliseconds(), pid2, producerEpoch, 0) builder.append(new SimpleRecord("key".getBytes, "value".getBytes)) builder.close() // pid1 seq = 1 builder = MemoryRecords.builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, Compression.NONE, - TimestampType.LOG_APPEND_TIME, 2L, mockTime.milliseconds(), pid1, epoch, 1) + TimestampType.LOG_APPEND_TIME, 2L, mockTime.milliseconds(), pid1, producerEpoch, 1) builder.append(new SimpleRecord("key".getBytes, "value".getBytes)) builder.close() // pid2 seq = 1 builder = MemoryRecords.builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, Compression.NONE, - TimestampType.LOG_APPEND_TIME, 3L, mockTime.milliseconds(), pid2, epoch, 1) + TimestampType.LOG_APPEND_TIME, 3L, mockTime.milliseconds(), pid2, producerEpoch, 1) builder.append(new SimpleRecord("key".getBytes, "value".getBytes)) builder.close() // // pid1 seq = 1 (duplicate) builder = MemoryRecords.builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, Compression.NONE, - TimestampType.LOG_APPEND_TIME, 4L, mockTime.milliseconds(), pid1, epoch, 1) + TimestampType.LOG_APPEND_TIME, 4L, mockTime.milliseconds(), pid1, producerEpoch, 1) builder.append(new SimpleRecord("key".getBytes, "value".getBytes)) builder.close() buffer.flip() + val epoch = 0 val records = MemoryRecords.readableRecords(buffer) - records.batches.forEach(_.setPartitionLeaderEpoch(0)) + records.batches.forEach(_.setPartitionLeaderEpoch(epoch)) // Ensure that batches with duplicates are accepted on the follower. assertEquals(0L, log.logEndOffset) - log.appendAsFollower(records) + log.appendAsFollower(records, epoch) assertEquals(5L, log.logEndOffset) } @@ -1580,8 +1652,12 @@ class UnifiedLogTest { val records = messageIds.map(id => new SimpleRecord(id.toString.getBytes)) // now test the case that we give the offsets and use non-sequential offsets - for (i <- records.indices) - log.appendAsFollower(MemoryRecords.withRecords(messageIds(i), Compression.NONE, 0, records(i))) + for (i <- records.indices) { + log.appendAsFollower( + MemoryRecords.withRecords(messageIds(i), Compression.NONE, 0, records(i)), + Int.MaxValue + ) + } for (i <- 50 until messageIds.max) { val idx = messageIds.indexWhere(_ >= i) val read = LogTestUtils.readLog(log, i, 100).records.records.iterator.next() @@ -1628,8 +1704,12 @@ class UnifiedLogTest { val records = messageIds.map(id => new SimpleRecord(id.toString.getBytes)) // now test the case that we give the offsets and use non-sequential offsets - for (i <- records.indices) - log.appendAsFollower(MemoryRecords.withRecords(messageIds(i), Compression.NONE, 0, records(i))) + for (i <- records.indices) { + log.appendAsFollower( + MemoryRecords.withRecords(messageIds(i), Compression.NONE, 0, records(i)), + Int.MaxValue + ) + } for (i <- 50 until messageIds.max) { val idx = messageIds.indexWhere(_ >= i) @@ -1653,8 +1733,12 @@ class UnifiedLogTest { val records = messageIds.map(id => new SimpleRecord(id.toString.getBytes)) // now test the case that we give the offsets and use non-sequential offsets - for (i <- records.indices) - log.appendAsFollower(MemoryRecords.withRecords(messageIds(i), Compression.NONE, 0, records(i))) + for (i <- records.indices) { + log.appendAsFollower( + MemoryRecords.withRecords(messageIds(i), Compression.NONE, 0, records(i)), + Int.MaxValue + ) + } for (i <- 50 until messageIds.max) { assertEquals(MemoryRecords.EMPTY, LogTestUtils.readLog(log, i, maxLength = 0, minOneMessage = false).records) @@ -1902,9 +1986,94 @@ class UnifiedLogTest { val log = createLog(logDir, LogTestUtils.createLogConfig(maxMessageBytes = second.sizeInBytes - 1)) - log.appendAsFollower(first) + log.appendAsFollower(first, Int.MaxValue) // the second record is larger then limit but appendAsFollower does not validate the size. - log.appendAsFollower(second) + log.appendAsFollower(second, Int.MaxValue) + } + + @ParameterizedTest + @ArgumentsSource(classOf[InvalidMemoryRecordsProvider]) + def testInvalidMemoryRecords(records: MemoryRecords, expectedException: Optional[Class[Exception]]): Unit = { + val logConfig = LogTestUtils.createLogConfig() + val log = createLog(logDir, logConfig) + val previousEndOffset = log.logEndOffsetMetadata.messageOffset + + if (expectedException.isPresent()) { + assertThrows( + expectedException.get(), + () => log.appendAsFollower(records, Int.MaxValue) + ) + } else { + log.appendAsFollower(records, Int.MaxValue) + } + + assertEquals(previousEndOffset, log.logEndOffsetMetadata.messageOffset) + } + + @Property(tries = 100, afterFailure = AfterFailureMode.SAMPLE_ONLY) + def testRandomRecords( + @ForAll(supplier = classOf[ArbitraryMemoryRecords]) records: MemoryRecords + ): Unit = { + val tempDir = TestUtils.tempDir() + val logDir = TestUtils.randomPartitionLogDir(tempDir) + try { + val logConfig = LogTestUtils.createLogConfig() + val log = createLog(logDir, logConfig) + val previousEndOffset = log.logEndOffsetMetadata.messageOffset + + // Depending on the corruption, unified log sometimes throws and sometimes returns an + // empty set of batches + assertThrows( + classOf[CorruptRecordException], + () => { + val info = log.appendAsFollower(records, Int.MaxValue) + if (info.firstOffset == UnifiedLog.UnknownOffset) { + throw new CorruptRecordException("Unknown offset is test") + } + } + ) + + assertEquals(previousEndOffset, log.logEndOffsetMetadata.messageOffset) + } finally { + Utils.delete(tempDir) + } + } + + @Test + def testInvalidLeaderEpoch(): Unit = { + val logConfig = LogTestUtils.createLogConfig() + val log = createLog(logDir, logConfig) + val previousEndOffset = log.logEndOffsetMetadata.messageOffset + val epoch = log.latestEpoch.getOrElse(0) + 1 + val numberOfRecords = 10 + + val batchWithValidEpoch = MemoryRecords.withRecords( + previousEndOffset, + Compression.NONE, + epoch, + (0 until numberOfRecords).map(number => new SimpleRecord(number.toString.getBytes)): _* + ) + + val batchWithInvalidEpoch = MemoryRecords.withRecords( + previousEndOffset + numberOfRecords, + Compression.NONE, + epoch + 1, + (0 until numberOfRecords).map(number => new SimpleRecord(number.toString.getBytes)): _* + ) + + val buffer = ByteBuffer.allocate(batchWithValidEpoch.sizeInBytes() + batchWithInvalidEpoch.sizeInBytes()) + buffer.put(batchWithValidEpoch.buffer()) + buffer.put(batchWithInvalidEpoch.buffer()) + buffer.flip() + + val records = MemoryRecords.readableRecords(buffer) + + log.appendAsFollower(records, epoch) + + // Check that only the first batch was appended + assertEquals(previousEndOffset + numberOfRecords, log.logEndOffsetMetadata.messageOffset) + // Check that the last fetched epoch matches the first batch + assertEquals(epoch, log.latestEpoch.get) } @Test @@ -2005,7 +2174,7 @@ class UnifiedLogTest { val messages = (0 until numMessages).map { i => MemoryRecords.withRecords(100 + i, Compression.NONE, 0, new SimpleRecord(mockTime.milliseconds + i, i.toString.getBytes())) } - messages.foreach(log.appendAsFollower) + messages.foreach(message => log.appendAsFollower(message, Int.MaxValue)) val timeIndexEntries = log.logSegments.asScala.foldLeft(0) { (entries, segment) => entries + segment.timeIndex.entries } assertEquals(numMessages - 1, timeIndexEntries, s"There should be ${numMessages - 1} time index entries") assertEquals(mockTime.milliseconds + numMessages - 1, log.activeSegment.timeIndex.lastEntry.timestamp, @@ -2367,20 +2536,22 @@ class UnifiedLogTest { def testAppendWithOutOfOrderOffsetsThrowsException(): Unit = { val log = createLog(logDir, new LogConfig(new Properties)) + val epoch = 0 val appendOffsets = Seq(0L, 1L, 3L, 2L, 4L) val buffer = ByteBuffer.allocate(512) for (offset <- appendOffsets) { val builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, Compression.NONE, TimestampType.LOG_APPEND_TIME, offset, mockTime.milliseconds(), - 1L, 0, 0, false, 0) + 1L, 0, 0, false, epoch) builder.append(new SimpleRecord("key".getBytes, "value".getBytes)) builder.close() } buffer.flip() val memoryRecords = MemoryRecords.readableRecords(buffer) - assertThrows(classOf[OffsetsOutOfOrderException], () => - log.appendAsFollower(memoryRecords) + assertThrows( + classOf[OffsetsOutOfOrderException], + () => log.appendAsFollower(memoryRecords, epoch) ) } @@ -2395,9 +2566,11 @@ class UnifiedLogTest { for (magic <- magicVals; compressionType <- compressionTypes) { val compression = Compression.of(compressionType).build() val invalidRecord = MemoryRecords.withRecords(magic, compression, new SimpleRecord(1.toString.getBytes)) - assertThrows(classOf[UnexpectedAppendOffsetException], - () => log.appendAsFollower(invalidRecord), - () => s"Magic=$magic, compressionType=$compressionType") + assertThrows( + classOf[UnexpectedAppendOffsetException], + () => log.appendAsFollower(invalidRecord, Int.MaxValue), + () => s"Magic=$magic, compressionType=$compressionType" + ) } } @@ -2418,7 +2591,10 @@ class UnifiedLogTest { magicValue = magic, codec = Compression.of(compressionType).build(), baseOffset = firstOffset) - val exception = assertThrows(classOf[UnexpectedAppendOffsetException], () => log.appendAsFollower(records = batch)) + val exception = assertThrows( + classOf[UnexpectedAppendOffsetException], + () => log.appendAsFollower(records = batch, Int.MaxValue) + ) assertEquals(firstOffset, exception.firstOffset, s"Magic=$magic, compressionType=$compressionType, UnexpectedAppendOffsetException#firstOffset") assertEquals(firstOffset + 2, exception.lastOffset, s"Magic=$magic, compressionType=$compressionType, UnexpectedAppendOffsetException#lastOffset") } @@ -2517,9 +2693,16 @@ class UnifiedLogTest { log.appendAsLeader(TestUtils.records(List(new SimpleRecord("foo".getBytes()))), leaderEpoch = 5) assertEquals(Some(5), log.leaderEpochCache.flatMap(_.latestEpoch.asScala)) - log.appendAsFollower(TestUtils.records(List(new SimpleRecord("foo".getBytes())), - baseOffset = 1L, - magicValue = RecordVersion.V1.value)) + log.appendAsFollower( + TestUtils.records( + List( + new SimpleRecord("foo".getBytes()) + ), + baseOffset = 1L, + magicValue = RecordVersion.V1.value + ), + 5 + ) assertEquals(None, log.leaderEpochCache.flatMap(_.latestEpoch.asScala)) } @@ -2903,7 +3086,7 @@ class UnifiedLogTest { //When appending as follower (assignOffsets = false) for (i <- records.indices) - log.appendAsFollower(recordsForEpoch(i)) + log.appendAsFollower(recordsForEpoch(i), i) assertEquals(Some(42), log.latestEpoch) } @@ -2971,7 +3154,7 @@ class UnifiedLogTest { def append(epoch: Int, startOffset: Long, count: Int): Unit = { for (i <- 0 until count) - log.appendAsFollower(createRecords(startOffset + i, epoch)) + log.appendAsFollower(createRecords(startOffset + i, epoch), epoch) } //Given 2 segments, 10 messages per segment @@ -3205,7 +3388,7 @@ class UnifiedLogTest { buffer.flip() - appendAsFollower(log, MemoryRecords.readableRecords(buffer)) + appendAsFollower(log, MemoryRecords.readableRecords(buffer), epoch) val abortedTransactions = LogTestUtils.allAbortedTransactions(log) val expectedTransactions = List( @@ -3289,7 +3472,7 @@ class UnifiedLogTest { appendEndTxnMarkerToBuffer(buffer, pid, epoch, 10L, ControlRecordType.COMMIT, leaderEpoch = 1) buffer.flip() - log.appendAsFollower(MemoryRecords.readableRecords(buffer)) + log.appendAsFollower(MemoryRecords.readableRecords(buffer), epoch) LogTestUtils.appendEndTxnMarkerAsLeader(log, pid, epoch, ControlRecordType.ABORT, mockTime.milliseconds(), coordinatorEpoch = 2, leaderEpoch = 1) LogTestUtils.appendEndTxnMarkerAsLeader(log, pid, epoch, ControlRecordType.ABORT, mockTime.milliseconds(), coordinatorEpoch = 2, leaderEpoch = 1) @@ -3410,10 +3593,16 @@ class UnifiedLogTest { val log = createLog(logDir, logConfig) // append a few records - appendAsFollower(log, MemoryRecords.withRecords(Compression.NONE, - new SimpleRecord("a".getBytes), - new SimpleRecord("b".getBytes), - new SimpleRecord("c".getBytes)), 5) + appendAsFollower( + log, + MemoryRecords.withRecords( + Compression.NONE, + new SimpleRecord("a".getBytes), + new SimpleRecord("b".getBytes), + new SimpleRecord("c".getBytes) + ), + 5 + ) log.updateHighWatermark(2L) @@ -4438,9 +4627,9 @@ class UnifiedLogTest { builder.close() } - private def appendAsFollower(log: UnifiedLog, records: MemoryRecords, leaderEpoch: Int = 0): Unit = { + private def appendAsFollower(log: UnifiedLog, records: MemoryRecords, leaderEpoch: Int): Unit = { records.batches.forEach(_.setPartitionLeaderEpoch(leaderEpoch)) - log.appendAsFollower(records) + log.appendAsFollower(records, leaderEpoch) } private def createLog(dir: File, diff --git a/core/src/test/scala/unit/kafka/server/AbstractFetcherManagerTest.scala b/core/src/test/scala/unit/kafka/server/AbstractFetcherManagerTest.scala index f9750618ae0..414c9f6d189 100644 --- a/core/src/test/scala/unit/kafka/server/AbstractFetcherManagerTest.scala +++ b/core/src/test/scala/unit/kafka/server/AbstractFetcherManagerTest.scala @@ -328,9 +328,12 @@ class AbstractFetcherManagerTest { fetchBackOffMs = 0, brokerTopicStats = new BrokerTopicStats) { - override protected def processPartitionData(topicPartition: TopicPartition, fetchOffset: Long, partitionData: FetchData): Option[LogAppendInfo] = { - None - } + override protected def processPartitionData( + topicPartition: TopicPartition, + fetchOffset: Long, + partitionLeaderEpoch: Int, + partitionData: FetchData + ): Option[LogAppendInfo] = None override protected def truncate(topicPartition: TopicPartition, truncationState: OffsetTruncationState): Unit = {} diff --git a/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadTest.scala b/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadTest.scala index b98c1ddfd03..2efc31d8aaa 100644 --- a/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadTest.scala +++ b/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadTest.scala @@ -668,6 +668,7 @@ class AbstractFetcherThreadTest { @Test def testFollowerFetchOutOfRangeLow(): Unit = { + val leaderEpoch = 4 val partition = new TopicPartition("topic", 0) val mockLeaderEndpoint = new MockLeaderEndPoint(truncateOnFetch = truncateOnFetch, version = version) val mockTierStateMachine = new MockTierStateMachine(mockLeaderEndpoint) @@ -677,14 +678,19 @@ class AbstractFetcherThreadTest { val replicaLog = Seq( mkBatch(baseOffset = 0, leaderEpoch = 0, new SimpleRecord("a".getBytes))) - val replicaState = PartitionState(replicaLog, leaderEpoch = 0, highWatermark = 0L) + val replicaState = PartitionState(replicaLog, leaderEpoch = leaderEpoch, highWatermark = 0L) fetcher.setReplicaState(partition, replicaState) - fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 3L, leaderEpoch = 0))) + fetcher.addPartitions( + Map( + partition -> initialFetchState(topicIds.get(partition.topic), 3L, leaderEpoch = leaderEpoch) + ) + ) val leaderLog = Seq( - mkBatch(baseOffset = 2, leaderEpoch = 4, new SimpleRecord("c".getBytes))) + mkBatch(baseOffset = 2, leaderEpoch = leaderEpoch, new SimpleRecord("c".getBytes)) + ) - val leaderState = PartitionState(leaderLog, leaderEpoch = 0, highWatermark = 2L) + val leaderState = PartitionState(leaderLog, leaderEpoch = leaderEpoch, highWatermark = 2L) fetcher.mockLeader.setLeaderState(partition, leaderState) fetcher.mockLeader.setReplicaPartitionStateCallback(fetcher.replicaPartitionState) @@ -711,6 +717,7 @@ class AbstractFetcherThreadTest { @Test def testRetryAfterUnknownLeaderEpochInLatestOffsetFetch(): Unit = { + val leaderEpoch = 4 val partition = new TopicPartition("topic", 0) val mockLeaderEndPoint = new MockLeaderEndPoint(truncateOnFetch = truncateOnFetch, version = version) { val tries = new AtomicInteger(0) @@ -725,16 +732,18 @@ class AbstractFetcherThreadTest { // The follower begins from an offset which is behind the leader's log start offset val replicaLog = Seq( - mkBatch(baseOffset = 0, leaderEpoch = 0, new SimpleRecord("a".getBytes))) + mkBatch(baseOffset = 0, leaderEpoch = 0, new SimpleRecord("a".getBytes)) + ) - val replicaState = PartitionState(replicaLog, leaderEpoch = 0, highWatermark = 0L) + val replicaState = PartitionState(replicaLog, leaderEpoch = leaderEpoch, highWatermark = 0L) fetcher.setReplicaState(partition, replicaState) - fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 3L, leaderEpoch = 0))) + fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 3L, leaderEpoch = leaderEpoch))) val leaderLog = Seq( - mkBatch(baseOffset = 2, leaderEpoch = 4, new SimpleRecord("c".getBytes))) + mkBatch(baseOffset = 2, leaderEpoch = 4, new SimpleRecord("c".getBytes)) + ) - val leaderState = PartitionState(leaderLog, leaderEpoch = 0, highWatermark = 2L) + val leaderState = PartitionState(leaderLog, leaderEpoch = leaderEpoch, highWatermark = 2L) fetcher.mockLeader.setLeaderState(partition, leaderState) fetcher.mockLeader.setReplicaPartitionStateCallback(fetcher.replicaPartitionState) @@ -752,6 +761,46 @@ class AbstractFetcherThreadTest { assertEquals(leaderState.highWatermark, replicaState.highWatermark) } + @Test + def testReplicateBatchesUpToLeaderEpoch(): Unit = { + val leaderEpoch = 4 + val partition = new TopicPartition("topic", 0) + val mockLeaderEndpoint = new MockLeaderEndPoint(version = version) + val mockTierStateMachine = new MockTierStateMachine(mockLeaderEndpoint) + val fetcher = new MockFetcherThread(mockLeaderEndpoint, mockTierStateMachine, failedPartitions = failedPartitions) + + val replicaState = PartitionState(Seq(), leaderEpoch = leaderEpoch, highWatermark = 0L) + fetcher.setReplicaState(partition, replicaState) + fetcher.addPartitions( + Map( + partition -> initialFetchState(topicIds.get(partition.topic), 3L, leaderEpoch = leaderEpoch) + ) + ) + + val leaderLog = Seq( + mkBatch(baseOffset = 0, leaderEpoch = leaderEpoch - 1, new SimpleRecord("c".getBytes)), + mkBatch(baseOffset = 1, leaderEpoch = leaderEpoch, new SimpleRecord("d".getBytes)), + mkBatch(baseOffset = 2, leaderEpoch = leaderEpoch + 1, new SimpleRecord("e".getBytes)) + ) + + val leaderState = PartitionState(leaderLog, leaderEpoch = leaderEpoch, highWatermark = 0L) + fetcher.mockLeader.setLeaderState(partition, leaderState) + fetcher.mockLeader.setReplicaPartitionStateCallback(fetcher.replicaPartitionState) + + assertEquals(Option(Fetching), fetcher.fetchState(partition).map(_.state)) + assertEquals(0, replicaState.logStartOffset) + assertEquals(List(), replicaState.log.toList) + + TestUtils.waitUntilTrue(() => { + fetcher.doWork() + fetcher.replicaPartitionState(partition).log == fetcher.mockLeader.leaderPartitionState(partition).log.dropRight(1) + }, "Failed to reconcile leader and follower logs up to the leader epoch") + + assertEquals(leaderState.logStartOffset, replicaState.logStartOffset) + assertEquals(leaderState.logEndOffset - 1, replicaState.logEndOffset) + assertEquals(leaderState.highWatermark, replicaState.highWatermark) + } + @Test def testCorruptMessage(): Unit = { val partition = new TopicPartition("topic", 0) @@ -941,11 +990,16 @@ class AbstractFetcherThreadTest { val mockLeaderEndpoint = new MockLeaderEndPoint(truncateOnFetch = truncateOnFetch, version = version) val mockTierStateMachine = new MockTierStateMachine(mockLeaderEndpoint) val fetcherForAppend = new MockFetcherThread(mockLeaderEndpoint, mockTierStateMachine, failedPartitions = failedPartitions) { - override def processPartitionData(topicPartition: TopicPartition, fetchOffset: Long, partitionData: FetchData): Option[LogAppendInfo] = { + override def processPartitionData( + topicPartition: TopicPartition, + fetchOffset: Long, + partitionLeaderEpoch: Int, + partitionData: FetchData + ): Option[LogAppendInfo] = { if (topicPartition == partition1) { throw new KafkaException() } else { - super.processPartitionData(topicPartition, fetchOffset, partitionData) + super.processPartitionData(topicPartition, fetchOffset, partitionLeaderEpoch, partitionData) } } } @@ -1050,9 +1104,14 @@ class AbstractFetcherThreadTest { val mockLeaderEndpoint = new MockLeaderEndPoint(truncateOnFetch = truncateOnFetch, version = version) val mockTierStateMachine = new MockTierStateMachine(mockLeaderEndpoint) val fetcher = new MockFetcherThread(mockLeaderEndpoint, mockTierStateMachine) { - override def processPartitionData(topicPartition: TopicPartition, fetchOffset: Long, partitionData: FetchData): Option[LogAppendInfo] = { + override def processPartitionData( + topicPartition: TopicPartition, + fetchOffset: Long, + partitionLeaderEpoch: Int, + partitionData: FetchData + ): Option[LogAppendInfo] = { processPartitionDataCalls += 1 - super.processPartitionData(topicPartition, fetchOffset, partitionData) + super.processPartitionData(topicPartition, fetchOffset, partitionLeaderEpoch, partitionData) } override def truncate(topicPartition: TopicPartition, truncationState: OffsetTruncationState): Unit = { diff --git a/core/src/test/scala/unit/kafka/server/MockFetcherThread.scala b/core/src/test/scala/unit/kafka/server/MockFetcherThread.scala index a68428775a9..09db8275c04 100644 --- a/core/src/test/scala/unit/kafka/server/MockFetcherThread.scala +++ b/core/src/test/scala/unit/kafka/server/MockFetcherThread.scala @@ -65,9 +65,12 @@ class MockFetcherThread(val mockLeader: MockLeaderEndPoint, partitions } - override def processPartitionData(topicPartition: TopicPartition, - fetchOffset: Long, - partitionData: FetchData): Option[LogAppendInfo] = { + override def processPartitionData( + topicPartition: TopicPartition, + fetchOffset: Long, + leaderEpochForReplica: Int, + partitionData: FetchData + ): Option[LogAppendInfo] = { val state = replicaPartitionState(topicPartition) if (leader.isTruncationOnFetchSupported && FetchResponse.isDivergingEpoch(partitionData)) { @@ -86,17 +89,24 @@ class MockFetcherThread(val mockLeader: MockLeaderEndPoint, var shallowOffsetOfMaxTimestamp = -1L var lastOffset = state.logEndOffset var lastEpoch: OptionalInt = OptionalInt.empty() + var skipRemainingBatches = false for (batch <- batches) { batch.ensureValid() - if (batch.maxTimestamp > maxTimestamp) { - maxTimestamp = batch.maxTimestamp - shallowOffsetOfMaxTimestamp = batch.baseOffset + + skipRemainingBatches = skipRemainingBatches || hasHigherPartitionLeaderEpoch(batch, leaderEpochForReplica); + if (skipRemainingBatches) { + info(s"Skipping batch $batch because leader epoch is $leaderEpochForReplica") + } else { + if (batch.maxTimestamp > maxTimestamp) { + maxTimestamp = batch.maxTimestamp + shallowOffsetOfMaxTimestamp = batch.baseOffset + } + state.log.append(batch) + state.logEndOffset = batch.nextOffset + lastOffset = batch.lastOffset + lastEpoch = OptionalInt.of(batch.partitionLeaderEpoch) } - state.log.append(batch) - state.logEndOffset = batch.nextOffset - lastOffset = batch.lastOffset - lastEpoch = OptionalInt.of(batch.partitionLeaderEpoch) } state.logStartOffset = partitionData.logStartOffset @@ -114,6 +124,11 @@ class MockFetcherThread(val mockLeader: MockLeaderEndPoint, batches.headOption.map(_.lastOffset).getOrElse(-1))) } + private def hasHigherPartitionLeaderEpoch(batch: RecordBatch, leaderEpoch: Int): Boolean = { + batch.partitionLeaderEpoch() != RecordBatch.NO_PARTITION_LEADER_EPOCH && + batch.partitionLeaderEpoch() > leaderEpoch + } + override def truncate(topicPartition: TopicPartition, truncationState: OffsetTruncationState): Unit = { val state = replicaPartitionState(topicPartition) state.log = state.log.takeWhile { batch => diff --git a/core/src/test/scala/unit/kafka/server/ReplicaFetcherThreadTest.scala b/core/src/test/scala/unit/kafka/server/ReplicaFetcherThreadTest.scala index 6fc2ad4e2f1..ba399290925 100644 --- a/core/src/test/scala/unit/kafka/server/ReplicaFetcherThreadTest.scala +++ b/core/src/test/scala/unit/kafka/server/ReplicaFetcherThreadTest.scala @@ -594,9 +594,22 @@ class ReplicaFetcherThreadTest { val fetchSessionHandler = new FetchSessionHandler(logContext, brokerEndPoint.id) val leader = new RemoteLeaderEndPoint(logContext.logPrefix, mockNetwork, fetchSessionHandler, config, replicaManager, quota, () => config.interBrokerProtocolVersion, () => 1) - val thread = new ReplicaFetcherThread("bob", leader, config, failedPartitions, - replicaManager, quota, logContext.logPrefix, () => config.interBrokerProtocolVersion) { - override def processPartitionData(topicPartition: TopicPartition, fetchOffset: Long, partitionData: FetchData): Option[LogAppendInfo] = None + val thread = new ReplicaFetcherThread( + "bob", + leader, + config, + failedPartitions, + replicaManager, + quota, + logContext.logPrefix, + () => config.interBrokerProtocolVersion + ) { + override def processPartitionData( + topicPartition: TopicPartition, + fetchOffset: Long, + partitionLeaderEpoch: Int, + partitionData: FetchData + ): Option[LogAppendInfo] = None } thread.addPartitions(Map(t1p0 -> initialFetchState(Some(topicId1), initialLEO), t1p1 -> initialFetchState(Some(topicId1), initialLEO))) val partitions = Set(t1p0, t1p1) @@ -692,7 +705,7 @@ class ReplicaFetcherThreadTest { when(replicaManager.getPartitionOrException(t1p0)).thenReturn(partition) when(partition.localLogOrException).thenReturn(log) - when(partition.appendRecordsToFollowerOrFutureReplica(any(), any())).thenReturn(None) + when(partition.appendRecordsToFollowerOrFutureReplica(any(), any(), any())).thenReturn(None) val logContext = new LogContext(s"[ReplicaFetcher replicaId=${config.brokerId}, leaderId=${brokerEndPoint.id}, fetcherId=0] ") @@ -773,7 +786,7 @@ class ReplicaFetcherThreadTest { when(replicaManager.brokerTopicStats).thenReturn(mock(classOf[BrokerTopicStats])) when(partition.localLogOrException).thenReturn(log) - when(partition.appendRecordsToFollowerOrFutureReplica(any(), any())).thenReturn(Some(new LogAppendInfo( + when(partition.appendRecordsToFollowerOrFutureReplica(any(), any(), any())).thenReturn(Some(new LogAppendInfo( -1, 0, OptionalInt.empty, @@ -1310,7 +1323,7 @@ class ReplicaFetcherThreadTest { val partition: Partition = mock(classOf[Partition]) when(partition.localLogOrException).thenReturn(log) - when(partition.appendRecordsToFollowerOrFutureReplica(any[MemoryRecords], any[Boolean])).thenReturn(appendInfo) + when(partition.appendRecordsToFollowerOrFutureReplica(any[MemoryRecords], any[Boolean], any[Int])).thenReturn(appendInfo) // In Scala 2.12, the partitionsWithNewHighWatermark buffer is cleared before the replicaManager mock is verified. // Capture the argument at the time of invocation. @@ -1342,8 +1355,8 @@ class ReplicaFetcherThreadTest { .setRecords(records) .setHighWatermark(highWatermarkReceivedFromLeader) - thread.processPartitionData(tp0, 0, partitionData.setPartitionIndex(0)) - thread.processPartitionData(tp1, 0, partitionData.setPartitionIndex(1)) + thread.processPartitionData(tp0, 0, Int.MaxValue, partitionData.setPartitionIndex(0)) + thread.processPartitionData(tp1, 0, Int.MaxValue, partitionData.setPartitionIndex(1)) verify(replicaManager, times(0)).completeDelayedFetchRequests(any[Seq[TopicPartition]]) thread.doWork() @@ -1393,7 +1406,7 @@ class ReplicaFetcherThreadTest { when(partition.localLogOrException).thenReturn(log) when(partition.isReassigning).thenReturn(isReassigning) when(partition.isAddingLocalReplica).thenReturn(isReassigning) - when(partition.appendRecordsToFollowerOrFutureReplica(records, isFuture = false)).thenReturn(None) + when(partition.appendRecordsToFollowerOrFutureReplica(records, isFuture = false, Int.MaxValue)).thenReturn(None) val replicaManager: ReplicaManager = mock(classOf[ReplicaManager]) when(replicaManager.getPartitionOrException(any[TopicPartition])).thenReturn(partition) @@ -1417,7 +1430,7 @@ class ReplicaFetcherThreadTest { .setLastStableOffset(0) .setLogStartOffset(0) .setRecords(records) - thread.processPartitionData(t1p0, 0, partitionData) + thread.processPartitionData(t1p0, 0, Int.MaxValue, partitionData) if (isReassigning) assertEquals(records.sizeInBytes(), brokerTopicStats.allTopicsStats.reassignmentBytesInPerSec.get.count()) diff --git a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala index b1fffa88518..20ce9a0a0ce 100644 --- a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala +++ b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala @@ -5794,9 +5794,12 @@ class ReplicaManagerTest { replicaManager.getPartition(topicPartition) match { case HostedPartition.Online(partition) => partition.appendRecordsToFollowerOrFutureReplica( - records = MemoryRecords.withRecords(Compression.NONE, 0, - new SimpleRecord("first message".getBytes)), - isFuture = false + records = MemoryRecords.withRecords( + Compression.NONE, 0, + new SimpleRecord("first message".getBytes) + ), + isFuture = false, + partitionLeaderEpoch = 0 ) case _ => diff --git a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/fetcher/ReplicaFetcherThreadBenchmark.java b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/fetcher/ReplicaFetcherThreadBenchmark.java index 5b30ed08fdd..770f7940aa3 100644 --- a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/fetcher/ReplicaFetcherThreadBenchmark.java +++ b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/fetcher/ReplicaFetcherThreadBenchmark.java @@ -381,8 +381,12 @@ public class ReplicaFetcherThreadBenchmark { } @Override - public Option processPartitionData(TopicPartition topicPartition, long fetchOffset, - FetchResponseData.PartitionData partitionData) { + public Option processPartitionData( + TopicPartition topicPartition, + long fetchOffset, + int partitionLeaderEpoch, + FetchResponseData.PartitionData partitionData + ) { return Option.empty(); } } diff --git a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/partition/PartitionMakeFollowerBenchmark.java b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/partition/PartitionMakeFollowerBenchmark.java index 7000b99e24a..eda418294f7 100644 --- a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/partition/PartitionMakeFollowerBenchmark.java +++ b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/partition/PartitionMakeFollowerBenchmark.java @@ -137,7 +137,7 @@ public class PartitionMakeFollowerBenchmark { int initialOffSet = 0; while (true) { MemoryRecords memoryRecords = MemoryRecords.withRecords(initialOffSet, Compression.NONE, 0, simpleRecords); - partition.appendRecordsToFollowerOrFutureReplica(memoryRecords, false); + partition.appendRecordsToFollowerOrFutureReplica(memoryRecords, false, Integer.MAX_VALUE); initialOffSet = initialOffSet + 2; } }); diff --git a/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java b/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java index ffed1cfc387..624c16b008b 100644 --- a/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java +++ b/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java @@ -16,6 +16,7 @@ */ package org.apache.kafka.raft; +import org.apache.kafka.common.InvalidRecordException; import org.apache.kafka.common.KafkaException; import org.apache.kafka.common.Node; import org.apache.kafka.common.TopicPartition; @@ -23,6 +24,7 @@ import org.apache.kafka.common.Uuid; import org.apache.kafka.common.compress.Compression; import org.apache.kafka.common.config.ConfigException; import org.apache.kafka.common.errors.ClusterAuthorizationException; +import org.apache.kafka.common.errors.CorruptRecordException; import org.apache.kafka.common.errors.NotLeaderOrFollowerException; import org.apache.kafka.common.feature.SupportedVersionRange; import org.apache.kafka.common.memory.MemoryPool; @@ -50,6 +52,7 @@ import org.apache.kafka.common.network.ListenerName; import org.apache.kafka.common.protocol.ApiKeys; import org.apache.kafka.common.protocol.ApiMessage; import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.record.DefaultRecordBatch; import org.apache.kafka.common.record.MemoryRecords; import org.apache.kafka.common.record.Records; import org.apache.kafka.common.record.UnalignedMemoryRecords; @@ -93,6 +96,7 @@ import org.apache.kafka.snapshot.SnapshotWriter; import org.slf4j.Logger; import java.net.InetSocketAddress; +import java.nio.ByteBuffer; import java.util.Collection; import java.util.Collections; import java.util.IdentityHashMap; @@ -1674,10 +1678,7 @@ public final class KafkaRaftClient implements RaftClient { } } } else { - Records records = FetchResponse.recordsOrFail(partitionResponse); - if (records.sizeInBytes() > 0) { - appendAsFollower(records); - } + appendAsFollower(FetchResponse.recordsOrFail(partitionResponse)); OptionalLong highWatermark = partitionResponse.highWatermark() < 0 ? OptionalLong.empty() : OptionalLong.of(partitionResponse.highWatermark()); @@ -1691,10 +1692,36 @@ public final class KafkaRaftClient implements RaftClient { } } - private void appendAsFollower( - Records records - ) { - LogAppendInfo info = log.appendAsFollower(records); + private static String convertToHexadecimal(Records records) { + ByteBuffer buffer = ((MemoryRecords) records).buffer(); + int size = Math.min(buffer.remaining(), DefaultRecordBatch.RECORD_BATCH_OVERHEAD); + buffer.limit(buffer.position() + size); + + StringBuilder builder = new StringBuilder(); + while (buffer.hasRemaining()) { + builder.append(String.format("%02x", buffer.get())); + } + + return builder.toString(); + } + + private void appendAsFollower(Records records) { + if (records.sizeInBytes() == 0) { + // Nothing to do if there are no bytes in the response + return; + } + + try { + LogAppendInfo info = log.appendAsFollower(records, quorum.epoch()); + kafkaRaftMetrics.updateFetchedRecords(info.lastOffset - info.firstOffset + 1); + } catch (CorruptRecordException | InvalidRecordException e) { + logger.info( + "Failed to append the records with the batch header '{}' to the log", + convertToHexadecimal(records), + e + ); + } + if (quorum.isVoter() || followersAlwaysFlush) { // the leader only requires that voters have flushed their log before sending a Fetch // request. Because of reconfiguration some observers (that are getting added to the @@ -1706,14 +1733,11 @@ public final class KafkaRaftClient implements RaftClient { partitionState.updateState(); OffsetAndEpoch endOffset = endOffset(); - kafkaRaftMetrics.updateFetchedRecords(info.lastOffset - info.firstOffset + 1); kafkaRaftMetrics.updateLogEnd(endOffset); logger.trace("Follower end offset updated to {} after append", endOffset); } - private LogAppendInfo appendAsLeader( - Records records - ) { + private LogAppendInfo appendAsLeader(Records records) { LogAppendInfo info = log.appendAsLeader(records, quorum.epoch()); partitionState.updateState(); @@ -3331,6 +3355,10 @@ public final class KafkaRaftClient implements RaftClient { () -> new NotLeaderException("Append failed because the replica is not the current leader") ); + if (records.isEmpty()) { + throw new IllegalArgumentException("Append failed because there are no records"); + } + BatchAccumulator accumulator = leaderState.accumulator(); boolean isFirstAppend = accumulator.isEmpty(); final long offset = accumulator.append(epoch, records, true); diff --git a/raft/src/main/java/org/apache/kafka/raft/ReplicatedLog.java b/raft/src/main/java/org/apache/kafka/raft/ReplicatedLog.java index a22f7fd73cd..8f5ba31a45d 100644 --- a/raft/src/main/java/org/apache/kafka/raft/ReplicatedLog.java +++ b/raft/src/main/java/org/apache/kafka/raft/ReplicatedLog.java @@ -31,6 +31,8 @@ public interface ReplicatedLog extends AutoCloseable { * be written atomically in a single batch or the call will fail and raise an * exception. * + * @param records records batches to append + * @param epoch the epoch of the replica * @return the metadata information of the appended batch * @throws IllegalArgumentException if the record set is empty * @throws RuntimeException if the batch base offset doesn't match the log end offset @@ -42,11 +44,16 @@ public interface ReplicatedLog extends AutoCloseable { * difference from appendAsLeader is that we do not need to assign the epoch * or do additional validation. * + * The log will append record batches up to and including batches that have a partition + * leader epoch less than or equal to the passed epoch. + * + * @param records records batches to append + * @param epoch the epoch of the replica * @return the metadata information of the appended batch * @throws IllegalArgumentException if the record set is empty * @throws RuntimeException if the batch base offset doesn't match the log end offset */ - LogAppendInfo appendAsFollower(Records records); + LogAppendInfo appendAsFollower(Records records, int epoch); /** * Read a set of records within a range of offsets. diff --git a/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientFetchTest.java b/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientFetchTest.java new file mode 100644 index 00000000000..de420a34c12 --- /dev/null +++ b/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientFetchTest.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import org.apache.kafka.common.compress.Compression; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.record.ArbitraryMemoryRecords; +import org.apache.kafka.common.record.InvalidMemoryRecordsProvider; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.SimpleRecord; + +import net.jqwik.api.AfterFailureMode; +import net.jqwik.api.ForAll; +import net.jqwik.api.Property; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ArgumentsSource; + +import java.nio.ByteBuffer; +import java.util.Optional; +import java.util.stream.IntStream; +import java.util.stream.Stream; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public final class KafkaRaftClientFetchTest { + @Property(tries = 100, afterFailure = AfterFailureMode.SAMPLE_ONLY) + void testRandomRecords( + @ForAll(supplier = ArbitraryMemoryRecords.class) MemoryRecords memoryRecords + ) throws Exception { + testFetchResponseWithInvalidRecord(memoryRecords, Integer.MAX_VALUE); + } + + @ParameterizedTest + @ArgumentsSource(InvalidMemoryRecordsProvider.class) + void testInvalidMemoryRecords(MemoryRecords records, Optional> expectedException) throws Exception { + // CorruptRecordException are handled by the KafkaRaftClient so ignore the expected exception + testFetchResponseWithInvalidRecord(records, Integer.MAX_VALUE); + } + + private static void testFetchResponseWithInvalidRecord(MemoryRecords records, int epoch) throws Exception { + int localId = KafkaRaftClientTest.randomReplicaId(); + ReplicaKey local = KafkaRaftClientTest.replicaKey(localId, true); + ReplicaKey electedLeader = KafkaRaftClientTest.replicaKey(localId + 1, true); + + RaftClientTestContext context = new RaftClientTestContext.Builder( + local.id(), + local.directoryId().get() + ) + .withBootstrapSnapshot( + Optional.of(VoterSetTest.voterSet(Stream.of(local, electedLeader))) + ) + .withElectedLeader(epoch, electedLeader.id()) + .withKip853Rpc(true) + .build(); + + context.pollUntilRequest(); + RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(); + context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); + + long oldLogEndOffset = context.log.endOffset().offset(); + + context.deliverResponse( + fetchRequest.correlationId(), + fetchRequest.destination(), + context.fetchResponse(epoch, electedLeader.id(), records, 0L, Errors.NONE) + ); + + context.client.poll(); + + assertEquals(oldLogEndOffset, context.log.endOffset().offset()); + } + + @Test + void testReplicationOfHigherPartitionLeaderEpoch() throws Exception { + int epoch = 2; + int localId = KafkaRaftClientTest.randomReplicaId(); + ReplicaKey local = KafkaRaftClientTest.replicaKey(localId, true); + ReplicaKey electedLeader = KafkaRaftClientTest.replicaKey(localId + 1, true); + + RaftClientTestContext context = new RaftClientTestContext.Builder( + local.id(), + local.directoryId().get() + ) + .withBootstrapSnapshot( + Optional.of(VoterSetTest.voterSet(Stream.of(local, electedLeader))) + ) + .withElectedLeader(epoch, electedLeader.id()) + .withKip853Rpc(true) + .build(); + + context.pollUntilRequest(); + RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(); + context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); + + long oldLogEndOffset = context.log.endOffset().offset(); + int numberOfRecords = 10; + MemoryRecords batchWithValidEpoch = MemoryRecords.withRecords( + oldLogEndOffset, + Compression.NONE, + epoch, + IntStream + .range(0, numberOfRecords) + .mapToObj(number -> new SimpleRecord(Integer.toString(number).getBytes())) + .toArray(SimpleRecord[]::new) + ); + + MemoryRecords batchWithInvalidEpoch = MemoryRecords.withRecords( + oldLogEndOffset + numberOfRecords, + Compression.NONE, + epoch + 1, + IntStream + .range(0, numberOfRecords) + .mapToObj(number -> new SimpleRecord(Integer.toString(number).getBytes())) + .toArray(SimpleRecord[]::new) + ); + + ByteBuffer buffer = ByteBuffer.allocate(batchWithValidEpoch.sizeInBytes() + batchWithInvalidEpoch.sizeInBytes()); + buffer.put(batchWithValidEpoch.buffer()); + buffer.put(batchWithInvalidEpoch.buffer()); + buffer.flip(); + + MemoryRecords records = MemoryRecords.readableRecords(buffer); + + context.deliverResponse( + fetchRequest.correlationId(), + fetchRequest.destination(), + context.fetchResponse(epoch, electedLeader.id(), records, 0L, Errors.NONE) + ); + + context.client.poll(); + + // Check that only the first batch was appended because the second batch has a greater epoch + assertEquals(oldLogEndOffset + numberOfRecords, context.log.endOffset().offset()); + } +} diff --git a/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientTest.java b/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientTest.java index c1c5857148c..248fc8cf564 100644 --- a/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientTest.java +++ b/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientTest.java @@ -4495,7 +4495,7 @@ public class KafkaRaftClientTest { return ReplicaKey.of(id, directoryId); } - private static int randomReplicaId() { + public static int randomReplicaId() { return ThreadLocalRandom.current().nextInt(1025); } } diff --git a/raft/src/test/java/org/apache/kafka/raft/MockLog.java b/raft/src/test/java/org/apache/kafka/raft/MockLog.java index 29281fa633f..00d5f8e5088 100644 --- a/raft/src/test/java/org/apache/kafka/raft/MockLog.java +++ b/raft/src/test/java/org/apache/kafka/raft/MockLog.java @@ -19,6 +19,7 @@ package org.apache.kafka.raft; import org.apache.kafka.common.TopicPartition; import org.apache.kafka.common.Uuid; import org.apache.kafka.common.compress.Compression; +import org.apache.kafka.common.errors.CorruptRecordException; import org.apache.kafka.common.errors.OffsetOutOfRangeException; import org.apache.kafka.common.record.MemoryRecords; import org.apache.kafka.common.record.MemoryRecordsBuilder; @@ -279,7 +280,7 @@ public class MockLog implements ReplicatedLog { @Override public LogAppendInfo appendAsLeader(Records records, int epoch) { - return append(records, OptionalInt.of(epoch)); + return append(records, epoch, true); } private long appendBatch(LogBatch batch) { @@ -292,16 +293,18 @@ public class MockLog implements ReplicatedLog { } @Override - public LogAppendInfo appendAsFollower(Records records) { - return append(records, OptionalInt.empty()); + public LogAppendInfo appendAsFollower(Records records, int epoch) { + return append(records, epoch, false); } - private LogAppendInfo append(Records records, OptionalInt epoch) { - if (records.sizeInBytes() == 0) + private LogAppendInfo append(Records records, int epoch, boolean isLeader) { + if (records.sizeInBytes() == 0) { throw new IllegalArgumentException("Attempt to append an empty record set"); + } long baseOffset = endOffset().offset(); long lastOffset = baseOffset; + boolean hasBatches = false; for (RecordBatch batch : records.batches()) { if (batch.baseOffset() != endOffset().offset()) { /* KafkaMetadataLog throws an kafka.common.UnexpectedAppendOffsetException this is the @@ -314,26 +317,47 @@ public class MockLog implements ReplicatedLog { endOffset().offset() ) ); + } else if (isLeader && epoch != batch.partitionLeaderEpoch()) { + // the partition leader epoch is set and does not match the one set in the batch + throw new RuntimeException( + String.format( + "Epoch %s doesn't match batch leader epoch %s", + epoch, + batch.partitionLeaderEpoch() + ) + ); + } else if (!isLeader && batch.partitionLeaderEpoch() > epoch) { + /* To avoid inconsistent log replication, follower should only append record + * batches with an epoch less than or equal to the leader epoch. There is more + * details on this issue and scenario in KAFKA-18723. + */ + break; } + hasBatches = true; LogBatch logBatch = new LogBatch( - epoch.orElseGet(batch::partitionLeaderEpoch), + batch.partitionLeaderEpoch(), batch.isControlBatch(), buildEntries(batch, Record::offset) ); if (logger.isDebugEnabled()) { - String nodeState = "Follower"; - if (epoch.isPresent()) { - nodeState = "Leader"; - } - logger.debug("{} appending to the log {}", nodeState, logBatch); + logger.debug( + "{} appending to the log {}", + isLeader ? "Leader" : "Follower", + logBatch + ); } appendBatch(logBatch); lastOffset = logBatch.last().offset; } + if (!hasBatches) { + // This emulates the default handling when records doesn't have enough bytes for a batch + throw new CorruptRecordException("Append failed unexpectedly"); + } + return new LogAppendInfo(baseOffset, lastOffset); } diff --git a/raft/src/test/java/org/apache/kafka/raft/MockLogTest.java b/raft/src/test/java/org/apache/kafka/raft/MockLogTest.java index 08e19866d9b..eef35268c7d 100644 --- a/raft/src/test/java/org/apache/kafka/raft/MockLogTest.java +++ b/raft/src/test/java/org/apache/kafka/raft/MockLogTest.java @@ -19,9 +19,12 @@ package org.apache.kafka.raft; import org.apache.kafka.common.TopicPartition; import org.apache.kafka.common.Uuid; import org.apache.kafka.common.compress.Compression; +import org.apache.kafka.common.errors.CorruptRecordException; import org.apache.kafka.common.errors.OffsetOutOfRangeException; import org.apache.kafka.common.message.LeaderChangeMessage; +import org.apache.kafka.common.record.ArbitraryMemoryRecords; import org.apache.kafka.common.record.ControlRecordUtils; +import org.apache.kafka.common.record.InvalidMemoryRecordsProvider; import org.apache.kafka.common.record.MemoryRecords; import org.apache.kafka.common.record.Record; import org.apache.kafka.common.record.RecordBatch; @@ -32,9 +35,16 @@ import org.apache.kafka.common.utils.Utils; import org.apache.kafka.snapshot.RawSnapshotReader; import org.apache.kafka.snapshot.RawSnapshotWriter; +import net.jqwik.api.AfterFailureMode; +import net.jqwik.api.ForAll; +import net.jqwik.api.Property; + import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ArgumentsSource; import java.nio.ByteBuffer; import java.util.ArrayList; @@ -44,6 +54,7 @@ import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.stream.IntStream; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; @@ -169,14 +180,17 @@ public class MockLogTest { assertThrows( RuntimeException.class, () -> log.appendAsLeader( - MemoryRecords.withRecords(initialOffset, Compression.NONE, currentEpoch, recordFoo), - currentEpoch) + MemoryRecords.withRecords(initialOffset, Compression.NONE, currentEpoch, recordFoo), + currentEpoch + ) ); assertThrows( RuntimeException.class, () -> log.appendAsFollower( - MemoryRecords.withRecords(initialOffset, Compression.NONE, currentEpoch, recordFoo)) + MemoryRecords.withRecords(initialOffset, Compression.NONE, currentEpoch, recordFoo), + currentEpoch + ) ); } @@ -187,7 +201,13 @@ public class MockLogTest { LeaderChangeMessage messageData = new LeaderChangeMessage().setLeaderId(0); ByteBuffer buffer = ByteBuffer.allocate(256); log.appendAsLeader( - MemoryRecords.withLeaderChangeMessage(initialOffset, 0L, 2, buffer, messageData), + MemoryRecords.withLeaderChangeMessage( + initialOffset, + 0L, + currentEpoch, + buffer, + messageData + ), currentEpoch ); @@ -221,7 +241,10 @@ public class MockLogTest { } log.truncateToLatestSnapshot(); - log.appendAsFollower(MemoryRecords.withRecords(initialOffset, Compression.NONE, epoch, recordFoo)); + log.appendAsFollower( + MemoryRecords.withRecords(initialOffset, Compression.NONE, epoch, recordFoo), + epoch + ); assertEquals(initialOffset, log.startOffset()); assertEquals(initialOffset + 1, log.endOffset().offset()); @@ -368,10 +391,82 @@ public class MockLogTest { @Test public void testEmptyAppendNotAllowed() { - assertThrows(IllegalArgumentException.class, () -> log.appendAsFollower(MemoryRecords.EMPTY)); + assertThrows(IllegalArgumentException.class, () -> log.appendAsFollower(MemoryRecords.EMPTY, 1)); assertThrows(IllegalArgumentException.class, () -> log.appendAsLeader(MemoryRecords.EMPTY, 1)); } + @ParameterizedTest + @ArgumentsSource(InvalidMemoryRecordsProvider.class) + void testInvalidMemoryRecords(MemoryRecords records, Optional> expectedException) { + long previousEndOffset = log.endOffset().offset(); + + Executable action = () -> log.appendAsFollower(records, Integer.MAX_VALUE); + if (expectedException.isPresent()) { + assertThrows(expectedException.get(), action); + } else { + assertThrows(CorruptRecordException.class, action); + } + + assertEquals(previousEndOffset, log.endOffset().offset()); + } + + @Property(tries = 100, afterFailure = AfterFailureMode.SAMPLE_ONLY) + void testRandomRecords( + @ForAll(supplier = ArbitraryMemoryRecords.class) MemoryRecords records + ) { + try (MockLog log = new MockLog(topicPartition, topicId, new LogContext())) { + long previousEndOffset = log.endOffset().offset(); + + assertThrows( + CorruptRecordException.class, + () -> log.appendAsFollower(records, Integer.MAX_VALUE) + ); + + assertEquals(previousEndOffset, log.endOffset().offset()); + } + } + + @Test + void testInvalidLeaderEpoch() { + long previousEndOffset = log.endOffset().offset(); + int epoch = log.lastFetchedEpoch() + 1; + int numberOfRecords = 10; + + MemoryRecords batchWithValidEpoch = MemoryRecords.withRecords( + previousEndOffset, + Compression.NONE, + epoch, + IntStream + .range(0, numberOfRecords) + .mapToObj(number -> new SimpleRecord(Integer.toString(number).getBytes())) + .toArray(SimpleRecord[]::new) + ); + + MemoryRecords batchWithInvalidEpoch = MemoryRecords.withRecords( + previousEndOffset + numberOfRecords, + Compression.NONE, + epoch + 1, + IntStream + .range(0, numberOfRecords) + .mapToObj(number -> new SimpleRecord(Integer.toString(number).getBytes())) + .toArray(SimpleRecord[]::new) + ); + + ByteBuffer buffer = ByteBuffer.allocate(batchWithValidEpoch.sizeInBytes() + batchWithInvalidEpoch.sizeInBytes()); + buffer.put(batchWithValidEpoch.buffer()); + buffer.put(batchWithInvalidEpoch.buffer()); + buffer.flip(); + + MemoryRecords records = MemoryRecords.readableRecords(buffer); + + log.appendAsFollower(records, epoch); + + // Check that only the first batch was appended + assertEquals(previousEndOffset + numberOfRecords, log.endOffset().offset()); + // Check that the last fetched epoch matches the first batch + assertEquals(epoch, log.lastFetchedEpoch()); + } + @Test public void testReadOutOfRangeOffset() { final long initialOffset = 5L; @@ -383,12 +478,19 @@ public class MockLogTest { } log.truncateToLatestSnapshot(); - log.appendAsFollower(MemoryRecords.withRecords(initialOffset, Compression.NONE, epoch, recordFoo)); + log.appendAsFollower( + MemoryRecords.withRecords(initialOffset, Compression.NONE, epoch, recordFoo), + epoch + ); - assertThrows(OffsetOutOfRangeException.class, () -> log.read(log.startOffset() - 1, - Isolation.UNCOMMITTED)); - assertThrows(OffsetOutOfRangeException.class, () -> log.read(log.endOffset().offset() + 1, - Isolation.UNCOMMITTED)); + assertThrows( + OffsetOutOfRangeException.class, + () -> log.read(log.startOffset() - 1, Isolation.UNCOMMITTED) + ); + assertThrows( + OffsetOutOfRangeException.class, + () -> log.read(log.endOffset().offset() + 1, Isolation.UNCOMMITTED) + ); } @Test @@ -948,6 +1050,7 @@ public class MockLogTest { MemoryRecords.withRecords( log.endOffset().offset(), Compression.NONE, + epoch, records.toArray(new SimpleRecord[records.size()]) ), epoch