diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/ByteBufferAccessor.java b/clients/src/main/java/org/apache/kafka/common/protocol/ByteBufferAccessor.java index d902f049fae..a91f534a8dd 100644 --- a/clients/src/main/java/org/apache/kafka/common/protocol/ByteBufferAccessor.java +++ b/clients/src/main/java/org/apache/kafka/common/protocol/ByteBufferAccessor.java @@ -112,4 +112,25 @@ public class ByteBufferAccessor implements Readable, Writable { public void writeByteBuffer(ByteBuffer src) { buf.put(src); } + + @Override + public void writeVarint(int i) { + ByteUtils.writeVarint(i, buf); + } + + @Override + public void writeVarlong(long i) { + ByteUtils.writeVarlong(i, buf); + } + + @Override + public int readVarint() { + return ByteUtils.readVarint(buf); + } + + @Override + public long readVarlong() { + return ByteUtils.readVarlong(buf); + } + } diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/DataOutputStreamWritable.java b/clients/src/main/java/org/apache/kafka/common/protocol/DataOutputStreamWritable.java new file mode 100644 index 00000000000..f484016a530 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/protocol/DataOutputStreamWritable.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.protocol; + +import org.apache.kafka.common.utils.ByteUtils; +import org.apache.kafka.common.utils.Utils; + +import java.io.Closeable; +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; + +public class DataOutputStreamWritable implements Writable, Closeable { + protected final DataOutputStream out; + + public DataOutputStreamWritable(DataOutputStream out) { + this.out = out; + } + + @Override + public void writeByte(byte val) { + try { + out.writeByte(val); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void writeShort(short val) { + try { + out.writeShort(val); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void writeInt(int val) { + try { + out.writeInt(val); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void writeLong(long val) { + try { + out.writeLong(val); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void writeDouble(double val) { + try { + out.writeDouble(val); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void writeByteArray(byte[] arr) { + try { + out.write(arr); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void writeUnsignedVarint(int i) { + try { + ByteUtils.writeUnsignedVarint(i, out); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void writeByteBuffer(ByteBuffer buf) { + try { + if (buf.hasArray()) { + out.write(buf.array(), buf.position(), buf.limit()); + } else { + byte[] bytes = Utils.toArray(buf); + out.write(bytes); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void writeVarint(int i) { + try { + ByteUtils.writeVarint(i, out); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void writeVarlong(long i) { + try { + ByteUtils.writeVarlong(i, out); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public void flush() { + try { + out.flush(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void close() { + try { + out.close(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/Readable.java b/clients/src/main/java/org/apache/kafka/common/protocol/Readable.java index 9ceb4c1a025..653f88c7673 100644 --- a/clients/src/main/java/org/apache/kafka/common/protocol/Readable.java +++ b/clients/src/main/java/org/apache/kafka/common/protocol/Readable.java @@ -35,6 +35,8 @@ public interface Readable { void readArray(byte[] arr); int readUnsignedVarint(); ByteBuffer readByteBuffer(int length); + int readVarint(); + long readVarlong(); default String readString(int length) { byte[] arr = new byte[length]; diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/RecordsReadable.java b/clients/src/main/java/org/apache/kafka/common/protocol/RecordsReadable.java index 5967731a208..6458fe18307 100644 --- a/clients/src/main/java/org/apache/kafka/common/protocol/RecordsReadable.java +++ b/clients/src/main/java/org/apache/kafka/common/protocol/RecordsReadable.java @@ -80,6 +80,16 @@ public class RecordsReadable implements Readable { return res; } + @Override + public int readVarint() { + return ByteUtils.readVarint(buf); + } + + @Override + public long readVarlong() { + return ByteUtils.readVarlong(buf); + } + public BaseRecords readRecords(int length) { if (length < 0) { // no records @@ -89,4 +99,5 @@ public class RecordsReadable implements Readable { return MemoryRecords.readableRecords(recordsBuffer); } } + } diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/RecordsWritable.java b/clients/src/main/java/org/apache/kafka/common/protocol/RecordsWritable.java index 61f3ee1a452..9d49129f00e 100644 --- a/clients/src/main/java/org/apache/kafka/common/protocol/RecordsWritable.java +++ b/clients/src/main/java/org/apache/kafka/common/protocol/RecordsWritable.java @@ -105,6 +105,16 @@ public class RecordsWritable implements Writable { buffer.put(src); } + @Override + public void writeVarint(int i) { + ByteUtils.writeVarint(i, buffer); + } + + @Override + public void writeVarlong(long i) { + ByteUtils.writeVarlong(i, buffer); + } + public void writeRecords(BaseRecords records) { flush(); sendConsumer.accept(records.toSend(dest)); diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/Writable.java b/clients/src/main/java/org/apache/kafka/common/protocol/Writable.java index 230898dcb2a..fd3e2b85c8b 100644 --- a/clients/src/main/java/org/apache/kafka/common/protocol/Writable.java +++ b/clients/src/main/java/org/apache/kafka/common/protocol/Writable.java @@ -30,6 +30,8 @@ public interface Writable { void writeByteArray(byte[] arr); void writeUnsignedVarint(int i); void writeByteBuffer(ByteBuffer buf); + void writeVarint(int i); + void writeVarlong(long i); default void writeUUID(UUID uuid) { writeLong(uuid.getMostSignificantBits()); diff --git a/clients/src/main/java/org/apache/kafka/common/record/DefaultRecord.java b/clients/src/main/java/org/apache/kafka/common/record/DefaultRecord.java index 665ad54a6ce..d85f1000bc1 100644 --- a/clients/src/main/java/org/apache/kafka/common/record/DefaultRecord.java +++ b/clients/src/main/java/org/apache/kafka/common/record/DefaultRecord.java @@ -576,17 +576,16 @@ public class DefaultRecord implements Record { ByteBuffer key, ByteBuffer value, Header[] headers) { - int keySize = key == null ? -1 : key.remaining(); int valueSize = value == null ? -1 : value.remaining(); return sizeOfBodyInBytes(offsetDelta, timestampDelta, keySize, valueSize, headers); } - private static int sizeOfBodyInBytes(int offsetDelta, - long timestampDelta, - int keySize, - int valueSize, - Header[] headers) { + public static int sizeOfBodyInBytes(int offsetDelta, + long timestampDelta, + int keySize, + int valueSize, + Header[] headers) { int size = 1; // always one byte for attributes size += ByteUtils.sizeOfVarint(offsetDelta); size += ByteUtils.sizeOfVarlong(timestampDelta); diff --git a/clients/src/main/java/org/apache/kafka/common/record/DefaultRecordBatch.java b/clients/src/main/java/org/apache/kafka/common/record/DefaultRecordBatch.java index b49f2fd9d0f..28f271dedf8 100644 --- a/clients/src/main/java/org/apache/kafka/common/record/DefaultRecordBatch.java +++ b/clients/src/main/java/org/apache/kafka/common/record/DefaultRecordBatch.java @@ -440,22 +440,22 @@ public class DefaultRecordBatch extends AbstractRecordBatch implements MutableRe producerEpoch, baseSequence, isTransactional, isControlRecord, partitionLeaderEpoch, 0); } - static void writeHeader(ByteBuffer buffer, - long baseOffset, - int lastOffsetDelta, - int sizeInBytes, - byte magic, - CompressionType compressionType, - TimestampType timestampType, - long firstTimestamp, - long maxTimestamp, - long producerId, - short epoch, - int sequence, - boolean isTransactional, - boolean isControlBatch, - int partitionLeaderEpoch, - int numRecords) { + public static void writeHeader(ByteBuffer buffer, + long baseOffset, + int lastOffsetDelta, + int sizeInBytes, + byte magic, + CompressionType compressionType, + TimestampType timestampType, + long firstTimestamp, + long maxTimestamp, + long producerId, + short epoch, + int sequence, + boolean isTransactional, + boolean isControlBatch, + int partitionLeaderEpoch, + int numRecords) { if (magic < RecordBatch.CURRENT_MAGIC_VALUE) throw new IllegalArgumentException("Invalid magic value " + magic); if (firstTimestamp < 0 && firstTimestamp != NO_TIMESTAMP) diff --git a/core/src/main/scala/kafka/common/RecordValidationException.scala b/core/src/main/scala/kafka/common/RecordValidationException.scala index 29ff53bbbf9..baa7d725576 100644 --- a/core/src/main/scala/kafka/common/RecordValidationException.scala +++ b/core/src/main/scala/kafka/common/RecordValidationException.scala @@ -23,5 +23,6 @@ import org.apache.kafka.common.requests.ProduceResponse.RecordError import scala.collection.Seq class RecordValidationException(val invalidException: ApiException, - val recordErrors: Seq[RecordError]) extends RuntimeException { + val recordErrors: Seq[RecordError]) + extends RuntimeException(invalidException) { } diff --git a/core/src/main/scala/kafka/raft/KafkaNetworkChannel.scala b/core/src/main/scala/kafka/raft/KafkaNetworkChannel.scala index cac21a48250..7faa2f3fe2b 100644 --- a/core/src/main/scala/kafka/raft/KafkaNetworkChannel.scala +++ b/core/src/main/scala/kafka/raft/KafkaNetworkChannel.scala @@ -185,7 +185,18 @@ class KafkaNetworkChannel(time: Time, new RaftResponse.Inbound(header.correlationId, data, response.destination.toInt) } - private def pollInboundResponses(timeoutMs: Long): util.List[RaftMessage] = { + private def pollInboundResponses(timeoutMs: Long, inboundMessages: util.List[RaftMessage]): Unit = { + val responses = client.poll(timeoutMs, time.milliseconds()) + for (response <- responses.asScala) { + inboundMessages.add(buildInboundRaftResponse(response)) + } + } + + private def drainInboundRequests(inboundMessages: util.List[RaftMessage]): Unit = { + undelivered.drainTo(inboundMessages) + } + + private def pollInboundMessages(timeoutMs: Long): util.List[RaftMessage] = { val pollTimeoutMs = if (!undelivered.isEmpty) { 0L } else if (!pendingOutbound.isEmpty) { @@ -193,18 +204,15 @@ class KafkaNetworkChannel(time: Time, } else { timeoutMs } - val responses = client.poll(pollTimeoutMs, time.milliseconds()) val messages = new util.ArrayList[RaftMessage] - for (response <- responses.asScala) { - messages.add(buildInboundRaftResponse(response)) - } - undelivered.drainTo(messages) + pollInboundResponses(pollTimeoutMs, messages) + drainInboundRequests(messages) messages } override def receive(timeoutMs: Long): util.List[RaftMessage] = { sendOutboundRequests(time.milliseconds()) - pollInboundResponses(timeoutMs) + pollInboundMessages(timeoutMs) } override def wakeup(): Unit = { @@ -216,11 +224,9 @@ class KafkaNetworkChannel(time: Time, endpoints.put(id, node) } - def postInboundRequest(header: RequestHeader, - request: AbstractRequest, - onResponseReceived: ResponseHandler): Unit = { + def postInboundRequest(request: AbstractRequest, onResponseReceived: ResponseHandler): Unit = { val data = requestData(request) - val correlationId = header.correlationId + val correlationId = newCorrelationId() val req = new RaftRequest.Inbound(correlationId, data, time.milliseconds()) pendingInbound.put(correlationId, onResponseReceived) if (!undelivered.offer(req)) diff --git a/core/src/main/scala/kafka/tools/TestRaftRequestHandler.scala b/core/src/main/scala/kafka/tools/TestRaftRequestHandler.scala index 194c22ffbec..675e66a6b78 100644 --- a/core/src/main/scala/kafka/tools/TestRaftRequestHandler.scala +++ b/core/src/main/scala/kafka/tools/TestRaftRequestHandler.scala @@ -17,21 +17,14 @@ package kafka.tools -import java.util.Collections - import kafka.network.RequestChannel import kafka.raft.KafkaNetworkChannel import kafka.server.ApiRequestHandler import kafka.utils.Logging -import org.apache.kafka.common.TopicPartition -import org.apache.kafka.common.feature.Features import org.apache.kafka.common.internals.FatalExitError -import org.apache.kafka.common.message.MetadataResponseData -import org.apache.kafka.common.message.MetadataResponseData.{MetadataResponsePartition, MetadataResponseTopic} import org.apache.kafka.common.protocol.{ApiKeys, Errors} -import org.apache.kafka.common.requests.{AbstractRequest, AbstractResponse, ApiVersionsResponse, MetadataRequest, MetadataResponse, ProduceRequest, ProduceResponse} +import org.apache.kafka.common.requests.{AbstractRequest, AbstractResponse} import org.apache.kafka.common.utils.Time -import org.apache.kafka.raft.{AckMode, RaftClient} import scala.jdk.CollectionConverters._ @@ -42,8 +35,6 @@ class TestRaftRequestHandler( networkChannel: KafkaNetworkChannel, requestChannel: RequestChannel, time: Time, - client: RaftClient, - metadataPartition: TopicPartition ) extends ApiRequestHandler with Logging { override def handle(request: RequestChannel.Request): Unit = { @@ -56,73 +47,8 @@ class TestRaftRequestHandler( | ApiKeys.END_QUORUM_EPOCH | ApiKeys.FETCH => val requestBody = request.body[AbstractRequest] - networkChannel.postInboundRequest( - request.header, - requestBody, - response => sendResponse(request, Some(response))) - - case ApiKeys.API_VERSIONS => - sendResponse(request, Option(ApiVersionsResponse.apiVersionsResponse(0, 2, - Features.emptySupportedFeatures()))) - - case ApiKeys.METADATA => - val metadataRequest = request.body[MetadataRequest] - val topics = new MetadataResponseData.MetadataResponseTopicCollection - - if (!metadataRequest.data.topics.isEmpty) { - val leaderAndEpoch = client.currentLeaderAndEpoch() - - if (metadataRequest.data.topics.size != 1 - || !metadataRequest.data.topics.get(0).name().equals(metadataPartition.topic)) { - throw new IllegalArgumentException(s"Should only handle metadata request querying for " + - s"`${metadataPartition.topic}, but found ${metadataRequest.data.topics}") - } - - topics.add(new MetadataResponseTopic() - .setErrorCode(Errors.NONE.code) - .setName(metadataPartition.topic) - .setIsInternal(true) - .setPartitions(Collections.singletonList(new MetadataResponsePartition() - .setErrorCode(Errors.NONE.code) - .setPartitionIndex(metadataPartition.partition) - .setLeaderId(leaderAndEpoch.leaderId.orElse(-1))))) - } - - val brokers = new MetadataResponseData.MetadataResponseBrokerCollection - networkChannel.allConnections().foreach { connection => - brokers.add(new MetadataResponseData.MetadataResponseBroker() - .setNodeId(connection.id) - .setHost(connection.host) - .setPort(connection.port)) - } - - sendResponse(request, Option(new MetadataResponse( - new MetadataResponseData() - .setTopics(topics) - .setBrokers(brokers)))) - - case ApiKeys.PRODUCE => - val produceRequest = request.body[ProduceRequest] - val records = produceRequest.partitionRecordsOrFail().get(metadataPartition) - - val ackMode = produceRequest.acks match { - case 1 => AckMode.LEADER - case -1 => AckMode.QUORUM - case _ => throw new IllegalArgumentException(s"Unsupported ack mode ${produceRequest.acks} " + - s"in Produce request (the only supported modes are acks=1 and acks=-1)") - } - - client.append(records, ackMode, produceRequest.timeout) - .whenComplete { (_, exception) => - val error = if (exception == null) - Errors.NONE - else - Errors.forException(exception) - - sendResponse(request, Option(new ProduceResponse( - Collections.singletonMap(metadataPartition, - new ProduceResponse.PartitionResponse(error))))) - } + networkChannel.postInboundRequest(requestBody, response => + sendResponse(request, Some(response))) case _ => throw new IllegalArgumentException(s"Unsupported api key: ${request.header.apiKey}") } diff --git a/core/src/main/scala/kafka/tools/TestRaftServer.scala b/core/src/main/scala/kafka/tools/TestRaftServer.scala index 6c170d71d53..f8ae7af8e0c 100644 --- a/core/src/main/scala/kafka/tools/TestRaftServer.scala +++ b/core/src/main/scala/kafka/tools/TestRaftServer.scala @@ -19,37 +19,44 @@ package kafka.tools import java.io.File import java.nio.file.Files -import java.util.concurrent.CountDownLatch -import java.util.{Properties, Random} +import java.util +import java.util.concurrent.{CountDownLatch, TimeUnit} +import java.util.{Collections, OptionalInt, Random} -import joptsimple.OptionParser +import com.yammer.metrics.core.MetricName +import joptsimple.OptionException import kafka.log.{Log, LogConfig, LogManager} import kafka.network.SocketServer import kafka.raft.{KafkaFuturePurgatory, KafkaMetadataLog, KafkaNetworkChannel} import kafka.security.CredentialProvider import kafka.server.{BrokerTopicStats, KafkaConfig, KafkaRequestHandlerPool, KafkaServer, LogDirFailureChannel} import kafka.utils.timer.SystemTimer -import kafka.utils.{CommandLineUtils, CoreUtils, Exit, KafkaScheduler, Logging, ShutdownableThread} +import kafka.utils.{CommandDefaultOptions, CommandLineUtils, CoreUtils, Exit, KafkaScheduler, Logging, ShutdownableThread} import org.apache.kafka.clients.{ApiVersions, ClientDnsLookup, ManualMetadataUpdater, NetworkClient} import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.config.ConfigException import org.apache.kafka.common.metrics.Metrics import org.apache.kafka.common.network.{ChannelBuilders, NetworkReceive, Selectable, Selector} +import org.apache.kafka.common.protocol.Writable import org.apache.kafka.common.security.JaasContext import org.apache.kafka.common.security.scram.internals.ScramMechanism import org.apache.kafka.common.security.token.delegation.internals.DelegationTokenCache import org.apache.kafka.common.utils.{LogContext, Time, Utils} import org.apache.kafka.raft.internals.LogOffset -import org.apache.kafka.raft.{FileBasedStateStore, KafkaRaftClient, QuorumState, RaftConfig} +import org.apache.kafka.raft.{FileBasedStateStore, KafkaRaftClient, LeaderAndEpoch, QuorumState, RaftClient, RaftConfig, RecordSerde} import scala.jdk.CollectionConverters._ /** - * This is an experimental Raft server which is intended for testing purposes only. - * It can really only be used for performance testing using the producer performance - * tool with a hard-coded `__cluster_metadata` topic. + * This is an experimental server which is intended for testing the performance + * of the Raft implementation. It uses a hard-coded `__cluster_metadata` topic. */ -class TestRaftServer(val config: KafkaConfig) extends Logging { +class TestRaftServer( + val config: KafkaConfig, + val throughput: Int, + val recordSize: Int +) extends Logging { + import kafka.tools.TestRaftServer._ private val partition = new TopicPartition("__cluster_metadata", 0) private val time = Time.SYSTEM @@ -61,9 +68,10 @@ class TestRaftServer(val config: KafkaConfig) extends Logging { var dataPlaneRequestHandlerPool: KafkaRequestHandlerPool = _ var scheduler: KafkaScheduler = _ var metrics: Metrics = _ - var raftIoThread: RaftIoThread = _ + var ioThread: RaftIoThread = _ var networkChannel: KafkaNetworkChannel = _ var metadataLog: KafkaMetadataLog = _ + var workloadGenerator: RaftWorkloadGenerator = _ def startup(): Unit = { val logContext = new LogContext(s"[Raft id=${config.brokerId}] ") @@ -94,14 +102,20 @@ class TestRaftServer(val config: KafkaConfig) extends Logging { logDir ) - raftClient.initialize() + workloadGenerator = new RaftWorkloadGenerator( + raftClient, + time, + config.brokerId, + recordsPerSec = 20000, + recordSize = 256 + ) + + raftClient.initialize(workloadGenerator) val requestHandler = new TestRaftRequestHandler( networkChannel, socketServer.dataPlaneRequestChannel, - time, - raftClient, - partition + time ) dataPlaneRequestHandlerPool = new KafkaRequestHandlerPool( @@ -115,14 +129,16 @@ class TestRaftServer(val config: KafkaConfig) extends Logging { ) socketServer.startProcessingRequests(Map.empty) - raftIoThread = new RaftIoThread(raftClient) - raftIoThread.start() - + ioThread = new RaftIoThread(raftClient) + ioThread.start() + workloadGenerator.start() } def shutdown(): Unit = { - if (raftIoThread != null) - CoreUtils.swallow(raftIoThread.shutdown(), this) + if (workloadGenerator != null) + CoreUtils.swallow(workloadGenerator.shutdown(), this) + if (ioThread != null) + CoreUtils.swallow(ioThread.shutdown(), this) if (dataPlaneRequestHandlerPool != null) CoreUtils.swallow(dataPlaneRequestHandlerPool.shutdown(), this) if (socketServer != null) @@ -186,7 +202,7 @@ class TestRaftServer(val config: KafkaConfig) extends Logging { metadataLog: KafkaMetadataLog, networkChannel: KafkaNetworkChannel, logContext: LogContext, - logDir: File): KafkaRaftClient = { + logDir: File): KafkaRaftClient[Array[Byte]] = { val quorumState = new QuorumState( config.brokerId, raftConfig.quorumVoterIds, @@ -206,8 +222,11 @@ class TestRaftServer(val config: KafkaConfig) extends Logging { config.brokerId, new SystemTimer("raft-append-purgatory-reaper")) + val serde = new ByteArraySerde + new KafkaRaftClient( raftConfig, + serde, networkChannel, metadataLog, quorumState, @@ -272,7 +291,79 @@ class TestRaftServer(val config: KafkaConfig) extends Logging { ) } - class RaftIoThread(client: KafkaRaftClient) extends ShutdownableThread("raft-io-thread") { + class RaftWorkloadGenerator( + client: KafkaRaftClient[Array[Byte]], + time: Time, + brokerId: Int, + recordsPerSec: Int, + recordSize: Int + ) extends ShutdownableThread(name = "raft-workload-generator") with RaftClient.Listener[Array[Byte]] { + + private val stats = new WriteStats(time, printIntervalMs = 5000) + private val payload = new Array[Byte](recordSize) + private val pendingAppends = new util.ArrayDeque[PendingAppend]() + + private var latestLeaderAndEpoch = new LeaderAndEpoch(OptionalInt.empty, 0) + private var isLeader = false + private var throttler: ThroughputThrottler = _ + private var recordCount = 0 + + override def doWork(): Unit = { + if (latestLeaderAndEpoch != client.currentLeaderAndEpoch()) { + latestLeaderAndEpoch = client.currentLeaderAndEpoch() + isLeader = latestLeaderAndEpoch.leaderId.orElse(-1) == brokerId + if (isLeader) { + pendingAppends.clear() + throttler = new ThroughputThrottler(time, recordsPerSec) + recordCount = 0 + } + } + + if (isLeader) { + recordCount += 1 + + val startTimeMs = time.milliseconds() + val sendTimeMs = if (throttler.maybeThrottle(recordCount, startTimeMs)) { + time.milliseconds() + } else { + startTimeMs + } + + val offset = client.scheduleAppend(latestLeaderAndEpoch.epoch, Collections.singletonList(payload)) + if (offset == null) { + time.sleep(10) + } else { + pendingAppends.offer(PendingAppend(latestLeaderAndEpoch.epoch, offset, sendTimeMs)) + } + } else { + time.sleep(500) + } + } + + override def handleCommit(epoch: Int, lastOffset: Long, records: util.List[Array[Byte]]): Unit = { + var offset = lastOffset - records.size() + 1 + val currentTimeMs = time.milliseconds() + + for (record <- records.asScala) { + val pendingAppend = pendingAppends.poll() + if (pendingAppend.epoch != epoch || pendingAppend.offset!= offset) { + throw new IllegalStateException(s"Committed record $record from `handleCommit` does not " + + s"match the next expected append $pendingAppend" ) + } else { + val latencyMs = math.max(0, currentTimeMs - pendingAppend.appendTimeMs) + stats.record(latencyMs, record.length, currentTimeMs) + } + offset += 1 + } + } + } + + class RaftIoThread( + client: KafkaRaftClient[Array[Byte]] + ) extends ShutdownableThread( + name = "raft-io-thread", + isInterruptible = false + ) { override def doWork(): Unit = { client.poll() } @@ -281,9 +372,9 @@ class TestRaftServer(val config: KafkaConfig) extends Logging { if (super.initiateShutdown()) { client.shutdown(5000).whenComplete { (_, exception) => if (exception != null) { - logger.error("Shutdown of RaftClient failed", exception) + error("Shutdown of RaftClient failed", exception) } else { - logger.info("Completed shutdown of RaftClient") + info("Completed shutdown of RaftClient") } } true @@ -300,57 +391,140 @@ class TestRaftServer(val config: KafkaConfig) extends Logging { } object TestRaftServer extends Logging { - import kafka.utils.Implicits._ - def getPropsFromArgs(args: Array[String]): Properties = { - val optionParser = new OptionParser(false) - val overrideOpt = optionParser.accepts("override", "Optional property that should override values set in server.properties file") - .withRequiredArg() - .ofType(classOf[String]) - // This is just to make the parameter show up in the help output, we are not actually using this due the - // fact that this class ignores the first parameter which is interpreted as positional and mandatory - // but would not be mandatory if --version is specified - // This is a bit of an ugly crutch till we get a chance to rework the entire command line parsing - optionParser.accepts("version", "Print version information and exit.") + case class PendingAppend( + epoch: Int, + offset: Long, + appendTimeMs: Long + ) { + override def toString: String = { + s"PendingAppend(epoch=$epoch, offset=$offset, appendTimeMs=$appendTimeMs)" + } + } - if (args.length == 0 || args.contains("--help")) { - CommandLineUtils.printUsageAndDie(optionParser, "USAGE: java [options] %s server.properties [--override property=value]*".format(classOf[TestRaftServer].getSimpleName())) + private class ByteArraySerde extends RecordSerde[Array[Byte]] { + override def newWriteContext(): AnyRef = null + + override def recordSize(data: Array[Byte], context: Any): Int = { + data.length } - if (args.contains("--version")) { - CommandLineUtils.printVersionAndDie() + override def write(data: Array[Byte], context: Any, out: Writable): Unit = { + out.writeByteArray(data) } + } - val props = Utils.loadProps(args(0)) + private class ThroughputThrottler( + time: Time, + targetRecordsPerSec: Int + ) { + private val startTimeMs = time.milliseconds() - if (args.length > 1) { - val options = optionParser.parse(args.slice(1, args.length): _*) + require(targetRecordsPerSec > 0) - if (options.nonOptionArguments().size() > 0) { - CommandLineUtils.printUsageAndDie(optionParser, "Found non argument parameters: " + options.nonOptionArguments().toArray.mkString(",")) + def maybeThrottle( + currentCount: Int, + currentTimeMs: Long + ): Boolean = { + val targetDurationMs = math.round(currentCount / targetRecordsPerSec.toDouble * 1000) + if (targetDurationMs > 0) { + val targetDeadlineMs = startTimeMs + targetDurationMs + if (targetDeadlineMs > currentTimeMs) { + val sleepDurationMs = targetDeadlineMs - currentTimeMs + time.sleep(sleepDurationMs) + return true + } } - - props ++= CommandLineUtils.parseKeyValueArgs(options.valuesOf(overrideOpt).asScala) + false } - props + } + + private class WriteStats( + time: Time, + printIntervalMs: Long + ) { + private var lastReportTimeMs = time.milliseconds() + private val latency = com.yammer.metrics.Metrics.newHistogram( + new MetricName("kafka.raft", "write", "throughput") + ) + private val throughput = com.yammer.metrics.Metrics.newMeter( + new MetricName("kafka.raft", "write", "latency"), + "records", + TimeUnit.SECONDS + ) + + def record( + latencyMs: Long, + bytes: Int, + currentTimeMs: Long + ): Unit = { + throughput.mark(bytes) + latency.update(latencyMs) + + if (currentTimeMs - lastReportTimeMs >= printIntervalMs) { + printSummary() + this.lastReportTimeMs = currentTimeMs + } + } + + private def printSummary(): Unit = { + val latencies = latency.getSnapshot + println("Throughput (bytes/second): %.2f, Latency (ms): %.1f p50 %.1f p99 %.1f p999".format( + throughput.oneMinuteRate, + latencies.getMedian, + latencies.get99thPercentile, + latencies.get999thPercentile, + )) + } + } + + class TestRaftServerOptions(args: Array[String]) extends CommandDefaultOptions(args) { + val configOpt = parser.accepts("config", "Required configured file") + .withRequiredArg + .describedAs("filename") + .ofType(classOf[String]) + + val throughputOpt = parser.accepts("throughput", + "The number of records per second the leader will write to the metadata topic") + .withRequiredArg + .describedAs("records/sec") + .ofType(classOf[Int]) + .defaultsTo(5000) + + val recordSizeOpt = parser.accepts("record-size", "The size of each record") + .withRequiredArg + .describedAs("size in bytes") + .ofType(classOf[Int]) + .defaultsTo(256) + + options = parser.parse(args : _*) } def main(args: Array[String]): Unit = { + val opts = new TestRaftServerOptions(args) try { - val serverProps = getPropsFromArgs(args) - val config = KafkaConfig.fromProps(serverProps, false) - val server = new TestRaftServer(config) + CommandLineUtils.printHelpAndExitIfNeeded(opts, + "Standalone raft server for performance testing") + + val configFile = opts.options.valueOf(opts.configOpt) + val serverProps = Utils.loadProps(configFile) + val config = KafkaConfig.fromProps(serverProps, doLog = false) + val throughput = opts.options.valueOf(opts.throughputOpt) + val recordSize = opts.options.valueOf(opts.recordSizeOpt) + val server = new TestRaftServer(config, throughput, recordSize) Exit.addShutdownHook("raft-shutdown-hook", server.shutdown()) server.startup() server.awaitShutdown() - } - catch { + Exit.exit(0) + } catch { + case e: OptionException => + CommandLineUtils.printUsageAndDie(opts.parser, e.getMessage) case e: Throwable => fatal("Exiting Kafka due to fatal exception", e) Exit.exit(1) } - Exit.exit(0) } + } diff --git a/core/src/test/scala/unit/kafka/raft/KafkaNetworkChannelTest.scala b/core/src/test/scala/unit/kafka/raft/KafkaNetworkChannelTest.scala index c1a60b0f6fd..b95b00b9429 100644 --- a/core/src/test/scala/unit/kafka/raft/KafkaNetworkChannelTest.scala +++ b/core/src/test/scala/unit/kafka/raft/KafkaNetworkChannelTest.scala @@ -21,11 +21,11 @@ import java.util import java.util.Collections import java.util.concurrent.atomic.AtomicReference -import org.apache.kafka.clients.{ApiVersion, MockClient, NodeApiVersions} import org.apache.kafka.clients.MockClient.MockMetadataUpdater +import org.apache.kafka.clients.{ApiVersion, MockClient, NodeApiVersions} import org.apache.kafka.common.message.{BeginQuorumEpochResponseData, EndQuorumEpochResponseData, FetchResponseData, VoteResponseData} import org.apache.kafka.common.protocol.{ApiKeys, ApiMessage, Errors} -import org.apache.kafka.common.requests.{AbstractResponse, BeginQuorumEpochRequest, EndQuorumEpochRequest, RequestHeader, VoteRequest, VoteResponse} +import org.apache.kafka.common.requests.{AbstractResponse, BeginQuorumEpochRequest, EndQuorumEpochRequest, VoteRequest, VoteResponse} import org.apache.kafka.common.utils.{MockTime, Time} import org.apache.kafka.common.{Node, TopicPartition} import org.apache.kafka.raft.{RaftRequest, RaftResponse, RaftUtil} @@ -120,18 +120,14 @@ class KafkaNetworkChannelTest { for (apiKey <- RaftApis) { val request = KafkaNetworkChannel.buildRequest(buildTestRequest(apiKey)).build() val responseRef = new AtomicReference[AbstractResponse]() - val correlationId = 15 - val header = new RequestHeader(apiKey, request.version, "clientId", correlationId) - channel.postInboundRequest(header, request, responseRef.set) + channel.postInboundRequest(request, responseRef.set) val inbound = channel.receive(1000).asScala assertEquals(1, inbound.size) val inboundRequest = inbound.head.asInstanceOf[RaftRequest.Inbound] - assertEquals(correlationId, inboundRequest.correlationId) - val errorResponse = buildTestErrorResponse(apiKey, Errors.INVALID_REQUEST) - val outboundResponse = new RaftResponse.Outbound(correlationId, errorResponse) + val outboundResponse = new RaftResponse.Outbound(inboundRequest.correlationId, errorResponse) channel.send(outboundResponse) channel.receive(1000) diff --git a/raft/README.md b/raft/README.md index 8dde0b90f8f..f337ccacd96 100644 --- a/raft/README.md +++ b/raft/README.md @@ -12,7 +12,8 @@ Below we describe the details to set this up. bin/test-raft-server-start.sh config/raft.properties ### Run Multi Node Quorum ### -Create 3 separate raft quorum properties as the following: +Create 3 separate raft quorum properties as the following +(note that the `zookeeper.connect` config is required, but unused): `cat << EOF >> config/raft-quorum-1.properties` @@ -46,19 +47,11 @@ Create 3 separate raft quorum properties as the following: Open up 3 separate terminals, and run individual commands: - bin/test-raft-server-start.sh config/raft-quorum-1.properties - bin/test-raft-server-start.sh config/raft-quorum-2.properties - bin/test-raft-server-start.sh config/raft-quorum-3.properties - -This would setup a three node Raft quorum with node id 1,2,3 using different endpoints and log dirs. - -### Simulate a distributed state machine ### -You need to use a `VerifiableProducer` to produce monolithic increasing records to the replicated state machine. - - ./bin/kafka-run-class.sh org.apache.kafka.tools.VerifiableProducer --bootstrap-server http://localhost:9092 \ - --topic __cluster_metadata --max-messages 2000 --throughput 1 --producer.config config/producer.properties -### Run Performance Test ### -Run the `ProducerPerformance` module using this example command: - - ./bin/kafka-producer-perf-test.sh --topic __cluster_metadata --num-records 2000 --throughput -1 --record-size 10 --producer.config config/producer.properties + bin/test-raft-server-start.sh --config config/raft-quorum-1.properties + bin/test-raft-server-start.sh --config config/raft-quorum-2.properties + bin/test-raft-server-start.sh --config config/raft-quorum-3.properties +Once a leader is elected, it will begin writing to an internal +`__cluster_metadata` topic with a steady workload of random data. +You can control the workload using the `--throughput` and `--record-size` +arguments passed to `test-raft-server-start.sh`. diff --git a/raft/bin/test-raft-server-start.sh b/raft/bin/test-raft-server-start.sh index e9fe87463bb..95a2a2d169d 100755 --- a/raft/bin/test-raft-server-start.sh +++ b/raft/bin/test-raft-server-start.sh @@ -14,11 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -if [ $# -lt 1 ]; -then - echo "USAGE: $0 [-daemon] server.properties [--override property=value]*" - exit 1 -fi base_dir=$(dirname $0) if [ "x$KAFKA_LOG4J_OPTS" = "x" ]; then diff --git a/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java b/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java index a85d894cf4d..993e2956a29 100644 --- a/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java +++ b/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java @@ -19,6 +19,7 @@ package org.apache.kafka.raft; import org.apache.kafka.common.KafkaException; import org.apache.kafka.common.errors.ClusterAuthorizationException; import org.apache.kafka.common.errors.NotLeaderOrFollowerException; +import org.apache.kafka.common.memory.MemoryPool; import org.apache.kafka.common.message.BeginQuorumEpochRequestData; import org.apache.kafka.common.message.BeginQuorumEpochResponseData; import org.apache.kafka.common.message.DescribeQuorumRequestData; @@ -36,6 +37,7 @@ import org.apache.kafka.common.metrics.Metrics; import org.apache.kafka.common.protocol.ApiKeys; import org.apache.kafka.common.protocol.ApiMessage; import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.record.CompressionType; import org.apache.kafka.common.record.MemoryRecords; import org.apache.kafka.common.record.Records; import org.apache.kafka.common.requests.BeginQuorumEpochRequest; @@ -50,6 +52,8 @@ import org.apache.kafka.common.utils.LogContext; import org.apache.kafka.common.utils.Time; import org.apache.kafka.common.utils.Timer; import org.apache.kafka.raft.RequestManager.ConnectionState; +import org.apache.kafka.raft.internals.BatchAccumulator; +import org.apache.kafka.raft.internals.BatchMemoryPool; import org.apache.kafka.raft.internals.KafkaRaftMetrics; import org.apache.kafka.raft.internals.LogOffset; import org.slf4j.Logger; @@ -57,6 +61,7 @@ import org.slf4j.Logger; import java.io.IOException; import java.net.InetSocketAddress; import java.util.Collections; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Optional; @@ -64,10 +69,8 @@ import java.util.OptionalInt; import java.util.OptionalLong; import java.util.Random; import java.util.Set; -import java.util.concurrent.BlockingQueue; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; -import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; @@ -110,14 +113,16 @@ import static org.apache.kafka.raft.RaftUtil.hasValidTopicPartition; * we also piggyback truncation detection on this API rather than through a separate truncation state. * */ -public class KafkaRaftClient implements RaftClient { - private final static int RETRY_BACKOFF_BASE_MS = 100; +public class KafkaRaftClient implements RaftClient { + private static final int RETRY_BACKOFF_BASE_MS = 100; + static final int MAX_BATCH_SIZE = 1024 * 1024; private final AtomicReference shutdown = new AtomicReference<>(); private final Logger logger; private final Time time; private final int electionBackoffMaxMs; private final int fetchMaxWaitMs; + private final int appendLingerMs; private final KafkaRaftMetrics kafkaRaftMetrics; private final NetworkChannel channel; private final ReplicatedLog log; @@ -126,19 +131,28 @@ public class KafkaRaftClient implements RaftClient { private final RequestManager requestManager; private final FuturePurgatory appendPurgatory; private final FuturePurgatory fetchPurgatory; - private final BlockingQueue unwrittenAppends; + private final RecordSerde serde; + private final MemoryPool memoryPool; - public KafkaRaftClient(RaftConfig raftConfig, - NetworkChannel channel, - ReplicatedLog log, - QuorumState quorum, - Time time, - FuturePurgatory fetchPurgatory, - FuturePurgatory appendPurgatory, - LogContext logContext) { - this(channel, + private volatile BatchAccumulator accumulator; + private Listener listener; + + public KafkaRaftClient( + RaftConfig raftConfig, + RecordSerde serde, + NetworkChannel channel, + ReplicatedLog log, + QuorumState quorum, + Time time, + FuturePurgatory fetchPurgatory, + FuturePurgatory appendPurgatory, + LogContext logContext + ) { + this(serde, + channel, log, quorum, + new BatchMemoryPool(5, MAX_BATCH_SIZE), time, new Metrics(time), fetchPurgatory, @@ -148,36 +162,44 @@ public class KafkaRaftClient implements RaftClient { raftConfig.retryBackoffMs(), raftConfig.requestTimeoutMs(), 1000, + raftConfig.appendLingerMs(), logContext, new Random()); } - public KafkaRaftClient(NetworkChannel channel, - ReplicatedLog log, - QuorumState quorum, - Time time, - Metrics metrics, - FuturePurgatory fetchPurgatory, - FuturePurgatory appendPurgatory, - Map voterAddresses, - int electionBackoffMaxMs, - int retryBackoffMs, - int requestTimeoutMs, - int fetchMaxWaitMs, - LogContext logContext, - Random random) { + public KafkaRaftClient( + RecordSerde serde, + NetworkChannel channel, + ReplicatedLog log, + QuorumState quorum, + MemoryPool memoryPool, + Time time, + Metrics metrics, + FuturePurgatory fetchPurgatory, + FuturePurgatory appendPurgatory, + Map voterAddresses, + int electionBackoffMaxMs, + int retryBackoffMs, + int requestTimeoutMs, + int fetchMaxWaitMs, + int appendLingerMs, + LogContext logContext, + Random random + ) { + this.serde = serde; this.channel = channel; this.log = log; this.quorum = quorum; + this.memoryPool = memoryPool; this.fetchPurgatory = fetchPurgatory; this.appendPurgatory = appendPurgatory; this.time = time; this.electionBackoffMaxMs = electionBackoffMaxMs; this.fetchMaxWaitMs = fetchMaxWaitMs; + this.appendLingerMs = appendLingerMs; this.logger = logContext.logger(KafkaRaftClient.class); this.random = random; this.requestManager = new RequestManager(voterAddresses.keySet(), retryBackoffMs, requestTimeoutMs, random); - this.unwrittenAppends = new LinkedBlockingQueue<>(); this.kafkaRaftMetrics = new KafkaRaftMetrics(metrics, "raft", quorum); kafkaRaftMetrics.updateNumUnknownVoterConnections(quorum.remoteVoters().size()); @@ -233,7 +255,8 @@ public class KafkaRaftClient implements RaftClient { } @Override - public void initialize() throws IOException { + public void initialize(Listener listener) throws IOException { + this.listener = listener; quorum.initialize(new OffsetAndEpoch(log.endOffset().offset, log.lastFetchedEpoch())); long currentTimeMs = time.milliseconds(); @@ -276,6 +299,17 @@ public class KafkaRaftClient implements RaftClient { resetConnections(); kafkaRaftMetrics.maybeUpdateElectionLatency(currentTimeMs); + + accumulator = new BatchAccumulator<>( + quorum.epoch(), + log.endOffset().offset, + appendLingerMs, + MAX_BATCH_SIZE, + memoryPool, + time, + CompressionType.NONE, + serde + ); } private void appendLeaderChangeMessage(LeaderState state, long currentTimeMs) { @@ -318,18 +352,28 @@ public class KafkaRaftClient implements RaftClient { } } + private void maybeCloseAccumulator() { + if (accumulator != null) { + accumulator.close(); + accumulator = null; + } + } + private void transitionToCandidate(long currentTimeMs) throws IOException { quorum.transitionToCandidate(); + maybeCloseAccumulator(); onBecomeCandidate(currentTimeMs); } private void transitionToUnattached(int epoch) throws IOException { quorum.transitionToUnattached(epoch); + maybeCloseAccumulator(); resetConnections(); } private void transitionToVoted(int candidateId, int epoch) throws IOException { quorum.transitionToVoted(epoch, candidateId); + maybeCloseAccumulator(); resetConnections(); } @@ -346,9 +390,6 @@ public class KafkaRaftClient implements RaftClient { // Clearing the append purgatory should complete all future exceptionally since this node is no longer the leader appendPurgatory.completeAllExceptionally(new NotLeaderOrFollowerException( "Failed to receive sufficient acknowledgments for this append before leader change.")); - - failPendingAppends(new NotLeaderOrFollowerException( - "Append refused since this node is no longer the leader")); } private void transitionToFollower( @@ -357,6 +398,7 @@ public class KafkaRaftClient implements RaftClient { long currentTimeMs ) throws IOException { quorum.transitionToFollower(epoch, leaderId); + maybeCloseAccumulator(); onBecomeFollower(currentTimeMs); } @@ -1443,15 +1485,79 @@ public class KafkaRaftClient implements RaftClient { } } + private void appendBatch( + LeaderState state, + BatchAccumulator.CompletedBatch batch, + long appendTimeMs + ) { + try { + List records = batch.records; + int epoch = state.epoch(); + + LogAppendInfo info = appendAsLeader(batch.data); + OffsetAndEpoch offsetAndEpoch = new OffsetAndEpoch(info.lastOffset, epoch); + CompletableFuture future = appendPurgatory.await( + LogOffset.awaitCommitted(offsetAndEpoch.offset), + Integer.MAX_VALUE + ); + + future.whenComplete((commitTimeMs, exception) -> { + int numRecords = batch.records.size(); + if (exception != null) { + logger.debug("Failed to commit {} records at {}", numRecords, offsetAndEpoch, exception); + } else { + long elapsedTime = Math.max(0, commitTimeMs - appendTimeMs); + double elapsedTimePerRecord = (double) elapsedTime / numRecords; + kafkaRaftMetrics.updateCommitLatency(elapsedTimePerRecord, appendTimeMs); + logger.debug("Completed commit of {} records at {}", numRecords, offsetAndEpoch); + listener.handleCommit(epoch, info.lastOffset, records); + } + }); + } finally { + batch.release(); + } + } + + private long maybeAppendBatches( + LeaderState state, + long currentTimeMs + ) { + long timeUnitFlush = accumulator.timeUntilDrain(currentTimeMs); + if (timeUnitFlush <= 0) { + List> batches = accumulator.drain(); + Iterator> iterator = batches.iterator(); + + try { + while (iterator.hasNext()) { + BatchAccumulator.CompletedBatch batch = iterator.next(); + appendBatch(state, batch, currentTimeMs); + } + flushLeaderLog(state, currentTimeMs); + } finally { + // Release and discard any batches which failed to be appended + while (iterator.hasNext()) { + iterator.next().release(); + } + } + } + return timeUnitFlush; + } + private long pollLeader(long currentTimeMs) { LeaderState state = quorum.leaderStateOrThrow(); - pollPendingAppends(state, currentTimeMs); - return maybeSendRequests( + long timeUntilFlush = maybeAppendBatches( + state, + currentTimeMs + ); + + long timeUntilSend = maybeSendRequests( currentTimeMs, state.nonEndorsingFollowers(), this::buildBeginQuorumEpochRequest ); + + return Math.min(timeUntilFlush, timeUntilSend); } private long pollCandidate(long currentTimeMs) throws IOException { @@ -1492,9 +1598,6 @@ public class KafkaRaftClient implements RaftClient { } private long pollFollowerAsVoter(FollowerState state, long currentTimeMs) throws IOException { - failPendingAppends(new NotLeaderOrFollowerException("Failing append " + - "since this node is not the current leader")); - if (state.hasFetchTimeoutExpired(currentTimeMs)) { logger.info("Become candidate due to fetch timeout"); transitionToCandidate(currentTimeMs); @@ -1605,100 +1708,18 @@ public class KafkaRaftClient implements RaftClient { } } - private void failPendingAppends(KafkaException exception) { - for (UnwrittenAppend unwrittenAppend : unwrittenAppends) { - unwrittenAppend.fail(exception); - } - unwrittenAppends.clear(); - } - - private void pollPendingAppends(LeaderState state, long currentTimeMs) { - int numAppends = 0; - int maxNumAppends = unwrittenAppends.size(); - - while (!unwrittenAppends.isEmpty() && numAppends < maxNumAppends) { - final UnwrittenAppend unwrittenAppend = unwrittenAppends.poll(); - - if (unwrittenAppend.future.isDone()) - continue; - - if (unwrittenAppend.isTimedOut(currentTimeMs)) { - unwrittenAppend.fail(new TimeoutException("Request timeout " + unwrittenAppend.requestTimeoutMs - + " expired before the records could be appended to the log")); - } else { - int epoch = quorum.epoch(); - LogAppendInfo info = appendAsLeader(unwrittenAppend.records); - OffsetAndEpoch offsetAndEpoch = new OffsetAndEpoch(info.lastOffset, epoch); - long numRecords = info.lastOffset - info.firstOffset + 1; - logger.debug("Completed write of {} records at {}", numRecords, offsetAndEpoch); - - if (unwrittenAppend.ackMode == AckMode.LEADER) { - unwrittenAppend.complete(offsetAndEpoch); - } else if (unwrittenAppend.ackMode == AckMode.QUORUM) { - CompletableFuture future = appendPurgatory.await( - LogOffset.awaitCommitted(offsetAndEpoch.offset), - unwrittenAppend.requestTimeoutMs); - - future.whenComplete((completionTimeMs, exception) -> { - if (exception != null) { - logger.error("Failed to commit append at {} due to {}", offsetAndEpoch, exception); - - unwrittenAppend.fail(exception); - } else { - long elapsedTime = Math.max(0, completionTimeMs - currentTimeMs); - double elapsedTimePerRecord = (double) elapsedTime / numRecords; - kafkaRaftMetrics.updateCommitLatency(elapsedTimePerRecord, currentTimeMs); - unwrittenAppend.complete(offsetAndEpoch); - - logger.debug("Completed commit of {} records at {}", numRecords, offsetAndEpoch); - } - }); - } - } - - numAppends++; - } - - if (numAppends > 0) { - flushLeaderLog(state, currentTimeMs); - } - } - - /** - * Append a set of records to the log. Successful completion of the future indicates a success of - * the append, with the uncommitted base offset and epoch. - * - * @param records The records to write to the log - * @param ackMode The commit mode for the appended records - * @param timeoutMs The maximum time to wait for the append operation to complete (including - * any time needed for replication) - * @return The uncommitted base offset and epoch of the appended records - */ @Override - public CompletableFuture append( - Records records, - AckMode ackMode, - long timeoutMs - ) { - if (records.sizeInBytes() == 0) - throw new IllegalArgumentException("Attempt to append empty record set"); - - if (shutdown.get() != null) - throw new IllegalStateException("Cannot append records while we are shutting down"); - - if (quorum.isObserver()) - throw new IllegalStateException("Illegal attempt to write to an observer"); - - CompletableFuture future = new CompletableFuture<>(); - UnwrittenAppend unwrittenAppend = new UnwrittenAppend( - records, time.milliseconds(), timeoutMs, ackMode, future); - - if (!unwrittenAppends.offer(unwrittenAppend)) { - future.completeExceptionally(new KafkaException("Failed to append records since the unsent " + - "append queue is full")); + public Long scheduleAppend(int epoch, List records) { + BatchAccumulator accumulator = this.accumulator; + if (accumulator == null) { + return Long.MAX_VALUE; } - channel.wakeup(); - return future; + + Long offset = accumulator.append(epoch, records); + if (accumulator.needsDrain(time.milliseconds())) { + channel.wakeup(); + } + return offset; } @Override @@ -1757,35 +1778,4 @@ public class KafkaRaftClient implements RaftClient { } } - private static class UnwrittenAppend { - private final Records records; - private final long createTimeMs; - private final long requestTimeoutMs; - private final AckMode ackMode; - private final CompletableFuture future; - - private UnwrittenAppend(Records records, - long createTimeMs, - long requestTimeoutMs, - AckMode ackMode, - CompletableFuture future) { - this.future = future; - this.records = records; - this.ackMode = ackMode; - this.createTimeMs = createTimeMs; - this.requestTimeoutMs = requestTimeoutMs; - } - - public void complete(OffsetAndEpoch offsetAndEpoch) { - future.complete(offsetAndEpoch); - } - - public void fail(Throwable e) { - future.completeExceptionally(e); - } - - public boolean isTimedOut(long currentTimeMs) { - return currentTimeMs > createTimeMs + requestTimeoutMs; - } - } } diff --git a/raft/src/main/java/org/apache/kafka/raft/RaftClient.java b/raft/src/main/java/org/apache/kafka/raft/RaftClient.java index 71c5734c53c..d296c8f1aec 100644 --- a/raft/src/main/java/org/apache/kafka/raft/RaftClient.java +++ b/raft/src/main/java/org/apache/kafka/raft/RaftClient.java @@ -19,9 +19,29 @@ package org.apache.kafka.raft; import org.apache.kafka.common.record.Records; import java.io.IOException; +import java.util.List; import java.util.concurrent.CompletableFuture; -public interface RaftClient { +public interface RaftClient { + + interface Listener { + /** + * Callback which is invoked when records written through {@link #scheduleAppend(int, List)} + * become committed. + * + * Note that there is not a one-to-one correspondence between writes through + * {@link #scheduleAppend(int, List)} and this callback. The Raft implementation + * is free to batch together the records from multiple append calls provided + * that batch boundaries are respected. This means that each batch specified + * through {@link #scheduleAppend(int, List)} is guaranteed to be a subset of + * a batch passed to {@link #handleCommit(int, long, List)}. + * + * @param epoch the epoch in which the write was accepted + * @param lastOffset the offset of the last record in the record list + * @param records the set of records that were committed + */ + void handleCommit(int epoch, long lastOffset, List records); + } /** * Initialize the client. This should only be called once and it must be @@ -29,23 +49,27 @@ public interface RaftClient { * * @throws IOException For any IO errors during initialization */ - void initialize() throws IOException; + void initialize(Listener listener) throws IOException; /** - * Append a new entry to the log. The client must be in the leader state to - * accept an append: it is up to the state machine implementation - * to ensure this using {@link #currentLeaderAndEpoch()}. + * Append a list of records to the log. The write will be scheduled for some time + * in the future. There is no guarantee that appended records will be written to + * the log and eventually committed. However, it is guaranteed that if any of the + * records become committed, then all of them will be. * - * TODO: One improvement we can make here is to allow the caller to specify - * the current leader epoch in the record set. That would ensure that each - * leader change must be "observed" by the state machine before new appends - * are accepted. + * If the provided current leader epoch does not match the current epoch, which + * is possible when the state machine has yet to observe the epoch change, then + * this method will return {@link Long#MAX_VALUE} to indicate an offset which is + * not possible to become committed. The state machine is expected to discard all + * uncommitted entries after observing an epoch change. * - * @param records The records to append to the log - * @param timeoutMs Maximum time to wait for the append to complete - * @return A future containing the last offset and epoch of the appended records (if successful) + * @param epoch the current leader epoch + * @param records the list of records to append + * @return the offset within the current epoch that the log entries will be appended, + * or null if the leader was unable to accept the write (e.g. due to memory + * being reached). */ - CompletableFuture append(Records records, AckMode ackMode, long timeoutMs); + Long scheduleAppend(int epoch, List records); /** * Read a set of records from the log. Note that it is the responsibility of the state machine diff --git a/raft/src/main/java/org/apache/kafka/raft/RaftConfig.java b/raft/src/main/java/org/apache/kafka/raft/RaftConfig.java index faaaea82dbc..59d2d99b4ff 100644 --- a/raft/src/main/java/org/apache/kafka/raft/RaftConfig.java +++ b/raft/src/main/java/org/apache/kafka/raft/RaftConfig.java @@ -55,6 +55,10 @@ public class RaftConfig extends AbstractConfig { private static final String QUORUM_ELECTION_BACKOFF_MAX_MS_DOC = "Maximum time in milliseconds before starting new elections. " + "This is used in the binary exponential backoff mechanism that helps prevent gridlocked elections"; + public static final String QUORUM_LINGER_MS_CONFIG = QUORUM_PREFIX + "append.linger.ms"; + private static final String QUORUM_LINGER_MS_DOC = "The duration in milliseconds that the leader will " + + "wait for writes to accumulate before flushing them to disk."; + private static final String QUORUM_REQUEST_TIMEOUT_MS_CONFIG = QUORUM_PREFIX + CommonClientConfigs.REQUEST_TIMEOUT_MS_CONFIG; @@ -116,7 +120,13 @@ public class RaftConfig extends AbstractConfig { 15000, atLeast(0), ConfigDef.Importance.HIGH, - QUORUM_FETCH_TIMEOUT_MS_DOC); + QUORUM_FETCH_TIMEOUT_MS_DOC) + .define(QUORUM_LINGER_MS_CONFIG, + ConfigDef.Type.INT, + 25, + atLeast(0), + ConfigDef.Importance.MEDIUM, + QUORUM_LINGER_MS_DOC); } public RaftConfig(Properties props) { @@ -163,6 +173,10 @@ public class RaftConfig extends AbstractConfig { return getInt(QUORUM_FETCH_TIMEOUT_MS_CONFIG); } + public int appendLingerMs() { + return getInt(QUORUM_LINGER_MS_CONFIG); + } + public Set quorumVoterIds() { return quorumVoterConnections().keySet(); } diff --git a/raft/src/main/java/org/apache/kafka/raft/RecordSerde.java b/raft/src/main/java/org/apache/kafka/raft/RecordSerde.java new file mode 100644 index 00000000000..49d14071f18 --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/RecordSerde.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import org.apache.kafka.common.protocol.Writable; + +public interface RecordSerde { + /** + * Create a new context object for to be used when serializing a batch of records. + * This allows for state to be shared between {@link #recordSize(Object, Object)} + * and {@link #write(Object, Object, Writable)}, which is useful in order to avoid + * redundant work (see e.g. {@link org.apache.kafka.common.protocol.ObjectSerializationCache}). + * + * @return context object or null if none is needed + */ + default Object newWriteContext() { + return null; + } + + /** + * Get the size of a record. + * + * @param data the record that will be serialized + * @param context context object created by {@link #newWriteContext()} + * @return the size in bytes of the serialized record + */ + int recordSize(T data, Object context); + + + /** + * Write the record to the output stream. + * + * @param data the record to serialize and write + * @param context context object created by {@link #newWriteContext()} + * @param out the output stream to write the record to + */ + void write(T data, Object context, Writable out); +} diff --git a/raft/src/main/java/org/apache/kafka/raft/ReplicatedCounter.java b/raft/src/main/java/org/apache/kafka/raft/ReplicatedCounter.java index a901a07faef..659ed326c22 100644 --- a/raft/src/main/java/org/apache/kafka/raft/ReplicatedCounter.java +++ b/raft/src/main/java/org/apache/kafka/raft/ReplicatedCounter.java @@ -17,82 +17,40 @@ package org.apache.kafka.raft; import org.apache.kafka.common.KafkaException; -import org.apache.kafka.common.protocol.types.Type; -import org.apache.kafka.common.record.CompressionType; -import org.apache.kafka.common.record.MemoryRecords; -import org.apache.kafka.common.record.Record; -import org.apache.kafka.common.record.RecordBatch; -import org.apache.kafka.common.record.Records; -import org.apache.kafka.common.record.SimpleRecord; import org.apache.kafka.common.utils.LogContext; import org.slf4j.Logger; -import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.List; import java.util.OptionalInt; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.atomic.AtomicInteger; -public class ReplicatedCounter { +public class ReplicatedCounter implements RaftClient.Listener { private final int localBrokerId; private final Logger log; - private final RaftClient client; + private final RaftClient client; - private final AtomicInteger committed = new AtomicInteger(0); - private final AtomicInteger uncommitted = new AtomicInteger(0); - private OffsetAndEpoch position = new OffsetAndEpoch(0, 0); + private int committed; + private int uncommitted; private LeaderAndEpoch currentLeaderAndEpoch = new LeaderAndEpoch(OptionalInt.empty(), 0); - public ReplicatedCounter(int localBrokerId, - RaftClient client, - LogContext logContext) { + public ReplicatedCounter( + int localBrokerId, + RaftClient client, + LogContext logContext + ) { this.localBrokerId = localBrokerId; this.client = client; this.log = logContext.logger(ReplicatedCounter.class); } - private Records tryRead(long durationMs) { - CompletableFuture future = client.read(position, Isolation.COMMITTED, durationMs); - try { - return future.get(); - } catch (InterruptedException | ExecutionException e) { - throw new RuntimeException(e); - } - } - - private void apply(Record record) { - int value = deserialize(record); - if (value != committed.get() + 1) { - log.debug("Ignoring non-sequential append at offset {}: {} -> {}", - record.offset(), committed.get(), value); - return; - } - - log.debug("Applied increment at offset {}: {} -> {}", record.offset(), committed.get(), value); - committed.set(value); - - if (value > uncommitted.get()) { - uncommitted.set(value); - } - } - - public synchronized void poll(long durationMs) { + public synchronized void poll() { // Check for leader changes LeaderAndEpoch latestLeaderAndEpoch = client.currentLeaderAndEpoch(); if (!currentLeaderAndEpoch.equals(latestLeaderAndEpoch)) { - if (localBrokerId == latestLeaderAndEpoch.leaderId.orElse(-1)) { - uncommitted.set(committed.get()); - } + this.committed = 0; + this.uncommitted = 0; this.currentLeaderAndEpoch = latestLeaderAndEpoch; } - - Records records = tryRead(durationMs); - for (RecordBatch batch : records.batches()) { - if (!batch.isControlBatch()) { - batch.forEach(this::apply); - } - this.position = new OffsetAndEpoch(batch.lastOffset() + 1, batch.partitionLeaderEpoch()); - } } public synchronized boolean isWritable() { @@ -104,29 +62,19 @@ public class ReplicatedCounter { public synchronized void increment() { if (!isWritable()) throw new KafkaException("Counter is not currently writable"); - int initialValue = uncommitted.get(); - int incrementedValue = uncommitted.incrementAndGet(); - Records records = MemoryRecords.withRecords(CompressionType.NONE, serialize(incrementedValue)); - client.append(records, AckMode.LEADER, Integer.MAX_VALUE).whenComplete((offsetAndEpoch, throwable) -> { - if (offsetAndEpoch != null) { - log.debug("Appended increment at offset {}: {} -> {}", - offsetAndEpoch.offset, initialValue, incrementedValue); - } else { - uncommitted.set(initialValue); - log.debug("Failed append of increment: {} -> {}", initialValue, incrementedValue, throwable); - } - }); + uncommitted += 1; + Long offset = client.scheduleAppend(currentLeaderAndEpoch.epoch, Collections.singletonList(uncommitted)); + if (offset != null) { + log.debug("Scheduled append of record {} with epoch {} at offset {}", + uncommitted, currentLeaderAndEpoch.epoch, offset); + } } - private SimpleRecord serialize(int value) { - ByteBuffer buffer = ByteBuffer.allocate(4); - Type.INT32.write(buffer, value); - buffer.flip(); - return new SimpleRecord(buffer); - } - - private int deserialize(Record record) { - return (int) Type.INT32.read(record.value()); + @Override + public void handleCommit(int epoch, long lastOffset, List records) { + log.debug("Received commit of records {} with epoch {} at last offset {}", + records, epoch, lastOffset); + this.committed = records.get(records.size() - 1); } } diff --git a/raft/src/main/java/org/apache/kafka/raft/internals/BatchAccumulator.java b/raft/src/main/java/org/apache/kafka/raft/internals/BatchAccumulator.java new file mode 100644 index 00000000000..4fdeb948a50 --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/internals/BatchAccumulator.java @@ -0,0 +1,346 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft.internals; + +import org.apache.kafka.common.memory.MemoryPool; +import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.raft.RecordSerde; + +import java.io.Closeable; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.locks.ReentrantLock; + +public class BatchAccumulator implements Closeable { + private final int epoch; + private final Time time; + private final SimpleTimer lingerTimer; + private final int lingerMs; + private final int maxBatchSize; + private final CompressionType compressionType; + private final MemoryPool memoryPool; + private final ReentrantLock appendLock; + private final RecordSerde serde; + + private final ConcurrentLinkedQueue> completed; + private volatile DrainStatus drainStatus; + + // These fields are protected by the append lock + private long nextOffset; + private BatchBuilder currentBatch; + + private enum DrainStatus { + STARTED, FINISHED, NONE + } + + public BatchAccumulator( + int epoch, + long baseOffset, + int lingerMs, + int maxBatchSize, + MemoryPool memoryPool, + Time time, + CompressionType compressionType, + RecordSerde serde + ) { + this.epoch = epoch; + this.lingerMs = lingerMs; + this.maxBatchSize = maxBatchSize; + this.memoryPool = memoryPool; + this.time = time; + this.lingerTimer = new SimpleTimer(); + this.compressionType = compressionType; + this.serde = serde; + this.nextOffset = baseOffset; + this.drainStatus = DrainStatus.NONE; + this.completed = new ConcurrentLinkedQueue<>(); + this.appendLock = new ReentrantLock(); + } + + /** + * Append a list of records into an atomic batch. We guarantee all records + * are included in the same underlying record batch so that either all of + * the records become committed or none of them do. + * + * @param epoch the expected leader epoch. If this does not match, then + * {@link Long#MAX_VALUE} will be returned as an offset which + * cannot become committed. + * @param records the list of records to include in a batch + * @return the expected offset of the last record (which will be + * {@link Long#MAX_VALUE} if the epoch does not match), or null if + * no memory could be allocated for the batch at this time + */ + public Long append(int epoch, List records) { + if (epoch != this.epoch) { + // If the epoch does not match, then the state machine probably + // has not gotten the notification about the latest epoch change. + // In this case, ignore the append and return a large offset value + // which will never be committed. + return Long.MAX_VALUE; + } + + Object serdeContext = serde.newWriteContext(); + int batchSize = 0; + for (T record : records) { + batchSize += serde.recordSize(record, serdeContext); + } + + if (batchSize > maxBatchSize) { + throw new IllegalArgumentException("The total size of " + records + " is " + batchSize + + ", which exceeds the maximum allowed batch size of " + maxBatchSize); + } + + appendLock.lock(); + try { + maybeCompleteDrain(); + + BatchBuilder batch = maybeAllocateBatch(batchSize); + if (batch == null) { + return null; + } + + // Restart the linger timer if necessary + if (!lingerTimer.isRunning()) { + lingerTimer.reset(time.milliseconds() + lingerMs); + } + + for (T record : records) { + batch.appendRecord(record, serdeContext); + nextOffset += 1; + } + + return nextOffset - 1; + } finally { + appendLock.unlock(); + } + } + + private BatchBuilder maybeAllocateBatch(int batchSize) { + if (currentBatch == null) { + startNewBatch(); + } else if (!currentBatch.hasRoomFor(batchSize)) { + completeCurrentBatch(); + } + return currentBatch; + } + + private void completeCurrentBatch() { + MemoryRecords data = currentBatch.build(); + completed.add(new CompletedBatch<>( + currentBatch.baseOffset(), + currentBatch.records(), + data, + memoryPool, + currentBatch.initialBuffer() + )); + currentBatch = null; + } + + private void maybeCompleteDrain() { + if (drainStatus == DrainStatus.STARTED) { + if (currentBatch != null && currentBatch.nonEmpty()) { + completeCurrentBatch(); + } + // Reset the timer to a large value. The linger clock will begin + // ticking after the next append. + lingerTimer.reset(Long.MAX_VALUE); + drainStatus = DrainStatus.FINISHED; + } + } + + private void startNewBatch() { + ByteBuffer buffer = memoryPool.tryAllocate(maxBatchSize); + if (buffer != null) { + currentBatch = new BatchBuilder<>( + buffer, + serde, + compressionType, + nextOffset, + time.milliseconds(), + false, + RecordBatch.NO_PARTITION_LEADER_EPOCH, + maxBatchSize + ); + } + } + + /** + * Check whether there are any batches which need to be drained now. + * + * @param currentTimeMs current time in milliseconds + * @return true if there are batches ready to drain, false otherwise + */ + public boolean needsDrain(long currentTimeMs) { + return timeUntilDrain(currentTimeMs) <= 0; + } + + /** + * Check the time remaining until the next needed drain. If the accumulator + * is empty, then {@link Long#MAX_VALUE} will be returned. + * + * @param currentTimeMs current time in milliseconds + * @return the delay in milliseconds before the next expected drain + */ + public long timeUntilDrain(long currentTimeMs) { + if (drainStatus == DrainStatus.FINISHED) { + return 0; + } else { + return lingerTimer.remainingMs(currentTimeMs); + } + } + + /** + * Get the leader epoch, which is constant for each instance. + * + * @return the leader epoch + */ + public int epoch() { + return epoch; + } + + /** + * Drain completed batches. The caller is expected to first check whether + * {@link #needsDrain(long)} returns true in order to avoid unnecessary draining. + * + * Note on thread-safety: this method is safe in the presence of concurrent + * appends, but it assumes a single thread is responsible for draining. + * + * This call will not block, but the drain may require multiple attempts before + * it can be completed if the thread responsible for appending is holding the + * append lock. In the worst case, the append will be completed on the next + * call to {@link #append(int, List)} following the initial call to this method. + * The caller should respect the time to the next flush as indicated by + * {@link #timeUntilDrain(long)}. + * + * @return the list of completed batches + */ + public List> drain() { + // Start the drain if it has not been started already + if (drainStatus == DrainStatus.NONE) { + drainStatus = DrainStatus.STARTED; + } + + // Complete the drain ourselves if we can acquire the lock + if (appendLock.tryLock()) { + try { + maybeCompleteDrain(); + } finally { + appendLock.unlock(); + } + } + + // If the drain has finished, then all of the batches will be completed + if (drainStatus == DrainStatus.FINISHED) { + drainStatus = DrainStatus.NONE; + return drainCompleted(); + } else { + return Collections.emptyList(); + } + } + + private List> drainCompleted() { + List> res = new ArrayList<>(completed.size()); + while (true) { + CompletedBatch batch = completed.poll(); + if (batch == null) { + return res; + } else { + res.add(batch); + } + } + } + + /** + * Get the number of batches including the one that is currently being + * written to (if it exists). + */ + public int count() { + appendLock.lock(); + try { + int count = completed.size(); + if (currentBatch != null) { + return count + 1; + } else { + return count; + } + } finally { + appendLock.unlock(); + } + } + + @Override + public void close() { + List> unwritten = drain(); + unwritten.forEach(CompletedBatch::release); + } + + public static class CompletedBatch { + public final long baseOffset; + public final List records; + public final MemoryRecords data; + private final MemoryPool pool; + private final ByteBuffer buffer; + + private CompletedBatch( + long baseOffset, + List records, + MemoryRecords data, + MemoryPool pool, + ByteBuffer buffer + ) { + this.baseOffset = baseOffset; + this.records = records; + this.data = data; + this.pool = pool; + this.buffer = buffer; + } + + public int sizeInBytes() { + return data.sizeInBytes(); + } + + public void release() { + pool.release(buffer); + } + } + + private static class SimpleTimer { + // We use an atomic long so that the Raft IO thread can query the linger + // time without any locking + private final AtomicLong deadlineMs = new AtomicLong(Long.MAX_VALUE); + + boolean isRunning() { + return deadlineMs.get() != Long.MAX_VALUE; + } + + void reset(long deadlineMs) { + this.deadlineMs.set(deadlineMs); + } + + long remainingMs(long currentTimeMs) { + return Math.max(0, deadlineMs.get() - currentTimeMs); + } + } + +} diff --git a/raft/src/main/java/org/apache/kafka/raft/internals/BatchBuilder.java b/raft/src/main/java/org/apache/kafka/raft/internals/BatchBuilder.java new file mode 100644 index 00000000000..ea203f57730 --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/internals/BatchBuilder.java @@ -0,0 +1,306 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft.internals; + +import org.apache.kafka.common.protocol.DataOutputStreamWritable; +import org.apache.kafka.common.protocol.Writable; +import org.apache.kafka.common.record.AbstractRecords; +import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.record.DefaultRecord; +import org.apache.kafka.common.record.DefaultRecordBatch; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.record.TimestampType; +import org.apache.kafka.common.utils.ByteBufferOutputStream; +import org.apache.kafka.common.utils.ByteUtils; +import org.apache.kafka.raft.RecordSerde; + +import java.io.DataOutputStream; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + +/** + * Collect a set of records into a single batch. New records are added + * through {@link #appendRecord(Object, Object)}, but the caller must first + * check whether there is room using {@link #hasRoomFor(int)}. Once the + * batch is ready, then {@link #build()} should be used to get the resulting + * {@link MemoryRecords} instance. + * + * @param record type indicated by {@link RecordSerde} passed in constructor + */ +public class BatchBuilder { + private final ByteBuffer initialBuffer; + private final CompressionType compressionType; + private final ByteBufferOutputStream batchOutput; + private final DataOutputStreamWritable recordOutput; + private final long baseOffset; + private final long logAppendTime; + private final boolean isControlBatch; + private final int leaderEpoch; + private final int initialPosition; + private final int maxBytes; + private final RecordSerde serde; + private final List records; + + private long nextOffset; + private int unflushedBytes; + private boolean isOpenForAppends = true; + + public BatchBuilder( + ByteBuffer buffer, + RecordSerde serde, + CompressionType compressionType, + long baseOffset, + long logAppendTime, + boolean isControlBatch, + int leaderEpoch, + int maxBytes + ) { + this.initialBuffer = buffer; + this.batchOutput = new ByteBufferOutputStream(buffer); + this.serde = serde; + this.compressionType = compressionType; + this.baseOffset = baseOffset; + this.nextOffset = baseOffset; + this.logAppendTime = logAppendTime; + this.isControlBatch = isControlBatch; + this.initialPosition = batchOutput.position(); + this.leaderEpoch = leaderEpoch; + this.maxBytes = maxBytes; + this.records = new ArrayList<>(); + + int batchHeaderSizeInBytes = AbstractRecords.recordBatchHeaderSizeInBytes( + RecordBatch.MAGIC_VALUE_V2, compressionType); + batchOutput.position(initialPosition + batchHeaderSizeInBytes); + + this.recordOutput = new DataOutputStreamWritable(new DataOutputStream( + compressionType.wrapForOutput(this.batchOutput, RecordBatch.MAGIC_VALUE_V2))); + } + + /** + * Append a record to this patch. The caller must first verify there is room for the batch + * using {@link #hasRoomFor(int)}. + * + * @param record the record to append + * @param serdeContext serialization context for use in {@link RecordSerde#write(Object, Object, Writable)} + * @return the offset of the appended batch + */ + public long appendRecord(T record, Object serdeContext) { + if (!isOpenForAppends) { + throw new IllegalArgumentException("Cannot append new records after the batch has been built"); + } + + if (nextOffset - baseOffset > Integer.MAX_VALUE) { + throw new IllegalArgumentException("Cannot include more than " + Integer.MAX_VALUE + + " records in a single batch"); + } + + long offset = nextOffset++; + int recordSizeInBytes = writeRecord(offset, record, serdeContext); + unflushedBytes += recordSizeInBytes; + records.add(record); + return offset; + } + + /** + * Check whether the batch has enough room for a record of the given size in bytes. + * + * @param sizeInBytes the size of the record to be appended + * @return true if there is room for the record to be appended, false otherwise + */ + public boolean hasRoomFor(int sizeInBytes) { + if (!isOpenForAppends) { + return false; + } + + if (nextOffset - baseOffset >= Integer.MAX_VALUE) { + return false; + } + + int recordSizeInBytes = DefaultRecord.sizeOfBodyInBytes( + (int) (nextOffset - baseOffset), + 0, + -1, + sizeInBytes, + DefaultRecord.EMPTY_HEADERS + ); + + int unusedSizeInBytes = maxBytes - approximateSizeInBytes(); + if (unusedSizeInBytes >= recordSizeInBytes) { + return true; + } else if (unflushedBytes > 0) { + recordOutput.flush(); + unflushedBytes = 0; + unusedSizeInBytes = maxBytes - flushedSizeInBytes(); + return unusedSizeInBytes >= recordSizeInBytes; + } else { + return false; + } + } + + private int flushedSizeInBytes() { + return batchOutput.position() - initialPosition; + } + + /** + * Get an estimate of the current size of the appended data. This estimate + * is precise if no compression is in use. + * + * @return estimated size in bytes of the appended records + */ + public int approximateSizeInBytes() { + return flushedSizeInBytes() + unflushedBytes; + } + + /** + * Get the base offset of this batch. This is constant upon constructing + * the builder instance. + * + * @return the base offset + */ + public long baseOffset() { + return baseOffset; + } + + /** + * Return the offset of the last appended record. This is updated after + * every append and can be used after the batch has been built to obtain + * the last offset. + * + * @return the offset of the last appended record + */ + public long lastOffset() { + return nextOffset - 1; + } + + /** + * Get the number of records appended to the batch. This is updated after + * each append. + * + * @return the number of appended records + */ + public int numRecords() { + return (int) (nextOffset - baseOffset); + } + + /** + * Check whether there has been at least one record appended to the batch. + * + * @return true if one or more records have been appended + */ + public boolean nonEmpty() { + return numRecords() > 0; + } + + /** + * Return the reference to the initial buffer passed through the constructor. + * This is used in case the buffer needs to be returned to a pool (e.g. + * in {@link org.apache.kafka.common.memory.MemoryPool#release(ByteBuffer)}. + * + * @return the initial buffer passed to the constructor + */ + public ByteBuffer initialBuffer() { + return initialBuffer; + } + + /** + * Get a list of the records appended to the batch. + * @return a list of records + */ + public List records() { + return records; + } + + private void writeDefaultBatchHeader() { + ByteBuffer buffer = batchOutput.buffer(); + int lastPosition = buffer.position(); + + buffer.position(initialPosition); + int size = lastPosition - initialPosition; + int lastOffsetDelta = (int) (lastOffset() - baseOffset); + + DefaultRecordBatch.writeHeader( + buffer, + baseOffset, + lastOffsetDelta, + size, + RecordBatch.MAGIC_VALUE_V2, + compressionType, + TimestampType.CREATE_TIME, + logAppendTime, + logAppendTime, + RecordBatch.NO_PRODUCER_ID, + RecordBatch.NO_PRODUCER_EPOCH, + RecordBatch.NO_SEQUENCE, + false, + isControlBatch, + leaderEpoch, + numRecords() + ); + + buffer.position(lastPosition); + } + + public MemoryRecords build() { + recordOutput.close(); + writeDefaultBatchHeader(); + ByteBuffer buffer = batchOutput.buffer().duplicate(); + buffer.flip(); + buffer.position(initialPosition); + isOpenForAppends = false; + return MemoryRecords.readableRecords(buffer.slice()); + } + + private int writeRecord( + long offset, + T payload, + Object serdeContext + ) { + int offsetDelta = (int) (offset - baseOffset); + long timestampDelta = 0; + + int payloadSize = serde.recordSize(payload, serdeContext); + int sizeInBytes = DefaultRecord.sizeOfBodyInBytes( + offsetDelta, + timestampDelta, + -1, + payloadSize, + DefaultRecord.EMPTY_HEADERS + ); + recordOutput.writeVarint(sizeInBytes); + + // Write attributes (currently unused) + recordOutput.writeByte((byte) 0); + + // Write timestamp and offset + recordOutput.writeVarlong(timestampDelta); + recordOutput.writeVarint(offsetDelta); + + // Write key, which is always null for controller messages + recordOutput.writeVarint(-1); + + // Write value + recordOutput.writeVarint(payloadSize); + serde.write(payload, serdeContext, recordOutput); + + // Write headers (currently unused) + recordOutput.writeVarint(0); + return ByteUtils.sizeOfVarint(sizeInBytes) + sizeInBytes; + } + +} diff --git a/raft/src/main/java/org/apache/kafka/raft/internals/BatchMemoryPool.java b/raft/src/main/java/org/apache/kafka/raft/internals/BatchMemoryPool.java new file mode 100644 index 00000000000..5cd3e3316cb --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/internals/BatchMemoryPool.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft.internals; + +import org.apache.kafka.common.memory.MemoryPool; + +import java.nio.ByteBuffer; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.concurrent.locks.ReentrantLock; + +/** + * Simple memory pool which maintains a limited number of fixed-size buffers. + */ +public class BatchMemoryPool implements MemoryPool { + private final ReentrantLock lock; + private final Deque free; + private final int maxBatches; + private final int batchSize; + + private int numAllocatedBatches = 0; + + public BatchMemoryPool(int maxBatches, int batchSize) { + this.maxBatches = maxBatches; + this.batchSize = batchSize; + this.free = new ArrayDeque<>(maxBatches); + this.lock = new ReentrantLock(); + } + + @Override + public ByteBuffer tryAllocate(int sizeBytes) { + if (sizeBytes > batchSize) { + throw new IllegalArgumentException("Cannot allocate buffers larger than max " + + "batch size of " + batchSize); + } + + lock.lock(); + try { + ByteBuffer buffer = free.poll(); + if (buffer == null && numAllocatedBatches < maxBatches) { + buffer = ByteBuffer.allocate(batchSize); + numAllocatedBatches += 1; + } + return buffer; + } finally { + lock.unlock(); + } + } + + @Override + public void release(ByteBuffer previouslyAllocated) { + lock.lock(); + try { + previouslyAllocated.clear(); + + if (previouslyAllocated.limit() != batchSize) { + throw new IllegalArgumentException("Released buffer with unexpected size " + + previouslyAllocated.limit()); + } + + free.offer(previouslyAllocated); + } finally { + lock.unlock(); + } + } + + @Override + public long size() { + lock.lock(); + try { + return numAllocatedBatches * (long) batchSize; + } finally { + lock.unlock(); + } + } + + @Override + public long availableMemory() { + lock.lock(); + try { + int freeBatches = free.size() + (maxBatches - numAllocatedBatches); + return freeBatches * (long) batchSize; + } finally { + lock.unlock(); + } + } + + @Override + public boolean isOutOfMemory() { + return availableMemory() == 0; + } + +} diff --git a/raft/src/main/java/org/apache/kafka/raft/internals/StringSerde.java b/raft/src/main/java/org/apache/kafka/raft/internals/StringSerde.java new file mode 100644 index 00000000000..2a793167e42 --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/internals/StringSerde.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft.internals; + +import org.apache.kafka.common.protocol.Writable; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.raft.RecordSerde; + +public class StringSerde implements RecordSerde { + + @Override + public int recordSize(String data, Object context) { + return recordSize(data); + } + + public int recordSize(String data) { + return Utils.utf8Length(data); + } + + @Override + public void write(String data, Object context, Writable out) { + out.writeByteArray(Utils.utf8(data)); + } + +} diff --git a/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientTest.java b/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientTest.java index 80e92855961..40a4a760ec9 100644 --- a/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientTest.java +++ b/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientTest.java @@ -17,7 +17,7 @@ package org.apache.kafka.raft; import org.apache.kafka.common.errors.ClusterAuthorizationException; -import org.apache.kafka.common.errors.NotLeaderOrFollowerException; +import org.apache.kafka.common.memory.MemoryPool; import org.apache.kafka.common.message.BeginQuorumEpochResponseData; import org.apache.kafka.common.message.DescribeQuorumResponseData.ReplicaState; import org.apache.kafka.common.message.EndQuorumEpochResponseData; @@ -40,6 +40,7 @@ import org.junit.jupiter.api.Test; import org.mockito.Mockito; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -50,7 +51,8 @@ import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeoutException; -import static org.apache.kafka.raft.RaftClientTestContext.Builder.ELECTION_TIMEOUT_MS; +import static java.util.Collections.singletonList; +import static org.apache.kafka.raft.RaftClientTestContext.Builder.DEFAULT_ELECTION_TIMEOUT_MS; import static org.apache.kafka.test.TestUtils.assertFutureThrows; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; @@ -61,6 +63,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; public class KafkaRaftClientTest { + @Test public void testInitializeSingleMemberQuorum() throws IOException { int localId = 0; @@ -94,12 +97,11 @@ public class KafkaRaftClientTest { RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) .updateRandom(random -> { - Mockito.doReturn(0).when(random).nextInt(ELECTION_TIMEOUT_MS); + Mockito.doReturn(0).when(random).nextInt(DEFAULT_ELECTION_TIMEOUT_MS); }) .withElectedLeader(epoch, localId) .build(); - assertEquals(0L, context.log.endOffset().offset); context.assertUnknownLeader(epoch); @@ -162,7 +164,8 @@ public class KafkaRaftClientTest { Record record = batch.iterator().next(); assertEquals(electionTimestamp, record.timestamp()); - RaftClientTestContext.verifyLeaderChangeMessage(context.localId, Collections.singletonList(otherNodeId), record.key(), record.value()); + RaftClientTestContext.verifyLeaderChangeMessage(context.localId, + Collections.singletonList(otherNodeId), record.key(), record.value()); } @Test @@ -176,7 +179,6 @@ public class KafkaRaftClientTest { .withVotedCandidate(votedCandidateEpoch, otherNodeId) .build(); - context.deliverRequest(context.beginEpochRequest(votedCandidateEpoch, otherNodeId)); context.client.poll(); @@ -219,7 +221,8 @@ public class KafkaRaftClientTest { .withVotedCandidate(epoch, localId) .build(); - context.deliverRequest(context.endEpochRequest(epoch, OptionalInt.empty(), otherNodeId, Collections.singletonList(context.localId))); + context.deliverRequest(context.endEpochRequest(epoch, OptionalInt.empty(), + otherNodeId, Collections.singletonList(context.localId))); context.client.poll(); context.assertSentEndQuorumEpochResponse(Errors.NONE, epoch, OptionalInt.empty()); @@ -274,7 +277,8 @@ public class KafkaRaftClientTest { .withVotedCandidate(epoch, otherNodeId) .build(); - context.deliverRequest(context.endEpochRequest(epoch, OptionalInt.empty(), otherNodeId, Collections.singletonList(context.localId))); + context.deliverRequest(context.endEpochRequest(epoch, OptionalInt.empty(), + otherNodeId, Collections.singletonList(context.localId))); context.client.poll(); context.assertSentEndQuorumEpochResponse(Errors.NONE, epoch, OptionalInt.empty()); @@ -311,7 +315,6 @@ public class KafkaRaftClientTest { int otherNodeId = 1; int epoch = 2; Set voters = Utils.mkSet(localId, otherNodeId); - RaftClientTestContext context = RaftClientTestContext.initializeAsLeader(localId, voters, epoch); assertEquals(1L, context.log.endOffset().offset); @@ -323,20 +326,17 @@ public class KafkaRaftClientTest { assertEquals(OptionalLong.of(1L), context.client.highWatermark()); context.assertSentFetchResponse(Errors.NONE, epoch, OptionalInt.of(context.localId)); - SimpleRecord[] records = new SimpleRecord[] { - new SimpleRecord("a".getBytes()), - new SimpleRecord("b".getBytes()) - }; - context.client.append(MemoryRecords.withRecords(CompressionType.NONE, records), AckMode.LEADER, Integer.MAX_VALUE); + String[] records = new String[] {"a", "b"}; + assertEquals(2L, context.client.scheduleAppend(epoch, Arrays.asList(records))); context.client.poll(); assertEquals(3L, context.log.endOffset().offset); assertEquals(3L, context.log.lastFlushedOffset()); assertEquals(OptionalLong.of(1L), context.client.highWatermark()); - context.validateLocalRead(new OffsetAndEpoch(1L, epoch), Isolation.COMMITTED, new SimpleRecord[0]); + context.validateLocalRead(new OffsetAndEpoch(1L, epoch), Isolation.COMMITTED, new String[0]); context.validateLocalRead(new OffsetAndEpoch(1L, epoch), Isolation.UNCOMMITTED, records); - context.validateLocalRead(new OffsetAndEpoch(3L, epoch), Isolation.COMMITTED, new SimpleRecord[0]); - context.validateLocalRead(new OffsetAndEpoch(3L, epoch), Isolation.UNCOMMITTED, new SimpleRecord[0]); + context.validateLocalRead(new OffsetAndEpoch(3L, epoch), Isolation.COMMITTED, new String[0]); + context.validateLocalRead(new OffsetAndEpoch(3L, epoch), Isolation.UNCOMMITTED, new String[0]); context.deliverRequest(context.fetchRequest(epoch, otherNodeId, 3L, epoch, 0)); context.client.poll(); @@ -372,11 +372,8 @@ public class KafkaRaftClientTest { Isolation.COMMITTED, 500); assertFalse(logEndReadFuture.isDone()); - SimpleRecord[] records = new SimpleRecord[] { - new SimpleRecord("a".getBytes()), - new SimpleRecord("b".getBytes()) - }; - context.client.append(MemoryRecords.withRecords(CompressionType.NONE, records), AckMode.LEADER, Integer.MAX_VALUE); + String[] records = new String[] {"a", "b"}; + assertEquals(2L, context.client.scheduleAppend(epoch, Arrays.asList(records))); context.client.poll(); assertEquals(3L, context.log.endOffset().offset); assertEquals(OptionalLong.of(1L), context.client.highWatermark()); @@ -534,12 +531,8 @@ public class KafkaRaftClientTest { Set voters = Utils.mkSet(localId, otherNodeId); RaftClientTestContext context = RaftClientTestContext.initializeAsLeader(localId, voters, epoch); - SimpleRecord[] records = new SimpleRecord[] { - new SimpleRecord("a".getBytes()), - new SimpleRecord("b".getBytes()) - }; - context.client.append(MemoryRecords.withRecords(CompressionType.NONE, records), - AckMode.LEADER, Integer.MAX_VALUE); + String[] records = new String[] {"a", "b"}; + assertEquals(2L, context.client.scheduleAppend(epoch, Arrays.asList(records))); context.client.poll(); assertEquals(3L, context.log.endOffset().offset); @@ -565,6 +558,93 @@ public class KafkaRaftClientTest { assertFutureThrows(future, LogTruncationException.class); } + @Test + public void testAccumulatorClearedAfterBecomingFollower() throws Exception { + int localId = 0; + int otherNodeId = 1; + int lingerMs = 50; + Set voters = Utils.mkSet(localId, otherNodeId); + + MemoryPool memoryPool = Mockito.mock(MemoryPool.class); + ByteBuffer buffer = ByteBuffer.allocate(KafkaRaftClient.MAX_BATCH_SIZE); + Mockito.when(memoryPool.tryAllocate(KafkaRaftClient.MAX_BATCH_SIZE)) + .thenReturn(buffer); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withAppendLingerMs(lingerMs) + .withMemoryPool(memoryPool) + .build(); + + context.becomeLeader(); + assertEquals(OptionalInt.of(localId), context.client.currentLeaderAndEpoch().leaderId); + int epoch = context.client.currentLeaderAndEpoch().epoch; + + assertEquals(1L, context.client.scheduleAppend(epoch, singletonList("a"))); + context.deliverRequest(context.beginEpochRequest(epoch + 1, otherNodeId)); + context.client.poll(); + + context.assertElectedLeader(epoch + 1, otherNodeId); + Mockito.verify(memoryPool).release(buffer); + } + + @Test + public void testAccumulatorClearedAfterBecomingVoted() throws Exception { + int localId = 0; + int otherNodeId = 1; + int lingerMs = 50; + Set voters = Utils.mkSet(localId, otherNodeId); + + MemoryPool memoryPool = Mockito.mock(MemoryPool.class); + ByteBuffer buffer = ByteBuffer.allocate(KafkaRaftClient.MAX_BATCH_SIZE); + Mockito.when(memoryPool.tryAllocate(KafkaRaftClient.MAX_BATCH_SIZE)) + .thenReturn(buffer); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withAppendLingerMs(lingerMs) + .withMemoryPool(memoryPool) + .build(); + + context.becomeLeader(); + assertEquals(OptionalInt.of(localId), context.client.currentLeaderAndEpoch().leaderId); + int epoch = context.client.currentLeaderAndEpoch().epoch; + + assertEquals(1L, context.client.scheduleAppend(epoch, singletonList("a"))); + context.deliverRequest(context.voteRequest(epoch + 1, otherNodeId, epoch, + context.log.endOffset().offset)); + context.client.poll(); + + context.assertVotedCandidate(epoch + 1, otherNodeId); + Mockito.verify(memoryPool).release(buffer); + } + + @Test + public void testAccumulatorClearedAfterBecomingUnattached() throws Exception { + int localId = 0; + int otherNodeId = 1; + int lingerMs = 50; + Set voters = Utils.mkSet(localId, otherNodeId); + + MemoryPool memoryPool = Mockito.mock(MemoryPool.class); + ByteBuffer buffer = ByteBuffer.allocate(KafkaRaftClient.MAX_BATCH_SIZE); + Mockito.when(memoryPool.tryAllocate(KafkaRaftClient.MAX_BATCH_SIZE)) + .thenReturn(buffer); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withAppendLingerMs(lingerMs) + .withMemoryPool(memoryPool) + .build(); + + context.becomeLeader(); + assertEquals(OptionalInt.of(localId), context.client.currentLeaderAndEpoch().leaderId); + int epoch = context.client.currentLeaderAndEpoch().epoch; + + assertEquals(1L, context.client.scheduleAppend(epoch, singletonList("a"))); + context.deliverRequest(context.voteRequest(epoch + 1, otherNodeId, epoch, 0L)); + context.client.poll(); + + context.assertUnknownLeader(epoch + 1); + Mockito.verify(memoryPool).release(buffer); + } @Test public void testHandleEndQuorumRequest() throws Exception { @@ -630,7 +710,6 @@ public class KafkaRaftClientTest { Set voters = Utils.mkSet(localId, otherNodeId); RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters).build(); - context.assertUnknownLeader(0); context.time.sleep(2 * context.electionTimeoutMs); @@ -773,7 +852,7 @@ public class KafkaRaftClientTest { } @Test - public void testStateMachineApplyCommittedRecords() throws Exception { + public void testListenerCommitCallbackAfterLeaderWrite() throws Exception { int localId = 0; int otherNodeId = 1; int epoch = 5; @@ -791,95 +870,22 @@ public class KafkaRaftClientTest { context.pollUntilSend(); assertEquals(OptionalLong.of(0L), context.client.highWatermark()); - // Append some records with leader commit mode - SimpleRecord[] appendRecords = new SimpleRecord[] { - new SimpleRecord("a".getBytes()), - new SimpleRecord("b".getBytes()), - new SimpleRecord("c".getBytes()) - }; - Records records = MemoryRecords.withRecords(0L, CompressionType.NONE, 1, appendRecords); - CompletableFuture future = context.client.append(records, - AckMode.LEADER, Integer.MAX_VALUE); - + List records = Arrays.asList("a", "b", "c"); + long offset = context.client.scheduleAppend(epoch, records); context.client.poll(); - assertTrue(future.isDone()); - assertEquals(new OffsetAndEpoch(3, epoch), future.get()); + assertTrue(context.listener.commits.isEmpty()); // Let follower send a fetch, it should advance the high watermark context.deliverRequest(context.fetchRequest(epoch, otherNodeId, 1L, epoch, 500)); context.pollUntilSend(); assertEquals(OptionalLong.of(1L), context.client.highWatermark()); + assertTrue(context.listener.commits.isEmpty()); // Let the follower to send another fetch from offset 4 context.deliverRequest(context.fetchRequest(epoch, otherNodeId, 4L, epoch, 500)); context.client.poll(); assertEquals(OptionalLong.of(4L), context.client.highWatermark()); - - // Append more records with quorum commit mode - appendRecords = new SimpleRecord[] { - new SimpleRecord("a".getBytes()), - new SimpleRecord("b".getBytes()), - new SimpleRecord("c".getBytes()) - }; - records = MemoryRecords.withRecords(0L, CompressionType.NONE, 1, appendRecords); - future = context.client.append(records, AckMode.QUORUM, Integer.MAX_VALUE); - - // Appending locally should not complete the future - context.client.poll(); - assertFalse(future.isDone()); - - // Let follower send a fetch, it should not yet advance the high watermark - context.deliverRequest(context.fetchRequest(epoch, otherNodeId, 4L, epoch, 500)); - context.pollUntilSend(); - assertFalse(future.isDone()); - assertEquals(OptionalLong.of(4L), context.client.highWatermark()); - - // Let the follower to send another fetch at 7, which should not advance the high watermark and complete the future - context.deliverRequest(context.fetchRequest(epoch, otherNodeId, 7L, epoch, 500)); - context.client.poll(); - assertEquals(OptionalLong.of(7L), context.client.highWatermark()); - - assertTrue(future.isDone()); - assertEquals(new OffsetAndEpoch(6, epoch), future.get()); - } - - @Test - public void testStateMachineExpireAppendedRecords() throws Exception { - int localId = 0; - int otherNodeId = 1; - int epoch = 5; - Set voters = Utils.mkSet(localId, otherNodeId); - - RaftClientTestContext context = RaftClientTestContext.initializeAsLeader(localId, voters, epoch); - - // First poll has no high watermark advance - context.client.poll(); - assertEquals(OptionalLong.empty(), context.client.highWatermark()); - - // Let follower send a fetch to initialize the high watermark, - // note the offset 0 would be a control message for becoming the leader - context.deliverRequest(context.fetchRequest(epoch, otherNodeId, 0L, epoch, 500)); - context.pollUntilSend(); - assertEquals(OptionalLong.of(0L), context.client.highWatermark()); - - // Append some records with quorum commit mode - SimpleRecord[] appendRecords = new SimpleRecord[] { - new SimpleRecord("a".getBytes()), - new SimpleRecord("b".getBytes()), - new SimpleRecord("c".getBytes()) - }; - - Records records = MemoryRecords.withRecords(0L, CompressionType.NONE, 1, appendRecords); - CompletableFuture future = context.client.append(records, AckMode.QUORUM, context.requestTimeoutMs); - - context.client.poll(); - assertFalse(future.isDone()); - - context.time.sleep(context.requestTimeoutMs - 1); - assertFalse(future.isDone()); - - context.time.sleep(1); - assertTrue(future.isCompletedExceptionally()); + assertEquals(records, context.listener.commits.get(new OffsetAndEpoch(offset, epoch))); } @Test @@ -915,7 +921,6 @@ public class KafkaRaftClientTest { }) .build(); - context.assertUnknownLeader(0); context.time.sleep(2 * context.electionTimeoutMs); @@ -978,7 +983,6 @@ public class KafkaRaftClientTest { }) .build(); - context.assertElectedLeader(epoch, otherNodeId); context.pollUntilSend(); @@ -999,7 +1003,6 @@ public class KafkaRaftClientTest { log.appendAsLeader(Collections.singleton(new SimpleRecord("foo".getBytes())), lastEpoch); }) .build(); - context.assertElectedLeader(epoch, otherNodeId); context.pollUntilSend(); @@ -1165,7 +1168,6 @@ public class KafkaRaftClientTest { RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) .withElectedLeader(epoch, otherNodeId) .build(); - context.assertElectedLeader(epoch, otherNodeId); context.deliverRequest(context.voteRequest(epoch + 1, otherNodeId, 0, -5L)); @@ -1221,22 +1223,12 @@ public class KafkaRaftClientTest { assertEquals(0, context.channel.drainSendQueue().size()); // Append some records that can fulfill the Fetch request - SimpleRecord[] appendRecords = new SimpleRecord[] { - new SimpleRecord("a".getBytes()), - new SimpleRecord("b".getBytes()), - new SimpleRecord("c".getBytes()) - }; - Records records = MemoryRecords.withRecords(0L, CompressionType.NONE, 1, appendRecords); - CompletableFuture future = context.client.append(records, AckMode.LEADER, Integer.MAX_VALUE); + String[] appendRecords = new String[] {"a", "b", "c"}; + context.client.scheduleAppend(epoch, Arrays.asList(appendRecords)); context.client.poll(); - assertTrue(future.isDone()); - MemoryRecords fetchedRecords = context.assertSentFetchResponse(Errors.NONE, epoch, OptionalInt.of(context.localId)); - List recordList = Utils.toList(fetchedRecords.records()); - assertEquals(appendRecords.length, recordList.size()); - for (int i = 0; i < appendRecords.length; i++) { - assertEquals(appendRecords[i], new SimpleRecord(recordList.get(i))); - } + MemoryRecords fetchedRecords = context.assertSentFetchResponse(Errors.NONE, epoch, OptionalInt.of(localId)); + RaftClientTestContext.assertMatchingRecords(appendRecords, fetchedRecords); } @Test @@ -1270,7 +1262,6 @@ public class KafkaRaftClientTest { assertEquals(0, fetchedRecords.sizeInBytes()); } - @Test public void testFetchResponseIgnoredAfterBecomingCandidate() throws Exception { int localId = 0; @@ -1282,7 +1273,6 @@ public class KafkaRaftClientTest { RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) .withElectedLeader(epoch, otherNodeId) .build(); - context.assertElectedLeader(epoch, otherNodeId); // Wait until we have a Fetch inflight to the leader @@ -1318,7 +1308,6 @@ public class KafkaRaftClientTest { RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) .withElectedLeader(epoch, voter2) .build(); - context.assertElectedLeader(epoch, voter2); // Wait until we have a Fetch inflight to the leader @@ -1354,8 +1343,6 @@ public class KafkaRaftClientTest { RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) .withVotedCandidate(epoch, voter1) .build(); - - context.assertVotedCandidate(epoch, voter1); // Wait until the vote requests are inflight @@ -1558,7 +1545,7 @@ public class KafkaRaftClientTest { new ReplicaState() .setReplicaId(closeFollower) .setLogEndOffset(1L)), - Collections.singletonList( + singletonList( new ReplicaState() .setReplicaId(observerId) .setLogEndOffset(0L))); @@ -1606,7 +1593,6 @@ public class KafkaRaftClientTest { RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) .withElectedLeader(epoch, otherNodeId) .build(); - context.assertElectedLeader(epoch, otherNodeId); context.client.poll(); @@ -1647,7 +1633,6 @@ public class KafkaRaftClientTest { RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) .withElectedLeader(epoch, otherNodeId) .build(); - context.assertElectedLeader(epoch, otherNodeId); context.pollUntilSend(); @@ -1673,7 +1658,6 @@ public class KafkaRaftClientTest { RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) .withElectedLeader(epoch, otherNodeId) .build(); - context.assertElectedLeader(epoch, otherNodeId); // Receive an empty fetch response @@ -1709,46 +1693,6 @@ public class KafkaRaftClientTest { assertEquals(OptionalLong.of(2L), context.client.highWatermark()); } - @Test - public void testAppendEmptyRecordSetNotAllowed() throws Exception { - int localId = 0; - int epoch = 5; - Set voters = Collections.singleton(localId); - - RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) - .withElectedLeader(epoch, localId) - .build(); - - assertThrows(IllegalArgumentException.class, () -> - context.client.append(MemoryRecords.EMPTY, AckMode.LEADER, Integer.MAX_VALUE)); - } - - @Test - public void testAppendToNonLeaderFails() throws IOException { - int localId = 0; - int otherNodeId = 1; - int epoch = 5; - Set voters = Utils.mkSet(localId, otherNodeId); - - RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) - .withElectedLeader(epoch, otherNodeId) - .build(); - - context.assertElectedLeader(epoch, otherNodeId); - - SimpleRecord[] appendRecords = new SimpleRecord[] { - new SimpleRecord("a".getBytes()), - new SimpleRecord("b".getBytes()), - new SimpleRecord("c".getBytes()) - }; - Records records = MemoryRecords.withRecords(0L, CompressionType.NONE, 1, appendRecords); - - CompletableFuture future = context.client.append(records, AckMode.LEADER, Integer.MAX_VALUE); - context.client.poll(); - - assertFutureThrows(future, NotLeaderOrFollowerException.class); - } - @Test public void testFetchShouldBeTreatedAsLeaderEndorsement() throws Exception { int localId = 0; @@ -1758,13 +1702,13 @@ public class KafkaRaftClientTest { RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) .updateRandom(random -> { - Mockito.doReturn(0).when(random).nextInt(ELECTION_TIMEOUT_MS); + Mockito.doReturn(0).when(random).nextInt(DEFAULT_ELECTION_TIMEOUT_MS); }) .withUnknownLeader(epoch - 1) .build(); context.time.sleep(context.electionTimeoutMs); - context.expectLeaderElection(epoch); + context.expectAndGrantVotes(epoch); context.pollUntilSend(); @@ -1799,18 +1743,13 @@ public class KafkaRaftClientTest { // We still write the leader change message assertEquals(OptionalLong.of(1L), context.client.highWatermark()); - SimpleRecord[] appendRecords = new SimpleRecord[] { - new SimpleRecord("a".getBytes()), - new SimpleRecord("b".getBytes()), - new SimpleRecord("c".getBytes()) - }; - Records records = MemoryRecords.withRecords(1L, CompressionType.NONE, 1, appendRecords); + String[] appendRecords = new String[] {"a", "b", "c"}; // First poll has no high watermark advance context.client.poll(); assertEquals(OptionalLong.of(1L), context.client.highWatermark()); - context.client.append(records, AckMode.LEADER, Integer.MAX_VALUE); + context.client.scheduleAppend(context.client.currentLeaderAndEpoch().epoch, Arrays.asList(appendRecords)); // Then poll the appended data with leader change record context.client.poll(); @@ -1842,7 +1781,7 @@ public class KafkaRaftClientTest { assertEquals(3, readRecords.size()); for (int i = 0; i < appendRecords.length; i++) { - assertEquals(appendRecords[i].value(), readRecords.get(i).value()); + assertEquals(appendRecords[i], Utils.utf8(readRecords.get(i).value())); } } @@ -1889,8 +1828,8 @@ public class KafkaRaftClientTest { public void testMetrics() throws Exception { int localId = 0; int epoch = 1; - - RaftClientTestContext context = new RaftClientTestContext.Builder(localId, Collections.singleton(localId)).build(); + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, Collections.singleton(localId)) + .build(); assertNotNull(getMetric(context.metrics, "current-state")); assertNotNull(getMetric(context.metrics, "current-leader")); @@ -1916,13 +1855,7 @@ public class KafkaRaftClientTest { assertEquals((double) 1L, getMetric(context.metrics, "log-end-offset").metricValue()); assertEquals((double) epoch, getMetric(context.metrics, "log-end-epoch").metricValue()); - SimpleRecord[] appendRecords = new SimpleRecord[] { - new SimpleRecord("a".getBytes()), - new SimpleRecord("b".getBytes()), - new SimpleRecord("c".getBytes()) - }; - Records records = MemoryRecords.withRecords(0L, CompressionType.NONE, 1, appendRecords); - context.client.append(records, AckMode.LEADER, Integer.MAX_VALUE); + context.client.scheduleAppend(epoch, Arrays.asList("a", "b", "c")); context.client.poll(); assertEquals((double) 4L, getMetric(context.metrics, "high-watermark").metricValue()); @@ -1969,14 +1902,13 @@ public class KafkaRaftClientTest { RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) .updateRandom(random -> { - Mockito.doReturn(0).when(random).nextInt(ELECTION_TIMEOUT_MS); + Mockito.doReturn(0).when(random).nextInt(DEFAULT_ELECTION_TIMEOUT_MS); }) .withUnknownLeader(epoch - 1) .build(); - context.time.sleep(context.electionTimeoutMs); - context.expectLeaderElection(epoch); + context.expectAndGrantVotes(epoch); context.pollUntilSend(); int correlationId = context.assertSentBeginQuorumEpochRequest(epoch); @@ -2032,4 +1964,5 @@ public class KafkaRaftClientTest { private static KafkaMetric getMetric(final Metrics metrics, final String name) { return metrics.metrics().get(metrics.metricName(name, "raft-metrics")); } + } diff --git a/raft/src/test/java/org/apache/kafka/raft/MockLog.java b/raft/src/test/java/org/apache/kafka/raft/MockLog.java index 3b76aafeb73..57c519e001d 100644 --- a/raft/src/test/java/org/apache/kafka/raft/MockLog.java +++ b/raft/src/test/java/org/apache/kafka/raft/MockLog.java @@ -26,6 +26,7 @@ import org.apache.kafka.common.record.RecordBatch; import org.apache.kafka.common.record.Records; import org.apache.kafka.common.record.SimpleRecord; import org.apache.kafka.common.record.TimestampType; +import org.apache.kafka.common.utils.Utils; import java.nio.ByteBuffer; import java.util.ArrayList; @@ -35,17 +36,18 @@ import java.util.List; import java.util.Objects; import java.util.Optional; import java.util.OptionalLong; -import java.util.UUID; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Function; import java.util.stream.Collectors; public class MockLog implements ReplicatedLog { + private static final AtomicLong ID_GENERATOR = new AtomicLong(); + private final List epochStartOffsets = new ArrayList<>(); private final List log = new ArrayList<>(); private final TopicPartition topicPartition; - private UUID nextId = UUID.randomUUID(); + private long nextId = ID_GENERATOR.getAndIncrement(); private LogOffsetMetadata highWatermark = new LogOffsetMetadata(0L, Optional.empty()); private long lastFlushedOffset = 0L; @@ -109,11 +111,11 @@ public class MockLog implements ReplicatedLog { return; } - UUID id = ((MockOffsetMetadata) offsetMetadata.metadata.get()).id; + long id = ((MockOffsetMetadata) offsetMetadata.metadata.get()).id; long offset = offsetMetadata.offset; metadataForOffset(offset).ifPresent(metadata -> { - UUID entryId = ((MockOffsetMetadata) metadata).id; + long entryId = ((MockOffsetMetadata) metadata).id; if (entryId != id) { throw new IllegalArgumentException("High watermark " + offset + " metadata uuid " + id + " does not match the " + @@ -180,14 +182,26 @@ public class MockLog implements ReplicatedLog { List entries = new ArrayList<>(); for (Record record : batch) { long offset = offsetSupplier.apply(record); - entries.add(buildEntry(offset, new SimpleRecord(record))); + long timestamp = record.timestamp(); + ByteBuffer key = copy(record.key()); + ByteBuffer value = copy(record.value()); + entries.add(buildEntry(offset, new SimpleRecord(timestamp, key, value))); } return entries; } + private ByteBuffer copy(ByteBuffer nullableByteBuffer) { + if (nullableByteBuffer == null) { + return null; + } else { + byte[] array = Utils.toArray(nullableByteBuffer, nullableByteBuffer.position(), nullableByteBuffer.limit()); + return ByteBuffer.wrap(array); + } + } + private LogEntry buildEntry(Long offset, SimpleRecord record) { - UUID id = nextId; - nextId = UUID.randomUUID(); + long id = nextId; + nextId = ID_GENERATOR.getAndIncrement(); return new LogEntry(new MockOffsetMetadata(id), offset, record); } @@ -338,15 +352,17 @@ public class MockLog implements ReplicatedLog { } static class MockOffsetMetadata implements OffsetMetadata { - final UUID id; + final long id; - MockOffsetMetadata(UUID id) { + MockOffsetMetadata(long id) { this.id = id; } @Override public String toString() { - return id.toString(); + return "MockOffsetMetadata(" + + "id=" + id + + ')'; } @Override @@ -354,7 +370,7 @@ public class MockLog implements ReplicatedLog { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; MockOffsetMetadata that = (MockOffsetMetadata) o; - return Objects.equals(id, that.id); + return id == that.id; } @Override diff --git a/raft/src/test/java/org/apache/kafka/raft/MockLogTest.java b/raft/src/test/java/org/apache/kafka/raft/MockLogTest.java index 285c805109f..6ed14202ef7 100644 --- a/raft/src/test/java/org/apache/kafka/raft/MockLogTest.java +++ b/raft/src/test/java/org/apache/kafka/raft/MockLogTest.java @@ -37,7 +37,6 @@ import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.Optional; -import java.util.UUID; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -315,7 +314,7 @@ public class MockLogTest { // Now update to a high watermark with invalid metadata assertThrows(IllegalArgumentException.class, () -> log.updateHighWatermark(new LogOffsetMetadata(10L, - Optional.of(new MockLog.MockOffsetMetadata(UUID.randomUUID()))))); + Optional.of(new MockLog.MockOffsetMetadata(98230980L))))); // Ensure we can update the high watermark to the end offset LogFetchInfo readFromEndInfo = log.read(15L, Isolation.UNCOMMITTED); diff --git a/raft/src/test/java/org/apache/kafka/raft/RaftClientTestContext.java b/raft/src/test/java/org/apache/kafka/raft/RaftClientTestContext.java index 970bc3519ed..97866722ad6 100644 --- a/raft/src/test/java/org/apache/kafka/raft/RaftClientTestContext.java +++ b/raft/src/test/java/org/apache/kafka/raft/RaftClientTestContext.java @@ -20,8 +20,10 @@ import java.io.IOException; import java.net.InetSocketAddress; import java.nio.ByteBuffer; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.HashSet; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -33,6 +35,7 @@ import java.util.function.Consumer; import java.util.function.Function; import java.util.stream.Collectors; import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.memory.MemoryPool; import org.apache.kafka.common.message.BeginQuorumEpochRequestData; import org.apache.kafka.common.message.BeginQuorumEpochResponseData; import org.apache.kafka.common.message.DescribeQuorumResponseData.ReplicaState; @@ -65,6 +68,7 @@ import org.apache.kafka.common.requests.VoteResponse; import org.apache.kafka.common.utils.LogContext; import org.apache.kafka.common.utils.MockTime; import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.raft.internals.StringSerde; import org.apache.kafka.test.TestUtils; import org.mockito.Mockito; import static org.apache.kafka.raft.RaftUtil.hasValidTopicPartition; @@ -74,25 +78,25 @@ import static org.junit.jupiter.api.Assertions.assertTrue; final class RaftClientTestContext { final TopicPartition metadataPartition = Builder.METADATA_PARTITION; final int electionBackoffMaxMs = Builder.ELECTION_BACKOFF_MAX_MS; - final int electionTimeoutMs = Builder.ELECTION_TIMEOUT_MS; + final int electionTimeoutMs = Builder.DEFAULT_ELECTION_TIMEOUT_MS; final int electionFetchMaxWaitMs = Builder.FETCH_MAX_WAIT_MS; final int fetchTimeoutMs = Builder.FETCH_TIMEOUT_MS; final int requestTimeoutMs = Builder.REQUEST_TIMEOUT_MS; final int retryBackoffMs = Builder.RETRY_BACKOFF_MS; private final QuorumStateStore quorumStateStore; - private final Random random; - final int localId; - final KafkaRaftClient client; + final KafkaRaftClient client; final Metrics metrics; final MockLog log; final MockNetworkChannel channel; final MockTime time; + final MockListener listener; + final Set voters; public static final class Builder { - static final int ELECTION_TIMEOUT_MS = 10000; + static final int DEFAULT_ELECTION_TIMEOUT_MS = 10000; private static final TopicPartition METADATA_PARTITION = new TopicPartition("metadata", 0); private static final int ELECTION_BACKOFF_MAX_MS = 100; @@ -101,6 +105,7 @@ final class RaftClientTestContext { private static final int FETCH_TIMEOUT_MS = 50000; private static final int REQUEST_TIMEOUT_MS = 5000; private static final int RETRY_BACKOFF_MS = 50; + private static final int DEFAULT_APPEND_LINGER_MS = 0; private final QuorumStateStore quorumStateStore = new MockQuorumStateStore(); private final Random random = Mockito.spy(new Random(1)); @@ -108,6 +113,10 @@ final class RaftClientTestContext { private final Set voters; private final int localId; + private int electionTimeoutMs = DEFAULT_ELECTION_TIMEOUT_MS; + private int appendLingerMs = DEFAULT_APPEND_LINGER_MS; + private MemoryPool memoryPool = MemoryPool.NONE; + Builder(int localId, Set voters) { this.voters = voters; this.localId = localId; @@ -133,6 +142,16 @@ final class RaftClientTestContext { return this; } + Builder withMemoryPool(MemoryPool pool) { + this.memoryPool = pool; + return this; + } + + Builder withAppendLingerMs(int appendLingerMs) { + this.appendLingerMs = appendLingerMs; + return this; + } + Builder updateLog(Consumer consumer) { consumer.accept(log); return this; @@ -143,8 +162,9 @@ final class RaftClientTestContext { Metrics metrics = new Metrics(time); MockNetworkChannel channel = new MockNetworkChannel(); LogContext logContext = new LogContext(); - QuorumState quorum = new QuorumState(localId, voters, ELECTION_TIMEOUT_MS, FETCH_TIMEOUT_MS, + QuorumState quorum = new QuorumState(localId, voters, electionTimeoutMs, FETCH_TIMEOUT_MS, quorumStateStore, time, logContext, random); + MockListener listener = new MockListener(); Map voterAddresses = voters .stream() @@ -153,26 +173,52 @@ final class RaftClientTestContext { RaftClientTestContext::mockAddress )); - KafkaRaftClient client = new KafkaRaftClient(channel, log, quorum, time, metrics, - new MockFuturePurgatory<>(time), new MockFuturePurgatory<>(time), voterAddresses, - ELECTION_BACKOFF_MAX_MS, RETRY_BACKOFF_MS, REQUEST_TIMEOUT_MS, FETCH_MAX_WAIT_MS, logContext, random); + KafkaRaftClient client = new KafkaRaftClient<>( + new StringSerde(), + channel, + log, + quorum, + memoryPool, + time, + metrics, + new MockFuturePurgatory<>(time), + new MockFuturePurgatory<>(time), + voterAddresses, + ELECTION_BACKOFF_MAX_MS, + RETRY_BACKOFF_MS, + REQUEST_TIMEOUT_MS, + FETCH_MAX_WAIT_MS, + appendLingerMs, + logContext, + random + ); - client.initialize(); + client.initialize(listener); - return new RaftClientTestContext(localId, client, log, channel, time, quorumStateStore, voters, random, metrics); + return new RaftClientTestContext( + localId, + client, + log, + channel, + time, + quorumStateStore, + voters, + metrics, + listener + ); } } private RaftClientTestContext( int localId, - KafkaRaftClient client, + KafkaRaftClient client, MockLog log, MockNetworkChannel channel, MockTime time, QuorumStateStore quorumStateStore, Set voters, - Random random, - Metrics metrics + Metrics metrics, + MockListener listener ) { this.localId = localId; this.client = client; @@ -181,8 +227,8 @@ final class RaftClientTestContext { this.time = time; this.quorumStateStore = quorumStateStore; this.voters = voters; - this.random = random; this.metrics = metrics; + this.listener = listener; } static RaftClientTestContext initializeAsLeader(int localId, Set voters, int epoch) throws Exception { @@ -191,30 +237,22 @@ final class RaftClientTestContext { } RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) - .updateRandom(random -> { - Mockito.doReturn(0).when(random).nextInt(Builder.ELECTION_TIMEOUT_MS); - }) .withUnknownLeader(epoch - 1) .build(); context.assertUnknownLeader(epoch - 1); - - // Advance the clock so that we become a candidate - context.time.sleep(context.electionTimeoutMs); - context.expectLeaderElection(epoch); - - // Handle BeginEpoch - context.pollUntilSend(); - for (RaftRequest.Outbound request : context.collectBeginEpochRequests(epoch)) { - BeginQuorumEpochResponseData beginEpochResponse = context.beginEpochResponse(epoch, localId); - context.deliverResponse(request.correlationId, request.destinationId(), beginEpochResponse); - } - - context.client.poll(); + context.becomeLeader(); return context; } - void expectLeaderElection( + void becomeLeader() throws Exception { + int currentEpoch = client.currentLeaderAndEpoch().epoch; + time.sleep(electionTimeoutMs * 2); + expectAndGrantVotes(currentEpoch + 1); + expectBeginEpoch(currentEpoch + 1); + } + + void expectAndGrantVotes( int epoch ) throws Exception { pollUntilSend(); @@ -231,6 +269,17 @@ final class RaftClientTestContext { assertElectedLeader(epoch, localId); } + void expectBeginEpoch( + int epoch + ) throws Exception { + pollUntilSend(); + for (RaftRequest.Outbound request : collectBeginEpochRequests(epoch)) { + BeginQuorumEpochResponseData beginEpochResponse = beginEpochResponse(epoch, localId); + deliverResponse(request.correlationId, request.destinationId(), beginEpochResponse); + } + client.poll(); + } + void pollUntilSend() throws InterruptedException { TestUtils.waitForCondition(() -> { client.poll(); @@ -311,9 +360,7 @@ final class RaftClientTestContext { for (RaftMessage raftMessage : channel.drainSendQueue()) { if (raftMessage.data() instanceof VoteRequestData) { VoteRequestData request = (VoteRequestData) raftMessage.data(); - assertTrue(hasValidTopicPartition(request, metadataPartition)); - - VoteRequestData.PartitionData partitionRequest = request.topics().get(0).partitions().get(0); + VoteRequestData.PartitionData partitionRequest = unwrap(request); assertEquals(epoch, partitionRequest.candidateEpoch()); assertEquals(localId, partitionRequest.candidateId()); @@ -428,6 +475,16 @@ final class RaftClientTestContext { return (MemoryRecords) partitionResponse.recordSet(); } + void validateLocalRead( + OffsetAndEpoch fetchOffsetAndEpoch, + Isolation isolation, + String[] expectedRecords + ) throws Exception { + CompletableFuture future = client.read(fetchOffsetAndEpoch, isolation, 0L); + assertTrue(future.isDone()); + assertMatchingRecords(expectedRecords, future.get()); + } + void validateLocalRead( OffsetAndEpoch fetchOffsetAndEpoch, Isolation isolation, @@ -469,10 +526,7 @@ final class RaftClientTestContext { assertSentFetchResponse(0L, epoch); // Append some records, so that the close follower will be able to advance further. - MemoryRecords records = MemoryRecords.withRecords(CompressionType.NONE, - new SimpleRecord("foo".getBytes()), - new SimpleRecord("bar".getBytes())); - client.append(records, AckMode.LEADER, Integer.MAX_VALUE); + client.scheduleAppend(epoch, Arrays.asList("foo", "bar")); client.poll(); deliverRequest(fetchRequest(epoch, closeFollower, 1L, epoch, 0)); @@ -608,6 +662,24 @@ final class RaftClientTestContext { ); } + private VoteRequestData.PartitionData unwrap(VoteRequestData voteRequest) { + assertTrue(RaftUtil.hasValidTopicPartition(voteRequest, metadataPartition)); + return voteRequest.topics().get(0).partitions().get(0); + } + + static void assertMatchingRecords( + String[] expected, + Records actual + ) { + List recordList = Utils.toList(actual.records()); + assertEquals(expected.length, recordList.size()); + for (int i = 0; i < expected.length; i++) { + Record record = recordList.get(i); + assertEquals(expected[i], Utils.utf8(record.value()), + "Record at offset " + record.offset() + " does not match expected"); + } + } + static void assertMatchingRecords( SimpleRecord[] expected, Records actual @@ -715,4 +787,19 @@ final class RaftClientTestContext { .setEndOffset(divergingEpochEndOffset); }); } + + static class MockListener implements RaftClient.Listener { + final LinkedHashMap> commits = new LinkedHashMap<>(); + + @Override + public void handleCommit(int epoch, long lastOffset, List records) { + OffsetAndEpoch offsetAndEpoch = new OffsetAndEpoch(lastOffset, epoch); + if (commits.containsKey(offsetAndEpoch)) { + throw new AssertionError("Found duplicate append at " + offsetAndEpoch); + } + commits.put(offsetAndEpoch, records); + } + + } + } diff --git a/raft/src/test/java/org/apache/kafka/raft/RaftEventSimulationTest.java b/raft/src/test/java/org/apache/kafka/raft/RaftEventSimulationTest.java index 6336663d1aa..8f18ade1fef 100644 --- a/raft/src/test/java/org/apache/kafka/raft/RaftEventSimulationTest.java +++ b/raft/src/test/java/org/apache/kafka/raft/RaftEventSimulationTest.java @@ -17,7 +17,9 @@ package org.apache.kafka.raft; import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.memory.MemoryPool; import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.protocol.Writable; import org.apache.kafka.common.protocol.types.Type; import org.apache.kafka.common.utils.LogContext; import org.apache.kafka.common.utils.MockTime; @@ -25,7 +27,9 @@ import org.apache.kafka.common.utils.Time; import org.apache.kafka.common.utils.Utils; import org.apache.kafka.raft.MockLog.LogBatch; import org.apache.kafka.raft.MockLog.LogEntry; +import org.apache.kafka.raft.internals.BatchMemoryPool; import org.apache.kafka.raft.internals.LogOffset; +import org.junit.jupiter.api.Test; import java.io.IOException; import java.net.InetSocketAddress; @@ -48,7 +52,6 @@ import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; import java.util.stream.Collectors; -import org.junit.jupiter.api.Test; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -63,6 +66,7 @@ public class RaftEventSimulationTest { private static final int RETRY_BACKOFF_MS = 50; private static final int REQUEST_TIMEOUT_MS = 500; private static final int FETCH_MAX_WAIT_MS = 100; + private static final int LINGER_MS = 0; @Test public void testInitialLeaderElectionQuorumSizeOne() { @@ -758,9 +762,28 @@ public class RaftEventSimulationTest { persistentState.log.reopen(); - KafkaRaftClient client = new KafkaRaftClient(channel, persistentState.log, quorum, time, metrics, - fetchPurgatory, appendPurgatory, voterConnectionMap, ELECTION_JITTER_MS, - RETRY_BACKOFF_MS, REQUEST_TIMEOUT_MS, FETCH_MAX_WAIT_MS, logContext, random); + IntSerde serde = new IntSerde(); + MemoryPool memoryPool = new BatchMemoryPool(2, KafkaRaftClient.MAX_BATCH_SIZE); + + KafkaRaftClient client = new KafkaRaftClient<>( + serde, + channel, + persistentState.log, + quorum, + memoryPool, + time, + metrics, + fetchPurgatory, + appendPurgatory, + voterConnectionMap, + ELECTION_JITTER_MS, + RETRY_BACKOFF_MS, + REQUEST_TIMEOUT_MS, + FETCH_MAX_WAIT_MS, + LINGER_MS, + logContext, + random + ); RaftNode node = new RaftNode(nodeId, client, persistentState.log, channel, persistentState.store, quorum, logContext); node.initialize(); @@ -770,7 +793,7 @@ public class RaftEventSimulationTest { private static class RaftNode { final int nodeId; - final KafkaRaftClient client; + final KafkaRaftClient client; final MockLog log; final MockNetworkChannel channel; final MockQuorumStateStore store; @@ -779,7 +802,7 @@ public class RaftEventSimulationTest { final ReplicatedCounter counter; private RaftNode(int nodeId, - KafkaRaftClient client, + KafkaRaftClient client, MockLog log, MockNetworkChannel channel, MockQuorumStateStore store, @@ -797,7 +820,7 @@ public class RaftEventSimulationTest { void initialize() { try { - client.initialize(); + client.initialize(counter); } catch (IOException e) { throw new RuntimeException(e); } @@ -806,7 +829,7 @@ public class RaftEventSimulationTest { void poll() { try { client.poll(); - counter.poll(0L); + counter.poll(); } catch (IOException e) { throw new RuntimeException(e); } @@ -1108,4 +1131,17 @@ public class RaftEventSimulationTest { } } + private static class IntSerde implements RecordSerde { + + @Override + public int recordSize(Integer data, Object context) { + return Type.INT32.sizeOf(data); + } + + @Override + public void write(Integer data, Object context, Writable out) { + out.writeInt(data); + } + } + } diff --git a/raft/src/test/java/org/apache/kafka/raft/internals/BatchAccumulatorTest.java b/raft/src/test/java/org/apache/kafka/raft/internals/BatchAccumulatorTest.java new file mode 100644 index 00000000000..c33c9e1a2de --- /dev/null +++ b/raft/src/test/java/org/apache/kafka/raft/internals/BatchAccumulatorTest.java @@ -0,0 +1,296 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft.internals; + +import org.apache.kafka.common.memory.MemoryPool; +import org.apache.kafka.common.protocol.Writable; +import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Utils; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CountDownLatch; + +import static java.util.Arrays.asList; +import static java.util.Collections.singletonList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class BatchAccumulatorTest { + private final MemoryPool memoryPool = Mockito.mock(MemoryPool.class); + private final MockTime time = new MockTime(); + private final StringSerde serde = new StringSerde(); + + private BatchAccumulator buildAccumulator( + int leaderEpoch, + long baseOffset, + int lingerMs, + int maxBatchSize + ) { + return new BatchAccumulator<>( + leaderEpoch, + baseOffset, + lingerMs, + maxBatchSize, + memoryPool, + time, + CompressionType.NONE, + serde + ); + } + + @Test + public void testLingerIgnoredIfAccumulatorEmpty() { + int leaderEpoch = 17; + long baseOffset = 157; + int lingerMs = 50; + int maxBatchSize = 512; + + BatchAccumulator acc = buildAccumulator( + leaderEpoch, + baseOffset, + lingerMs, + maxBatchSize + ); + + assertFalse(acc.needsDrain(time.milliseconds())); + assertEquals(Long.MAX_VALUE - time.milliseconds(), acc.timeUntilDrain(time.milliseconds())); + } + + @Test + public void testLingerBeginsOnFirstWrite() { + int leaderEpoch = 17; + long baseOffset = 157; + int lingerMs = 50; + int maxBatchSize = 512; + + Mockito.when(memoryPool.tryAllocate(maxBatchSize)) + .thenReturn(ByteBuffer.allocate(maxBatchSize)); + + BatchAccumulator acc = buildAccumulator( + leaderEpoch, + baseOffset, + lingerMs, + maxBatchSize + ); + + time.sleep(15); + assertEquals(baseOffset, acc.append(leaderEpoch, singletonList("a"))); + assertEquals(lingerMs, acc.timeUntilDrain(time.milliseconds())); + + time.sleep(lingerMs / 2); + assertEquals(lingerMs / 2, acc.timeUntilDrain(time.milliseconds())); + + time.sleep(lingerMs / 2); + assertEquals(0, acc.timeUntilDrain(time.milliseconds())); + assertTrue(acc.needsDrain(time.milliseconds())); + } + + @Test + public void testCompletedBatchReleaseBuffer() { + int leaderEpoch = 17; + long baseOffset = 157; + int lingerMs = 50; + int maxBatchSize = 512; + + ByteBuffer buffer = ByteBuffer.allocate(maxBatchSize); + Mockito.when(memoryPool.tryAllocate(maxBatchSize)) + .thenReturn(buffer); + + BatchAccumulator acc = buildAccumulator( + leaderEpoch, + baseOffset, + lingerMs, + maxBatchSize + ); + + assertEquals(baseOffset, acc.append(leaderEpoch, singletonList("a"))); + time.sleep(lingerMs); + + List> batches = acc.drain(); + assertEquals(1, batches.size()); + + BatchAccumulator.CompletedBatch batch = batches.get(0); + batch.release(); + Mockito.verify(memoryPool).release(buffer); + } + + @Test + public void testUnflushedBuffersReleasedByClose() { + int leaderEpoch = 17; + long baseOffset = 157; + int lingerMs = 50; + int maxBatchSize = 512; + + ByteBuffer buffer = ByteBuffer.allocate(maxBatchSize); + Mockito.when(memoryPool.tryAllocate(maxBatchSize)) + .thenReturn(buffer); + + BatchAccumulator acc = buildAccumulator( + leaderEpoch, + baseOffset, + lingerMs, + maxBatchSize + ); + + assertEquals(baseOffset, acc.append(leaderEpoch, singletonList("a"))); + acc.close(); + Mockito.verify(memoryPool).release(buffer); + } + + @Test + public void testSingleBatchAccumulation() { + int leaderEpoch = 17; + long baseOffset = 157; + int lingerMs = 50; + int maxBatchSize = 512; + + Mockito.when(memoryPool.tryAllocate(maxBatchSize)) + .thenReturn(ByteBuffer.allocate(maxBatchSize)); + + BatchAccumulator acc = buildAccumulator( + leaderEpoch, + baseOffset, + lingerMs, + maxBatchSize + ); + + List records = asList("a", "b", "c", "d", "e", "f", "g", "h", "i"); + assertEquals(baseOffset, acc.append(leaderEpoch, records.subList(0, 1))); + assertEquals(baseOffset + 2, acc.append(leaderEpoch, records.subList(1, 3))); + assertEquals(baseOffset + 5, acc.append(leaderEpoch, records.subList(3, 6))); + assertEquals(baseOffset + 7, acc.append(leaderEpoch, records.subList(6, 8))); + assertEquals(baseOffset + 8, acc.append(leaderEpoch, records.subList(8, 9))); + + time.sleep(lingerMs); + assertTrue(acc.needsDrain(time.milliseconds())); + + List> batches = acc.drain(); + assertEquals(1, batches.size()); + assertFalse(acc.needsDrain(time.milliseconds())); + assertEquals(Long.MAX_VALUE - time.milliseconds(), acc.timeUntilDrain(time.milliseconds())); + + BatchAccumulator.CompletedBatch batch = batches.get(0); + assertEquals(records, batch.records); + assertEquals(baseOffset, batch.baseOffset); + } + + @Test + public void testMultipleBatchAccumulation() { + int leaderEpoch = 17; + long baseOffset = 157; + int lingerMs = 50; + int maxBatchSize = 256; + + Mockito.when(memoryPool.tryAllocate(maxBatchSize)) + .thenReturn(ByteBuffer.allocate(maxBatchSize)); + + BatchAccumulator acc = buildAccumulator( + leaderEpoch, + baseOffset, + lingerMs, + maxBatchSize + ); + + while (acc.count() < 3) { + acc.append(leaderEpoch, singletonList("foo")); + } + + List> batches = acc.drain(); + assertEquals(3, batches.size()); + assertTrue(batches.stream().allMatch(batch -> batch.data.sizeInBytes() <= maxBatchSize)); + } + + @Test + public void testCloseWhenEmpty() { + int leaderEpoch = 17; + long baseOffset = 157; + int lingerMs = 50; + int maxBatchSize = 256; + + BatchAccumulator acc = buildAccumulator( + leaderEpoch, + baseOffset, + lingerMs, + maxBatchSize + ); + + acc.close(); + Mockito.verifyNoInteractions(memoryPool); + } + + @Test + public void testDrainDoesNotBlockWithConcurrentAppend() throws Exception { + int leaderEpoch = 17; + long baseOffset = 157; + int lingerMs = 50; + int maxBatchSize = 256; + + StringSerde serde = Mockito.spy(new StringSerde()); + BatchAccumulator acc = new BatchAccumulator<>( + leaderEpoch, + baseOffset, + lingerMs, + maxBatchSize, + memoryPool, + time, + CompressionType.NONE, + serde + ); + + CountDownLatch acquireLockLatch = new CountDownLatch(1); + CountDownLatch releaseLockLatch = new CountDownLatch(1); + + // Do the first append outside the thread to start the linger timer + Mockito.when(memoryPool.tryAllocate(maxBatchSize)) + .thenReturn(ByteBuffer.allocate(maxBatchSize)); + acc.append(leaderEpoch, singletonList("a")); + + // Let the serde block to simulate a slow append + Mockito.doAnswer(invocation -> { + Writable writable = invocation.getArgument(2); + acquireLockLatch.countDown(); + releaseLockLatch.await(); + writable.writeByteArray(Utils.utf8("b")); + return null; + }).when(serde) + .write(Mockito.eq("b"), Mockito.eq(null), Mockito.any(Writable.class)); + + Thread appendThread = new Thread(() -> acc.append(leaderEpoch, singletonList("b"))); + appendThread.start(); + + // Attempt to drain while the append thread is holding the lock + acquireLockLatch.await(); + time.sleep(lingerMs); + assertTrue(acc.needsDrain(time.milliseconds())); + assertEquals(Collections.emptyList(), acc.drain()); + assertTrue(acc.needsDrain(time.milliseconds())); + + // Now let the append thread complete and verify that we can finish the drain + releaseLockLatch.countDown(); + appendThread.join(); + List> drained = acc.drain(); + assertEquals(1, drained.size()); + assertEquals(Long.MAX_VALUE - time.milliseconds(), acc.timeUntilDrain(time.milliseconds())); + } + +} diff --git a/raft/src/test/java/org/apache/kafka/raft/internals/BatchBuilderTest.java b/raft/src/test/java/org/apache/kafka/raft/internals/BatchBuilderTest.java new file mode 100644 index 00000000000..f860df7afd1 --- /dev/null +++ b/raft/src/test/java/org/apache/kafka/raft/internals/BatchBuilderTest.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft.internals; + +import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.MutableRecordBatch; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Utils; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; +import org.junit.jupiter.params.provider.ValueSource; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class BatchBuilderTest { + private StringSerde serde = new StringSerde(); + private MockTime time = new MockTime(); + + @ParameterizedTest + @EnumSource(CompressionType.class) + void testBuildBatch(CompressionType compressionType) { + ByteBuffer buffer = ByteBuffer.allocate(1024); + long baseOffset = 57; + long logAppendTime = time.milliseconds(); + boolean isControlBatch = false; + int leaderEpoch = 15; + + BatchBuilder builder = new BatchBuilder<>( + buffer, + serde, + compressionType, + baseOffset, + logAppendTime, + isControlBatch, + leaderEpoch, + buffer.limit() + ); + + List records = Arrays.asList( + "a", + "ap", + "app", + "appl", + "apple" + ); + + records.forEach(record -> builder.appendRecord(record, null)); + MemoryRecords builtRecordSet = builder.build(); + assertFalse(builder.hasRoomFor(1)); + assertThrows(IllegalArgumentException.class, () -> builder.appendRecord("a", null)); + + List builtBatches = Utils.toList(builtRecordSet.batchIterator()); + assertEquals(1, builtBatches.size()); + assertEquals(records, builder.records()); + + MutableRecordBatch batch = builtBatches.get(0); + assertEquals(5, batch.countOrNull()); + assertEquals(compressionType, batch.compressionType()); + assertEquals(baseOffset, batch.baseOffset()); + assertEquals(logAppendTime, batch.maxTimestamp()); + assertEquals(isControlBatch, batch.isControlBatch()); + assertEquals(leaderEpoch, batch.partitionLeaderEpoch()); + + List builtRecords = Utils.toList(batch).stream() + .map(record -> Utils.utf8(record.value())) + .collect(Collectors.toList()); + assertEquals(records, builtRecords); + } + + + @ParameterizedTest + @ValueSource(ints = {128, 157, 256, 433, 512, 777, 1024}) + public void testHasRoomForUncompressed(int batchSize) { + ByteBuffer buffer = ByteBuffer.allocate(batchSize); + long baseOffset = 57; + long logAppendTime = time.milliseconds(); + boolean isControlBatch = false; + int leaderEpoch = 15; + + BatchBuilder builder = new BatchBuilder<>( + buffer, + serde, + CompressionType.NONE, + baseOffset, + logAppendTime, + isControlBatch, + leaderEpoch, + buffer.limit() + ); + + String record = "i am a record"; + int recordSize = serde.recordSize(record); + + while (builder.hasRoomFor(recordSize)) { + builder.appendRecord(record, null); + } + + // Approximate size should be exact when compression is not used + int sizeInBytes = builder.approximateSizeInBytes(); + MemoryRecords records = builder.build(); + assertEquals(sizeInBytes, records.sizeInBytes()); + assertTrue(sizeInBytes <= batchSize, "Built batch size " + + sizeInBytes + " is larger than max batch size " + batchSize); + } + +} diff --git a/raft/src/test/java/org/apache/kafka/raft/internals/BatchMemoryPoolTest.java b/raft/src/test/java/org/apache/kafka/raft/internals/BatchMemoryPoolTest.java new file mode 100644 index 00000000000..4177de145fe --- /dev/null +++ b/raft/src/test/java/org/apache/kafka/raft/internals/BatchMemoryPoolTest.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft.internals; + +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class BatchMemoryPoolTest { + + @Test + public void testAllocateAndRelease() { + int batchSize = 1024; + int maxBatches = 1; + + BatchMemoryPool pool = new BatchMemoryPool(maxBatches, batchSize); + assertEquals(batchSize, pool.availableMemory()); + assertFalse(pool.isOutOfMemory()); + + ByteBuffer allocated = pool.tryAllocate(batchSize); + assertNotNull(allocated); + assertEquals(0, allocated.position()); + assertEquals(batchSize, allocated.limit()); + assertEquals(0, pool.availableMemory()); + assertTrue(pool.isOutOfMemory()); + assertNull(pool.tryAllocate(batchSize)); + + allocated.position(512); + allocated.limit(724); + + pool.release(allocated); + ByteBuffer reallocated = pool.tryAllocate(batchSize); + assertSame(allocated, reallocated); + assertEquals(0, allocated.position()); + assertEquals(batchSize, allocated.limit()); + } + + @Test + public void testMultipleAllocations() { + int batchSize = 1024; + int maxBatches = 3; + + BatchMemoryPool pool = new BatchMemoryPool(maxBatches, batchSize); + assertEquals(batchSize * maxBatches, pool.availableMemory()); + + ByteBuffer batch1 = pool.tryAllocate(batchSize); + assertNotNull(batch1); + + ByteBuffer batch2 = pool.tryAllocate(batchSize); + assertNotNull(batch2); + + ByteBuffer batch3 = pool.tryAllocate(batchSize); + assertNotNull(batch3); + + assertNull(pool.tryAllocate(batchSize)); + + pool.release(batch2); + assertSame(batch2, pool.tryAllocate(batchSize)); + + pool.release(batch1); + pool.release(batch3); + ByteBuffer buffer = pool.tryAllocate(batchSize); + assertTrue(buffer == batch1 || buffer == batch3); + } + + @Test + public void testOversizeAllocation() { + int batchSize = 1024; + int maxBatches = 3; + + BatchMemoryPool pool = new BatchMemoryPool(maxBatches, batchSize); + assertThrows(IllegalArgumentException.class, () -> pool.tryAllocate(batchSize + 1)); + } + + @Test + public void testReleaseBufferNotMatchingBatchSize() { + int batchSize = 1024; + int maxBatches = 3; + + BatchMemoryPool pool = new BatchMemoryPool(maxBatches, batchSize); + ByteBuffer buffer = ByteBuffer.allocate(1023); + assertThrows(IllegalArgumentException.class, () -> pool.release(buffer)); + } + +} diff --git a/tools/src/main/java/org/apache/kafka/tools/ProducerPerformance.java b/tools/src/main/java/org/apache/kafka/tools/ProducerPerformance.java index b76ce94ddb3..12a0fccea8e 100644 --- a/tools/src/main/java/org/apache/kafka/tools/ProducerPerformance.java +++ b/tools/src/main/java/org/apache/kafka/tools/ProducerPerformance.java @@ -356,9 +356,9 @@ public class ProducerPerformance { } public void printWindow() { - long ellapsed = System.currentTimeMillis() - windowStart; - double recsPerSec = 1000.0 * windowCount / (double) ellapsed; - double mbPerSec = 1000.0 * this.windowBytes / (double) ellapsed / (1024.0 * 1024.0); + long elapsed = System.currentTimeMillis() - windowStart; + double recsPerSec = 1000.0 * windowCount / (double) elapsed; + double mbPerSec = 1000.0 * this.windowBytes / (double) elapsed / (1024.0 * 1024.0); System.out.printf("%d records sent, %.1f records/sec (%.2f MB/sec), %.1f ms avg latency, %.1f ms max latency.%n", windowCount, recsPerSec,