diff --git a/build.gradle b/build.gradle index a4bed55f1d1..9d04500a286 100644 --- a/build.gradle +++ b/build.gradle @@ -1037,6 +1037,7 @@ project(':core') { testImplementation project(':test-common:test-common-util') testImplementation libs.bcpkix testImplementation libs.mockitoCore + testImplementation libs.jqwik testImplementation(libs.apacheda) { exclude group: 'xml-apis', module: 'xml-apis' // `mina-core` is a transitive dependency for `apacheds` and `apacheda`. @@ -1231,6 +1232,12 @@ project(':core') { ) } + test { + useJUnitPlatform { + includeEngines 'jqwik', 'junit-jupiter' + } + } + tasks.create(name: "copyDependantTestLibs", type: Copy) { from (configurations.testRuntimeClasspath) { include('*.jar') @@ -1802,6 +1809,7 @@ project(':clients') { testImplementation libs.jacksonJakartarsJsonProvider testImplementation libs.jose4j testImplementation libs.junitJupiter + testImplementation libs.jqwik testImplementation libs.spotbugs testImplementation libs.mockitoCore testImplementation libs.mockitoJunitJupiter // supports MockitoExtension diff --git a/clients/src/main/java/org/apache/kafka/common/record/DefaultRecordBatch.java b/clients/src/main/java/org/apache/kafka/common/record/DefaultRecordBatch.java index 912c3490f43..d6e9cc6bd7f 100644 --- a/clients/src/main/java/org/apache/kafka/common/record/DefaultRecordBatch.java +++ b/clients/src/main/java/org/apache/kafka/common/record/DefaultRecordBatch.java @@ -159,7 +159,7 @@ public class DefaultRecordBatch extends AbstractRecordBatch implements MutableRe /** * Gets the base timestamp of the batch which is used to calculate the record timestamps from the deltas. - * + * * @return The base timestamp */ public long baseTimestamp() { @@ -502,6 +502,7 @@ public class DefaultRecordBatch extends AbstractRecordBatch implements MutableRe public String toString() { return "RecordBatch(magic=" + magic() + ", offsets=[" + baseOffset() + ", " + lastOffset() + "], " + "sequence=[" + baseSequence() + ", " + lastSequence() + "], " + + "partitionLeaderEpoch=" + partitionLeaderEpoch() + ", " + "isTransactional=" + isTransactional() + ", isControlBatch=" + isControlBatch() + ", " + "compression=" + compressionType() + ", timestampType=" + timestampType() + ", crc=" + checksum() + ")"; } diff --git a/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java b/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java index 1aad97d5920..c06188edf22 100644 --- a/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java +++ b/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java @@ -32,9 +32,6 @@ import org.apache.kafka.common.utils.ByteBufferOutputStream; import org.apache.kafka.common.utils.CloseableIterator; import org.apache.kafka.common.utils.Utils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.GatheringByteChannel; @@ -49,7 +46,6 @@ import java.util.Objects; * or one of the {@link #builder(ByteBuffer, byte, Compression, TimestampType, long)} variants. */ public class MemoryRecords extends AbstractRecords { - private static final Logger log = LoggerFactory.getLogger(MemoryRecords.class); public static final MemoryRecords EMPTY = MemoryRecords.readableRecords(ByteBuffer.allocate(0)); private final ByteBuffer buffer; @@ -596,7 +592,7 @@ public class MemoryRecords extends AbstractRecords { return withRecords(magic, initialOffset, compression, TimestampType.CREATE_TIME, records); } - public static MemoryRecords withRecords(long initialOffset, Compression compression, Integer partitionLeaderEpoch, SimpleRecord... records) { + public static MemoryRecords withRecords(long initialOffset, Compression compression, int partitionLeaderEpoch, SimpleRecord... records) { return withRecords(RecordBatch.CURRENT_MAGIC_VALUE, initialOffset, compression, TimestampType.CREATE_TIME, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, RecordBatch.NO_SEQUENCE, partitionLeaderEpoch, false, records); } diff --git a/clients/src/test/java/org/apache/kafka/common/record/ArbitraryMemoryRecords.java b/clients/src/test/java/org/apache/kafka/common/record/ArbitraryMemoryRecords.java new file mode 100644 index 00000000000..30eec866a6c --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/record/ArbitraryMemoryRecords.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import net.jqwik.api.Arbitraries; +import net.jqwik.api.Arbitrary; +import net.jqwik.api.ArbitrarySupplier; + +import java.nio.ByteBuffer; +import java.util.Random; + +public final class ArbitraryMemoryRecords implements ArbitrarySupplier { + @Override + public Arbitrary get() { + return Arbitraries.randomValue(ArbitraryMemoryRecords::buildRandomRecords); + } + + private static MemoryRecords buildRandomRecords(Random random) { + int size = random.nextInt(128) + 1; + byte[] bytes = new byte[size]; + random.nextBytes(bytes); + + return MemoryRecords.readableRecords(ByteBuffer.wrap(bytes)); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/record/InvalidMemoryRecordsProvider.java b/clients/src/test/java/org/apache/kafka/common/record/InvalidMemoryRecordsProvider.java new file mode 100644 index 00000000000..0f9446a6391 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/record/InvalidMemoryRecordsProvider.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.errors.CorruptRecordException; + +import org.junit.jupiter.api.extension.ExtensionContext; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.ArgumentsProvider; + +import java.nio.ByteBuffer; +import java.util.Optional; +import java.util.stream.Stream; + +public final class InvalidMemoryRecordsProvider implements ArgumentsProvider { + // Use a baseOffset that's not zero so that it is less likely to match the LEO + private static final long BASE_OFFSET = 1234; + private static final int EPOCH = 4321; + + /** + * Returns a stream of arguments for invalid memory records and the expected exception. + * + * The first object in the {@code Arguments} is a {@code MemoryRecords}. + * + * The second object in the {@code Arguments} is an {@code Optional>} which is + * the expected exception from the log layer. + */ + @Override + public Stream provideArguments(ExtensionContext context) { + return Stream.of( + Arguments.of(MemoryRecords.readableRecords(notEnoughBytes()), Optional.empty()), + Arguments.of(MemoryRecords.readableRecords(recordsSizeTooSmall()), Optional.of(CorruptRecordException.class)), + Arguments.of(MemoryRecords.readableRecords(notEnoughBytesToMagic()), Optional.empty()), + Arguments.of(MemoryRecords.readableRecords(negativeMagic()), Optional.of(CorruptRecordException.class)), + Arguments.of(MemoryRecords.readableRecords(largeMagic()), Optional.of(CorruptRecordException.class)), + Arguments.of(MemoryRecords.readableRecords(lessBytesThanRecordSize()), Optional.empty()) + ); + } + + private static ByteBuffer notEnoughBytes() { + var buffer = ByteBuffer.allocate(Records.LOG_OVERHEAD - 1); + buffer.limit(buffer.capacity()); + + return buffer; + } + + private static ByteBuffer recordsSizeTooSmall() { + var buffer = ByteBuffer.allocate(256); + // Write the base offset + buffer.putLong(BASE_OFFSET); + // Write record size + buffer.putInt(LegacyRecord.RECORD_OVERHEAD_V0 - 1); + buffer.position(0); + buffer.limit(buffer.capacity()); + + return buffer; + } + + private static ByteBuffer notEnoughBytesToMagic() { + var buffer = ByteBuffer.allocate(256); + // Write the base offset + buffer.putLong(BASE_OFFSET); + // Write record size + buffer.putInt(buffer.capacity() - Records.LOG_OVERHEAD); + buffer.position(0); + buffer.limit(Records.HEADER_SIZE_UP_TO_MAGIC - 1); + + return buffer; + } + + private static ByteBuffer negativeMagic() { + var buffer = ByteBuffer.allocate(256); + // Write the base offset + buffer.putLong(BASE_OFFSET); + // Write record size + buffer.putInt(buffer.capacity() - Records.LOG_OVERHEAD); + // Write the epoch + buffer.putInt(EPOCH); + // Write magic + buffer.put((byte) -1); + buffer.position(0); + buffer.limit(buffer.capacity()); + + return buffer; + } + + private static ByteBuffer largeMagic() { + var buffer = ByteBuffer.allocate(256); + // Write the base offset + buffer.putLong(BASE_OFFSET); + // Write record size + buffer.putInt(buffer.capacity() - Records.LOG_OVERHEAD); + // Write the epoch + buffer.putInt(EPOCH); + // Write magic + buffer.put((byte) (RecordBatch.CURRENT_MAGIC_VALUE + 1)); + buffer.position(0); + buffer.limit(buffer.capacity()); + + return buffer; + } + + private static ByteBuffer lessBytesThanRecordSize() { + var buffer = ByteBuffer.allocate(256); + // Write the base offset + buffer.putLong(BASE_OFFSET); + // Write record size + buffer.putInt(buffer.capacity() - Records.LOG_OVERHEAD); + // Write the epoch + buffer.putInt(EPOCH); + // Write magic + buffer.put(RecordBatch.CURRENT_MAGIC_VALUE); + buffer.position(0); + buffer.limit(buffer.capacity() - Records.LOG_OVERHEAD - 1); + + return buffer; + } +} diff --git a/core/src/main/scala/kafka/cluster/Partition.scala b/core/src/main/scala/kafka/cluster/Partition.scala index 2a10afb3b5d..b2394cbb944 100755 --- a/core/src/main/scala/kafka/cluster/Partition.scala +++ b/core/src/main/scala/kafka/cluster/Partition.scala @@ -1302,27 +1302,35 @@ class Partition(val topicPartition: TopicPartition, } } - private def doAppendRecordsToFollowerOrFutureReplica(records: MemoryRecords, isFuture: Boolean): Option[LogAppendInfo] = { + private def doAppendRecordsToFollowerOrFutureReplica( + records: MemoryRecords, + isFuture: Boolean, + partitionLeaderEpoch: Int + ): Option[LogAppendInfo] = { if (isFuture) { // The read lock is needed to handle race condition if request handler thread tries to // remove future replica after receiving AlterReplicaLogDirsRequest. inReadLock(leaderIsrUpdateLock) { // Note the replica may be undefined if it is removed by a non-ReplicaAlterLogDirsThread before // this method is called - futureLog.map { _.appendAsFollower(records) } + futureLog.map { _.appendAsFollower(records, partitionLeaderEpoch) } } } else { // The lock is needed to prevent the follower replica from being updated while ReplicaAlterDirThread // is executing maybeReplaceCurrentWithFutureReplica() to replace follower replica with the future replica. futureLogLock.synchronized { - Some(localLogOrException.appendAsFollower(records)) + Some(localLogOrException.appendAsFollower(records, partitionLeaderEpoch)) } } } - def appendRecordsToFollowerOrFutureReplica(records: MemoryRecords, isFuture: Boolean): Option[LogAppendInfo] = { + def appendRecordsToFollowerOrFutureReplica( + records: MemoryRecords, + isFuture: Boolean, + partitionLeaderEpoch: Int + ): Option[LogAppendInfo] = { try { - doAppendRecordsToFollowerOrFutureReplica(records, isFuture) + doAppendRecordsToFollowerOrFutureReplica(records, isFuture, partitionLeaderEpoch) } catch { case e: UnexpectedAppendOffsetException => val log = if (isFuture) futureLocalLogOrException else localLogOrException @@ -1340,7 +1348,7 @@ class Partition(val topicPartition: TopicPartition, info(s"Unexpected offset in append to $topicPartition. First offset ${e.firstOffset} is less than log start offset ${log.logStartOffset}." + s" Since this is the first record to be appended to the $replicaName's log, will start the log from offset ${e.firstOffset}.") truncateFullyAndStartAt(e.firstOffset, isFuture) - doAppendRecordsToFollowerOrFutureReplica(records, isFuture) + doAppendRecordsToFollowerOrFutureReplica(records, isFuture, partitionLeaderEpoch) } else throw e } diff --git a/core/src/main/scala/kafka/log/UnifiedLog.scala b/core/src/main/scala/kafka/log/UnifiedLog.scala index b3c447faac4..fbacbe0af1a 100644 --- a/core/src/main/scala/kafka/log/UnifiedLog.scala +++ b/core/src/main/scala/kafka/log/UnifiedLog.scala @@ -669,6 +669,7 @@ class UnifiedLog(@volatile var logStartOffset: Long, * Append this message set to the active segment of the local log, assigning offsets and Partition Leader Epochs * * @param records The records to append + * @param leaderEpoch the epoch of the replica appending * @param origin Declares the origin of the append which affects required validations * @param requestLocal request local instance * @throws KafkaStorageException If the append fails due to an I/O error. @@ -699,14 +700,15 @@ class UnifiedLog(@volatile var logStartOffset: Long, * Append this message set to the active segment of the local log without assigning offsets or Partition Leader Epochs * * @param records The records to append + * @param leaderEpoch the epoch of the replica appending * @throws KafkaStorageException If the append fails due to an I/O error. * @return Information about the appended messages including the first and last offset. */ - def appendAsFollower(records: MemoryRecords): LogAppendInfo = { + def appendAsFollower(records: MemoryRecords, leaderEpoch: Int): LogAppendInfo = { append(records, origin = AppendOrigin.REPLICATION, validateAndAssignOffsets = false, - leaderEpoch = -1, + leaderEpoch = leaderEpoch, requestLocal = None, verificationGuard = VerificationGuard.SENTINEL, // disable to check the validation of record size since the record is already accepted by leader. @@ -1085,63 +1087,85 @@ class UnifiedLog(@volatile var logStartOffset: Long, var shallowOffsetOfMaxTimestamp = -1L var readFirstMessage = false var lastOffsetOfFirstBatch = -1L + var skipRemainingBatches = false records.batches.forEach { batch => if (origin == AppendOrigin.RAFT_LEADER && batch.partitionLeaderEpoch != leaderEpoch) { - throw new InvalidRecordException("Append from Raft leader did not set the batch epoch correctly") + throw new InvalidRecordException( + s"Append from Raft leader did not set the batch epoch correctly, expected $leaderEpoch " + + s"but the batch has ${batch.partitionLeaderEpoch}" + ) } // we only validate V2 and higher to avoid potential compatibility issues with older clients - if (batch.magic >= RecordBatch.MAGIC_VALUE_V2 && origin == AppendOrigin.CLIENT && batch.baseOffset != 0) + if (batch.magic >= RecordBatch.MAGIC_VALUE_V2 && origin == AppendOrigin.CLIENT && batch.baseOffset != 0) { throw new InvalidRecordException(s"The baseOffset of the record batch in the append to $topicPartition should " + s"be 0, but it is ${batch.baseOffset}") - - // update the first offset if on the first message. For magic versions older than 2, we use the last offset - // to avoid the need to decompress the data (the last offset can be obtained directly from the wrapper message). - // For magic version 2, we can get the first offset directly from the batch header. - // When appending to the leader, we will update LogAppendInfo.baseOffset with the correct value. In the follower - // case, validation will be more lenient. - // Also indicate whether we have the accurate first offset or not - if (!readFirstMessage) { - if (batch.magic >= RecordBatch.MAGIC_VALUE_V2) - firstOffset = batch.baseOffset - lastOffsetOfFirstBatch = batch.lastOffset - readFirstMessage = true } - // check that offsets are monotonically increasing - if (lastOffset >= batch.lastOffset) - monotonic = false + /* During replication of uncommitted data it is possible for the remote replica to send record batches after it lost + * leadership. This can happen if sending FETCH responses is slow. There is a race between sending the FETCH + * response and the replica truncating and appending to the log. The replicating replica resolves this issue by only + * persisting up to the current leader epoch used in the fetch request. See KAFKA-18723 for more details. + */ + skipRemainingBatches = skipRemainingBatches || hasHigherPartitionLeaderEpoch(batch, origin, leaderEpoch) + if (skipRemainingBatches) { + info( + s"Skipping batch $batch from an origin of $origin because its partition leader epoch " + + s"${batch.partitionLeaderEpoch} is higher than the replica's current leader epoch " + + s"$leaderEpoch" + ) + } else { + // update the first offset if on the first message. For magic versions older than 2, we use the last offset + // to avoid the need to decompress the data (the last offset can be obtained directly from the wrapper message). + // For magic version 2, we can get the first offset directly from the batch header. + // When appending to the leader, we will update LogAppendInfo.baseOffset with the correct value. In the follower + // case, validation will be more lenient. + // Also indicate whether we have the accurate first offset or not + if (!readFirstMessage) { + if (batch.magic >= RecordBatch.MAGIC_VALUE_V2) { + firstOffset = batch.baseOffset + } + lastOffsetOfFirstBatch = batch.lastOffset + readFirstMessage = true + } - // update the last offset seen - lastOffset = batch.lastOffset - lastLeaderEpoch = batch.partitionLeaderEpoch + // check that offsets are monotonically increasing + if (lastOffset >= batch.lastOffset) { + monotonic = false + } - // Check if the message sizes are valid. - val batchSize = batch.sizeInBytes - if (!ignoreRecordSize && batchSize > config.maxMessageSize) { - brokerTopicStats.topicStats(topicPartition.topic).bytesRejectedRate.mark(records.sizeInBytes) - brokerTopicStats.allTopicsStats.bytesRejectedRate.mark(records.sizeInBytes) - throw new RecordTooLargeException(s"The record batch size in the append to $topicPartition is $batchSize bytes " + - s"which exceeds the maximum configured value of ${config.maxMessageSize}.") + // update the last offset seen + lastOffset = batch.lastOffset + lastLeaderEpoch = batch.partitionLeaderEpoch + + // Check if the message sizes are valid. + val batchSize = batch.sizeInBytes + if (!ignoreRecordSize && batchSize > config.maxMessageSize) { + brokerTopicStats.topicStats(topicPartition.topic).bytesRejectedRate.mark(records.sizeInBytes) + brokerTopicStats.allTopicsStats.bytesRejectedRate.mark(records.sizeInBytes) + throw new RecordTooLargeException(s"The record batch size in the append to $topicPartition is $batchSize bytes " + + s"which exceeds the maximum configured value of ${config.maxMessageSize}.") + } + + // check the validity of the message by checking CRC + if (!batch.isValid) { + brokerTopicStats.allTopicsStats.invalidMessageCrcRecordsPerSec.mark() + throw new CorruptRecordException(s"Record is corrupt (stored crc = ${batch.checksum()}) in topic partition $topicPartition.") + } + + if (batch.maxTimestamp > maxTimestamp) { + maxTimestamp = batch.maxTimestamp + shallowOffsetOfMaxTimestamp = lastOffset + } + + validBytesCount += batchSize + + val batchCompression = CompressionType.forId(batch.compressionType.id) + // sourceCompression is only used on the leader path, which only contains one batch if version is v2 or messages are compressed + if (batchCompression != CompressionType.NONE) { + sourceCompression = batchCompression + } } - - // check the validity of the message by checking CRC - if (!batch.isValid) { - brokerTopicStats.allTopicsStats.invalidMessageCrcRecordsPerSec.mark() - throw new CorruptRecordException(s"Record is corrupt (stored crc = ${batch.checksum()}) in topic partition $topicPartition.") - } - - if (batch.maxTimestamp > maxTimestamp) { - maxTimestamp = batch.maxTimestamp - shallowOffsetOfMaxTimestamp = lastOffset - } - - validBytesCount += batchSize - - val batchCompression = CompressionType.forId(batch.compressionType.id) - // sourceCompression is only used on the leader path, which only contains one batch if version is v2 or messages are compressed - if (batchCompression != CompressionType.NONE) - sourceCompression = batchCompression } if (requireOffsetsMonotonic && !monotonic) @@ -1158,6 +1182,25 @@ class UnifiedLog(@volatile var logStartOffset: Long, validBytesCount, lastOffsetOfFirstBatch, Collections.emptyList[RecordError], LeaderHwChange.NONE) } + /** + * Return true if the record batch has a higher leader epoch than the specified leader epoch + * + * @param batch the batch to validate + * @param origin the reason for appending the record batch + * @param leaderEpoch the epoch to compare + * @return true if the append reason is replication and the batch's partition leader epoch is + * greater than the specified leaderEpoch, otherwise false + */ + private def hasHigherPartitionLeaderEpoch( + batch: RecordBatch, + origin: AppendOrigin, + leaderEpoch: Int + ): Boolean = { + origin == AppendOrigin.REPLICATION && + batch.partitionLeaderEpoch() != RecordBatch.NO_PARTITION_LEADER_EPOCH && + batch.partitionLeaderEpoch() > leaderEpoch + } + /** * Trim any invalid bytes from the end of this message set (if there are any) * @@ -1295,7 +1338,7 @@ class UnifiedLog(@volatile var logStartOffset: Long, val asyncOffsetReadFutureHolder = remoteOffsetReader.get.asyncOffsetRead(topicPartition, targetTimestamp, logStartOffset, leaderEpochCache, () => searchOffsetInLocalLog(targetTimestamp, localLogStartOffset())) - + new OffsetResultHolder(Optional.empty(), Optional.of(asyncOffsetReadFutureHolder)) } else { new OffsetResultHolder(searchOffsetInLocalLog(targetTimestamp, logStartOffset)) diff --git a/core/src/main/scala/kafka/raft/KafkaMetadataLog.scala b/core/src/main/scala/kafka/raft/KafkaMetadataLog.scala index 3f6c2044df5..be03e0723af 100644 --- a/core/src/main/scala/kafka/raft/KafkaMetadataLog.scala +++ b/core/src/main/scala/kafka/raft/KafkaMetadataLog.scala @@ -25,6 +25,7 @@ import kafka.raft.KafkaMetadataLog.UnknownReason import kafka.utils.Logging import org.apache.kafka.common.config.TopicConfig import org.apache.kafka.common.errors.InvalidConfigurationException +import org.apache.kafka.common.errors.CorruptRecordException import org.apache.kafka.common.record.{MemoryRecords, Records} import org.apache.kafka.common.utils.{Time, Utils} import org.apache.kafka.common.{KafkaException, TopicPartition, Uuid} @@ -89,8 +90,9 @@ final class KafkaMetadataLog private ( } override def appendAsLeader(records: Records, epoch: Int): LogAppendInfo = { - if (records.sizeInBytes == 0) + if (records.sizeInBytes == 0) { throw new IllegalArgumentException("Attempt to append an empty record set") + } handleAndConvertLogAppendInfo( log.appendAsLeader(records.asInstanceOf[MemoryRecords], @@ -101,18 +103,20 @@ final class KafkaMetadataLog private ( ) } - override def appendAsFollower(records: Records): LogAppendInfo = { - if (records.sizeInBytes == 0) + override def appendAsFollower(records: Records, epoch: Int): LogAppendInfo = { + if (records.sizeInBytes == 0) { throw new IllegalArgumentException("Attempt to append an empty record set") + } - handleAndConvertLogAppendInfo(log.appendAsFollower(records.asInstanceOf[MemoryRecords])) + handleAndConvertLogAppendInfo(log.appendAsFollower(records.asInstanceOf[MemoryRecords], epoch)) } private def handleAndConvertLogAppendInfo(appendInfo: internals.log.LogAppendInfo): LogAppendInfo = { - if (appendInfo.firstOffset != JUnifiedLog.UNKNOWN_OFFSET) + if (appendInfo.firstOffset == JUnifiedLog.UNKNOWN_OFFSET) { + throw new CorruptRecordException(s"Append failed unexpectedly $appendInfo") + } else { new LogAppendInfo(appendInfo.firstOffset, appendInfo.lastOffset) - else - throw new KafkaException(s"Append failed unexpectedly") + } } override def lastFetchedEpoch: Int = { diff --git a/core/src/main/scala/kafka/server/AbstractFetcherThread.scala b/core/src/main/scala/kafka/server/AbstractFetcherThread.scala index be663d19ec8..7a98c83e7f4 100755 --- a/core/src/main/scala/kafka/server/AbstractFetcherThread.scala +++ b/core/src/main/scala/kafka/server/AbstractFetcherThread.scala @@ -78,9 +78,12 @@ abstract class AbstractFetcherThread(name: String, /* callbacks to be defined in subclass */ // process fetched data - protected def processPartitionData(topicPartition: TopicPartition, - fetchOffset: Long, - partitionData: FetchData): Option[LogAppendInfo] + protected def processPartitionData( + topicPartition: TopicPartition, + fetchOffset: Long, + partitionLeaderEpoch: Int, + partitionData: FetchData + ): Option[LogAppendInfo] protected def truncate(topicPartition: TopicPartition, truncationState: OffsetTruncationState): Unit @@ -333,7 +336,9 @@ abstract class AbstractFetcherThread(name: String, // In this case, we only want to process the fetch response if the partition state is ready for fetch and // the current offset is the same as the offset requested. val fetchPartitionData = sessionPartitions.get(topicPartition) - if (fetchPartitionData != null && fetchPartitionData.fetchOffset == currentFetchState.fetchOffset && currentFetchState.isReadyForFetch) { + if (fetchPartitionData != null && + fetchPartitionData.fetchOffset == currentFetchState.fetchOffset && + currentFetchState.isReadyForFetch) { Errors.forCode(partitionData.errorCode) match { case Errors.NONE => try { @@ -348,10 +353,16 @@ abstract class AbstractFetcherThread(name: String, .setLeaderEpoch(partitionData.divergingEpoch.epoch) .setEndOffset(partitionData.divergingEpoch.endOffset) } else { - // Once we hand off the partition data to the subclass, we can't mess with it any more in this thread + /* Once we hand off the partition data to the subclass, we can't mess with it any more in this thread + * + * When appending batches to the log only append record batches up to the leader epoch when the FETCH + * request was handled. This is done to make sure that logs are not inconsistent because of log + * truncation and append after the FETCH request was handled. See KAFKA-18723 for more details. + */ val logAppendInfoOpt = processPartitionData( topicPartition, currentFetchState.fetchOffset, + fetchPartitionData.currentLeaderEpoch.orElse(currentFetchState.currentLeaderEpoch), partitionData ) diff --git a/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala b/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala index 56492de3485..5f5373b3641 100644 --- a/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala +++ b/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala @@ -66,9 +66,12 @@ class ReplicaAlterLogDirsThread(name: String, } // process fetched data - override def processPartitionData(topicPartition: TopicPartition, - fetchOffset: Long, - partitionData: FetchData): Option[LogAppendInfo] = { + override def processPartitionData( + topicPartition: TopicPartition, + fetchOffset: Long, + partitionLeaderEpoch: Int, + partitionData: FetchData + ): Option[LogAppendInfo] = { val partition = replicaMgr.getPartitionOrException(topicPartition) val futureLog = partition.futureLocalLogOrException val records = toMemoryRecords(FetchResponse.recordsOrFail(partitionData)) @@ -78,7 +81,7 @@ class ReplicaAlterLogDirsThread(name: String, topicPartition, fetchOffset, futureLog.logEndOffset)) val logAppendInfo = if (records.sizeInBytes() > 0) - partition.appendRecordsToFollowerOrFutureReplica(records, isFuture = true) + partition.appendRecordsToFollowerOrFutureReplica(records, isFuture = true, partitionLeaderEpoch) else None diff --git a/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala b/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala index 7f0c6d41dbd..4c11301c567 100644 --- a/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala +++ b/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala @@ -98,9 +98,12 @@ class ReplicaFetcherThread(name: String, } // process fetched data - override def processPartitionData(topicPartition: TopicPartition, - fetchOffset: Long, - partitionData: FetchData): Option[LogAppendInfo] = { + override def processPartitionData( + topicPartition: TopicPartition, + fetchOffset: Long, + partitionLeaderEpoch: Int, + partitionData: FetchData + ): Option[LogAppendInfo] = { val logTrace = isTraceEnabled val partition = replicaMgr.getPartitionOrException(topicPartition) val log = partition.localLogOrException @@ -117,7 +120,7 @@ class ReplicaFetcherThread(name: String, .format(log.logEndOffset, topicPartition, records.sizeInBytes, partitionData.highWatermark)) // Append the leader's messages to the log - val logAppendInfo = partition.appendRecordsToFollowerOrFutureReplica(records, isFuture = false) + val logAppendInfo = partition.appendRecordsToFollowerOrFutureReplica(records, isFuture = false, partitionLeaderEpoch) if (logTrace) trace("Follower has replica log end offset %d after appending %d bytes of messages for partition %s" diff --git a/core/src/test/scala/kafka/raft/KafkaMetadataLogTest.scala b/core/src/test/scala/kafka/raft/KafkaMetadataLogTest.scala index 5f1752a54a6..263c35f5bfa 100644 --- a/core/src/test/scala/kafka/raft/KafkaMetadataLogTest.scala +++ b/core/src/test/scala/kafka/raft/KafkaMetadataLogTest.scala @@ -19,9 +19,12 @@ package kafka.raft import kafka.server.{KafkaConfig, KafkaRaftServer} import kafka.utils.TestUtils import org.apache.kafka.common.compress.Compression +import org.apache.kafka.common.errors.CorruptRecordException import org.apache.kafka.common.errors.{InvalidConfigurationException, RecordTooLargeException} import org.apache.kafka.common.protocol import org.apache.kafka.common.protocol.{ObjectSerializationCache, Writable} +import org.apache.kafka.common.record.ArbitraryMemoryRecords +import org.apache.kafka.common.record.InvalidMemoryRecordsProvider import org.apache.kafka.common.record.{MemoryRecords, SimpleRecord} import org.apache.kafka.common.utils.Utils import org.apache.kafka.raft._ @@ -33,7 +36,14 @@ import org.apache.kafka.snapshot.{FileRawSnapshotWriter, RawSnapshotReader, RawS import org.apache.kafka.storage.internals.log.{LogConfig, LogStartOffsetIncrementReason, UnifiedLog} import org.apache.kafka.test.TestUtils.assertOptional import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.function.Executable import org.junit.jupiter.api.{AfterEach, BeforeEach, Test} +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ArgumentsSource + +import net.jqwik.api.AfterFailureMode +import net.jqwik.api.ForAll +import net.jqwik.api.Property import java.io.File import java.nio.ByteBuffer @@ -108,12 +118,93 @@ final class KafkaMetadataLogTest { classOf[RuntimeException], () => { log.appendAsFollower( - MemoryRecords.withRecords(initialOffset, Compression.NONE, currentEpoch, recordFoo) + MemoryRecords.withRecords(initialOffset, Compression.NONE, currentEpoch, recordFoo), + currentEpoch ) } ) } + @Test + def testEmptyAppendNotAllowed(): Unit = { + val log = buildMetadataLog(tempDir, mockTime) + + assertThrows(classOf[IllegalArgumentException], () => log.appendAsFollower(MemoryRecords.EMPTY, 1)); + assertThrows(classOf[IllegalArgumentException], () => log.appendAsLeader(MemoryRecords.EMPTY, 1)); + } + + @ParameterizedTest + @ArgumentsSource(classOf[InvalidMemoryRecordsProvider]) + def testInvalidMemoryRecords(records: MemoryRecords, expectedException: Optional[Class[Exception]]): Unit = { + val log = buildMetadataLog(tempDir, mockTime) + val previousEndOffset = log.endOffset().offset() + + val action: Executable = () => log.appendAsFollower(records, Int.MaxValue) + if (expectedException.isPresent()) { + assertThrows(expectedException.get, action) + } else { + assertThrows(classOf[CorruptRecordException], action) + } + + assertEquals(previousEndOffset, log.endOffset().offset()) + } + + @Property(tries = 100, afterFailure = AfterFailureMode.SAMPLE_ONLY) + def testRandomRecords( + @ForAll(supplier = classOf[ArbitraryMemoryRecords]) records: MemoryRecords + ): Unit = { + val tempDir = TestUtils.tempDir() + try { + val log = buildMetadataLog(tempDir, mockTime) + val previousEndOffset = log.endOffset().offset() + + assertThrows( + classOf[CorruptRecordException], + () => log.appendAsFollower(records, Int.MaxValue) + ) + + assertEquals(previousEndOffset, log.endOffset().offset()) + } finally { + Utils.delete(tempDir) + } + } + + @Test + def testInvalidLeaderEpoch(): Unit = { + val log = buildMetadataLog(tempDir, mockTime) + val previousEndOffset = log.endOffset().offset() + val epoch = log.lastFetchedEpoch() + 1 + val numberOfRecords = 10 + + val batchWithValidEpoch = MemoryRecords.withRecords( + previousEndOffset, + Compression.NONE, + epoch, + (0 until numberOfRecords).map(number => new SimpleRecord(number.toString.getBytes)): _* + ) + + val batchWithInvalidEpoch = MemoryRecords.withRecords( + previousEndOffset + numberOfRecords, + Compression.NONE, + epoch + 1, + (0 until numberOfRecords).map(number => new SimpleRecord(number.toString.getBytes)): _* + ) + + val buffer = ByteBuffer.allocate(batchWithValidEpoch.sizeInBytes() + batchWithInvalidEpoch.sizeInBytes()) + buffer.put(batchWithValidEpoch.buffer()) + buffer.put(batchWithInvalidEpoch.buffer()) + buffer.flip() + + val records = MemoryRecords.readableRecords(buffer) + + log.appendAsFollower(records, epoch) + + // Check that only the first batch was appended + assertEquals(previousEndOffset + numberOfRecords, log.endOffset().offset()) + // Check that the last fetched epoch matches the first batch + assertEquals(epoch, log.lastFetchedEpoch()) + } + @Test def testCreateSnapshot(): Unit = { val numberOfRecords = 10 @@ -1061,4 +1152,4 @@ object KafkaMetadataLogTest { } dir } -} \ No newline at end of file +} diff --git a/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala b/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala index 9aaaa4a64da..1bb32f1d6c2 100644 --- a/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala +++ b/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala @@ -428,6 +428,7 @@ class PartitionTest extends AbstractPartitionTest { def testMakeFollowerWithWithFollowerAppendRecords(): Unit = { val appendSemaphore = new Semaphore(0) val mockTime = new MockTime() + val prevLeaderEpoch = 0 partition = new Partition( topicPartition, @@ -480,24 +481,38 @@ class PartitionTest extends AbstractPartitionTest { } partition.createLogIfNotExists(isNew = true, isFutureReplica = false, offsetCheckpoints, None) + var partitionState = new LeaderAndIsrRequest.PartitionState() + .setControllerEpoch(0) + .setLeader(2) + .setLeaderEpoch(prevLeaderEpoch) + .setIsr(List[Integer](0, 1, 2, brokerId).asJava) + .setPartitionEpoch(1) + .setReplicas(List[Integer](0, 1, 2, brokerId).asJava) + .setIsNew(false) + assertTrue(partition.makeFollower(partitionState, offsetCheckpoints, None)) val appendThread = new Thread { override def run(): Unit = { - val records = createRecords(List(new SimpleRecord("k1".getBytes, "v1".getBytes), - new SimpleRecord("k2".getBytes, "v2".getBytes)), - baseOffset = 0) - partition.appendRecordsToFollowerOrFutureReplica(records, isFuture = false) + val records = createRecords( + List( + new SimpleRecord("k1".getBytes, "v1".getBytes), + new SimpleRecord("k2".getBytes, "v2".getBytes) + ), + baseOffset = 0, + partitionLeaderEpoch = prevLeaderEpoch + ) + partition.appendRecordsToFollowerOrFutureReplica(records, isFuture = false, prevLeaderEpoch) } } appendThread.start() TestUtils.waitUntilTrue(() => appendSemaphore.hasQueuedThreads, "follower log append is not called.") - val partitionState = new LeaderAndIsrRequest.PartitionState() + partitionState = new LeaderAndIsrRequest.PartitionState() .setControllerEpoch(0) .setLeader(2) - .setLeaderEpoch(1) + .setLeaderEpoch(prevLeaderEpoch + 1) .setIsr(List[Integer](0, 1, 2, brokerId).asJava) - .setPartitionEpoch(1) + .setPartitionEpoch(2) .setReplicas(List[Integer](0, 1, 2, brokerId).asJava) .setIsNew(false) assertTrue(partition.makeFollower(partitionState, offsetCheckpoints, None)) @@ -537,15 +552,22 @@ class PartitionTest extends AbstractPartitionTest { // Write to the future replica as if the log had been compacted, and do not roll the segment val buffer = ByteBuffer.allocate(1024) - val builder = MemoryRecords.builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, Compression.NONE, - TimestampType.CREATE_TIME, 0L, RecordBatch.NO_TIMESTAMP, 0) + val builder = MemoryRecords.builder( + buffer, + RecordBatch.CURRENT_MAGIC_VALUE, + Compression.NONE, + TimestampType.CREATE_TIME, + 0L, // baseOffset + RecordBatch.NO_TIMESTAMP, + 0 // partitionLeaderEpoch + ) builder.appendWithOffset(2L, new SimpleRecord("k1".getBytes, "v3".getBytes)) builder.appendWithOffset(5L, new SimpleRecord("k2".getBytes, "v6".getBytes)) builder.appendWithOffset(6L, new SimpleRecord("k3".getBytes, "v7".getBytes)) builder.appendWithOffset(7L, new SimpleRecord("k4".getBytes, "v8".getBytes)) val futureLog = partition.futureLocalLogOrException - futureLog.appendAsFollower(builder.build()) + futureLog.appendAsFollower(builder.build(), 0) assertTrue(partition.maybeReplaceCurrentWithFutureReplica()) } @@ -955,6 +977,18 @@ class PartitionTest extends AbstractPartitionTest { def testAppendRecordsAsFollowerBelowLogStartOffset(): Unit = { partition.createLogIfNotExists(isNew = false, isFutureReplica = false, offsetCheckpoints, None) val log = partition.localLogOrException + val epoch = 1 + + // Start off as follower + val partitionState = new LeaderAndIsrRequest.PartitionState() + .setControllerEpoch(0) + .setLeader(1) + .setLeaderEpoch(epoch) + .setIsr(List[Integer](0, 1, 2, brokerId).asJava) + .setPartitionEpoch(1) + .setReplicas(List[Integer](0, 1, 2, brokerId).asJava) + .setIsNew(false) + partition.makeFollower(partitionState, offsetCheckpoints, None) val initialLogStartOffset = 5L partition.truncateFullyAndStartAt(initialLogStartOffset, isFuture = false) @@ -964,9 +998,14 @@ class PartitionTest extends AbstractPartitionTest { s"Log start offset after truncate fully and start at $initialLogStartOffset:") // verify that we cannot append records that do not contain log start offset even if the log is empty - assertThrows(classOf[UnexpectedAppendOffsetException], () => + assertThrows( + classOf[UnexpectedAppendOffsetException], // append one record with offset = 3 - partition.appendRecordsToFollowerOrFutureReplica(createRecords(List(new SimpleRecord("k1".getBytes, "v1".getBytes)), baseOffset = 3L), isFuture = false) + () => partition.appendRecordsToFollowerOrFutureReplica( + createRecords(List(new SimpleRecord("k1".getBytes, "v1".getBytes)), baseOffset = 3L), + isFuture = false, + partitionLeaderEpoch = epoch + ) ) assertEquals(initialLogStartOffset, log.logEndOffset, s"Log end offset should not change after failure to append") @@ -978,12 +1017,16 @@ class PartitionTest extends AbstractPartitionTest { new SimpleRecord("k2".getBytes, "v2".getBytes), new SimpleRecord("k3".getBytes, "v3".getBytes)), baseOffset = newLogStartOffset) - partition.appendRecordsToFollowerOrFutureReplica(records, isFuture = false) + partition.appendRecordsToFollowerOrFutureReplica(records, isFuture = false, partitionLeaderEpoch = epoch) assertEquals(7L, log.logEndOffset, s"Log end offset after append of 3 records with base offset $newLogStartOffset:") assertEquals(newLogStartOffset, log.logStartOffset, s"Log start offset after append of 3 records with base offset $newLogStartOffset:") // and we can append more records after that - partition.appendRecordsToFollowerOrFutureReplica(createRecords(List(new SimpleRecord("k1".getBytes, "v1".getBytes)), baseOffset = 7L), isFuture = false) + partition.appendRecordsToFollowerOrFutureReplica( + createRecords(List(new SimpleRecord("k1".getBytes, "v1".getBytes)), baseOffset = 7L), + isFuture = false, + partitionLeaderEpoch = epoch + ) assertEquals(8L, log.logEndOffset, s"Log end offset after append of 1 record at offset 7:") assertEquals(newLogStartOffset, log.logStartOffset, s"Log start offset not expected to change:") @@ -991,11 +1034,18 @@ class PartitionTest extends AbstractPartitionTest { val records2 = createRecords(List(new SimpleRecord("k1".getBytes, "v1".getBytes), new SimpleRecord("k2".getBytes, "v2".getBytes)), baseOffset = 3L) - assertThrows(classOf[UnexpectedAppendOffsetException], () => partition.appendRecordsToFollowerOrFutureReplica(records2, isFuture = false)) + assertThrows( + classOf[UnexpectedAppendOffsetException], + () => partition.appendRecordsToFollowerOrFutureReplica(records2, isFuture = false, partitionLeaderEpoch = epoch) + ) assertEquals(8L, log.logEndOffset, s"Log end offset should not change after failure to append") // we still can append to next offset - partition.appendRecordsToFollowerOrFutureReplica(createRecords(List(new SimpleRecord("k1".getBytes, "v1".getBytes)), baseOffset = 8L), isFuture = false) + partition.appendRecordsToFollowerOrFutureReplica( + createRecords(List(new SimpleRecord("k1".getBytes, "v1".getBytes)), baseOffset = 8L), + isFuture = false, + partitionLeaderEpoch = epoch + ) assertEquals(9L, log.logEndOffset, s"Log end offset after append of 1 record at offset 8:") assertEquals(newLogStartOffset, log.logStartOffset, s"Log start offset not expected to change:") } @@ -1078,9 +1128,13 @@ class PartitionTest extends AbstractPartitionTest { @Test def testAppendRecordsToFollowerWithNoReplicaThrowsException(): Unit = { - assertThrows(classOf[NotLeaderOrFollowerException], () => - partition.appendRecordsToFollowerOrFutureReplica( - createRecords(List(new SimpleRecord("k1".getBytes, "v1".getBytes)), baseOffset = 0L), isFuture = false) + assertThrows( + classOf[NotLeaderOrFollowerException], + () => partition.appendRecordsToFollowerOrFutureReplica( + createRecords(List(new SimpleRecord("k1".getBytes, "v1".getBytes)), baseOffset = 0L), + isFuture = false, + partitionLeaderEpoch = 0 + ) ) } @@ -3457,12 +3511,13 @@ class PartitionTest extends AbstractPartitionTest { val replicas = Seq(brokerId, brokerId + 1) val isr = replicas + val epoch = 0 addBrokerEpochToMockMetadataCache(metadataCache, replicas.toList) partition.makeLeader( new LeaderAndIsrRequest.PartitionState() .setControllerEpoch(0) .setLeader(brokerId) - .setLeaderEpoch(0) + .setLeaderEpoch(epoch) .setIsr(isr.map(Int.box).asJava) .setReplicas(replicas.map(Int.box).asJava) .setPartitionEpoch(1) @@ -3495,7 +3550,8 @@ class PartitionTest extends AbstractPartitionTest { partition.appendRecordsToFollowerOrFutureReplica( records = records, - isFuture = true + isFuture = true, + partitionLeaderEpoch = epoch ) listener.verify() @@ -3640,9 +3696,9 @@ class PartitionTest extends AbstractPartitionTest { producerStateManager, _topicId = topicId) { - override def appendAsFollower(records: MemoryRecords): LogAppendInfo = { + override def appendAsFollower(records: MemoryRecords, epoch: Int): LogAppendInfo = { appendSemaphore.acquire() - val appendInfo = super.appendAsFollower(records) + val appendInfo = super.appendAsFollower(records, epoch) appendInfo } } diff --git a/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala b/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala index 895d5d64363..716731b48d3 100644 --- a/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala +++ b/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala @@ -1457,7 +1457,7 @@ class LogCleanerTest extends Logging { log.appendAsLeader(TestUtils.singletonRecords(value = v, key = k), leaderEpoch = 0) //0 to Int.MaxValue is Int.MaxValue+1 message, -1 will be the last message of i-th segment val records = messageWithOffset(k, v, (i + 1L) * (Int.MaxValue + 1L) -1 ) - log.appendAsFollower(records) + log.appendAsFollower(records, Int.MaxValue) assertEquals(i + 1, log.numberOfSegments) } @@ -1511,7 +1511,7 @@ class LogCleanerTest extends Logging { // forward offset and append message to next segment at offset Int.MaxValue val records = messageWithOffset("hello".getBytes, "hello".getBytes, Int.MaxValue - 1) - log.appendAsFollower(records) + log.appendAsFollower(records, Int.MaxValue) log.appendAsLeader(TestUtils.singletonRecords(value = "hello".getBytes, key = "hello".getBytes), leaderEpoch = 0) assertEquals(Int.MaxValue, log.activeSegment.offsetIndex.lastOffset) @@ -1560,14 +1560,14 @@ class LogCleanerTest extends Logging { val log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps)) val record1 = messageWithOffset("hello".getBytes, "hello".getBytes, 0) - log.appendAsFollower(record1) + log.appendAsFollower(record1, Int.MaxValue) val record2 = messageWithOffset("hello".getBytes, "hello".getBytes, 1) - log.appendAsFollower(record2) + log.appendAsFollower(record2, Int.MaxValue) log.roll(Some(Int.MaxValue/2)) // starting a new log segment at offset Int.MaxValue/2 val record3 = messageWithOffset("hello".getBytes, "hello".getBytes, Int.MaxValue/2) - log.appendAsFollower(record3) + log.appendAsFollower(record3, Int.MaxValue) val record4 = messageWithOffset("hello".getBytes, "hello".getBytes, Int.MaxValue.toLong + 1) - log.appendAsFollower(record4) + log.appendAsFollower(record4, Int.MaxValue) assertTrue(log.logEndOffset - 1 - log.logStartOffset > Int.MaxValue, "Actual offset range should be > Int.MaxValue") assertTrue(log.logSegments.asScala.last.offsetIndex.lastOffset - log.logStartOffset <= Int.MaxValue, @@ -1881,8 +1881,8 @@ class LogCleanerTest extends Logging { val noDupSetOffset = 50 val noDupSet = noDupSetKeys zip (noDupSetOffset until noDupSetOffset + noDupSetKeys.size) - log.appendAsFollower(invalidCleanedMessage(dupSetOffset, dupSet, codec)) - log.appendAsFollower(invalidCleanedMessage(noDupSetOffset, noDupSet, codec)) + log.appendAsFollower(invalidCleanedMessage(dupSetOffset, dupSet, codec), Int.MaxValue) + log.appendAsFollower(invalidCleanedMessage(noDupSetOffset, noDupSet, codec), Int.MaxValue) log.roll() @@ -1968,7 +1968,7 @@ class LogCleanerTest extends Logging { log.roll(Some(11L)) // active segment record - log.appendAsFollower(messageWithOffset(1015, 1015, 11L)) + log.appendAsFollower(messageWithOffset(1015, 1015, 11L), Int.MaxValue) val (nextDirtyOffset, _) = cleaner.clean(LogToClean(log.topicPartition, log, 0L, log.activeSegment.baseOffset, needCompactionNow = true)) assertEquals(log.activeSegment.baseOffset, nextDirtyOffset, @@ -1987,7 +1987,7 @@ class LogCleanerTest extends Logging { log.roll(Some(30L)) // active segment record - log.appendAsFollower(messageWithOffset(1015, 1015, 30L)) + log.appendAsFollower(messageWithOffset(1015, 1015, 30L), Int.MaxValue) val (nextDirtyOffset, _) = cleaner.clean(LogToClean(log.topicPartition, log, 0L, log.activeSegment.baseOffset, needCompactionNow = true)) assertEquals(log.activeSegment.baseOffset, nextDirtyOffset, @@ -2204,7 +2204,7 @@ class LogCleanerTest extends Logging { private def writeToLog(log: UnifiedLog, keysAndValues: Iterable[(Int, Int)], offsetSeq: Iterable[Long]): Iterable[Long] = { for (((key, value), offset) <- keysAndValues.zip(offsetSeq)) - yield log.appendAsFollower(messageWithOffset(key, value, offset)).lastOffset + yield log.appendAsFollower(messageWithOffset(key, value, offset), Int.MaxValue).lastOffset } private def invalidCleanedMessage(initialOffset: Long, diff --git a/core/src/test/scala/unit/kafka/log/LogConcurrencyTest.scala b/core/src/test/scala/unit/kafka/log/LogConcurrencyTest.scala index 342ef145b6d..72ad7a718d1 100644 --- a/core/src/test/scala/unit/kafka/log/LogConcurrencyTest.scala +++ b/core/src/test/scala/unit/kafka/log/LogConcurrencyTest.scala @@ -126,9 +126,14 @@ class LogConcurrencyTest { log.appendAsLeader(TestUtils.records(records), leaderEpoch) log.maybeIncrementHighWatermark(logEndOffsetMetadata) } else { - log.appendAsFollower(TestUtils.records(records, - baseOffset = logEndOffset, - partitionLeaderEpoch = leaderEpoch)) + log.appendAsFollower( + TestUtils.records( + records, + baseOffset = logEndOffset, + partitionLeaderEpoch = leaderEpoch + ), + Int.MaxValue + ) log.updateHighWatermark(logEndOffset) } diff --git a/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala b/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala index 3bdf8a9436c..11c2b620058 100644 --- a/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala +++ b/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala @@ -923,17 +923,17 @@ class LogLoaderTest { val set3 = MemoryRecords.withRecords(Integer.MAX_VALUE.toLong + 3, Compression.NONE, 0, new SimpleRecord("v4".getBytes(), "k4".getBytes())) val set4 = MemoryRecords.withRecords(Integer.MAX_VALUE.toLong + 4, Compression.NONE, 0, new SimpleRecord("v5".getBytes(), "k5".getBytes())) //Writes into an empty log with baseOffset 0 - log.appendAsFollower(set1) + log.appendAsFollower(set1, Int.MaxValue) assertEquals(0L, log.activeSegment.baseOffset) //This write will roll the segment, yielding a new segment with base offset = max(1, Integer.MAX_VALUE+2) = Integer.MAX_VALUE+2 - log.appendAsFollower(set2) + log.appendAsFollower(set2, Int.MaxValue) assertEquals(Integer.MAX_VALUE.toLong + 2, log.activeSegment.baseOffset) assertTrue(LogFileUtils.producerSnapshotFile(logDir, Integer.MAX_VALUE.toLong + 2).exists) //This will go into the existing log - log.appendAsFollower(set3) + log.appendAsFollower(set3, Int.MaxValue) assertEquals(Integer.MAX_VALUE.toLong + 2, log.activeSegment.baseOffset) //This will go into the existing log - log.appendAsFollower(set4) + log.appendAsFollower(set4, Int.MaxValue) assertEquals(Integer.MAX_VALUE.toLong + 2, log.activeSegment.baseOffset) log.close() val indexFiles = logDir.listFiles.filter(file => file.getName.contains(".index")) @@ -962,17 +962,17 @@ class LogLoaderTest { new SimpleRecord("v7".getBytes(), "k7".getBytes()), new SimpleRecord("v8".getBytes(), "k8".getBytes())) //Writes into an empty log with baseOffset 0 - log.appendAsFollower(set1) + log.appendAsFollower(set1, Int.MaxValue) assertEquals(0L, log.activeSegment.baseOffset) //This write will roll the segment, yielding a new segment with base offset = max(1, Integer.MAX_VALUE+2) = Integer.MAX_VALUE+2 - log.appendAsFollower(set2) + log.appendAsFollower(set2, Int.MaxValue) assertEquals(Integer.MAX_VALUE.toLong + 2, log.activeSegment.baseOffset) assertTrue(LogFileUtils.producerSnapshotFile(logDir, Integer.MAX_VALUE.toLong + 2).exists) //This will go into the existing log - log.appendAsFollower(set3) + log.appendAsFollower(set3, Int.MaxValue) assertEquals(Integer.MAX_VALUE.toLong + 2, log.activeSegment.baseOffset) //This will go into the existing log - log.appendAsFollower(set4) + log.appendAsFollower(set4, Int.MaxValue) assertEquals(Integer.MAX_VALUE.toLong + 2, log.activeSegment.baseOffset) log.close() val indexFiles = logDir.listFiles.filter(file => file.getName.contains(".index")) @@ -1002,18 +1002,18 @@ class LogLoaderTest { new SimpleRecord("v7".getBytes(), "k7".getBytes()), new SimpleRecord("v8".getBytes(), "k8".getBytes())) //Writes into an empty log with baseOffset 0 - log.appendAsFollower(set1) + log.appendAsFollower(set1, Int.MaxValue) assertEquals(0L, log.activeSegment.baseOffset) //This write will roll the segment, yielding a new segment with base offset = max(1, 3) = 3 - log.appendAsFollower(set2) + log.appendAsFollower(set2, Int.MaxValue) assertEquals(3, log.activeSegment.baseOffset) assertTrue(LogFileUtils.producerSnapshotFile(logDir, 3).exists) //This will also roll the segment, yielding a new segment with base offset = max(5, Integer.MAX_VALUE+4) = Integer.MAX_VALUE+4 - log.appendAsFollower(set3) + log.appendAsFollower(set3, Int.MaxValue) assertEquals(Integer.MAX_VALUE.toLong + 4, log.activeSegment.baseOffset) assertTrue(LogFileUtils.producerSnapshotFile(logDir, Integer.MAX_VALUE.toLong + 4).exists) //This will go into the existing log - log.appendAsFollower(set4) + log.appendAsFollower(set4, Int.MaxValue) assertEquals(Integer.MAX_VALUE.toLong + 4, log.activeSegment.baseOffset) log.close() val indexFiles = logDir.listFiles.filter(file => file.getName.contains(".index")) @@ -1203,16 +1203,16 @@ class LogLoaderTest { val log = createLog(logDir, new LogConfig(new Properties)) val leaderEpochCache = log.leaderEpochCache val firstBatch = singletonRecordsWithLeaderEpoch(value = "random".getBytes, leaderEpoch = 1, offset = 0) - log.appendAsFollower(records = firstBatch) + log.appendAsFollower(records = firstBatch, Int.MaxValue) val secondBatch = singletonRecordsWithLeaderEpoch(value = "random".getBytes, leaderEpoch = 2, offset = 1) - log.appendAsFollower(records = secondBatch) + log.appendAsFollower(records = secondBatch, Int.MaxValue) val thirdBatch = singletonRecordsWithLeaderEpoch(value = "random".getBytes, leaderEpoch = 2, offset = 2) - log.appendAsFollower(records = thirdBatch) + log.appendAsFollower(records = thirdBatch, Int.MaxValue) val fourthBatch = singletonRecordsWithLeaderEpoch(value = "random".getBytes, leaderEpoch = 3, offset = 3) - log.appendAsFollower(records = fourthBatch) + log.appendAsFollower(records = fourthBatch, Int.MaxValue) assertEquals(java.util.Arrays.asList(new EpochEntry(1, 0), new EpochEntry(2, 1), new EpochEntry(3, 3)), leaderEpochCache.epochEntries) diff --git a/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala b/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala index 2ac71abd7be..f1e53014d72 100755 --- a/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala +++ b/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala @@ -48,11 +48,16 @@ import org.apache.kafka.storage.log.metrics.{BrokerTopicMetrics, BrokerTopicStat import org.junit.jupiter.api.Assertions._ import org.junit.jupiter.api.{AfterEach, BeforeEach, Test} import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ArgumentsSource import org.junit.jupiter.params.provider.{EnumSource, ValueSource} import org.mockito.ArgumentMatchers import org.mockito.ArgumentMatchers.{any, anyLong} import org.mockito.Mockito.{doAnswer, doThrow, spy} +import net.jqwik.api.AfterFailureMode +import net.jqwik.api.ForAll +import net.jqwik.api.Property + import java.io._ import java.nio.ByteBuffer import java.nio.file.Files @@ -304,7 +309,7 @@ class UnifiedLogTest { assertHighWatermark(3L) // Update high watermark as follower - log.appendAsFollower(records(3L)) + log.appendAsFollower(records(3L), leaderEpoch) log.updateHighWatermark(6L) assertHighWatermark(6L) @@ -582,6 +587,7 @@ class UnifiedLogTest { @Test def testRollSegmentThatAlreadyExists(): Unit = { val logConfig = LogTestUtils.createLogConfig(segmentMs = 1 * 60 * 60L) + val partitionLeaderEpoch = 0 // create a log val log = createLog(logDir, logConfig) @@ -594,16 +600,16 @@ class UnifiedLogTest { // should be able to append records to active segment val records = TestUtils.records( List(new SimpleRecord(mockTime.milliseconds, "k1".getBytes, "v1".getBytes)), - baseOffset = 0L, partitionLeaderEpoch = 0) - log.appendAsFollower(records) + baseOffset = 0L, partitionLeaderEpoch = partitionLeaderEpoch) + log.appendAsFollower(records, partitionLeaderEpoch) assertEquals(1, log.numberOfSegments, "Expect one segment.") assertEquals(0L, log.activeSegment.baseOffset) // make sure we can append more records val records2 = TestUtils.records( List(new SimpleRecord(mockTime.milliseconds + 10, "k2".getBytes, "v2".getBytes)), - baseOffset = 1L, partitionLeaderEpoch = 0) - log.appendAsFollower(records2) + baseOffset = 1L, partitionLeaderEpoch = partitionLeaderEpoch) + log.appendAsFollower(records2, partitionLeaderEpoch) assertEquals(2, log.logEndOffset, "Expect two records in the log") assertEquals(0, LogTestUtils.readLog(log, 0, 1).records.batches.iterator.next().lastOffset) @@ -618,8 +624,8 @@ class UnifiedLogTest { log.activeSegment.offsetIndex.resize(0) val records3 = TestUtils.records( List(new SimpleRecord(mockTime.milliseconds + 12, "k3".getBytes, "v3".getBytes)), - baseOffset = 2L, partitionLeaderEpoch = 0) - log.appendAsFollower(records3) + baseOffset = 2L, partitionLeaderEpoch = partitionLeaderEpoch) + log.appendAsFollower(records3, partitionLeaderEpoch) assertTrue(log.activeSegment.offsetIndex.maxEntries > 1) assertEquals(2, LogTestUtils.readLog(log, 2, 1).records.batches.iterator.next().lastOffset) assertEquals(2, log.numberOfSegments, "Expect two segments.") @@ -793,17 +799,25 @@ class UnifiedLogTest { val logConfig = LogTestUtils.createLogConfig(segmentBytes = 2048 * 5) val log = createLog(logDir, logConfig) val pid = 1L - val epoch = 0.toShort + val producerEpoch = 0.toShort + val partitionLeaderEpoch = 0 val seq = 0 val baseOffset = 23L // create a batch with a couple gaps to simulate compaction - val records = TestUtils.records(producerId = pid, producerEpoch = epoch, sequence = seq, baseOffset = baseOffset, records = List( - new SimpleRecord(mockTime.milliseconds(), "a".getBytes), - new SimpleRecord(mockTime.milliseconds(), "key".getBytes, "b".getBytes), - new SimpleRecord(mockTime.milliseconds(), "c".getBytes), - new SimpleRecord(mockTime.milliseconds(), "key".getBytes, "d".getBytes))) - records.batches.forEach(_.setPartitionLeaderEpoch(0)) + val records = TestUtils.records( + producerId = pid, + producerEpoch = producerEpoch, + sequence = seq, + baseOffset = baseOffset, + records = List( + new SimpleRecord(mockTime.milliseconds(), "a".getBytes), + new SimpleRecord(mockTime.milliseconds(), "key".getBytes, "b".getBytes), + new SimpleRecord(mockTime.milliseconds(), "c".getBytes), + new SimpleRecord(mockTime.milliseconds(), "key".getBytes, "d".getBytes) + ) + ) + records.batches.forEach(_.setPartitionLeaderEpoch(partitionLeaderEpoch)) val filtered = ByteBuffer.allocate(2048) records.filterTo(new RecordFilter(0, 0) { @@ -814,14 +828,18 @@ class UnifiedLogTest { filtered.flip() val filteredRecords = MemoryRecords.readableRecords(filtered) - log.appendAsFollower(filteredRecords) + log.appendAsFollower(filteredRecords, partitionLeaderEpoch) // append some more data and then truncate to force rebuilding of the PID map - val moreRecords = TestUtils.records(baseOffset = baseOffset + 4, records = List( - new SimpleRecord(mockTime.milliseconds(), "e".getBytes), - new SimpleRecord(mockTime.milliseconds(), "f".getBytes))) - moreRecords.batches.forEach(_.setPartitionLeaderEpoch(0)) - log.appendAsFollower(moreRecords) + val moreRecords = TestUtils.records( + baseOffset = baseOffset + 4, + records = List( + new SimpleRecord(mockTime.milliseconds(), "e".getBytes), + new SimpleRecord(mockTime.milliseconds(), "f".getBytes) + ) + ) + moreRecords.batches.forEach(_.setPartitionLeaderEpoch(partitionLeaderEpoch)) + log.appendAsFollower(moreRecords, partitionLeaderEpoch) log.truncateTo(baseOffset + 4) @@ -837,15 +855,23 @@ class UnifiedLogTest { val logConfig = LogTestUtils.createLogConfig(segmentBytes = 2048 * 5) val log = createLog(logDir, logConfig) val pid = 1L - val epoch = 0.toShort + val producerEpoch = 0.toShort + val partitionLeaderEpoch = 0 val seq = 0 val baseOffset = 23L // create an empty batch - val records = TestUtils.records(producerId = pid, producerEpoch = epoch, sequence = seq, baseOffset = baseOffset, records = List( - new SimpleRecord(mockTime.milliseconds(), "key".getBytes, "a".getBytes), - new SimpleRecord(mockTime.milliseconds(), "key".getBytes, "b".getBytes))) - records.batches.forEach(_.setPartitionLeaderEpoch(0)) + val records = TestUtils.records( + producerId = pid, + producerEpoch = producerEpoch, + sequence = seq, + baseOffset = baseOffset, + records = List( + new SimpleRecord(mockTime.milliseconds(), "key".getBytes, "a".getBytes), + new SimpleRecord(mockTime.milliseconds(), "key".getBytes, "b".getBytes) + ) + ) + records.batches.forEach(_.setPartitionLeaderEpoch(partitionLeaderEpoch)) val filtered = ByteBuffer.allocate(2048) records.filterTo(new RecordFilter(0, 0) { @@ -856,14 +882,18 @@ class UnifiedLogTest { filtered.flip() val filteredRecords = MemoryRecords.readableRecords(filtered) - log.appendAsFollower(filteredRecords) + log.appendAsFollower(filteredRecords, partitionLeaderEpoch) // append some more data and then truncate to force rebuilding of the PID map - val moreRecords = TestUtils.records(baseOffset = baseOffset + 2, records = List( - new SimpleRecord(mockTime.milliseconds(), "e".getBytes), - new SimpleRecord(mockTime.milliseconds(), "f".getBytes))) - moreRecords.batches.forEach(_.setPartitionLeaderEpoch(0)) - log.appendAsFollower(moreRecords) + val moreRecords = TestUtils.records( + baseOffset = baseOffset + 2, + records = List( + new SimpleRecord(mockTime.milliseconds(), "e".getBytes), + new SimpleRecord(mockTime.milliseconds(), "f".getBytes) + ) + ) + moreRecords.batches.forEach(_.setPartitionLeaderEpoch(partitionLeaderEpoch)) + log.appendAsFollower(moreRecords, partitionLeaderEpoch) log.truncateTo(baseOffset + 2) @@ -879,17 +909,25 @@ class UnifiedLogTest { val logConfig = LogTestUtils.createLogConfig(segmentBytes = 2048 * 5) val log = createLog(logDir, logConfig) val pid = 1L - val epoch = 0.toShort + val producerEpoch = 0.toShort + val partitionLeaderEpoch = 0 val seq = 0 val baseOffset = 23L // create a batch with a couple gaps to simulate compaction - val records = TestUtils.records(producerId = pid, producerEpoch = epoch, sequence = seq, baseOffset = baseOffset, records = List( - new SimpleRecord(mockTime.milliseconds(), "a".getBytes), - new SimpleRecord(mockTime.milliseconds(), "key".getBytes, "b".getBytes), - new SimpleRecord(mockTime.milliseconds(), "c".getBytes), - new SimpleRecord(mockTime.milliseconds(), "key".getBytes, "d".getBytes))) - records.batches.forEach(_.setPartitionLeaderEpoch(0)) + val records = TestUtils.records( + producerId = pid, + producerEpoch = producerEpoch, + sequence = seq, + baseOffset = baseOffset, + records = List( + new SimpleRecord(mockTime.milliseconds(), "a".getBytes), + new SimpleRecord(mockTime.milliseconds(), "key".getBytes, "b".getBytes), + new SimpleRecord(mockTime.milliseconds(), "c".getBytes), + new SimpleRecord(mockTime.milliseconds(), "key".getBytes, "d".getBytes) + ) + ) + records.batches.forEach(_.setPartitionLeaderEpoch(partitionLeaderEpoch)) val filtered = ByteBuffer.allocate(2048) records.filterTo(new RecordFilter(0, 0) { @@ -900,7 +938,7 @@ class UnifiedLogTest { filtered.flip() val filteredRecords = MemoryRecords.readableRecords(filtered) - log.appendAsFollower(filteredRecords) + log.appendAsFollower(filteredRecords, partitionLeaderEpoch) val activeProducers = log.activeProducersWithLastSequence assertTrue(activeProducers.contains(pid)) @@ -1330,33 +1368,44 @@ class UnifiedLogTest { // create a log val log = createLog(logDir, new LogConfig(new Properties)) - val epoch: Short = 0 + val producerEpoch: Short = 0 + val partitionLeaderEpoch = 0 val buffer = ByteBuffer.allocate(512) - var builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, Compression.NONE, - TimestampType.LOG_APPEND_TIME, 0L, mockTime.milliseconds(), 1L, epoch, 0, false, 0) + var builder = MemoryRecords.builder( + buffer, RecordBatch.MAGIC_VALUE_V2, Compression.NONE, + TimestampType.LOG_APPEND_TIME, 0L, mockTime.milliseconds(), 1L, producerEpoch, 0, false, + partitionLeaderEpoch + ) builder.append(new SimpleRecord("key".getBytes, "value".getBytes)) builder.close() builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, Compression.NONE, - TimestampType.LOG_APPEND_TIME, 1L, mockTime.milliseconds(), 2L, epoch, 0, false, 0) + TimestampType.LOG_APPEND_TIME, 1L, mockTime.milliseconds(), 2L, producerEpoch, 0, false, + partitionLeaderEpoch) builder.append(new SimpleRecord("key".getBytes, "value".getBytes)) builder.close() - builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, Compression.NONE, - TimestampType.LOG_APPEND_TIME, 2L, mockTime.milliseconds(), 3L, epoch, 0, false, 0) + builder = MemoryRecords.builder( + buffer, RecordBatch.MAGIC_VALUE_V2, Compression.NONE, + TimestampType.LOG_APPEND_TIME, 2L, mockTime.milliseconds(), 3L, producerEpoch, 0, false, + partitionLeaderEpoch + ) builder.append(new SimpleRecord("key".getBytes, "value".getBytes)) builder.close() - builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, Compression.NONE, - TimestampType.LOG_APPEND_TIME, 3L, mockTime.milliseconds(), 4L, epoch, 0, false, 0) + builder = MemoryRecords.builder( + buffer, RecordBatch.MAGIC_VALUE_V2, Compression.NONE, + TimestampType.LOG_APPEND_TIME, 3L, mockTime.milliseconds(), 4L, producerEpoch, 0, false, + partitionLeaderEpoch + ) builder.append(new SimpleRecord("key".getBytes, "value".getBytes)) builder.close() buffer.flip() val memoryRecords = MemoryRecords.readableRecords(buffer) - log.appendAsFollower(memoryRecords) + log.appendAsFollower(memoryRecords, partitionLeaderEpoch) log.flush(false) val fetchedData = LogTestUtils.readLog(log, 0, Int.MaxValue) @@ -1375,7 +1424,7 @@ class UnifiedLogTest { def testDuplicateAppendToFollower(): Unit = { val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1024 * 1024 * 5) val log = createLog(logDir, logConfig) - val epoch: Short = 0 + val producerEpoch: Short = 0 val pid = 1L val baseSequence = 0 val partitionLeaderEpoch = 0 @@ -1383,10 +1432,32 @@ class UnifiedLogTest { // this is a bit contrived. to trigger the duplicate case for a follower append, we have to append // a batch with matching sequence numbers, but valid increasing offsets assertEquals(0L, log.logEndOffset) - log.appendAsFollower(MemoryRecords.withIdempotentRecords(0L, Compression.NONE, pid, epoch, baseSequence, - partitionLeaderEpoch, new SimpleRecord("a".getBytes), new SimpleRecord("b".getBytes))) - log.appendAsFollower(MemoryRecords.withIdempotentRecords(2L, Compression.NONE, pid, epoch, baseSequence, - partitionLeaderEpoch, new SimpleRecord("a".getBytes), new SimpleRecord("b".getBytes))) + log.appendAsFollower( + MemoryRecords.withIdempotentRecords( + 0L, + Compression.NONE, + pid, + producerEpoch, + baseSequence, + partitionLeaderEpoch, + new SimpleRecord("a".getBytes), + new SimpleRecord("b".getBytes) + ), + partitionLeaderEpoch + ) + log.appendAsFollower( + MemoryRecords.withIdempotentRecords( + 2L, + Compression.NONE, + pid, + producerEpoch, + baseSequence, + partitionLeaderEpoch, + new SimpleRecord("a".getBytes), + new SimpleRecord("b".getBytes) + ), + partitionLeaderEpoch + ) // Ensure that even the duplicate sequences are accepted on the follower. assertEquals(4L, log.logEndOffset) @@ -1399,48 +1470,49 @@ class UnifiedLogTest { val pid1 = 1L val pid2 = 2L - val epoch: Short = 0 + val producerEpoch: Short = 0 val buffer = ByteBuffer.allocate(512) // pid1 seq = 0 var builder = MemoryRecords.builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, Compression.NONE, - TimestampType.LOG_APPEND_TIME, 0L, mockTime.milliseconds(), pid1, epoch, 0) + TimestampType.LOG_APPEND_TIME, 0L, mockTime.milliseconds(), pid1, producerEpoch, 0) builder.append(new SimpleRecord("key".getBytes, "value".getBytes)) builder.close() // pid2 seq = 0 builder = MemoryRecords.builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, Compression.NONE, - TimestampType.LOG_APPEND_TIME, 1L, mockTime.milliseconds(), pid2, epoch, 0) + TimestampType.LOG_APPEND_TIME, 1L, mockTime.milliseconds(), pid2, producerEpoch, 0) builder.append(new SimpleRecord("key".getBytes, "value".getBytes)) builder.close() // pid1 seq = 1 builder = MemoryRecords.builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, Compression.NONE, - TimestampType.LOG_APPEND_TIME, 2L, mockTime.milliseconds(), pid1, epoch, 1) + TimestampType.LOG_APPEND_TIME, 2L, mockTime.milliseconds(), pid1, producerEpoch, 1) builder.append(new SimpleRecord("key".getBytes, "value".getBytes)) builder.close() // pid2 seq = 1 builder = MemoryRecords.builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, Compression.NONE, - TimestampType.LOG_APPEND_TIME, 3L, mockTime.milliseconds(), pid2, epoch, 1) + TimestampType.LOG_APPEND_TIME, 3L, mockTime.milliseconds(), pid2, producerEpoch, 1) builder.append(new SimpleRecord("key".getBytes, "value".getBytes)) builder.close() // // pid1 seq = 1 (duplicate) builder = MemoryRecords.builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, Compression.NONE, - TimestampType.LOG_APPEND_TIME, 4L, mockTime.milliseconds(), pid1, epoch, 1) + TimestampType.LOG_APPEND_TIME, 4L, mockTime.milliseconds(), pid1, producerEpoch, 1) builder.append(new SimpleRecord("key".getBytes, "value".getBytes)) builder.close() buffer.flip() + val epoch = 0 val records = MemoryRecords.readableRecords(buffer) - records.batches.forEach(_.setPartitionLeaderEpoch(0)) + records.batches.forEach(_.setPartitionLeaderEpoch(epoch)) // Ensure that batches with duplicates are accepted on the follower. assertEquals(0L, log.logEndOffset) - log.appendAsFollower(records) + log.appendAsFollower(records, epoch) assertEquals(5L, log.logEndOffset) } @@ -1582,8 +1654,12 @@ class UnifiedLogTest { val records = messageIds.map(id => new SimpleRecord(id.toString.getBytes)) // now test the case that we give the offsets and use non-sequential offsets - for (i <- records.indices) - log.appendAsFollower(MemoryRecords.withRecords(messageIds(i), Compression.NONE, 0, records(i))) + for (i <- records.indices) { + log.appendAsFollower( + MemoryRecords.withRecords(messageIds(i), Compression.NONE, 0, records(i)), + Int.MaxValue + ) + } for (i <- 50 until messageIds.max) { val idx = messageIds.indexWhere(_ >= i) val read = LogTestUtils.readLog(log, i, 100).records.records.iterator.next() @@ -1630,8 +1706,12 @@ class UnifiedLogTest { val records = messageIds.map(id => new SimpleRecord(id.toString.getBytes)) // now test the case that we give the offsets and use non-sequential offsets - for (i <- records.indices) - log.appendAsFollower(MemoryRecords.withRecords(messageIds(i), Compression.NONE, 0, records(i))) + for (i <- records.indices) { + log.appendAsFollower( + MemoryRecords.withRecords(messageIds(i), Compression.NONE, 0, records(i)), + Int.MaxValue + ) + } for (i <- 50 until messageIds.max) { val idx = messageIds.indexWhere(_ >= i) @@ -1655,8 +1735,12 @@ class UnifiedLogTest { val records = messageIds.map(id => new SimpleRecord(id.toString.getBytes)) // now test the case that we give the offsets and use non-sequential offsets - for (i <- records.indices) - log.appendAsFollower(MemoryRecords.withRecords(messageIds(i), Compression.NONE, 0, records(i))) + for (i <- records.indices) { + log.appendAsFollower( + MemoryRecords.withRecords(messageIds(i), Compression.NONE, 0, records(i)), + Int.MaxValue + ) + } for (i <- 50 until messageIds.max) { assertEquals(MemoryRecords.EMPTY, LogTestUtils.readLog(log, i, maxLength = 0, minOneMessage = false).records) @@ -1904,9 +1988,94 @@ class UnifiedLogTest { val log = createLog(logDir, LogTestUtils.createLogConfig(maxMessageBytes = second.sizeInBytes - 1)) - log.appendAsFollower(first) + log.appendAsFollower(first, Int.MaxValue) // the second record is larger then limit but appendAsFollower does not validate the size. - log.appendAsFollower(second) + log.appendAsFollower(second, Int.MaxValue) + } + + @ParameterizedTest + @ArgumentsSource(classOf[InvalidMemoryRecordsProvider]) + def testInvalidMemoryRecords(records: MemoryRecords, expectedException: Optional[Class[Exception]]): Unit = { + val logConfig = LogTestUtils.createLogConfig() + val log = createLog(logDir, logConfig) + val previousEndOffset = log.logEndOffsetMetadata.messageOffset + + if (expectedException.isPresent()) { + assertThrows( + expectedException.get(), + () => log.appendAsFollower(records, Int.MaxValue) + ) + } else { + log.appendAsFollower(records, Int.MaxValue) + } + + assertEquals(previousEndOffset, log.logEndOffsetMetadata.messageOffset) + } + + @Property(tries = 100, afterFailure = AfterFailureMode.SAMPLE_ONLY) + def testRandomRecords( + @ForAll(supplier = classOf[ArbitraryMemoryRecords]) records: MemoryRecords + ): Unit = { + val tempDir = TestUtils.tempDir() + val logDir = TestUtils.randomPartitionLogDir(tempDir) + try { + val logConfig = LogTestUtils.createLogConfig() + val log = createLog(logDir, logConfig) + val previousEndOffset = log.logEndOffsetMetadata.messageOffset + + // Depending on the corruption, unified log sometimes throws and sometimes returns an + // empty set of batches + assertThrows( + classOf[CorruptRecordException], + () => { + val info = log.appendAsFollower(records, Int.MaxValue) + if (info.firstOffset == JUnifiedLog.UNKNOWN_OFFSET) { + throw new CorruptRecordException("Unknown offset is test") + } + } + ) + + assertEquals(previousEndOffset, log.logEndOffsetMetadata.messageOffset) + } finally { + Utils.delete(tempDir) + } + } + + @Test + def testInvalidLeaderEpoch(): Unit = { + val logConfig = LogTestUtils.createLogConfig() + val log = createLog(logDir, logConfig) + val previousEndOffset = log.logEndOffsetMetadata.messageOffset + val epoch = log.latestEpoch.getOrElse(0) + 1 + val numberOfRecords = 10 + + val batchWithValidEpoch = MemoryRecords.withRecords( + previousEndOffset, + Compression.NONE, + epoch, + (0 until numberOfRecords).map(number => new SimpleRecord(number.toString.getBytes)): _* + ) + + val batchWithInvalidEpoch = MemoryRecords.withRecords( + previousEndOffset + numberOfRecords, + Compression.NONE, + epoch + 1, + (0 until numberOfRecords).map(number => new SimpleRecord(number.toString.getBytes)): _* + ) + + val buffer = ByteBuffer.allocate(batchWithValidEpoch.sizeInBytes() + batchWithInvalidEpoch.sizeInBytes()) + buffer.put(batchWithValidEpoch.buffer()) + buffer.put(batchWithInvalidEpoch.buffer()) + buffer.flip() + + val records = MemoryRecords.readableRecords(buffer) + + log.appendAsFollower(records, epoch) + + // Check that only the first batch was appended + assertEquals(previousEndOffset + numberOfRecords, log.logEndOffsetMetadata.messageOffset) + // Check that the last fetched epoch matches the first batch + assertEquals(epoch, log.latestEpoch.get) } @Test @@ -1987,7 +2156,7 @@ class UnifiedLogTest { val messages = (0 until numMessages).map { i => MemoryRecords.withRecords(100 + i, Compression.NONE, 0, new SimpleRecord(mockTime.milliseconds + i, i.toString.getBytes())) } - messages.foreach(log.appendAsFollower) + messages.foreach(message => log.appendAsFollower(message, Int.MaxValue)) val timeIndexEntries = log.logSegments.asScala.foldLeft(0) { (entries, segment) => entries + segment.timeIndex.entries } assertEquals(numMessages - 1, timeIndexEntries, s"There should be ${numMessages - 1} time index entries") assertEquals(mockTime.milliseconds + numMessages - 1, log.activeSegment.timeIndex.lastEntry.timestamp, @@ -2131,7 +2300,7 @@ class UnifiedLogTest { // The cache can be updated directly after a leader change. // The new latest offset should reflect the updated epoch. log.assignEpochStartOffset(2, 2L) - + assertEquals(new OffsetResultHolder(new TimestampAndOffset(ListOffsetsResponse.UNKNOWN_TIMESTAMP, 2L, Optional.of(2))), log.fetchOffsetByTimestamp(ListOffsetsRequest.LATEST_TIMESTAMP, Optional.of(remoteLogManager))) } @@ -2399,20 +2568,22 @@ class UnifiedLogTest { def testAppendWithOutOfOrderOffsetsThrowsException(): Unit = { val log = createLog(logDir, new LogConfig(new Properties)) + val epoch = 0 val appendOffsets = Seq(0L, 1L, 3L, 2L, 4L) val buffer = ByteBuffer.allocate(512) for (offset <- appendOffsets) { val builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, Compression.NONE, TimestampType.LOG_APPEND_TIME, offset, mockTime.milliseconds(), - 1L, 0, 0, false, 0) + 1L, 0, 0, false, epoch) builder.append(new SimpleRecord("key".getBytes, "value".getBytes)) builder.close() } buffer.flip() val memoryRecords = MemoryRecords.readableRecords(buffer) - assertThrows(classOf[OffsetsOutOfOrderException], () => - log.appendAsFollower(memoryRecords) + assertThrows( + classOf[OffsetsOutOfOrderException], + () => log.appendAsFollower(memoryRecords, epoch) ) } @@ -2427,9 +2598,11 @@ class UnifiedLogTest { for (magic <- magicVals; compressionType <- compressionTypes) { val compression = Compression.of(compressionType).build() val invalidRecord = MemoryRecords.withRecords(magic, compression, new SimpleRecord(1.toString.getBytes)) - assertThrows(classOf[UnexpectedAppendOffsetException], - () => log.appendAsFollower(invalidRecord), - () => s"Magic=$magic, compressionType=$compressionType") + assertThrows( + classOf[UnexpectedAppendOffsetException], + () => log.appendAsFollower(invalidRecord, Int.MaxValue), + () => s"Magic=$magic, compressionType=$compressionType" + ) } } @@ -2450,7 +2623,10 @@ class UnifiedLogTest { magicValue = magic, codec = Compression.of(compressionType).build(), baseOffset = firstOffset) - val exception = assertThrows(classOf[UnexpectedAppendOffsetException], () => log.appendAsFollower(records = batch)) + val exception = assertThrows( + classOf[UnexpectedAppendOffsetException], + () => log.appendAsFollower(records = batch, Int.MaxValue) + ) assertEquals(firstOffset, exception.firstOffset, s"Magic=$magic, compressionType=$compressionType, UnexpectedAppendOffsetException#firstOffset") assertEquals(firstOffset + 2, exception.lastOffset, s"Magic=$magic, compressionType=$compressionType, UnexpectedAppendOffsetException#lastOffset") } @@ -2549,9 +2725,16 @@ class UnifiedLogTest { log.appendAsLeader(TestUtils.records(List(new SimpleRecord("foo".getBytes()))), leaderEpoch = 5) assertEquals(OptionalInt.of(5), log.leaderEpochCache.latestEpoch) - log.appendAsFollower(TestUtils.records(List(new SimpleRecord("foo".getBytes())), - baseOffset = 1L, - magicValue = RecordVersion.V1.value)) + log.appendAsFollower( + TestUtils.records( + List( + new SimpleRecord("foo".getBytes()) + ), + baseOffset = 1L, + magicValue = RecordVersion.V1.value + ), + 5 + ) assertEquals(OptionalInt.empty, log.leaderEpochCache.latestEpoch) } @@ -2907,7 +3090,7 @@ class UnifiedLogTest { //When appending as follower (assignOffsets = false) for (i <- records.indices) - log.appendAsFollower(recordsForEpoch(i)) + log.appendAsFollower(recordsForEpoch(i), i) assertEquals(Some(42), log.latestEpoch) } @@ -2975,7 +3158,7 @@ class UnifiedLogTest { def append(epoch: Int, startOffset: Long, count: Int): Unit = { for (i <- 0 until count) - log.appendAsFollower(createRecords(startOffset + i, epoch)) + log.appendAsFollower(createRecords(startOffset + i, epoch), epoch) } //Given 2 segments, 10 messages per segment @@ -3209,7 +3392,7 @@ class UnifiedLogTest { buffer.flip() - appendAsFollower(log, MemoryRecords.readableRecords(buffer)) + appendAsFollower(log, MemoryRecords.readableRecords(buffer), epoch) val abortedTransactions = LogTestUtils.allAbortedTransactions(log) val expectedTransactions = List( @@ -3293,7 +3476,7 @@ class UnifiedLogTest { appendEndTxnMarkerToBuffer(buffer, pid, epoch, 10L, ControlRecordType.COMMIT, leaderEpoch = 1) buffer.flip() - log.appendAsFollower(MemoryRecords.readableRecords(buffer)) + log.appendAsFollower(MemoryRecords.readableRecords(buffer), epoch) LogTestUtils.appendEndTxnMarkerAsLeader(log, pid, epoch, ControlRecordType.ABORT, mockTime.milliseconds(), coordinatorEpoch = 2, leaderEpoch = 1) LogTestUtils.appendEndTxnMarkerAsLeader(log, pid, epoch, ControlRecordType.ABORT, mockTime.milliseconds(), coordinatorEpoch = 2, leaderEpoch = 1) @@ -3414,10 +3597,16 @@ class UnifiedLogTest { val log = createLog(logDir, logConfig) // append a few records - appendAsFollower(log, MemoryRecords.withRecords(Compression.NONE, - new SimpleRecord("a".getBytes), - new SimpleRecord("b".getBytes), - new SimpleRecord("c".getBytes)), 5) + appendAsFollower( + log, + MemoryRecords.withRecords( + Compression.NONE, + new SimpleRecord("a".getBytes), + new SimpleRecord("b".getBytes), + new SimpleRecord("c".getBytes) + ), + 5 + ) log.updateHighWatermark(3L) @@ -4484,9 +4673,9 @@ class UnifiedLogTest { builder.close() } - private def appendAsFollower(log: UnifiedLog, records: MemoryRecords, leaderEpoch: Int = 0): Unit = { + private def appendAsFollower(log: UnifiedLog, records: MemoryRecords, leaderEpoch: Int): Unit = { records.batches.forEach(_.setPartitionLeaderEpoch(leaderEpoch)) - log.appendAsFollower(records) + log.appendAsFollower(records, leaderEpoch) } private def createLog(dir: File, diff --git a/core/src/test/scala/unit/kafka/server/AbstractFetcherManagerTest.scala b/core/src/test/scala/unit/kafka/server/AbstractFetcherManagerTest.scala index 7528eefc420..d1a05e7d915 100644 --- a/core/src/test/scala/unit/kafka/server/AbstractFetcherManagerTest.scala +++ b/core/src/test/scala/unit/kafka/server/AbstractFetcherManagerTest.scala @@ -328,9 +328,12 @@ class AbstractFetcherManagerTest { fetchBackOffMs = 0, brokerTopicStats = new BrokerTopicStats) { - override protected def processPartitionData(topicPartition: TopicPartition, fetchOffset: Long, partitionData: FetchData): Option[LogAppendInfo] = { - None - } + override protected def processPartitionData( + topicPartition: TopicPartition, + fetchOffset: Long, + partitionLeaderEpoch: Int, + partitionData: FetchData + ): Option[LogAppendInfo] = None override protected def truncate(topicPartition: TopicPartition, truncationState: OffsetTruncationState): Unit = {} diff --git a/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadTest.scala b/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadTest.scala index 0e38e9dfcb0..19856e1cdd2 100644 --- a/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadTest.scala +++ b/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadTest.scala @@ -630,6 +630,7 @@ class AbstractFetcherThreadTest { @Test def testFollowerFetchOutOfRangeLow(): Unit = { + val leaderEpoch = 4 val partition = new TopicPartition("topic", 0) val mockLeaderEndpoint = new MockLeaderEndPoint(version = version) val mockTierStateMachine = new MockTierStateMachine(mockLeaderEndpoint) @@ -639,14 +640,19 @@ class AbstractFetcherThreadTest { val replicaLog = Seq( mkBatch(baseOffset = 0, leaderEpoch = 0, new SimpleRecord("a".getBytes))) - val replicaState = PartitionState(replicaLog, leaderEpoch = 0, highWatermark = 0L) + val replicaState = PartitionState(replicaLog, leaderEpoch = leaderEpoch, highWatermark = 0L) fetcher.setReplicaState(partition, replicaState) - fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 3L, leaderEpoch = 0))) + fetcher.addPartitions( + Map( + partition -> initialFetchState(topicIds.get(partition.topic), 3L, leaderEpoch = leaderEpoch) + ) + ) val leaderLog = Seq( - mkBatch(baseOffset = 2, leaderEpoch = 4, new SimpleRecord("c".getBytes))) + mkBatch(baseOffset = 2, leaderEpoch = leaderEpoch, new SimpleRecord("c".getBytes)) + ) - val leaderState = PartitionState(leaderLog, leaderEpoch = 0, highWatermark = 2L) + val leaderState = PartitionState(leaderLog, leaderEpoch = leaderEpoch, highWatermark = 2L) fetcher.mockLeader.setLeaderState(partition, leaderState) fetcher.mockLeader.setReplicaPartitionStateCallback(fetcher.replicaPartitionState) @@ -671,6 +677,7 @@ class AbstractFetcherThreadTest { @Test def testRetryAfterUnknownLeaderEpochInLatestOffsetFetch(): Unit = { + val leaderEpoch = 4 val partition = new TopicPartition("topic", 0) val mockLeaderEndPoint = new MockLeaderEndPoint(version = version) { val tries = new AtomicInteger(0) @@ -685,16 +692,18 @@ class AbstractFetcherThreadTest { // The follower begins from an offset which is behind the leader's log start offset val replicaLog = Seq( - mkBatch(baseOffset = 0, leaderEpoch = 0, new SimpleRecord("a".getBytes))) + mkBatch(baseOffset = 0, leaderEpoch = 0, new SimpleRecord("a".getBytes)) + ) - val replicaState = PartitionState(replicaLog, leaderEpoch = 0, highWatermark = 0L) + val replicaState = PartitionState(replicaLog, leaderEpoch = leaderEpoch, highWatermark = 0L) fetcher.setReplicaState(partition, replicaState) - fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 3L, leaderEpoch = 0))) + fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 3L, leaderEpoch = leaderEpoch))) val leaderLog = Seq( - mkBatch(baseOffset = 2, leaderEpoch = 4, new SimpleRecord("c".getBytes))) + mkBatch(baseOffset = 2, leaderEpoch = 4, new SimpleRecord("c".getBytes)) + ) - val leaderState = PartitionState(leaderLog, leaderEpoch = 0, highWatermark = 2L) + val leaderState = PartitionState(leaderLog, leaderEpoch = leaderEpoch, highWatermark = 2L) fetcher.mockLeader.setLeaderState(partition, leaderState) fetcher.mockLeader.setReplicaPartitionStateCallback(fetcher.replicaPartitionState) @@ -712,6 +721,46 @@ class AbstractFetcherThreadTest { assertEquals(leaderState.highWatermark, replicaState.highWatermark) } + @Test + def testReplicateBatchesUpToLeaderEpoch(): Unit = { + val leaderEpoch = 4 + val partition = new TopicPartition("topic", 0) + val mockLeaderEndpoint = new MockLeaderEndPoint(version = version) + val mockTierStateMachine = new MockTierStateMachine(mockLeaderEndpoint) + val fetcher = new MockFetcherThread(mockLeaderEndpoint, mockTierStateMachine, failedPartitions = failedPartitions) + + val replicaState = PartitionState(Seq(), leaderEpoch = leaderEpoch, highWatermark = 0L) + fetcher.setReplicaState(partition, replicaState) + fetcher.addPartitions( + Map( + partition -> initialFetchState(topicIds.get(partition.topic), 3L, leaderEpoch = leaderEpoch) + ) + ) + + val leaderLog = Seq( + mkBatch(baseOffset = 0, leaderEpoch = leaderEpoch - 1, new SimpleRecord("c".getBytes)), + mkBatch(baseOffset = 1, leaderEpoch = leaderEpoch, new SimpleRecord("d".getBytes)), + mkBatch(baseOffset = 2, leaderEpoch = leaderEpoch + 1, new SimpleRecord("e".getBytes)) + ) + + val leaderState = PartitionState(leaderLog, leaderEpoch = leaderEpoch, highWatermark = 0L) + fetcher.mockLeader.setLeaderState(partition, leaderState) + fetcher.mockLeader.setReplicaPartitionStateCallback(fetcher.replicaPartitionState) + + assertEquals(Option(Fetching), fetcher.fetchState(partition).map(_.state)) + assertEquals(0, replicaState.logStartOffset) + assertEquals(List(), replicaState.log.toList) + + TestUtils.waitUntilTrue(() => { + fetcher.doWork() + fetcher.replicaPartitionState(partition).log == fetcher.mockLeader.leaderPartitionState(partition).log.dropRight(1) + }, "Failed to reconcile leader and follower logs up to the leader epoch") + + assertEquals(leaderState.logStartOffset, replicaState.logStartOffset) + assertEquals(leaderState.logEndOffset - 1, replicaState.logEndOffset) + assertEquals(leaderState.highWatermark, replicaState.highWatermark) + } + @Test def testCorruptMessage(): Unit = { val partition = new TopicPartition("topic", 0) @@ -897,11 +946,16 @@ class AbstractFetcherThreadTest { val mockLeaderEndpoint = new MockLeaderEndPoint(version = version) val mockTierStateMachine = new MockTierStateMachine(mockLeaderEndpoint) val fetcherForAppend = new MockFetcherThread(mockLeaderEndpoint, mockTierStateMachine, failedPartitions = failedPartitions) { - override def processPartitionData(topicPartition: TopicPartition, fetchOffset: Long, partitionData: FetchData): Option[LogAppendInfo] = { + override def processPartitionData( + topicPartition: TopicPartition, + fetchOffset: Long, + partitionLeaderEpoch: Int, + partitionData: FetchData + ): Option[LogAppendInfo] = { if (topicPartition == partition1) { throw new KafkaException() } else { - super.processPartitionData(topicPartition, fetchOffset, partitionData) + super.processPartitionData(topicPartition, fetchOffset, partitionLeaderEpoch, partitionData) } } } @@ -1003,9 +1057,14 @@ class AbstractFetcherThreadTest { val mockLeaderEndpoint = new MockLeaderEndPoint(version = version) val mockTierStateMachine = new MockTierStateMachine(mockLeaderEndpoint) val fetcher = new MockFetcherThread(mockLeaderEndpoint, mockTierStateMachine) { - override def processPartitionData(topicPartition: TopicPartition, fetchOffset: Long, partitionData: FetchData): Option[LogAppendInfo] = { + override def processPartitionData( + topicPartition: TopicPartition, + fetchOffset: Long, + partitionLeaderEpoch: Int, + partitionData: FetchData + ): Option[LogAppendInfo] = { processPartitionDataCalls += 1 - super.processPartitionData(topicPartition, fetchOffset, partitionData) + super.processPartitionData(topicPartition, fetchOffset, partitionLeaderEpoch, partitionData) } override def truncate(topicPartition: TopicPartition, truncationState: OffsetTruncationState): Unit = { diff --git a/core/src/test/scala/unit/kafka/server/MockFetcherThread.scala b/core/src/test/scala/unit/kafka/server/MockFetcherThread.scala index 5d50de04095..ff1e9196568 100644 --- a/core/src/test/scala/unit/kafka/server/MockFetcherThread.scala +++ b/core/src/test/scala/unit/kafka/server/MockFetcherThread.scala @@ -66,9 +66,12 @@ class MockFetcherThread(val mockLeader: MockLeaderEndPoint, partitions } - override def processPartitionData(topicPartition: TopicPartition, - fetchOffset: Long, - partitionData: FetchData): Option[LogAppendInfo] = { + override def processPartitionData( + topicPartition: TopicPartition, + fetchOffset: Long, + leaderEpochForReplica: Int, + partitionData: FetchData + ): Option[LogAppendInfo] = { val state = replicaPartitionState(topicPartition) if (leader.isTruncationOnFetchSupported && FetchResponse.isDivergingEpoch(partitionData)) { @@ -87,17 +90,24 @@ class MockFetcherThread(val mockLeader: MockLeaderEndPoint, var shallowOffsetOfMaxTimestamp = -1L var lastOffset = state.logEndOffset var lastEpoch: OptionalInt = OptionalInt.empty() + var skipRemainingBatches = false for (batch <- batches) { batch.ensureValid() - if (batch.maxTimestamp > maxTimestamp) { - maxTimestamp = batch.maxTimestamp - shallowOffsetOfMaxTimestamp = batch.baseOffset + + skipRemainingBatches = skipRemainingBatches || hasHigherPartitionLeaderEpoch(batch, leaderEpochForReplica); + if (skipRemainingBatches) { + info(s"Skipping batch $batch because leader epoch is $leaderEpochForReplica") + } else { + if (batch.maxTimestamp > maxTimestamp) { + maxTimestamp = batch.maxTimestamp + shallowOffsetOfMaxTimestamp = batch.baseOffset + } + state.log.append(batch) + state.logEndOffset = batch.nextOffset + lastOffset = batch.lastOffset + lastEpoch = OptionalInt.of(batch.partitionLeaderEpoch) } - state.log.append(batch) - state.logEndOffset = batch.nextOffset - lastOffset = batch.lastOffset - lastEpoch = OptionalInt.of(batch.partitionLeaderEpoch) } state.logStartOffset = partitionData.logStartOffset @@ -115,6 +125,11 @@ class MockFetcherThread(val mockLeader: MockLeaderEndPoint, batches.headOption.map(_.lastOffset).getOrElse(-1))) } + private def hasHigherPartitionLeaderEpoch(batch: RecordBatch, leaderEpoch: Int): Boolean = { + batch.partitionLeaderEpoch() != RecordBatch.NO_PARTITION_LEADER_EPOCH && + batch.partitionLeaderEpoch() > leaderEpoch + } + override def truncate(topicPartition: TopicPartition, truncationState: OffsetTruncationState): Unit = { val state = replicaPartitionState(topicPartition) state.log = state.log.takeWhile { batch => diff --git a/core/src/test/scala/unit/kafka/server/ReplicaFetcherThreadTest.scala b/core/src/test/scala/unit/kafka/server/ReplicaFetcherThreadTest.scala index 6526d6628c3..b0ee5a2d148 100644 --- a/core/src/test/scala/unit/kafka/server/ReplicaFetcherThreadTest.scala +++ b/core/src/test/scala/unit/kafka/server/ReplicaFetcherThreadTest.scala @@ -281,9 +281,22 @@ class ReplicaFetcherThreadTest { val fetchSessionHandler = new FetchSessionHandler(logContext, brokerEndPoint.id) val leader = new RemoteLeaderEndPoint(logContext.logPrefix, mockNetwork, fetchSessionHandler, config, replicaManager, quota, () => MetadataVersion.MINIMUM_VERSION, () => 1) - val thread = new ReplicaFetcherThread("bob", leader, config, failedPartitions, - replicaManager, quota, logContext.logPrefix, () => MetadataVersion.MINIMUM_VERSION) { - override def processPartitionData(topicPartition: TopicPartition, fetchOffset: Long, partitionData: FetchData): Option[LogAppendInfo] = None + val thread = new ReplicaFetcherThread( + "bob", + leader, + config, + failedPartitions, + replicaManager, + quota, + logContext.logPrefix, + () => MetadataVersion.MINIMUM_VERSION + ) { + override def processPartitionData( + topicPartition: TopicPartition, + fetchOffset: Long, + partitionLeaderEpoch: Int, + partitionData: FetchData + ): Option[LogAppendInfo] = None } thread.addPartitions(Map(t1p0 -> initialFetchState(Some(topicId1), initialLEO), t1p1 -> initialFetchState(Some(topicId1), initialLEO))) val partitions = Set(t1p0, t1p1) @@ -379,7 +392,7 @@ class ReplicaFetcherThreadTest { when(replicaManager.getPartitionOrException(t1p0)).thenReturn(partition) when(partition.localLogOrException).thenReturn(log) - when(partition.appendRecordsToFollowerOrFutureReplica(any(), any())).thenReturn(None) + when(partition.appendRecordsToFollowerOrFutureReplica(any(), any(), any())).thenReturn(None) val logContext = new LogContext(s"[ReplicaFetcher replicaId=${config.brokerId}, leaderId=${brokerEndPoint.id}, fetcherId=0] ") @@ -460,7 +473,7 @@ class ReplicaFetcherThreadTest { when(replicaManager.brokerTopicStats).thenReturn(mock(classOf[BrokerTopicStats])) when(partition.localLogOrException).thenReturn(log) - when(partition.appendRecordsToFollowerOrFutureReplica(any(), any())).thenReturn(Some(new LogAppendInfo( + when(partition.appendRecordsToFollowerOrFutureReplica(any(), any(), any())).thenReturn(Some(new LogAppendInfo( -1, 0, OptionalInt.empty, @@ -679,7 +692,7 @@ class ReplicaFetcherThreadTest { val partition: Partition = mock(classOf[Partition]) when(partition.localLogOrException).thenReturn(log) - when(partition.appendRecordsToFollowerOrFutureReplica(any[MemoryRecords], any[Boolean])).thenReturn(appendInfo) + when(partition.appendRecordsToFollowerOrFutureReplica(any[MemoryRecords], any[Boolean], any[Int])).thenReturn(appendInfo) // Capture the argument at the time of invocation. val completeDelayedFetchRequestsArgument = mutable.Buffer.empty[TopicPartition] @@ -710,8 +723,8 @@ class ReplicaFetcherThreadTest { .setRecords(records) .setHighWatermark(highWatermarkReceivedFromLeader) - thread.processPartitionData(tp0, 0, partitionData.setPartitionIndex(0)) - thread.processPartitionData(tp1, 0, partitionData.setPartitionIndex(1)) + thread.processPartitionData(tp0, 0, Int.MaxValue, partitionData.setPartitionIndex(0)) + thread.processPartitionData(tp1, 0, Int.MaxValue, partitionData.setPartitionIndex(1)) verify(replicaManager, times(0)).completeDelayedFetchRequests(any[Seq[TopicPartition]]) thread.doWork() @@ -761,7 +774,7 @@ class ReplicaFetcherThreadTest { when(partition.localLogOrException).thenReturn(log) when(partition.isReassigning).thenReturn(isReassigning) when(partition.isAddingLocalReplica).thenReturn(isReassigning) - when(partition.appendRecordsToFollowerOrFutureReplica(records, isFuture = false)).thenReturn(None) + when(partition.appendRecordsToFollowerOrFutureReplica(records, isFuture = false, Int.MaxValue)).thenReturn(None) val replicaManager: ReplicaManager = mock(classOf[ReplicaManager]) when(replicaManager.getPartitionOrException(any[TopicPartition])).thenReturn(partition) @@ -785,7 +798,7 @@ class ReplicaFetcherThreadTest { .setLastStableOffset(0) .setLogStartOffset(0) .setRecords(records) - thread.processPartitionData(t1p0, 0, partitionData) + thread.processPartitionData(t1p0, 0, Int.MaxValue, partitionData) if (isReassigning) assertEquals(records.sizeInBytes(), brokerTopicStats.allTopicsStats.reassignmentBytesInPerSec.get.count()) diff --git a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala index 6a27776babc..59d9b4b1a63 100644 --- a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala +++ b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala @@ -5253,9 +5253,12 @@ class ReplicaManagerTest { replicaManager.getPartition(topicPartition) match { case HostedPartition.Online(partition) => partition.appendRecordsToFollowerOrFutureReplica( - records = MemoryRecords.withRecords(Compression.NONE, 0, - new SimpleRecord("first message".getBytes)), - isFuture = false + records = MemoryRecords.withRecords( + Compression.NONE, 0, + new SimpleRecord("first message".getBytes) + ), + isFuture = false, + partitionLeaderEpoch = 0 ) case _ => diff --git a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/fetcher/ReplicaFetcherThreadBenchmark.java b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/fetcher/ReplicaFetcherThreadBenchmark.java index d6bec0c8016..d9091bd5b57 100644 --- a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/fetcher/ReplicaFetcherThreadBenchmark.java +++ b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/fetcher/ReplicaFetcherThreadBenchmark.java @@ -335,8 +335,12 @@ public class ReplicaFetcherThreadBenchmark { } @Override - public Option processPartitionData(TopicPartition topicPartition, long fetchOffset, - FetchResponseData.PartitionData partitionData) { + public Option processPartitionData( + TopicPartition topicPartition, + long fetchOffset, + int partitionLeaderEpoch, + FetchResponseData.PartitionData partitionData + ) { return Option.empty(); } } diff --git a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/partition/PartitionMakeFollowerBenchmark.java b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/partition/PartitionMakeFollowerBenchmark.java index feb8c985904..a345f3907b8 100644 --- a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/partition/PartitionMakeFollowerBenchmark.java +++ b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/partition/PartitionMakeFollowerBenchmark.java @@ -134,7 +134,7 @@ public class PartitionMakeFollowerBenchmark { int initialOffSet = 0; while (true) { MemoryRecords memoryRecords = MemoryRecords.withRecords(initialOffSet, Compression.NONE, 0, simpleRecords); - partition.appendRecordsToFollowerOrFutureReplica(memoryRecords, false); + partition.appendRecordsToFollowerOrFutureReplica(memoryRecords, false, Integer.MAX_VALUE); initialOffSet = initialOffSet + 2; } }); diff --git a/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java b/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java index 538eec64e91..34b5770cf70 100644 --- a/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java +++ b/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java @@ -16,6 +16,7 @@ */ package org.apache.kafka.raft; +import org.apache.kafka.common.InvalidRecordException; import org.apache.kafka.common.KafkaException; import org.apache.kafka.common.Node; import org.apache.kafka.common.TopicPartition; @@ -23,6 +24,7 @@ import org.apache.kafka.common.Uuid; import org.apache.kafka.common.compress.Compression; import org.apache.kafka.common.config.ConfigException; import org.apache.kafka.common.errors.ClusterAuthorizationException; +import org.apache.kafka.common.errors.CorruptRecordException; import org.apache.kafka.common.errors.NotLeaderOrFollowerException; import org.apache.kafka.common.feature.SupportedVersionRange; import org.apache.kafka.common.memory.MemoryPool; @@ -50,6 +52,7 @@ import org.apache.kafka.common.network.ListenerName; import org.apache.kafka.common.protocol.ApiKeys; import org.apache.kafka.common.protocol.ApiMessage; import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.record.DefaultRecordBatch; import org.apache.kafka.common.record.MemoryRecords; import org.apache.kafka.common.record.Records; import org.apache.kafka.common.record.UnalignedMemoryRecords; @@ -93,8 +96,10 @@ import org.apache.kafka.snapshot.SnapshotWriter; import org.slf4j.Logger; import java.net.InetSocketAddress; +import java.nio.ByteBuffer; import java.util.Collection; import java.util.Collections; +import java.util.HexFormat; import java.util.IdentityHashMap; import java.util.Iterator; import java.util.List; @@ -1785,10 +1790,7 @@ public final class KafkaRaftClient implements RaftClient { } } } else { - Records records = FetchResponse.recordsOrFail(partitionResponse); - if (records.sizeInBytes() > 0) { - appendAsFollower(records); - } + appendAsFollower(FetchResponse.recordsOrFail(partitionResponse)); OptionalLong highWatermark = partitionResponse.highWatermark() < 0 ? OptionalLong.empty() : OptionalLong.of(partitionResponse.highWatermark()); @@ -1802,10 +1804,31 @@ public final class KafkaRaftClient implements RaftClient { } } - private void appendAsFollower( - Records records - ) { - LogAppendInfo info = log.appendAsFollower(records); + private static String convertToHexadecimal(Records records) { + ByteBuffer buffer = ((MemoryRecords) records).buffer(); + byte[] bytes = new byte[Math.min(buffer.remaining(), DefaultRecordBatch.RECORD_BATCH_OVERHEAD)]; + buffer.get(bytes); + + return HexFormat.of().formatHex(bytes); + } + + private void appendAsFollower(Records records) { + if (records.sizeInBytes() == 0) { + // Nothing to do if there are no bytes in the response + return; + } + + try { + var info = log.appendAsFollower(records, quorum.epoch()); + kafkaRaftMetrics.updateFetchedRecords(info.lastOffset - info.firstOffset + 1); + } catch (CorruptRecordException | InvalidRecordException e) { + logger.info( + "Failed to append the records with the batch header '{}' to the log", + convertToHexadecimal(records), + e + ); + } + if (quorum.isVoter() || followersAlwaysFlush) { // the leader only requires that voters have flushed their log before sending a Fetch // request. Because of reconfiguration some observers (that are getting added to the @@ -1817,14 +1840,11 @@ public final class KafkaRaftClient implements RaftClient { partitionState.updateState(); OffsetAndEpoch endOffset = endOffset(); - kafkaRaftMetrics.updateFetchedRecords(info.lastOffset - info.firstOffset + 1); kafkaRaftMetrics.updateLogEnd(endOffset); logger.trace("Follower end offset updated to {} after append", endOffset); } - private LogAppendInfo appendAsLeader( - Records records - ) { + private LogAppendInfo appendAsLeader(Records records) { LogAppendInfo info = log.appendAsLeader(records, quorum.epoch()); partitionState.updateState(); @@ -3475,6 +3495,10 @@ public final class KafkaRaftClient implements RaftClient { () -> new NotLeaderException("Append failed because the replica is not the current leader") ); + if (records.isEmpty()) { + throw new IllegalArgumentException("Append failed because there are no records"); + } + BatchAccumulator accumulator = leaderState.accumulator(); boolean isFirstAppend = accumulator.isEmpty(); final long offset = accumulator.append(epoch, records, true); diff --git a/raft/src/main/java/org/apache/kafka/raft/ReplicatedLog.java b/raft/src/main/java/org/apache/kafka/raft/ReplicatedLog.java index a22f7fd73cd..8f5ba31a45d 100644 --- a/raft/src/main/java/org/apache/kafka/raft/ReplicatedLog.java +++ b/raft/src/main/java/org/apache/kafka/raft/ReplicatedLog.java @@ -31,6 +31,8 @@ public interface ReplicatedLog extends AutoCloseable { * be written atomically in a single batch or the call will fail and raise an * exception. * + * @param records records batches to append + * @param epoch the epoch of the replica * @return the metadata information of the appended batch * @throws IllegalArgumentException if the record set is empty * @throws RuntimeException if the batch base offset doesn't match the log end offset @@ -42,11 +44,16 @@ public interface ReplicatedLog extends AutoCloseable { * difference from appendAsLeader is that we do not need to assign the epoch * or do additional validation. * + * The log will append record batches up to and including batches that have a partition + * leader epoch less than or equal to the passed epoch. + * + * @param records records batches to append + * @param epoch the epoch of the replica * @return the metadata information of the appended batch * @throws IllegalArgumentException if the record set is empty * @throws RuntimeException if the batch base offset doesn't match the log end offset */ - LogAppendInfo appendAsFollower(Records records); + LogAppendInfo appendAsFollower(Records records, int epoch); /** * Read a set of records within a range of offsets. diff --git a/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientFetchTest.java b/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientFetchTest.java new file mode 100644 index 00000000000..ade509d8051 --- /dev/null +++ b/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientFetchTest.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import org.apache.kafka.common.compress.Compression; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.record.ArbitraryMemoryRecords; +import org.apache.kafka.common.record.InvalidMemoryRecordsProvider; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.SimpleRecord; +import org.apache.kafka.server.common.KRaftVersion; + +import net.jqwik.api.AfterFailureMode; +import net.jqwik.api.ForAll; +import net.jqwik.api.Property; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ArgumentsSource; + +import java.nio.ByteBuffer; +import java.util.Optional; +import java.util.stream.IntStream; +import java.util.stream.Stream; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public final class KafkaRaftClientFetchTest { + @Property(tries = 100, afterFailure = AfterFailureMode.SAMPLE_ONLY) + void testRandomRecords( + @ForAll(supplier = ArbitraryMemoryRecords.class) MemoryRecords memoryRecords + ) throws Exception { + testFetchResponseWithInvalidRecord(memoryRecords, Integer.MAX_VALUE); + } + + @ParameterizedTest + @ArgumentsSource(InvalidMemoryRecordsProvider.class) + void testInvalidMemoryRecords(MemoryRecords records, Optional> expectedException) throws Exception { + // CorruptRecordException are handled by the KafkaRaftClient so ignore the expected exception + testFetchResponseWithInvalidRecord(records, Integer.MAX_VALUE); + } + + private static void testFetchResponseWithInvalidRecord(MemoryRecords records, int epoch) throws Exception { + int localId = KafkaRaftClientTest.randomReplicaId(); + ReplicaKey local = KafkaRaftClientTest.replicaKey(localId, true); + ReplicaKey electedLeader = KafkaRaftClientTest.replicaKey(localId + 1, true); + + RaftClientTestContext context = new RaftClientTestContext.Builder( + local.id(), + local.directoryId().get() + ) + .withStartingVoters( + VoterSetTest.voterSet(Stream.of(local, electedLeader)), KRaftVersion.KRAFT_VERSION_1 + ) + .withElectedLeader(epoch, electedLeader.id()) + .withRaftProtocol(RaftClientTestContext.RaftProtocol.KIP_996_PROTOCOL) + .build(); + + context.pollUntilRequest(); + RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(); + context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); + + long oldLogEndOffset = context.log.endOffset().offset(); + + context.deliverResponse( + fetchRequest.correlationId(), + fetchRequest.destination(), + context.fetchResponse(epoch, electedLeader.id(), records, 0L, Errors.NONE) + ); + + context.client.poll(); + + assertEquals(oldLogEndOffset, context.log.endOffset().offset()); + } + + @Test + void testReplicationOfHigherPartitionLeaderEpoch() throws Exception { + int epoch = 2; + int localId = KafkaRaftClientTest.randomReplicaId(); + ReplicaKey local = KafkaRaftClientTest.replicaKey(localId, true); + ReplicaKey electedLeader = KafkaRaftClientTest.replicaKey(localId + 1, true); + + RaftClientTestContext context = new RaftClientTestContext.Builder( + local.id(), + local.directoryId().get() + ) + .withStartingVoters( + VoterSetTest.voterSet(Stream.of(local, electedLeader)), KRaftVersion.KRAFT_VERSION_1 + ) + .withElectedLeader(epoch, electedLeader.id()) + .withRaftProtocol(RaftClientTestContext.RaftProtocol.KIP_996_PROTOCOL) + .build(); + + context.pollUntilRequest(); + RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(); + context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); + + long oldLogEndOffset = context.log.endOffset().offset(); + int numberOfRecords = 10; + MemoryRecords batchWithValidEpoch = MemoryRecords.withRecords( + oldLogEndOffset, + Compression.NONE, + epoch, + IntStream + .range(0, numberOfRecords) + .mapToObj(number -> new SimpleRecord(Integer.toString(number).getBytes())) + .toArray(SimpleRecord[]::new) + ); + + MemoryRecords batchWithInvalidEpoch = MemoryRecords.withRecords( + oldLogEndOffset + numberOfRecords, + Compression.NONE, + epoch + 1, + IntStream + .range(0, numberOfRecords) + .mapToObj(number -> new SimpleRecord(Integer.toString(number).getBytes())) + .toArray(SimpleRecord[]::new) + ); + + var buffer = ByteBuffer.allocate(batchWithValidEpoch.sizeInBytes() + batchWithInvalidEpoch.sizeInBytes()); + buffer.put(batchWithValidEpoch.buffer()); + buffer.put(batchWithInvalidEpoch.buffer()); + buffer.flip(); + + MemoryRecords records = MemoryRecords.readableRecords(buffer); + + context.deliverResponse( + fetchRequest.correlationId(), + fetchRequest.destination(), + context.fetchResponse(epoch, electedLeader.id(), records, 0L, Errors.NONE) + ); + + context.client.poll(); + + // Check that only the first batch was appended because the second batch has a greater epoch + assertEquals(oldLogEndOffset + numberOfRecords, context.log.endOffset().offset()); + } +} diff --git a/raft/src/test/java/org/apache/kafka/raft/MockLog.java b/raft/src/test/java/org/apache/kafka/raft/MockLog.java index a7a8e89a88c..9fb4724cc0c 100644 --- a/raft/src/test/java/org/apache/kafka/raft/MockLog.java +++ b/raft/src/test/java/org/apache/kafka/raft/MockLog.java @@ -19,6 +19,7 @@ package org.apache.kafka.raft; import org.apache.kafka.common.TopicPartition; import org.apache.kafka.common.Uuid; import org.apache.kafka.common.compress.Compression; +import org.apache.kafka.common.errors.CorruptRecordException; import org.apache.kafka.common.errors.OffsetOutOfRangeException; import org.apache.kafka.common.record.MemoryRecords; import org.apache.kafka.common.record.MemoryRecordsBuilder; @@ -279,7 +280,7 @@ public class MockLog implements ReplicatedLog { @Override public LogAppendInfo appendAsLeader(Records records, int epoch) { - return append(records, OptionalInt.of(epoch)); + return append(records, epoch, true); } private long appendBatch(LogBatch batch) { @@ -292,16 +293,18 @@ public class MockLog implements ReplicatedLog { } @Override - public LogAppendInfo appendAsFollower(Records records) { - return append(records, OptionalInt.empty()); + public LogAppendInfo appendAsFollower(Records records, int epoch) { + return append(records, epoch, false); } - private LogAppendInfo append(Records records, OptionalInt epoch) { - if (records.sizeInBytes() == 0) + private LogAppendInfo append(Records records, int epoch, boolean isLeader) { + if (records.sizeInBytes() == 0) { throw new IllegalArgumentException("Attempt to append an empty record set"); + } long baseOffset = endOffset().offset(); long lastOffset = baseOffset; + boolean hasBatches = false; for (RecordBatch batch : records.batches()) { if (batch.baseOffset() != endOffset().offset()) { /* KafkaMetadataLog throws an kafka.common.UnexpectedAppendOffsetException this is the @@ -314,26 +317,47 @@ public class MockLog implements ReplicatedLog { endOffset().offset() ) ); + } else if (isLeader && epoch != batch.partitionLeaderEpoch()) { + // the partition leader epoch is set and does not match the one set in the batch + throw new RuntimeException( + String.format( + "Epoch %s doesn't match batch leader epoch %s", + epoch, + batch.partitionLeaderEpoch() + ) + ); + } else if (!isLeader && batch.partitionLeaderEpoch() > epoch) { + /* To avoid inconsistent log replication, follower should only append record + * batches with an epoch less than or equal to the leader epoch. There is more + * details on this issue and scenario in KAFKA-18723. + */ + break; } + hasBatches = true; LogBatch logBatch = new LogBatch( - epoch.orElseGet(batch::partitionLeaderEpoch), + batch.partitionLeaderEpoch(), batch.isControlBatch(), buildEntries(batch, Record::offset) ); if (logger.isDebugEnabled()) { - String nodeState = "Follower"; - if (epoch.isPresent()) { - nodeState = "Leader"; - } - logger.debug("{} appending to the log {}", nodeState, logBatch); + logger.debug( + "{} appending to the log {}", + isLeader ? "Leader" : "Follower", + logBatch + ); } appendBatch(logBatch); lastOffset = logBatch.last().offset; } + if (!hasBatches) { + // This emulates the default handling when records doesn't have enough bytes for a batch + throw new CorruptRecordException("Append failed unexpectedly"); + } + return new LogAppendInfo(baseOffset, lastOffset); } diff --git a/raft/src/test/java/org/apache/kafka/raft/MockLogTest.java b/raft/src/test/java/org/apache/kafka/raft/MockLogTest.java index 8306e103258..eca0fe5d3de 100644 --- a/raft/src/test/java/org/apache/kafka/raft/MockLogTest.java +++ b/raft/src/test/java/org/apache/kafka/raft/MockLogTest.java @@ -19,9 +19,12 @@ package org.apache.kafka.raft; import org.apache.kafka.common.TopicPartition; import org.apache.kafka.common.Uuid; import org.apache.kafka.common.compress.Compression; +import org.apache.kafka.common.errors.CorruptRecordException; import org.apache.kafka.common.errors.OffsetOutOfRangeException; import org.apache.kafka.common.message.LeaderChangeMessage; +import org.apache.kafka.common.record.ArbitraryMemoryRecords; import org.apache.kafka.common.record.ControlRecordUtils; +import org.apache.kafka.common.record.InvalidMemoryRecordsProvider; import org.apache.kafka.common.record.MemoryRecords; import org.apache.kafka.common.record.Record; import org.apache.kafka.common.record.RecordBatch; @@ -32,9 +35,16 @@ import org.apache.kafka.common.utils.Utils; import org.apache.kafka.snapshot.RawSnapshotReader; import org.apache.kafka.snapshot.RawSnapshotWriter; +import net.jqwik.api.AfterFailureMode; +import net.jqwik.api.ForAll; +import net.jqwik.api.Property; + import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ArgumentsSource; import java.nio.ByteBuffer; import java.util.ArrayList; @@ -44,6 +54,7 @@ import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.stream.IntStream; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; @@ -169,14 +180,17 @@ public class MockLogTest { assertThrows( RuntimeException.class, () -> log.appendAsLeader( - MemoryRecords.withRecords(initialOffset, Compression.NONE, currentEpoch, recordFoo), - currentEpoch) + MemoryRecords.withRecords(initialOffset, Compression.NONE, currentEpoch, recordFoo), + currentEpoch + ) ); assertThrows( RuntimeException.class, () -> log.appendAsFollower( - MemoryRecords.withRecords(initialOffset, Compression.NONE, currentEpoch, recordFoo)) + MemoryRecords.withRecords(initialOffset, Compression.NONE, currentEpoch, recordFoo), + currentEpoch + ) ); } @@ -187,7 +201,13 @@ public class MockLogTest { LeaderChangeMessage messageData = new LeaderChangeMessage().setLeaderId(0); ByteBuffer buffer = ByteBuffer.allocate(256); log.appendAsLeader( - MemoryRecords.withLeaderChangeMessage(initialOffset, 0L, 2, buffer, messageData), + MemoryRecords.withLeaderChangeMessage( + initialOffset, + 0L, + currentEpoch, + buffer, + messageData + ), currentEpoch ); @@ -221,7 +241,10 @@ public class MockLogTest { } log.truncateToLatestSnapshot(); - log.appendAsFollower(MemoryRecords.withRecords(initialOffset, Compression.NONE, epoch, recordFoo)); + log.appendAsFollower( + MemoryRecords.withRecords(initialOffset, Compression.NONE, epoch, recordFoo), + epoch + ); assertEquals(initialOffset, log.startOffset()); assertEquals(initialOffset + 1, log.endOffset().offset()); @@ -368,10 +391,82 @@ public class MockLogTest { @Test public void testEmptyAppendNotAllowed() { - assertThrows(IllegalArgumentException.class, () -> log.appendAsFollower(MemoryRecords.EMPTY)); + assertThrows(IllegalArgumentException.class, () -> log.appendAsFollower(MemoryRecords.EMPTY, 1)); assertThrows(IllegalArgumentException.class, () -> log.appendAsLeader(MemoryRecords.EMPTY, 1)); } + @ParameterizedTest + @ArgumentsSource(InvalidMemoryRecordsProvider.class) + void testInvalidMemoryRecords(MemoryRecords records, Optional> expectedException) { + long previousEndOffset = log.endOffset().offset(); + + Executable action = () -> log.appendAsFollower(records, Integer.MAX_VALUE); + if (expectedException.isPresent()) { + assertThrows(expectedException.get(), action); + } else { + assertThrows(CorruptRecordException.class, action); + } + + assertEquals(previousEndOffset, log.endOffset().offset()); + } + + @Property(tries = 100, afterFailure = AfterFailureMode.SAMPLE_ONLY) + void testRandomRecords( + @ForAll(supplier = ArbitraryMemoryRecords.class) MemoryRecords records + ) { + try (MockLog log = new MockLog(topicPartition, topicId, new LogContext())) { + long previousEndOffset = log.endOffset().offset(); + + assertThrows( + CorruptRecordException.class, + () -> log.appendAsFollower(records, Integer.MAX_VALUE) + ); + + assertEquals(previousEndOffset, log.endOffset().offset()); + } + } + + @Test + void testInvalidLeaderEpoch() { + var previousEndOffset = log.endOffset().offset(); + var epoch = log.lastFetchedEpoch() + 1; + var numberOfRecords = 10; + + MemoryRecords batchWithValidEpoch = MemoryRecords.withRecords( + previousEndOffset, + Compression.NONE, + epoch, + IntStream + .range(0, numberOfRecords) + .mapToObj(number -> new SimpleRecord(Integer.toString(number).getBytes())) + .toArray(SimpleRecord[]::new) + ); + + MemoryRecords batchWithInvalidEpoch = MemoryRecords.withRecords( + previousEndOffset + numberOfRecords, + Compression.NONE, + epoch + 1, + IntStream + .range(0, numberOfRecords) + .mapToObj(number -> new SimpleRecord(Integer.toString(number).getBytes())) + .toArray(SimpleRecord[]::new) + ); + + var buffer = ByteBuffer.allocate(batchWithValidEpoch.sizeInBytes() + batchWithInvalidEpoch.sizeInBytes()); + buffer.put(batchWithValidEpoch.buffer()); + buffer.put(batchWithInvalidEpoch.buffer()); + buffer.flip(); + + var records = MemoryRecords.readableRecords(buffer); + + log.appendAsFollower(records, epoch); + + // Check that only the first batch was appended + assertEquals(previousEndOffset + numberOfRecords, log.endOffset().offset()); + // Check that the last fetched epoch matches the first batch + assertEquals(epoch, log.lastFetchedEpoch()); + } + @Test public void testReadOutOfRangeOffset() { final long initialOffset = 5L; @@ -383,12 +478,19 @@ public class MockLogTest { } log.truncateToLatestSnapshot(); - log.appendAsFollower(MemoryRecords.withRecords(initialOffset, Compression.NONE, epoch, recordFoo)); + log.appendAsFollower( + MemoryRecords.withRecords(initialOffset, Compression.NONE, epoch, recordFoo), + epoch + ); - assertThrows(OffsetOutOfRangeException.class, () -> log.read(log.startOffset() - 1, - Isolation.UNCOMMITTED)); - assertThrows(OffsetOutOfRangeException.class, () -> log.read(log.endOffset().offset() + 1, - Isolation.UNCOMMITTED)); + assertThrows( + OffsetOutOfRangeException.class, + () -> log.read(log.startOffset() - 1, Isolation.UNCOMMITTED) + ); + assertThrows( + OffsetOutOfRangeException.class, + () -> log.read(log.endOffset().offset() + 1, Isolation.UNCOMMITTED) + ); } @Test @@ -958,6 +1060,7 @@ public class MockLogTest { MemoryRecords.withRecords( log.endOffset().offset(), Compression.NONE, + epoch, records.toArray(new SimpleRecord[records.size()]) ), epoch