KAFKA-16525; Dynamic KRaft network manager and channel (#15986)

Allow KRaft replicas to send requests to any node (Node) not just the nodes configured in the
controller.quorum.voters property. This flexibility is needed so KRaft can implement the
controller.quorum.voters configuration, send request to the dynamically changing set of voters and
send request to the leader endpoint (Node) discovered through the KRaft RPCs (specially
BeginQuorumEpoch request and Fetch response).

This was achieved by changing the RequestManager API to accept Node instead of just the replica ID.
Internally, the request manager tracks connection state using the Node.idString method to match the
connection management used by NetworkClient.

The API for RequestManager is also changed so that the ConnectState class is not exposed in the
API. This allows the request manager to reclaim heap memory for any connection that is ready.

The NetworkChannel was updated to receive the endpoint information (Node) through the outbound raft
request (RaftRequent.Outbound). This makes the network channel more flexible as it doesn't need to
be configured with the list of all possible endpoints. RaftRequest.Outbound and
RaftResponse.Inbound were updated to include the remote node instead of just the remote id.

The follower state tracked by KRaft replicas was updated to include both the leader id and the
leader's endpoint (Node). In this comment the node value is computed from the set of voters. In
future commit this will be updated so that it is sent through KRaft RPCs. For example
BeginQuorumEpoch request and Fetch response.

Support for configuring controller.quorum.bootstrap.servers was added. This includes changes to
KafkaConfig, QuorumConfig, etc. All of the tests using QuorumTestHarness were changed to use the
controller.quorum.bootstrap.servers instead of the controller.quorum.voters for the broker
configuration. Finally, the node id for the bootstrap server will be decreasing negative numbers
starting with -2.

Reviewers: Jason Gustafson <jason@confluent.io>, Luke Chen <showuon@gmail.com>, Colin P. McCabe <cmccabe@apache.org>
This commit is contained in:
José Armando García Sancio 2024-06-03 17:24:48 -04:00 committed by GitHub
parent 8a882a77a4
commit 459da4795a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
45 changed files with 2163 additions and 990 deletions

View File

@ -441,6 +441,7 @@
<allow pkg="org.apache.kafka.common.message" /> <allow pkg="org.apache.kafka.common.message" />
<allow pkg="org.apache.kafka.common.metadata" /> <allow pkg="org.apache.kafka.common.metadata" />
<allow pkg="org.apache.kafka.common.metrics" /> <allow pkg="org.apache.kafka.common.metrics" />
<allow pkg="org.apache.kafka.common.network" />
<allow pkg="org.apache.kafka.common.protocol" /> <allow pkg="org.apache.kafka.common.protocol" />
<allow pkg="org.apache.kafka.common.record" /> <allow pkg="org.apache.kafka.common.record" />
<allow pkg="org.apache.kafka.common.requests" /> <allow pkg="org.apache.kafka.common.requests" />

View File

@ -23,6 +23,7 @@ import java.nio.file.Paths
import java.util.OptionalInt import java.util.OptionalInt
import java.util.concurrent.CompletableFuture import java.util.concurrent.CompletableFuture
import java.util.{Map => JMap} import java.util.{Map => JMap}
import java.util.{Collection => JCollection}
import kafka.log.LogManager import kafka.log.LogManager
import kafka.log.UnifiedLog import kafka.log.UnifiedLog
import kafka.server.KafkaConfig import kafka.server.KafkaConfig
@ -133,7 +134,7 @@ trait RaftManager[T] {
def replicatedLog: ReplicatedLog def replicatedLog: ReplicatedLog
def voterNode(id: Int, listener: String): Option[Node] def voterNode(id: Int, listener: ListenerName): Option[Node]
} }
class KafkaRaftManager[T]( class KafkaRaftManager[T](
@ -147,6 +148,7 @@ class KafkaRaftManager[T](
metrics: Metrics, metrics: Metrics,
threadNamePrefixOpt: Option[String], threadNamePrefixOpt: Option[String],
val controllerQuorumVotersFuture: CompletableFuture[JMap[Integer, InetSocketAddress]], val controllerQuorumVotersFuture: CompletableFuture[JMap[Integer, InetSocketAddress]],
bootstrapServers: JCollection[InetSocketAddress],
fatalFaultHandler: FaultHandler fatalFaultHandler: FaultHandler
) extends RaftManager[T] with Logging { ) extends RaftManager[T] with Logging {
@ -185,7 +187,6 @@ class KafkaRaftManager[T](
def startup(): Unit = { def startup(): Unit = {
client.initialize( client.initialize(
controllerQuorumVotersFuture.get(), controllerQuorumVotersFuture.get(),
config.controllerListenerNames.head,
new FileQuorumStateStore(new File(dataDir, FileQuorumStateStore.DEFAULT_FILE_NAME)), new FileQuorumStateStore(new File(dataDir, FileQuorumStateStore.DEFAULT_FILE_NAME)),
metrics metrics
) )
@ -228,14 +229,15 @@ class KafkaRaftManager[T](
expirationService, expirationService,
logContext, logContext,
clusterId, clusterId,
bootstrapServers,
raftConfig raftConfig
) )
client client
} }
private def buildNetworkChannel(): KafkaNetworkChannel = { private def buildNetworkChannel(): KafkaNetworkChannel = {
val netClient = buildNetworkClient() val (listenerName, netClient) = buildNetworkClient()
new KafkaNetworkChannel(time, netClient, config.quorumRequestTimeoutMs, threadNamePrefix) new KafkaNetworkChannel(time, listenerName, netClient, config.quorumRequestTimeoutMs, threadNamePrefix)
} }
private def createDataDir(): File = { private def createDataDir(): File = {
@ -254,7 +256,7 @@ class KafkaRaftManager[T](
) )
} }
private def buildNetworkClient(): NetworkClient = { private def buildNetworkClient(): (ListenerName, NetworkClient) = {
val controllerListenerName = new ListenerName(config.controllerListenerNames.head) val controllerListenerName = new ListenerName(config.controllerListenerNames.head)
val controllerSecurityProtocol = config.effectiveListenerSecurityProtocolMap.getOrElse( val controllerSecurityProtocol = config.effectiveListenerSecurityProtocolMap.getOrElse(
controllerListenerName, controllerListenerName,
@ -292,7 +294,7 @@ class KafkaRaftManager[T](
val reconnectBackoffMsMs = 500 val reconnectBackoffMsMs = 500
val discoverBrokerVersions = true val discoverBrokerVersions = true
new NetworkClient( val networkClient = new NetworkClient(
selector, selector,
new ManualMetadataUpdater(), new ManualMetadataUpdater(),
clientId, clientId,
@ -309,13 +311,15 @@ class KafkaRaftManager[T](
apiVersions, apiVersions,
logContext logContext
) )
(controllerListenerName, networkClient)
} }
override def leaderAndEpoch: LeaderAndEpoch = { override def leaderAndEpoch: LeaderAndEpoch = {
client.leaderAndEpoch client.leaderAndEpoch
} }
override def voterNode(id: Int, listener: String): Option[Node] = { override def voterNode(id: Int, listener: ListenerName): Option[Node] = {
client.voterNode(id, listener).toScala client.voterNode(id, listener).toScala
} }
} }

View File

@ -439,6 +439,7 @@ object KafkaConfig {
/** ********* Raft Quorum Configuration *********/ /** ********* Raft Quorum Configuration *********/
.define(QuorumConfig.QUORUM_VOTERS_CONFIG, LIST, QuorumConfig.DEFAULT_QUORUM_VOTERS, new QuorumConfig.ControllerQuorumVotersValidator(), HIGH, QuorumConfig.QUORUM_VOTERS_DOC) .define(QuorumConfig.QUORUM_VOTERS_CONFIG, LIST, QuorumConfig.DEFAULT_QUORUM_VOTERS, new QuorumConfig.ControllerQuorumVotersValidator(), HIGH, QuorumConfig.QUORUM_VOTERS_DOC)
.define(QuorumConfig.QUORUM_BOOTSTRAP_SERVERS_CONFIG, LIST, QuorumConfig.DEFAULT_QUORUM_BOOTSTRAP_SERVERS, new QuorumConfig.ControllerQuorumBootstrapServersValidator(), HIGH, QuorumConfig.QUORUM_BOOTSTRAP_SERVERS_DOC)
.define(QuorumConfig.QUORUM_ELECTION_TIMEOUT_MS_CONFIG, INT, QuorumConfig.DEFAULT_QUORUM_ELECTION_TIMEOUT_MS, null, HIGH, QuorumConfig.QUORUM_ELECTION_TIMEOUT_MS_DOC) .define(QuorumConfig.QUORUM_ELECTION_TIMEOUT_MS_CONFIG, INT, QuorumConfig.DEFAULT_QUORUM_ELECTION_TIMEOUT_MS, null, HIGH, QuorumConfig.QUORUM_ELECTION_TIMEOUT_MS_DOC)
.define(QuorumConfig.QUORUM_FETCH_TIMEOUT_MS_CONFIG, INT, QuorumConfig.DEFAULT_QUORUM_FETCH_TIMEOUT_MS, null, HIGH, QuorumConfig.QUORUM_FETCH_TIMEOUT_MS_DOC) .define(QuorumConfig.QUORUM_FETCH_TIMEOUT_MS_CONFIG, INT, QuorumConfig.DEFAULT_QUORUM_FETCH_TIMEOUT_MS, null, HIGH, QuorumConfig.QUORUM_FETCH_TIMEOUT_MS_DOC)
.define(QuorumConfig.QUORUM_ELECTION_BACKOFF_MAX_MS_CONFIG, INT, QuorumConfig.DEFAULT_QUORUM_ELECTION_BACKOFF_MAX_MS, null, HIGH, QuorumConfig.QUORUM_ELECTION_BACKOFF_MAX_MS_DOC) .define(QuorumConfig.QUORUM_ELECTION_BACKOFF_MAX_MS_CONFIG, INT, QuorumConfig.DEFAULT_QUORUM_ELECTION_BACKOFF_MAX_MS, null, HIGH, QuorumConfig.QUORUM_ELECTION_BACKOFF_MAX_MS_DOC)
@ -1055,6 +1056,7 @@ class KafkaConfig private(doLog: Boolean, val props: java.util.Map[_, _], dynami
/** ********* Raft Quorum Configuration *********/ /** ********* Raft Quorum Configuration *********/
val quorumVoters = getList(QuorumConfig.QUORUM_VOTERS_CONFIG) val quorumVoters = getList(QuorumConfig.QUORUM_VOTERS_CONFIG)
val quorumBootstrapServers = getList(QuorumConfig.QUORUM_BOOTSTRAP_SERVERS_CONFIG)
val quorumElectionTimeoutMs = getInt(QuorumConfig.QUORUM_ELECTION_TIMEOUT_MS_CONFIG) val quorumElectionTimeoutMs = getInt(QuorumConfig.QUORUM_ELECTION_TIMEOUT_MS_CONFIG)
val quorumFetchTimeoutMs = getInt(QuorumConfig.QUORUM_FETCH_TIMEOUT_MS_CONFIG) val quorumFetchTimeoutMs = getInt(QuorumConfig.QUORUM_FETCH_TIMEOUT_MS_CONFIG)
val quorumElectionBackoffMs = getInt(QuorumConfig.QUORUM_ELECTION_BACKOFF_MAX_MS_CONFIG) val quorumElectionBackoffMs = getInt(QuorumConfig.QUORUM_ELECTION_BACKOFF_MAX_MS_CONFIG)

View File

@ -71,6 +71,7 @@ class KafkaRaftServer(
time, time,
metrics, metrics,
CompletableFuture.completedFuture(QuorumConfig.parseVoterConnections(config.quorumVoters)), CompletableFuture.completedFuture(QuorumConfig.parseVoterConnections(config.quorumVoters)),
QuorumConfig.parseBootstrapServers(config.quorumBootstrapServers),
new StandardFaultHandlerFactory(), new StandardFaultHandlerFactory(),
) )

View File

@ -70,9 +70,9 @@ import java.net.{InetAddress, SocketTimeoutException}
import java.nio.file.{Files, Paths} import java.nio.file.{Files, Paths}
import java.time.Duration import java.time.Duration
import java.util import java.util
import java.util.{Optional, OptionalInt, OptionalLong}
import java.util.concurrent._ import java.util.concurrent._
import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger} import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger}
import java.util.{Optional, OptionalInt, OptionalLong}
import scala.collection.{Map, Seq} import scala.collection.{Map, Seq}
import scala.compat.java8.OptionConverters.RichOptionForJava8 import scala.compat.java8.OptionConverters.RichOptionForJava8
import scala.jdk.CollectionConverters._ import scala.jdk.CollectionConverters._
@ -439,6 +439,7 @@ class KafkaServer(
metrics, metrics,
threadNamePrefix, threadNamePrefix,
CompletableFuture.completedFuture(quorumVoters), CompletableFuture.completedFuture(quorumVoters),
QuorumConfig.parseBootstrapServers(config.quorumBootstrapServers),
fatalFaultHandler = new LoggingFaultHandler("raftManager", () => shutdown()) fatalFaultHandler = new LoggingFaultHandler("raftManager", () => shutdown())
) )
quorumControllerNodeProvider = RaftControllerNodeProvider(raftManager, config) quorumControllerNodeProvider = RaftControllerNodeProvider(raftManager, config)

View File

@ -112,7 +112,7 @@ class RaftControllerNodeProvider(
val saslMechanism: String val saslMechanism: String
) extends ControllerNodeProvider with Logging { ) extends ControllerNodeProvider with Logging {
private def idToNode(id: Int): Option[Node] = raftManager.voterNode(id, listenerName.value()) private def idToNode(id: Int): Option[Node] = raftManager.voterNode(id, listenerName)
override def getControllerInfo(): ControllerInformation = override def getControllerInfo(): ControllerInformation =
ControllerInformation(raftManager.leaderAndEpoch.leaderId.asScala.flatMap(idToNode), ControllerInformation(raftManager.leaderAndEpoch.leaderId.asScala.flatMap(idToNode),

View File

@ -41,6 +41,7 @@ import java.util.Arrays
import java.util.Optional import java.util.Optional
import java.util.concurrent.atomic.AtomicReference import java.util.concurrent.atomic.AtomicReference
import java.util.concurrent.{CompletableFuture, TimeUnit} import java.util.concurrent.{CompletableFuture, TimeUnit}
import java.util.{Collection => JCollection}
import java.util.{Map => JMap} import java.util.{Map => JMap}
@ -94,6 +95,7 @@ class SharedServer(
val time: Time, val time: Time,
private val _metrics: Metrics, private val _metrics: Metrics,
val controllerQuorumVotersFuture: CompletableFuture[JMap[Integer, InetSocketAddress]], val controllerQuorumVotersFuture: CompletableFuture[JMap[Integer, InetSocketAddress]],
val bootstrapServers: JCollection[InetSocketAddress],
val faultHandlerFactory: FaultHandlerFactory val faultHandlerFactory: FaultHandlerFactory
) extends Logging { ) extends Logging {
private val logContext: LogContext = new LogContext(s"[SharedServer id=${sharedServerConfig.nodeId}] ") private val logContext: LogContext = new LogContext(s"[SharedServer id=${sharedServerConfig.nodeId}] ")
@ -265,6 +267,7 @@ class SharedServer(
metrics, metrics,
Some(s"kafka-${sharedServerConfig.nodeId}-raft"), // No dash expected at the end Some(s"kafka-${sharedServerConfig.nodeId}-raft"), // No dash expected at the end
controllerQuorumVotersFuture, controllerQuorumVotersFuture,
bootstrapServers,
raftManagerFaultHandler raftManagerFaultHandler
) )
raftManager = _raftManager raftManager = _raftManager

View File

@ -502,7 +502,7 @@ object StorageTool extends Logging {
metaPropertiesEnsemble.verify(metaProperties.clusterId(), metaProperties.nodeId(), metaPropertiesEnsemble.verify(metaProperties.clusterId(), metaProperties.nodeId(),
util.EnumSet.noneOf(classOf[VerificationFlag])) util.EnumSet.noneOf(classOf[VerificationFlag]))
System.out.println(s"metaPropertiesEnsemble=$metaPropertiesEnsemble") stream.println(s"metaPropertiesEnsemble=$metaPropertiesEnsemble")
val copier = new MetaPropertiesEnsemble.Copier(metaPropertiesEnsemble) val copier = new MetaPropertiesEnsemble.Copier(metaPropertiesEnsemble)
if (!(ignoreFormatted || copier.logDirProps().isEmpty)) { if (!(ignoreFormatted || copier.logDirProps().isEmpty)) {
val firstLogDir = copier.logDirProps().keySet().iterator().next() val firstLogDir = copier.logDirProps().keySet().iterator().next()

View File

@ -95,6 +95,7 @@ class TestRaftServer(
metrics, metrics,
Some(threadNamePrefix), Some(threadNamePrefix),
CompletableFuture.completedFuture(QuorumConfig.parseVoterConnections(config.quorumVoters)), CompletableFuture.completedFuture(QuorumConfig.parseVoterConnections(config.quorumVoters)),
QuorumConfig.parseBootstrapServers(config.quorumBootstrapServers),
new ProcessTerminatingFaultHandler.Builder().build() new ProcessTerminatingFaultHandler.Builder().build()
) )

View File

@ -239,12 +239,15 @@ public class KafkaClusterTestKit implements AutoCloseable {
ThreadUtils.createThreadFactory("kafka-cluster-test-kit-executor-%d", false)); ThreadUtils.createThreadFactory("kafka-cluster-test-kit-executor-%d", false));
for (ControllerNode node : nodes.controllerNodes().values()) { for (ControllerNode node : nodes.controllerNodes().values()) {
setupNodeDirectories(baseDirectory, node.metadataDirectory(), Collections.emptyList()); setupNodeDirectories(baseDirectory, node.metadataDirectory(), Collections.emptyList());
SharedServer sharedServer = new SharedServer(createNodeConfig(node), SharedServer sharedServer = new SharedServer(
node.initialMetaPropertiesEnsemble(), createNodeConfig(node),
Time.SYSTEM, node.initialMetaPropertiesEnsemble(),
new Metrics(), Time.SYSTEM,
connectFutureManager.future, new Metrics(),
faultHandlerFactory); connectFutureManager.future,
Collections.emptyList(),
faultHandlerFactory
);
ControllerServer controller = null; ControllerServer controller = null;
try { try {
controller = new ControllerServer( controller = new ControllerServer(
@ -267,13 +270,18 @@ public class KafkaClusterTestKit implements AutoCloseable {
jointServers.put(node.id(), sharedServer); jointServers.put(node.id(), sharedServer);
} }
for (BrokerNode node : nodes.brokerNodes().values()) { for (BrokerNode node : nodes.brokerNodes().values()) {
SharedServer sharedServer = jointServers.computeIfAbsent(node.id(), SharedServer sharedServer = jointServers.computeIfAbsent(
id -> new SharedServer(createNodeConfig(node), node.id(),
id -> new SharedServer(
createNodeConfig(node),
node.initialMetaPropertiesEnsemble(), node.initialMetaPropertiesEnsemble(),
Time.SYSTEM, Time.SYSTEM,
new Metrics(), new Metrics(),
connectFutureManager.future, connectFutureManager.future,
faultHandlerFactory)); Collections.emptyList(),
faultHandlerFactory
)
);
BrokerServer broker = null; BrokerServer broker = null;
try { try {
broker = new BrokerServer(sharedServer); broker = new BrokerServer(sharedServer);

View File

@ -21,6 +21,5 @@ log4j.appender.stdout.layout.ConversionPattern=[%d] %p %m (%c:%L)%n
log4j.logger.kafka=WARN log4j.logger.kafka=WARN
log4j.logger.org.apache.kafka=WARN log4j.logger.org.apache.kafka=WARN
# zkclient can be verbose, during debugging it is common to adjust it separately # zkclient can be verbose, during debugging it is common to adjust it separately
log4j.logger.org.apache.zookeeper=WARN log4j.logger.org.apache.zookeeper=WARN

View File

@ -21,12 +21,12 @@ import kafka.utils.{TestInfoUtils, TestUtils}
import org.apache.kafka.clients.admin.{NewPartitions, NewTopic} import org.apache.kafka.clients.admin.{NewPartitions, NewTopic}
import org.apache.kafka.clients.consumer._ import org.apache.kafka.clients.consumer._
import org.apache.kafka.clients.producer.{ProducerConfig, ProducerRecord} import org.apache.kafka.clients.producer.{ProducerConfig, ProducerRecord}
import org.apache.kafka.common.{KafkaException, MetricName, TopicPartition}
import org.apache.kafka.common.config.TopicConfig import org.apache.kafka.common.config.TopicConfig
import org.apache.kafka.common.errors.{InvalidGroupIdException, InvalidTopicException, TimeoutException, WakeupException} import org.apache.kafka.common.errors.{InvalidGroupIdException, InvalidTopicException, TimeoutException, WakeupException}
import org.apache.kafka.common.header.Headers import org.apache.kafka.common.header.Headers
import org.apache.kafka.common.record.{CompressionType, TimestampType} import org.apache.kafka.common.record.{CompressionType, TimestampType}
import org.apache.kafka.common.serialization._ import org.apache.kafka.common.serialization._
import org.apache.kafka.common.{KafkaException, MetricName, TopicPartition}
import org.apache.kafka.test.{MockConsumerInterceptor, MockProducerInterceptor} import org.apache.kafka.test.{MockConsumerInterceptor, MockProducerInterceptor}
import org.junit.jupiter.api.Assertions._ import org.junit.jupiter.api.Assertions._
import org.junit.jupiter.api.Timeout import org.junit.jupiter.api.Timeout

View File

@ -124,12 +124,15 @@ class KRaftQuorumImplementation(
metaPropertiesEnsemble.verify(Optional.of(clusterId), metaPropertiesEnsemble.verify(Optional.of(clusterId),
OptionalInt.of(config.nodeId), OptionalInt.of(config.nodeId),
util.EnumSet.of(REQUIRE_AT_LEAST_ONE_VALID, REQUIRE_METADATA_LOG_DIR)) util.EnumSet.of(REQUIRE_AT_LEAST_ONE_VALID, REQUIRE_METADATA_LOG_DIR))
val sharedServer = new SharedServer(config, val sharedServer = new SharedServer(
config,
metaPropertiesEnsemble, metaPropertiesEnsemble,
time, time,
new Metrics(), new Metrics(),
controllerQuorumVotersFuture, controllerQuorumVotersFuture,
faultHandlerFactory) controllerQuorumVotersFuture.get().values(),
faultHandlerFactory
)
var broker: BrokerServer = null var broker: BrokerServer = null
try { try {
broker = new BrokerServer(sharedServer) broker = new BrokerServer(sharedServer)
@ -371,12 +374,15 @@ abstract class QuorumTestHarness extends Logging {
metaPropertiesEnsemble.verify(Optional.of(metaProperties.clusterId().get()), metaPropertiesEnsemble.verify(Optional.of(metaProperties.clusterId().get()),
OptionalInt.of(nodeId), OptionalInt.of(nodeId),
util.EnumSet.of(REQUIRE_AT_LEAST_ONE_VALID, REQUIRE_METADATA_LOG_DIR)) util.EnumSet.of(REQUIRE_AT_LEAST_ONE_VALID, REQUIRE_METADATA_LOG_DIR))
val sharedServer = new SharedServer(config, val sharedServer = new SharedServer(
config,
metaPropertiesEnsemble, metaPropertiesEnsemble,
Time.SYSTEM, Time.SYSTEM,
new Metrics(), new Metrics(),
controllerQuorumVotersFuture, controllerQuorumVotersFuture,
faultHandlerFactory) Collections.emptyList(),
faultHandlerFactory
)
var controllerServer: ControllerServer = null var controllerServer: ControllerServer = null
try { try {
controllerServer = new ControllerServer( controllerServer = new ControllerServer(

View File

@ -86,7 +86,7 @@ class KafkaConfigTest {
@Test @Test
def testBrokerRoleNodeIdValidation(): Unit = { def testBrokerRoleNodeIdValidation(): Unit = {
// Ensure that validation is happening at startup to check that brokers do not use their node.id as a voter in controller.quorum.voters // Ensure that validation is happening at startup to check that brokers do not use their node.id as a voter in controller.quorum.voters
val propertiesFile = new Properties val propertiesFile = new Properties
propertiesFile.setProperty(KRaftConfigs.PROCESS_ROLES_CONFIG, "broker") propertiesFile.setProperty(KRaftConfigs.PROCESS_ROLES_CONFIG, "broker")
propertiesFile.setProperty(KRaftConfigs.NODE_ID_CONFIG, "1") propertiesFile.setProperty(KRaftConfigs.NODE_ID_CONFIG, "1")
@ -102,7 +102,7 @@ class KafkaConfigTest {
@Test @Test
def testControllerRoleNodeIdValidation(): Unit = { def testControllerRoleNodeIdValidation(): Unit = {
// Ensure that validation is happening at startup to check that controllers use their node.id as a voter in controller.quorum.voters // Ensure that validation is happening at startup to check that controllers use their node.id as a voter in controller.quorum.voters
val propertiesFile = new Properties val propertiesFile = new Properties
propertiesFile.setProperty(KRaftConfigs.PROCESS_ROLES_CONFIG, "controller") propertiesFile.setProperty(KRaftConfigs.PROCESS_ROLES_CONFIG, "controller")
propertiesFile.setProperty(KRaftConfigs.NODE_ID_CONFIG, "1") propertiesFile.setProperty(KRaftConfigs.NODE_ID_CONFIG, "1")

View File

@ -118,6 +118,7 @@ class RaftManagerTest {
new Metrics(Time.SYSTEM), new Metrics(Time.SYSTEM),
Option.empty, Option.empty,
CompletableFuture.completedFuture(QuorumConfig.parseVoterConnections(config.quorumVoters)), CompletableFuture.completedFuture(QuorumConfig.parseVoterConnections(config.quorumVoters)),
QuorumConfig.parseBootstrapServers(config.quorumBootstrapServers),
mock(classOf[FaultHandler]) mock(classOf[FaultHandler])
) )
} }

View File

@ -19,7 +19,7 @@ package kafka.server
import java.net.InetSocketAddress import java.net.InetSocketAddress
import java.util import java.util
import java.util.{Collections, Properties} import java.util.{Arrays, Collections, Properties}
import kafka.cluster.EndPoint import kafka.cluster.EndPoint
import kafka.security.authorizer.AclAuthorizer import kafka.security.authorizer.AclAuthorizer
import kafka.utils.TestUtils.assertBadConfigContainingMessage import kafka.utils.TestUtils.assertBadConfigContainingMessage
@ -1032,6 +1032,7 @@ class KafkaConfigTest {
// Raft Quorum Configs // Raft Quorum Configs
case QuorumConfig.QUORUM_VOTERS_CONFIG => // ignore string case QuorumConfig.QUORUM_VOTERS_CONFIG => // ignore string
case QuorumConfig.QUORUM_BOOTSTRAP_SERVERS_CONFIG => // ignore string
case QuorumConfig.QUORUM_ELECTION_TIMEOUT_MS_CONFIG => assertPropertyInvalid(baseProperties, name, "not_a_number") case QuorumConfig.QUORUM_ELECTION_TIMEOUT_MS_CONFIG => assertPropertyInvalid(baseProperties, name, "not_a_number")
case QuorumConfig.QUORUM_FETCH_TIMEOUT_MS_CONFIG => assertPropertyInvalid(baseProperties, name, "not_a_number") case QuorumConfig.QUORUM_FETCH_TIMEOUT_MS_CONFIG => assertPropertyInvalid(baseProperties, name, "not_a_number")
case QuorumConfig.QUORUM_ELECTION_BACKOFF_MAX_MS_CONFIG => assertPropertyInvalid(baseProperties, name, "not_a_number") case QuorumConfig.QUORUM_ELECTION_BACKOFF_MAX_MS_CONFIG => assertPropertyInvalid(baseProperties, name, "not_a_number")
@ -1402,6 +1403,23 @@ class KafkaConfigTest {
assertEquals(expectedVoters, addresses) assertEquals(expectedVoters, addresses)
} }
@Test
def testParseQuorumBootstrapServers(): Unit = {
val expected = Arrays.asList(
InetSocketAddress.createUnresolved("kafka1", 9092),
InetSocketAddress.createUnresolved("kafka2", 9092)
)
val props = TestUtils.createBrokerConfig(0, null)
props.setProperty(QuorumConfig.QUORUM_BOOTSTRAP_SERVERS_CONFIG, "kafka1:9092,kafka2:9092")
val addresses = QuorumConfig.parseBootstrapServers(
KafkaConfig.fromProps(props).quorumBootstrapServers
)
assertEquals(expected, addresses)
}
@Test @Test
def testAcceptsLargeNodeIdForRaftBasedCase(): Unit = { def testAcceptsLargeNodeIdForRaftBasedCase(): Unit = {
// Generation of Broker IDs is not supported when using Raft-based controller quorums, // Generation of Broker IDs is not supported when using Raft-based controller quorums,

View File

@ -22,8 +22,8 @@ import java.nio.ByteBuffer
import java.util import java.util
import java.util.Collections import java.util.Collections
import java.util.Optional import java.util.Optional
import java.util.Arrays
import java.util.Properties import java.util.Properties
import java.util.stream.IntStream
import kafka.log.{LogTestUtils, UnifiedLog} import kafka.log.{LogTestUtils, UnifiedLog}
import kafka.raft.{KafkaMetadataLog, MetadataLogConfig} import kafka.raft.{KafkaMetadataLog, MetadataLogConfig}
import kafka.server.{BrokerTopicStats, KafkaRaftServer} import kafka.server.{BrokerTopicStats, KafkaRaftServer}
@ -338,7 +338,7 @@ class DumpLogSegmentsTest {
.setLastContainedLogTimestamp(lastContainedLogTimestamp) .setLastContainedLogTimestamp(lastContainedLogTimestamp)
.setRawSnapshotWriter(metadataLog.createNewSnapshot(new OffsetAndEpoch(0, 0)).get) .setRawSnapshotWriter(metadataLog.createNewSnapshot(new OffsetAndEpoch(0, 0)).get)
.setKraftVersion(1) .setKraftVersion(1)
.setVoterSet(Optional.of(VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true)))) .setVoterSet(Optional.of(VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(1, 2, 3), true))))
.build(MetadataRecordSerde.INSTANCE) .build(MetadataRecordSerde.INSTANCE)
) { snapshotWriter => ) { snapshotWriter =>
snapshotWriter.append(metadataRecords.asJava) snapshotWriter.append(metadataRecords.asJava)

View File

@ -30,9 +30,9 @@ import org.apache.kafka.raft.internals.ReplicaKey;
* Encapsulate election state stored on disk after every state change. * Encapsulate election state stored on disk after every state change.
*/ */
final public class ElectionState { final public class ElectionState {
private static int unknownLeaderId = -1; private static final int UNKNOWN_LEADER_ID = -1;
private static int notVoted = -1; private static final int NOT_VOTED = -1;
private static Uuid noVotedDirectoryId = Uuid.ZERO_UUID; private static final Uuid NO_VOTED_DIRECTORY_ID = Uuid.ZERO_UUID;
private final int epoch; private final int epoch;
private final OptionalInt leaderId; private final OptionalInt leaderId;
@ -95,7 +95,7 @@ final public class ElectionState {
} }
public int leaderIdOrSentinel() { public int leaderIdOrSentinel() {
return leaderId.orElse(unknownLeaderId); return leaderId.orElse(UNKNOWN_LEADER_ID);
} }
public OptionalInt optionalLeaderId() { public OptionalInt optionalLeaderId() {
@ -126,7 +126,7 @@ final public class ElectionState {
QuorumStateData data = new QuorumStateData() QuorumStateData data = new QuorumStateData()
.setLeaderEpoch(epoch) .setLeaderEpoch(epoch)
.setLeaderId(leaderIdOrSentinel()) .setLeaderId(leaderIdOrSentinel())
.setVotedId(votedKey.map(ReplicaKey::id).orElse(notVoted)); .setVotedId(votedKey.map(ReplicaKey::id).orElse(NOT_VOTED));
if (version == 0) { if (version == 0) {
List<QuorumStateData.Voter> dataVoters = voters List<QuorumStateData.Voter> dataVoters = voters
@ -135,7 +135,7 @@ final public class ElectionState {
.collect(Collectors.toList()); .collect(Collectors.toList());
data.setCurrentVoters(dataVoters); data.setCurrentVoters(dataVoters);
} else if (version == 1) { } else if (version == 1) {
data.setVotedDirectoryId(votedKey.flatMap(ReplicaKey::directoryId).orElse(noVotedDirectoryId)); data.setVotedDirectoryId(votedKey.flatMap(ReplicaKey::directoryId).orElse(NO_VOTED_DIRECTORY_ID));
} else { } else {
throw new IllegalStateException( throw new IllegalStateException(
String.format( String.format(
@ -198,17 +198,17 @@ final public class ElectionState {
} }
public static ElectionState fromQuorumStateData(QuorumStateData data) { public static ElectionState fromQuorumStateData(QuorumStateData data) {
Optional<Uuid> votedDirectoryId = data.votedDirectoryId().equals(noVotedDirectoryId) ? Optional<Uuid> votedDirectoryId = data.votedDirectoryId().equals(NO_VOTED_DIRECTORY_ID) ?
Optional.empty() : Optional.empty() :
Optional.of(data.votedDirectoryId()); Optional.of(data.votedDirectoryId());
Optional<ReplicaKey> votedKey = data.votedId() == notVoted ? Optional<ReplicaKey> votedKey = data.votedId() == NOT_VOTED ?
Optional.empty() : Optional.empty() :
Optional.of(ReplicaKey.of(data.votedId(), votedDirectoryId)); Optional.of(ReplicaKey.of(data.votedId(), votedDirectoryId));
return new ElectionState( return new ElectionState(
data.leaderEpoch(), data.leaderEpoch(),
data.leaderId() == unknownLeaderId ? OptionalInt.empty() : OptionalInt.of(data.leaderId()), data.leaderId() == UNKNOWN_LEADER_ID ? OptionalInt.empty() : OptionalInt.of(data.leaderId()),
votedKey, votedKey,
data.currentVoters().stream().map(QuorumStateData.Voter::voterId).collect(Collectors.toSet()) data.currentVoters().stream().map(QuorumStateData.Voter::voterId).collect(Collectors.toSet())
); );

View File

@ -19,6 +19,7 @@ package org.apache.kafka.raft;
import java.util.Optional; import java.util.Optional;
import java.util.OptionalLong; import java.util.OptionalLong;
import java.util.Set; import java.util.Set;
import org.apache.kafka.common.Node;
import org.apache.kafka.common.utils.LogContext; import org.apache.kafka.common.utils.LogContext;
import org.apache.kafka.common.utils.Time; import org.apache.kafka.common.utils.Time;
import org.apache.kafka.common.utils.Timer; import org.apache.kafka.common.utils.Timer;
@ -29,7 +30,7 @@ import org.slf4j.Logger;
public class FollowerState implements EpochState { public class FollowerState implements EpochState {
private final int fetchTimeoutMs; private final int fetchTimeoutMs;
private final int epoch; private final int epoch;
private final int leaderId; private final Node leader;
private final Set<Integer> voters; private final Set<Integer> voters;
// Used for tracking the expiration of both the Fetch and FetchSnapshot requests // Used for tracking the expiration of both the Fetch and FetchSnapshot requests
private final Timer fetchTimer; private final Timer fetchTimer;
@ -37,14 +38,14 @@ public class FollowerState implements EpochState {
/* Used to track the currently fetching snapshot. When fetching snapshot regular /* Used to track the currently fetching snapshot. When fetching snapshot regular
* Fetch request are paused * Fetch request are paused
*/ */
private Optional<RawSnapshotWriter> fetchingSnapshot; private Optional<RawSnapshotWriter> fetchingSnapshot = Optional.empty();
private final Logger log; private final Logger log;
public FollowerState( public FollowerState(
Time time, Time time,
int epoch, int epoch,
int leaderId, Node leader,
Set<Integer> voters, Set<Integer> voters,
Optional<LogOffsetMetadata> highWatermark, Optional<LogOffsetMetadata> highWatermark,
int fetchTimeoutMs, int fetchTimeoutMs,
@ -52,17 +53,16 @@ public class FollowerState implements EpochState {
) { ) {
this.fetchTimeoutMs = fetchTimeoutMs; this.fetchTimeoutMs = fetchTimeoutMs;
this.epoch = epoch; this.epoch = epoch;
this.leaderId = leaderId; this.leader = leader;
this.voters = voters; this.voters = voters;
this.fetchTimer = time.timer(fetchTimeoutMs); this.fetchTimer = time.timer(fetchTimeoutMs);
this.highWatermark = highWatermark; this.highWatermark = highWatermark;
this.fetchingSnapshot = Optional.empty();
this.log = logContext.logger(FollowerState.class); this.log = logContext.logger(FollowerState.class);
} }
@Override @Override
public ElectionState election() { public ElectionState election() {
return ElectionState.withElectedLeader(epoch, leaderId, voters); return ElectionState.withElectedLeader(epoch, leader.id(), voters);
} }
@Override @Override
@ -80,8 +80,8 @@ public class FollowerState implements EpochState {
return fetchTimer.remainingMs(); return fetchTimer.remainingMs();
} }
public int leaderId() { public Node leader() {
return leaderId; return leader;
} }
public boolean hasFetchTimeoutExpired(long currentTimeMs) { public boolean hasFetchTimeoutExpired(long currentTimeMs) {
@ -156,7 +156,7 @@ public class FollowerState implements EpochState {
log.debug( log.debug(
"Rejecting vote request from candidate ({}) since we already have a leader {} in epoch {}", "Rejecting vote request from candidate ({}) since we already have a leader {} in epoch {}",
candidateKey, candidateKey,
leaderId(), leader,
epoch epoch
); );
return false; return false;
@ -164,14 +164,16 @@ public class FollowerState implements EpochState {
@Override @Override
public String toString() { public String toString() {
return "FollowerState(" + return String.format(
"fetchTimeoutMs=" + fetchTimeoutMs + "FollowerState(fetchTimeoutMs=%d, epoch=%d, leader=%s voters=%s, highWatermark=%s, " +
", epoch=" + epoch + "fetchingSnapshot=%s)",
", leaderId=" + leaderId + fetchTimeoutMs,
", voters=" + voters + epoch,
", highWatermark=" + highWatermark + leader,
", fetchingSnapshot=" + fetchingSnapshot + voters,
')'; highWatermark,
fetchingSnapshot
);
} }
@Override @Override

View File

@ -24,6 +24,7 @@ import org.apache.kafka.common.message.EndQuorumEpochRequestData;
import org.apache.kafka.common.message.FetchRequestData; import org.apache.kafka.common.message.FetchRequestData;
import org.apache.kafka.common.message.FetchSnapshotRequestData; import org.apache.kafka.common.message.FetchSnapshotRequestData;
import org.apache.kafka.common.message.VoteRequestData; import org.apache.kafka.common.message.VoteRequestData;
import org.apache.kafka.common.network.ListenerName;
import org.apache.kafka.common.protocol.ApiKeys; import org.apache.kafka.common.protocol.ApiKeys;
import org.apache.kafka.common.protocol.ApiMessage; import org.apache.kafka.common.protocol.ApiMessage;
import org.apache.kafka.common.protocol.Errors; import org.apache.kafka.common.protocol.Errors;
@ -39,12 +40,9 @@ import org.apache.kafka.server.util.RequestAndCompletionHandler;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import java.net.InetSocketAddress;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Queue; import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
@ -83,9 +81,17 @@ public class KafkaNetworkChannel implements NetworkChannel {
private final SendThread requestThread; private final SendThread requestThread;
private final AtomicInteger correlationIdCounter = new AtomicInteger(0); private final AtomicInteger correlationIdCounter = new AtomicInteger(0);
private final Map<Integer, Node> endpoints = new HashMap<>();
public KafkaNetworkChannel(Time time, KafkaClient client, int requestTimeoutMs, String threadNamePrefix) { private final ListenerName listenerName;
public KafkaNetworkChannel(
Time time,
ListenerName listenerName,
KafkaClient client,
int requestTimeoutMs,
String threadNamePrefix
) {
this.listenerName = listenerName;
this.requestThread = new SendThread( this.requestThread = new SendThread(
threadNamePrefix + "-outbound-request-thread", threadNamePrefix + "-outbound-request-thread",
client, client,
@ -102,23 +108,23 @@ public class KafkaNetworkChannel implements NetworkChannel {
@Override @Override
public void send(RaftRequest.Outbound request) { public void send(RaftRequest.Outbound request) {
Node node = endpoints.get(request.destinationId()); Node node = request.destination();
if (node != null) { if (node != null) {
requestThread.sendRequest(new RequestAndCompletionHandler( requestThread.sendRequest(new RequestAndCompletionHandler(
request.createdTimeMs, request.createdTimeMs(),
node, node,
buildRequest(request.data), buildRequest(request.data()),
response -> sendOnComplete(request, response) response -> sendOnComplete(request, response)
)); ));
} else } else
sendCompleteFuture(request, errorResponse(request.data, Errors.BROKER_NOT_AVAILABLE)); sendCompleteFuture(request, errorResponse(request.data(), Errors.BROKER_NOT_AVAILABLE));
} }
private void sendCompleteFuture(RaftRequest.Outbound request, ApiMessage message) { private void sendCompleteFuture(RaftRequest.Outbound request, ApiMessage message) {
RaftResponse.Inbound response = new RaftResponse.Inbound( RaftResponse.Inbound response = new RaftResponse.Inbound(
request.correlationId, request.correlationId(),
message, message,
request.destinationId() request.destination()
); );
request.completion.complete(response); request.completion.complete(response);
} }
@ -127,16 +133,16 @@ public class KafkaNetworkChannel implements NetworkChannel {
ApiMessage response; ApiMessage response;
if (clientResponse.versionMismatch() != null) { if (clientResponse.versionMismatch() != null) {
log.error("Request {} failed due to unsupported version error", request, clientResponse.versionMismatch()); log.error("Request {} failed due to unsupported version error", request, clientResponse.versionMismatch());
response = errorResponse(request.data, Errors.UNSUPPORTED_VERSION); response = errorResponse(request.data(), Errors.UNSUPPORTED_VERSION);
} else if (clientResponse.authenticationException() != null) { } else if (clientResponse.authenticationException() != null) {
// For now we treat authentication errors as retriable. We use the // For now we treat authentication errors as retriable. We use the
// `NETWORK_EXCEPTION` error code for lack of a good alternative. // `NETWORK_EXCEPTION` error code for lack of a good alternative.
// Note that `NodeToControllerChannelManager` will still log the // Note that `NodeToControllerChannelManager` will still log the
// authentication errors so that users have a chance to fix the problem. // authentication errors so that users have a chance to fix the problem.
log.error("Request {} failed due to authentication error", request, clientResponse.authenticationException()); log.error("Request {} failed due to authentication error", request, clientResponse.authenticationException());
response = errorResponse(request.data, Errors.NETWORK_EXCEPTION); response = errorResponse(request.data(), Errors.NETWORK_EXCEPTION);
} else if (clientResponse.wasDisconnected()) { } else if (clientResponse.wasDisconnected()) {
response = errorResponse(request.data, Errors.BROKER_NOT_AVAILABLE); response = errorResponse(request.data(), Errors.BROKER_NOT_AVAILABLE);
} else { } else {
response = clientResponse.responseBody().data(); response = clientResponse.responseBody().data();
} }
@ -149,9 +155,8 @@ public class KafkaNetworkChannel implements NetworkChannel {
} }
@Override @Override
public void updateEndpoint(int id, InetSocketAddress address) { public ListenerName listenerName() {
Node node = new Node(id, address.getHostString(), address.getPort()); return listenerName;
endpoints.put(id, node);
} }
public void start() { public void start() {

View File

@ -37,6 +37,7 @@ import org.apache.kafka.common.message.FetchSnapshotResponseData;
import org.apache.kafka.common.message.VoteRequestData; import org.apache.kafka.common.message.VoteRequestData;
import org.apache.kafka.common.message.VoteResponseData; import org.apache.kafka.common.message.VoteResponseData;
import org.apache.kafka.common.metrics.Metrics; import org.apache.kafka.common.metrics.Metrics;
import org.apache.kafka.common.network.ListenerName;
import org.apache.kafka.common.protocol.ApiKeys; import org.apache.kafka.common.protocol.ApiKeys;
import org.apache.kafka.common.protocol.ApiMessage; import org.apache.kafka.common.protocol.ApiMessage;
import org.apache.kafka.common.protocol.Errors; import org.apache.kafka.common.protocol.Errors;
@ -60,7 +61,6 @@ import org.apache.kafka.common.utils.BufferSupplier;
import org.apache.kafka.common.utils.LogContext; import org.apache.kafka.common.utils.LogContext;
import org.apache.kafka.common.utils.Time; import org.apache.kafka.common.utils.Time;
import org.apache.kafka.common.utils.Timer; import org.apache.kafka.common.utils.Timer;
import org.apache.kafka.raft.RequestManager.ConnectionState;
import org.apache.kafka.raft.errors.NotLeaderException; import org.apache.kafka.raft.errors.NotLeaderException;
import org.apache.kafka.raft.internals.BatchAccumulator; import org.apache.kafka.raft.internals.BatchAccumulator;
import org.apache.kafka.raft.internals.BatchMemoryPool; import org.apache.kafka.raft.internals.BatchMemoryPool;
@ -85,6 +85,7 @@ import org.apache.kafka.snapshot.SnapshotWriter;
import org.slf4j.Logger; import org.slf4j.Logger;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.IdentityHashMap; import java.util.IdentityHashMap;
import java.util.Iterator; import java.util.Iterator;
@ -100,8 +101,10 @@ import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException; import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier; import java.util.function.Supplier;
import java.util.stream.Collectors;
import static java.util.concurrent.CompletableFuture.completedFuture; import static java.util.concurrent.CompletableFuture.completedFuture;
import static org.apache.kafka.raft.RaftUtil.hasValidTopicPartition; import static org.apache.kafka.raft.RaftUtil.hasValidTopicPartition;
@ -209,6 +212,7 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
ExpirationService expirationService, ExpirationService expirationService,
LogContext logContext, LogContext logContext,
String clusterId, String clusterId,
Collection<InetSocketAddress> bootstrapServers,
QuorumConfig quorumConfig QuorumConfig quorumConfig
) { ) {
this( this(
@ -223,6 +227,7 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
expirationService, expirationService,
MAX_FETCH_WAIT_MS, MAX_FETCH_WAIT_MS,
clusterId, clusterId,
bootstrapServers,
logContext, logContext,
new Random(), new Random(),
quorumConfig quorumConfig
@ -241,6 +246,7 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
ExpirationService expirationService, ExpirationService expirationService,
int fetchMaxWaitMs, int fetchMaxWaitMs,
String clusterId, String clusterId,
Collection<InetSocketAddress> bootstrapServers,
LogContext logContext, LogContext logContext,
Random random, Random random,
QuorumConfig quorumConfig QuorumConfig quorumConfig
@ -262,6 +268,30 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
this.random = random; this.random = random;
this.quorumConfig = quorumConfig; this.quorumConfig = quorumConfig;
this.snapshotCleaner = new RaftMetadataLogCleanerManager(logger, time, 60000, log::maybeClean); this.snapshotCleaner = new RaftMetadataLogCleanerManager(logger, time, 60000, log::maybeClean);
if (!bootstrapServers.isEmpty()) {
// generate Node objects from network addresses by using decreasing negative ids
AtomicInteger id = new AtomicInteger(-2);
List<Node> bootstrapNodes = bootstrapServers
.stream()
.map(address ->
new Node(
id.getAndDecrement(),
address.getHostString(),
address.getPort()
)
)
.collect(Collectors.toList());
logger.info("Starting request manager with bootstrap servers: {}", bootstrapNodes);
requestManager = new RequestManager(
bootstrapNodes,
quorumConfig.retryBackoffMs(),
quorumConfig.requestTimeoutMs(),
random
);
}
} }
private void updateFollowerHighWatermark( private void updateFollowerHighWatermark(
@ -378,12 +408,11 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
public void initialize( public void initialize(
Map<Integer, InetSocketAddress> voterAddresses, Map<Integer, InetSocketAddress> voterAddresses,
String listenerName,
QuorumStateStore quorumStateStore, QuorumStateStore quorumStateStore,
Metrics metrics Metrics metrics
) { ) {
partitionState = new KRaftControlRecordStateMachine( partitionState = new KRaftControlRecordStateMachine(
Optional.of(VoterSet.fromInetSocketAddresses(listenerName, voterAddresses)), Optional.of(VoterSet.fromInetSocketAddresses(channel.listenerName(), voterAddresses)),
log, log,
serde, serde,
BufferSupplier.create(), BufferSupplier.create(),
@ -394,17 +423,35 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
logger.info("Reading KRaft snapshot and log as part of the initialization"); logger.info("Reading KRaft snapshot and log as part of the initialization");
partitionState.updateState(); partitionState.updateState();
VoterSet lastVoterSet = partitionState.lastVoterSet(); if (requestManager == null) {
requestManager = new RequestManager( // The request manager wasn't created using the bootstrap servers
lastVoterSet.voterIds(), // create it using the voters static configuration
quorumConfig.retryBackoffMs(), List<Node> bootstrapNodes = voterAddresses
quorumConfig.requestTimeoutMs(), .entrySet()
random .stream()
); .map(entry ->
new Node(
entry.getKey(),
entry.getValue().getHostString(),
entry.getValue().getPort()
)
)
.collect(Collectors.toList());
logger.info("Starting request manager with static voters: {}", bootstrapNodes);
requestManager = new RequestManager(
bootstrapNodes,
quorumConfig.retryBackoffMs(),
quorumConfig.requestTimeoutMs(),
random
);
}
quorum = new QuorumState( quorum = new QuorumState(
nodeId, nodeId,
nodeDirectoryId, nodeDirectoryId,
channel.listenerName(),
partitionState::lastVoterSet, partitionState::lastVoterSet,
partitionState::lastKraftVersion, partitionState::lastKraftVersion,
quorumConfig.electionTimeoutMs(), quorumConfig.electionTimeoutMs(),
@ -420,10 +467,6 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
// so there are no unknown voter connections. Report this metric as 0. // so there are no unknown voter connections. Report this metric as 0.
kafkaRaftMetrics.updateNumUnknownVoterConnections(0); kafkaRaftMetrics.updateNumUnknownVoterConnections(0);
for (Integer voterId : lastVoterSet.voterIds()) {
channel.updateEndpoint(voterId, lastVoterSet.voterAddress(voterId, listenerName).get());
}
quorum.initialize(new OffsetAndEpoch(log.endOffset().offset, log.lastFetchedEpoch())); quorum.initialize(new OffsetAndEpoch(log.endOffset().offset, log.lastFetchedEpoch()));
long currentTimeMs = time.milliseconds(); long currentTimeMs = time.milliseconds();
@ -569,10 +612,10 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
private void transitionToFollower( private void transitionToFollower(
int epoch, int epoch,
int leaderId, Node leader,
long currentTimeMs long currentTimeMs
) { ) {
quorum.transitionToFollower(epoch, leaderId); quorum.transitionToFollower(epoch, leader);
maybeFireLeaderChange(); maybeFireLeaderChange();
onBecomeFollower(currentTimeMs); onBecomeFollower(currentTimeMs);
} }
@ -601,7 +644,7 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
private VoteResponseData handleVoteRequest( private VoteResponseData handleVoteRequest(
RaftRequest.Inbound requestMetadata RaftRequest.Inbound requestMetadata
) { ) {
VoteRequestData request = (VoteRequestData) requestMetadata.data; VoteRequestData request = (VoteRequestData) requestMetadata.data();
if (!hasValidClusterId(request.clusterId())) { if (!hasValidClusterId(request.clusterId())) {
return new VoteResponseData().setErrorCode(Errors.INCONSISTENT_CLUSTER_ID.code()); return new VoteResponseData().setErrorCode(Errors.INCONSISTENT_CLUSTER_ID.code());
@ -652,8 +695,8 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
RaftResponse.Inbound responseMetadata, RaftResponse.Inbound responseMetadata,
long currentTimeMs long currentTimeMs
) { ) {
int remoteNodeId = responseMetadata.sourceId(); int remoteNodeId = responseMetadata.source().id();
VoteResponseData response = (VoteResponseData) responseMetadata.data; VoteResponseData response = (VoteResponseData) responseMetadata.data();
Errors topLevelError = Errors.forCode(response.errorCode()); Errors topLevelError = Errors.forCode(response.errorCode());
if (topLevelError != Errors.NONE) { if (topLevelError != Errors.NONE) {
return handleTopLevelError(topLevelError, responseMetadata); return handleTopLevelError(topLevelError, responseMetadata);
@ -751,7 +794,7 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
RaftRequest.Inbound requestMetadata, RaftRequest.Inbound requestMetadata,
long currentTimeMs long currentTimeMs
) { ) {
BeginQuorumEpochRequestData request = (BeginQuorumEpochRequestData) requestMetadata.data; BeginQuorumEpochRequestData request = (BeginQuorumEpochRequestData) requestMetadata.data();
if (!hasValidClusterId(request.clusterId())) { if (!hasValidClusterId(request.clusterId())) {
return new BeginQuorumEpochResponseData().setErrorCode(Errors.INCONSISTENT_CLUSTER_ID.code()); return new BeginQuorumEpochResponseData().setErrorCode(Errors.INCONSISTENT_CLUSTER_ID.code());
@ -773,7 +816,11 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
return buildBeginQuorumEpochResponse(errorOpt.get()); return buildBeginQuorumEpochResponse(errorOpt.get());
} }
maybeTransition(OptionalInt.of(requestLeaderId), requestEpoch, currentTimeMs); maybeTransition(
partitionState.lastVoterSet().voterNode(requestLeaderId, channel.listenerName()),
requestEpoch,
currentTimeMs
);
return buildBeginQuorumEpochResponse(Errors.NONE); return buildBeginQuorumEpochResponse(Errors.NONE);
} }
@ -781,8 +828,8 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
RaftResponse.Inbound responseMetadata, RaftResponse.Inbound responseMetadata,
long currentTimeMs long currentTimeMs
) { ) {
int remoteNodeId = responseMetadata.sourceId(); int remoteNodeId = responseMetadata.source().id();
BeginQuorumEpochResponseData response = (BeginQuorumEpochResponseData) responseMetadata.data; BeginQuorumEpochResponseData response = (BeginQuorumEpochResponseData) responseMetadata.data();
Errors topLevelError = Errors.forCode(response.errorCode()); Errors topLevelError = Errors.forCode(response.errorCode());
if (topLevelError != Errors.NONE) { if (topLevelError != Errors.NONE) {
return handleTopLevelError(topLevelError, responseMetadata); return handleTopLevelError(topLevelError, responseMetadata);
@ -840,7 +887,7 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
RaftRequest.Inbound requestMetadata, RaftRequest.Inbound requestMetadata,
long currentTimeMs long currentTimeMs
) { ) {
EndQuorumEpochRequestData request = (EndQuorumEpochRequestData) requestMetadata.data; EndQuorumEpochRequestData request = (EndQuorumEpochRequestData) requestMetadata.data();
if (!hasValidClusterId(request.clusterId())) { if (!hasValidClusterId(request.clusterId())) {
return new EndQuorumEpochResponseData().setErrorCode(Errors.INCONSISTENT_CLUSTER_ID.code()); return new EndQuorumEpochResponseData().setErrorCode(Errors.INCONSISTENT_CLUSTER_ID.code());
@ -861,11 +908,15 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
if (errorOpt.isPresent()) { if (errorOpt.isPresent()) {
return buildEndQuorumEpochResponse(errorOpt.get()); return buildEndQuorumEpochResponse(errorOpt.get());
} }
maybeTransition(OptionalInt.of(requestLeaderId), requestEpoch, currentTimeMs); maybeTransition(
partitionState.lastVoterSet().voterNode(requestLeaderId, channel.listenerName()),
requestEpoch,
currentTimeMs
);
if (quorum.isFollower()) { if (quorum.isFollower()) {
FollowerState state = quorum.followerStateOrThrow(); FollowerState state = quorum.followerStateOrThrow();
if (state.leaderId() == requestLeaderId) { if (state.leader().id() == requestLeaderId) {
List<Integer> preferredSuccessors = partitionRequest.preferredSuccessors(); List<Integer> preferredSuccessors = partitionRequest.preferredSuccessors();
long electionBackoffMs = endEpochElectionBackoff(preferredSuccessors); long electionBackoffMs = endEpochElectionBackoff(preferredSuccessors);
logger.debug("Overriding follower fetch timeout to {} after receiving " + logger.debug("Overriding follower fetch timeout to {} after receiving " +
@ -894,7 +945,7 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
RaftResponse.Inbound responseMetadata, RaftResponse.Inbound responseMetadata,
long currentTimeMs long currentTimeMs
) { ) {
EndQuorumEpochResponseData response = (EndQuorumEpochResponseData) responseMetadata.data; EndQuorumEpochResponseData response = (EndQuorumEpochResponseData) responseMetadata.data();
Errors topLevelError = Errors.forCode(response.errorCode()); Errors topLevelError = Errors.forCode(response.errorCode());
if (topLevelError != Errors.NONE) { if (topLevelError != Errors.NONE) {
return handleTopLevelError(topLevelError, responseMetadata); return handleTopLevelError(topLevelError, responseMetadata);
@ -917,7 +968,7 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
return handled.get(); return handled.get();
} else if (partitionError == Errors.NONE) { } else if (partitionError == Errors.NONE) {
ResignedState resignedState = quorum.resignedStateOrThrow(); ResignedState resignedState = quorum.resignedStateOrThrow();
resignedState.acknowledgeResignation(responseMetadata.sourceId()); resignedState.acknowledgeResignation(responseMetadata.source().id());
return true; return true;
} else { } else {
return handleUnexpectedError(partitionError, responseMetadata); return handleUnexpectedError(partitionError, responseMetadata);
@ -998,7 +1049,7 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
RaftRequest.Inbound requestMetadata, RaftRequest.Inbound requestMetadata,
long currentTimeMs long currentTimeMs
) { ) {
FetchRequestData request = (FetchRequestData) requestMetadata.data; FetchRequestData request = (FetchRequestData) requestMetadata.data();
if (!hasValidClusterId(request.clusterId())) { if (!hasValidClusterId(request.clusterId())) {
return completedFuture(new FetchResponseData().setErrorCode(Errors.INCONSISTENT_CLUSTER_ID.code())); return completedFuture(new FetchResponseData().setErrorCode(Errors.INCONSISTENT_CLUSTER_ID.code()));
@ -1147,13 +1198,13 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
RaftResponse.Inbound responseMetadata, RaftResponse.Inbound responseMetadata,
long currentTimeMs long currentTimeMs
) { ) {
FetchResponseData response = (FetchResponseData) responseMetadata.data; FetchResponseData response = (FetchResponseData) responseMetadata.data();
Errors topLevelError = Errors.forCode(response.errorCode()); Errors topLevelError = Errors.forCode(response.errorCode());
if (topLevelError != Errors.NONE) { if (topLevelError != Errors.NONE) {
return handleTopLevelError(topLevelError, responseMetadata); return handleTopLevelError(topLevelError, responseMetadata);
} }
if (!RaftUtil.hasValidTopicPartition(response, log.topicPartition(), log.topicId())) { if (!hasValidTopicPartition(response, log.topicPartition(), log.topicId())) {
return false; return false;
} }
// If the ID is valid, we can set the topic name. // If the ID is valid, we can set the topic name.
@ -1286,7 +1337,7 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
RaftRequest.Inbound requestMetadata, RaftRequest.Inbound requestMetadata,
long currentTimeMs long currentTimeMs
) { ) {
DescribeQuorumRequestData describeQuorumRequestData = (DescribeQuorumRequestData) requestMetadata.data; DescribeQuorumRequestData describeQuorumRequestData = (DescribeQuorumRequestData) requestMetadata.data();
if (!hasValidTopicPartition(describeQuorumRequestData, log.topicPartition())) { if (!hasValidTopicPartition(describeQuorumRequestData, log.topicPartition())) {
return DescribeQuorumRequest.getPartitionLevelErrorResponse( return DescribeQuorumRequest.getPartitionLevelErrorResponse(
describeQuorumRequestData, Errors.UNKNOWN_TOPIC_OR_PARTITION); describeQuorumRequestData, Errors.UNKNOWN_TOPIC_OR_PARTITION);
@ -1325,7 +1376,7 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
RaftRequest.Inbound requestMetadata, RaftRequest.Inbound requestMetadata,
long currentTimeMs long currentTimeMs
) { ) {
FetchSnapshotRequestData data = (FetchSnapshotRequestData) requestMetadata.data; FetchSnapshotRequestData data = (FetchSnapshotRequestData) requestMetadata.data();
if (!hasValidClusterId(data.clusterId())) { if (!hasValidClusterId(data.clusterId())) {
return new FetchSnapshotResponseData().setErrorCode(Errors.INCONSISTENT_CLUSTER_ID.code()); return new FetchSnapshotResponseData().setErrorCode(Errors.INCONSISTENT_CLUSTER_ID.code());
@ -1429,7 +1480,7 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
RaftResponse.Inbound responseMetadata, RaftResponse.Inbound responseMetadata,
long currentTimeMs long currentTimeMs
) { ) {
FetchSnapshotResponseData data = (FetchSnapshotResponseData) responseMetadata.data; FetchSnapshotResponseData data = (FetchSnapshotResponseData) responseMetadata.data();
Errors topLevelError = Errors.forCode(data.errorCode()); Errors topLevelError = Errors.forCode(data.errorCode());
if (topLevelError != Errors.NONE) { if (topLevelError != Errors.NONE) {
return handleTopLevelError(topLevelError, responseMetadata); return handleTopLevelError(topLevelError, responseMetadata);
@ -1593,6 +1644,12 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
int epoch, int epoch,
long currentTimeMs long currentTimeMs
) { ) {
Optional<Node> leader = leaderId.isPresent() ?
partitionState
.lastVoterSet()
.voterNode(leaderId.getAsInt(), channel.listenerName()) :
Optional.empty();
if (epoch < quorum.epoch() || error == Errors.UNKNOWN_LEADER_EPOCH) { if (epoch < quorum.epoch() || error == Errors.UNKNOWN_LEADER_EPOCH) {
// We have a larger epoch, so the response is no longer relevant // We have a larger epoch, so the response is no longer relevant
return Optional.of(true); return Optional.of(true);
@ -1602,10 +1659,10 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
// The response indicates that the request had a stale epoch, but we need // The response indicates that the request had a stale epoch, but we need
// to validate the epoch from the response against our current state. // to validate the epoch from the response against our current state.
maybeTransition(leaderId, epoch, currentTimeMs); maybeTransition(leader, epoch, currentTimeMs);
return Optional.of(true); return Optional.of(true);
} else if (epoch == quorum.epoch() } else if (epoch == quorum.epoch()
&& leaderId.isPresent() && leader.isPresent()
&& !quorum.hasLeader()) { && !quorum.hasLeader()) {
// Since we are transitioning to Follower, we will only forward the // Since we are transitioning to Follower, we will only forward the
@ -1613,7 +1670,7 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
// the request be retried immediately (if needed) after the transition. // the request be retried immediately (if needed) after the transition.
// This handling allows an observer to discover the leader and append // This handling allows an observer to discover the leader and append
// to the log in the same Fetch request. // to the log in the same Fetch request.
transitionToFollower(epoch, leaderId.getAsInt(), currentTimeMs); transitionToFollower(epoch, leader.get(), currentTimeMs);
if (error == Errors.NONE) { if (error == Errors.NONE) {
return Optional.empty(); return Optional.empty();
} else { } else {
@ -1635,24 +1692,28 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
} }
private void maybeTransition( private void maybeTransition(
OptionalInt leaderId, Optional<Node> leader,
int epoch, int epoch,
long currentTimeMs long currentTimeMs
) { ) {
OptionalInt leaderId = leader.isPresent() ?
OptionalInt.of(leader.get().id()) :
OptionalInt.empty();
if (!hasConsistentLeader(epoch, leaderId)) { if (!hasConsistentLeader(epoch, leaderId)) {
throw new IllegalStateException("Received request or response with leader " + leaderId + throw new IllegalStateException("Received request or response with leader " + leader +
" and epoch " + epoch + " which is inconsistent with current leader " + " and epoch " + epoch + " which is inconsistent with current leader " +
quorum.leaderId() + " and epoch " + quorum.epoch()); quorum.leaderId() + " and epoch " + quorum.epoch());
} else if (epoch > quorum.epoch()) { } else if (epoch > quorum.epoch()) {
if (leaderId.isPresent()) { if (leader.isPresent()) {
transitionToFollower(epoch, leaderId.getAsInt(), currentTimeMs); transitionToFollower(epoch, leader.get(), currentTimeMs);
} else { } else {
transitionToUnattached(epoch); transitionToUnattached(epoch);
} }
} else if (leaderId.isPresent() && !quorum.hasLeader()) { } else if (leader.isPresent() && !quorum.hasLeader()) {
// The request or response indicates the leader of the current epoch, // The request or response indicates the leader of the current epoch,
// which is currently unknown // which is currently unknown
transitionToFollower(epoch, leaderId.getAsInt(), currentTimeMs); transitionToFollower(epoch, leader.get(), currentTimeMs);
} }
} }
@ -1668,13 +1729,13 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
private boolean handleUnexpectedError(Errors error, RaftResponse.Inbound response) { private boolean handleUnexpectedError(Errors error, RaftResponse.Inbound response) {
logger.error("Unexpected error {} in {} response: {}", logger.error("Unexpected error {} in {} response: {}",
error, ApiKeys.forId(response.data.apiKey()), response); error, ApiKeys.forId(response.data().apiKey()), response);
return false; return false;
} }
private void handleResponse(RaftResponse.Inbound response, long currentTimeMs) { private void handleResponse(RaftResponse.Inbound response, long currentTimeMs) {
// The response epoch matches the local epoch, so we can handle the response // The response epoch matches the local epoch, so we can handle the response
ApiKeys apiKey = ApiKeys.forId(response.data.apiKey()); ApiKeys apiKey = ApiKeys.forId(response.data().apiKey());
final boolean handledSuccessfully; final boolean handledSuccessfully;
switch (apiKey) { switch (apiKey) {
@ -1702,12 +1763,12 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
throw new IllegalArgumentException("Received unexpected response type: " + apiKey); throw new IllegalArgumentException("Received unexpected response type: " + apiKey);
} }
ConnectionState connection = requestManager.getOrCreate(response.sourceId()); requestManager.onResponseResult(
if (handledSuccessfully) { response.source(),
connection.onResponseReceived(response.correlationId); response.correlationId(),
} else { handledSuccessfully,
connection.onResponseError(response.correlationId, currentTimeMs); currentTimeMs
} );
} }
/** /**
@ -1749,7 +1810,7 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
} }
private void handleRequest(RaftRequest.Inbound request, long currentTimeMs) { private void handleRequest(RaftRequest.Inbound request, long currentTimeMs) {
ApiKeys apiKey = ApiKeys.forId(request.data.apiKey()); ApiKeys apiKey = ApiKeys.forId(request.data().apiKey());
final CompletableFuture<? extends ApiMessage> responseFuture; final CompletableFuture<? extends ApiMessage> responseFuture;
switch (apiKey) { switch (apiKey) {
@ -1803,8 +1864,7 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
handleRequest(request, currentTimeMs); handleRequest(request, currentTimeMs);
} else if (message instanceof RaftResponse.Inbound) { } else if (message instanceof RaftResponse.Inbound) {
RaftResponse.Inbound response = (RaftResponse.Inbound) message; RaftResponse.Inbound response = (RaftResponse.Inbound) message;
ConnectionState connection = requestManager.getOrCreate(response.sourceId()); if (requestManager.isResponseExpected(response.source(), response.correlationId())) {
if (connection.isResponseExpected(response.correlationId)) {
handleResponse(response, currentTimeMs); handleResponse(response, currentTimeMs);
} else { } else {
logger.debug("Ignoring response {} since it is no longer needed", response); logger.debug("Ignoring response {} since it is no longer needed", response);
@ -1819,25 +1879,23 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
*/ */
private long maybeSendRequest( private long maybeSendRequest(
long currentTimeMs, long currentTimeMs,
int destinationId, Node destination,
Supplier<ApiMessage> requestSupplier Supplier<ApiMessage> requestSupplier
) { ) {
ConnectionState connection = requestManager.getOrCreate(destinationId); if (requestManager.isBackingOff(destination, currentTimeMs)) {
long remainingBackoffMs = requestManager.remainingBackoffMs(destination, currentTimeMs);
if (connection.isBackingOff(currentTimeMs)) { logger.debug("Connection for {} is backing off for {} ms", destination, remainingBackoffMs);
long remainingBackoffMs = connection.remainingBackoffMs(currentTimeMs);
logger.debug("Connection for {} is backing off for {} ms", destinationId, remainingBackoffMs);
return remainingBackoffMs; return remainingBackoffMs;
} }
if (connection.isReady(currentTimeMs)) { if (requestManager.isReady(destination, currentTimeMs)) {
int correlationId = channel.newCorrelationId(); int correlationId = channel.newCorrelationId();
ApiMessage request = requestSupplier.get(); ApiMessage request = requestSupplier.get();
RaftRequest.Outbound requestMessage = new RaftRequest.Outbound( RaftRequest.Outbound requestMessage = new RaftRequest.Outbound(
correlationId, correlationId,
request, request,
destinationId, destination,
currentTimeMs currentTimeMs
); );
@ -1850,20 +1908,19 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
response = new RaftResponse.Inbound( response = new RaftResponse.Inbound(
correlationId, correlationId,
errorResponse, errorResponse,
destinationId destination
); );
} }
messageQueue.add(response); messageQueue.add(response);
}); });
requestManager.onRequestSent(destination, correlationId, currentTimeMs);
channel.send(requestMessage); channel.send(requestMessage);
logger.trace("Sent outbound request: {}", requestMessage); logger.trace("Sent outbound request: {}", requestMessage);
connection.onRequestSent(correlationId, currentTimeMs);
return Long.MAX_VALUE;
} }
return connection.remainingRequestTimeMs(currentTimeMs); return requestManager.remainingRequestTimeMs(destination, currentTimeMs);
} }
private EndQuorumEpochRequestData buildEndQuorumEpochRequest( private EndQuorumEpochRequestData buildEndQuorumEpochRequest(
@ -1880,12 +1937,12 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
private long maybeSendRequests( private long maybeSendRequests(
long currentTimeMs, long currentTimeMs,
Set<Integer> destinationIds, Set<Node> destinations,
Supplier<ApiMessage> requestSupplier Supplier<ApiMessage> requestSupplier
) { ) {
long minBackoffMs = Long.MAX_VALUE; long minBackoffMs = Long.MAX_VALUE;
for (Integer destinationId : destinationIds) { for (Node destination : destinations) {
long backoffMs = maybeSendRequest(currentTimeMs, destinationId, requestSupplier); long backoffMs = maybeSendRequest(currentTimeMs, destination, requestSupplier);
if (backoffMs < minBackoffMs) { if (backoffMs < minBackoffMs) {
minBackoffMs = backoffMs; minBackoffMs = backoffMs;
} }
@ -1929,15 +1986,15 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
} }
private long maybeSendAnyVoterFetch(long currentTimeMs) { private long maybeSendAnyVoterFetch(long currentTimeMs) {
OptionalInt readyVoterIdOpt = requestManager.findReadyVoter(currentTimeMs); Optional<Node> readyNode = requestManager.findReadyBootstrapServer(currentTimeMs);
if (readyVoterIdOpt.isPresent()) { if (readyNode.isPresent()) {
return maybeSendRequest( return maybeSendRequest(
currentTimeMs, currentTimeMs,
readyVoterIdOpt.getAsInt(), readyNode.get(),
this::buildFetchRequest this::buildFetchRequest
); );
} else { } else {
return requestManager.backoffBeforeAvailableVoter(currentTimeMs); return requestManager.backoffBeforeAvailableBootstrapServer(currentTimeMs);
} }
} }
@ -2038,7 +2095,9 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
ResignedState state = quorum.resignedStateOrThrow(); ResignedState state = quorum.resignedStateOrThrow();
long endQuorumBackoffMs = maybeSendRequests( long endQuorumBackoffMs = maybeSendRequests(
currentTimeMs, currentTimeMs,
state.unackedVoters(), partitionState
.lastVoterSet()
.voterNodes(state.unackedVoters().stream(), channel.listenerName()),
() -> buildEndQuorumEpochRequest(state) () -> buildEndQuorumEpochRequest(state)
); );
@ -2075,7 +2134,9 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
long timeUntilSend = maybeSendRequests( long timeUntilSend = maybeSendRequests(
currentTimeMs, currentTimeMs,
state.nonAcknowledgingVoters(), partitionState
.lastVoterSet()
.voterNodes(state.nonAcknowledgingVoters().stream(), channel.listenerName()),
this::buildBeginQuorumEpochRequest this::buildBeginQuorumEpochRequest
); );
@ -2090,7 +2151,9 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
if (!state.isVoteRejected()) { if (!state.isVoteRejected()) {
return maybeSendRequests( return maybeSendRequests(
currentTimeMs, currentTimeMs,
state.unrecordedVoters(), partitionState
.lastVoterSet()
.voterNodes(state.unrecordedVoters().stream(), channel.listenerName()),
this::buildVoteRequest this::buildVoteRequest
); );
} }
@ -2163,14 +2226,16 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
// If the current leader is backing off due to some failure or if the // If the current leader is backing off due to some failure or if the
// request has timed out, then we attempt to send the Fetch to another // request has timed out, then we attempt to send the Fetch to another
// voter in order to discover if there has been a leader change. // voter in order to discover if there has been a leader change.
ConnectionState connection = requestManager.getOrCreate(state.leaderId()); if (requestManager.hasRequestTimedOut(state.leader(), currentTimeMs)) {
if (connection.hasRequestTimedOut(currentTimeMs)) { // Once the request has timed out backoff the connection
requestManager.reset(state.leader());
backoffMs = maybeSendAnyVoterFetch(currentTimeMs); backoffMs = maybeSendAnyVoterFetch(currentTimeMs);
connection.reset(); } else if (requestManager.isBackingOff(state.leader(), currentTimeMs)) {
} else if (connection.isBackingOff(currentTimeMs)) {
backoffMs = maybeSendAnyVoterFetch(currentTimeMs); backoffMs = maybeSendAnyVoterFetch(currentTimeMs);
} else { } else if (!requestManager.hasAnyInflightRequest(currentTimeMs)) {
backoffMs = maybeSendFetchOrFetchSnapshot(state, currentTimeMs); backoffMs = maybeSendFetchOrFetchSnapshot(state, currentTimeMs);
} else {
backoffMs = requestManager.backoffBeforeAvailableBootstrapServer(currentTimeMs);
} }
return Math.min(backoffMs, state.remainingFetchTimeMs(currentTimeMs)); return Math.min(backoffMs, state.remainingFetchTimeMs(currentTimeMs));
@ -2189,7 +2254,7 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
requestSupplier = this::buildFetchRequest; requestSupplier = this::buildFetchRequest;
} }
return maybeSendRequest(currentTimeMs, state.leaderId(), requestSupplier); return maybeSendRequest(currentTimeMs, state.leader(), requestSupplier);
} }
private long pollVoted(long currentTimeMs) { private long pollVoted(long currentTimeMs) {
@ -2549,8 +2614,8 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
} }
} }
public Optional<Node> voterNode(int id, String listener) { public Optional<Node> voterNode(int id, ListenerName listenerName) {
return partitionState.lastVoterSet().voterNode(id, listener); return partitionState.lastVoterSet().voterNode(id, listenerName);
} }
// Visible only for test // Visible only for test

View File

@ -16,7 +16,7 @@
*/ */
package org.apache.kafka.raft; package org.apache.kafka.raft;
import java.net.InetSocketAddress; import org.apache.kafka.common.network.ListenerName;
/** /**
* A simple network interface with few assumptions. We do not assume ordering * A simple network interface with few assumptions. We do not assume ordering
@ -37,10 +37,11 @@ public interface NetworkChannel extends AutoCloseable {
void send(RaftRequest.Outbound request); void send(RaftRequest.Outbound request);
/** /**
* Update connection information for the given id. * The name of listener used when sending requests.
*
* @return the name of the listener
*/ */
void updateEndpoint(int id, InetSocketAddress address); ListenerName listenerName();
default void close() throws InterruptedException {} default void close() throws InterruptedException {}
} }

View File

@ -54,6 +54,13 @@ public class QuorumConfig {
"For example: <code>1@localhost:9092,2@localhost:9093,3@localhost:9094</code>"; "For example: <code>1@localhost:9092,2@localhost:9093,3@localhost:9094</code>";
public static final List<String> DEFAULT_QUORUM_VOTERS = Collections.emptyList(); public static final List<String> DEFAULT_QUORUM_VOTERS = Collections.emptyList();
public static final String QUORUM_BOOTSTRAP_SERVERS_CONFIG = QUORUM_PREFIX + "bootstrap.servers";
public static final String QUORUM_BOOTSTRAP_SERVERS_DOC = "List of endpoints to use for " +
"bootstrapping the cluster metadata. The endpoints are specified in comma-separated list " +
"of <code>{host}:{port}</code> entries. For example: " +
"<code>localhost:9092,localhost:9093,localhost:9094</code>.";
public static final List<String> DEFAULT_QUORUM_BOOTSTRAP_SERVERS = Collections.emptyList();
public static final String QUORUM_ELECTION_TIMEOUT_MS_CONFIG = QUORUM_PREFIX + "election.timeout.ms"; public static final String QUORUM_ELECTION_TIMEOUT_MS_CONFIG = QUORUM_PREFIX + "election.timeout.ms";
public static final String QUORUM_ELECTION_TIMEOUT_MS_DOC = "Maximum time in milliseconds to wait " + public static final String QUORUM_ELECTION_TIMEOUT_MS_DOC = "Maximum time in milliseconds to wait " +
"without being able to fetch from the leader before triggering a new election"; "without being able to fetch from the leader before triggering a new election";
@ -163,7 +170,7 @@ public class QuorumConfig {
List<String> voterEntries, List<String> voterEntries,
boolean requireRoutableAddresses boolean requireRoutableAddresses
) { ) {
Map<Integer, InetSocketAddress> voterMap = new HashMap<>(); Map<Integer, InetSocketAddress> voterMap = new HashMap<>(voterEntries.size());
for (String voterMapEntry : voterEntries) { for (String voterMapEntry : voterEntries) {
String[] idAndAddress = voterMapEntry.split("@"); String[] idAndAddress = voterMapEntry.split("@");
if (idAndAddress.length != 2) { if (idAndAddress.length != 2) {
@ -173,7 +180,7 @@ public class QuorumConfig {
Integer voterId = parseVoterId(idAndAddress[0]); Integer voterId = parseVoterId(idAndAddress[0]);
String host = Utils.getHost(idAndAddress[1]); String host = Utils.getHost(idAndAddress[1]);
if (host == null) { if (host == null || !Utils.validHostPattern(host)) {
throw new ConfigException("Failed to parse host name from entry " + voterMapEntry throw new ConfigException("Failed to parse host name from entry " + voterMapEntry
+ " for the configuration " + QUORUM_VOTERS_CONFIG + " for the configuration " + QUORUM_VOTERS_CONFIG
+ ". Each entry should be in the form `{id}@{host}:{port}`."); + ". Each entry should be in the form `{id}@{host}:{port}`.");
@ -199,6 +206,41 @@ public class QuorumConfig {
return voterMap; return voterMap;
} }
public static List<InetSocketAddress> parseBootstrapServers(List<String> bootstrapServers) {
return bootstrapServers
.stream()
.map(QuorumConfig::parseBootstrapServer)
.collect(Collectors.toList());
}
private static InetSocketAddress parseBootstrapServer(String bootstrapServer) {
String host = Utils.getHost(bootstrapServer);
if (host == null || !Utils.validHostPattern(host)) {
throw new ConfigException(
String.format(
"Failed to parse host name from {} for the configuration {}. Each " +
"entry should be in the form \"{host}:{port}\"",
bootstrapServer,
QUORUM_BOOTSTRAP_SERVERS_CONFIG
)
);
}
Integer port = Utils.getPort(bootstrapServer);
if (port == null) {
throw new ConfigException(
String.format(
"Failed to parse host port from {} for the configuration {}. Each " +
"entry should be in the form \"{host}:{port}\"",
bootstrapServer,
QUORUM_BOOTSTRAP_SERVERS_CONFIG
)
);
}
return InetSocketAddress.createUnresolved(host, port);
}
public static List<Node> quorumVoterStringsToNodes(List<String> voters) { public static List<Node> quorumVoterStringsToNodes(List<String> voters) {
return voterConnectionsToNodes(parseVoterConnections(voters)); return voterConnectionsToNodes(parseVoterConnections(voters));
} }
@ -231,4 +273,26 @@ public class QuorumConfig {
return "non-empty list"; return "non-empty list";
} }
} }
public static class ControllerQuorumBootstrapServersValidator implements ConfigDef.Validator {
@Override
public void ensureValid(String name, Object value) {
if (value == null) {
throw new ConfigException(name, null);
}
@SuppressWarnings("unchecked")
List<String> entries = (List<String>) value;
// Attempt to parse the connect strings
for (String entry : entries) {
parseBootstrapServer(entry);
}
}
@Override
public String toString() {
return "non-empty list";
}
}
} }

View File

@ -25,7 +25,9 @@ import java.util.OptionalInt;
import java.util.Random; import java.util.Random;
import java.util.function.Supplier; import java.util.function.Supplier;
import org.apache.kafka.common.Node;
import org.apache.kafka.common.Uuid; import org.apache.kafka.common.Uuid;
import org.apache.kafka.common.network.ListenerName;
import org.apache.kafka.common.utils.LogContext; import org.apache.kafka.common.utils.LogContext;
import org.apache.kafka.common.utils.Time; import org.apache.kafka.common.utils.Time;
import org.apache.kafka.raft.internals.BatchAccumulator; import org.apache.kafka.raft.internals.BatchAccumulator;
@ -81,6 +83,7 @@ public class QuorumState {
private final Time time; private final Time time;
private final Logger log; private final Logger log;
private final QuorumStateStore store; private final QuorumStateStore store;
private final ListenerName listenerName;
private final Supplier<VoterSet> latestVoterSet; private final Supplier<VoterSet> latestVoterSet;
private final Supplier<Short> latestKraftVersion; private final Supplier<Short> latestKraftVersion;
private final Random random; private final Random random;
@ -93,6 +96,7 @@ public class QuorumState {
public QuorumState( public QuorumState(
OptionalInt localId, OptionalInt localId,
Uuid localDirectoryId, Uuid localDirectoryId,
ListenerName listenerName,
Supplier<VoterSet> latestVoterSet, Supplier<VoterSet> latestVoterSet,
Supplier<Short> latestKraftVersion, Supplier<Short> latestKraftVersion,
int electionTimeoutMs, int electionTimeoutMs,
@ -104,6 +108,7 @@ public class QuorumState {
) { ) {
this.localId = localId; this.localId = localId;
this.localDirectoryId = localDirectoryId; this.localDirectoryId = localDirectoryId;
this.listenerName = listenerName;
this.latestVoterSet = latestVoterSet; this.latestVoterSet = latestVoterSet;
this.latestKraftVersion = latestKraftVersion; this.latestKraftVersion = latestKraftVersion;
this.electionTimeoutMs = electionTimeoutMs; this.electionTimeoutMs = electionTimeoutMs;
@ -115,16 +120,21 @@ public class QuorumState {
this.logContext = logContext; this.logContext = logContext;
} }
public void initialize(OffsetAndEpoch logEndOffsetAndEpoch) throws IllegalStateException { private ElectionState readElectionState() {
// We initialize in whatever state we were in on shutdown. If we were a leader
// or candidate, probably an election was held, but we will find out about it
// when we send Vote or BeginEpoch requests.
ElectionState election; ElectionState election;
election = store election = store
.readElectionState() .readElectionState()
.orElseGet(() -> ElectionState.withUnknownLeader(0, latestVoterSet.get().voterIds())); .orElseGet(() -> ElectionState.withUnknownLeader(0, latestVoterSet.get().voterIds()));
return election;
}
public void initialize(OffsetAndEpoch logEndOffsetAndEpoch) throws IllegalStateException {
// We initialize in whatever state we were in on shutdown. If we were a leader
// or candidate, probably an election was held, but we will find out about it
// when we send Vote or BeginEpoch requests.
ElectionState election = readElectionState();
final EpochState initialState; final EpochState initialState;
if (election.hasVoted() && !localId.isPresent()) { if (election.hasVoted() && !localId.isPresent()) {
throw new IllegalStateException( throw new IllegalStateException(
@ -191,10 +201,26 @@ public class QuorumState {
logContext logContext
); );
} else if (election.hasLeader()) { } else if (election.hasLeader()) {
/* KAFKA-16529 is going to change this so that the leader is not required to be in the set
* of voters. In other words, don't throw an IllegalStateException if the leader is not in
* the set of voters.
*/
Node leader = latestVoterSet
.get()
.voterNode(election.leaderId(), listenerName)
.orElseThrow(() ->
new IllegalStateException(
String.format(
"Leader %s must be in the voter set %s",
election.leaderId(),
latestVoterSet.get()
)
)
);
initialState = new FollowerState( initialState = new FollowerState(
time, time,
election.epoch(), election.epoch(),
election.leaderId(), leader,
latestVoterSet.get().voterIds(), latestVoterSet.get().voterIds(),
Optional.empty(), Optional.empty(),
fetchTimeoutMs, fetchTimeoutMs,
@ -400,28 +426,24 @@ public class QuorumState {
/** /**
* Become a follower of an elected leader so that we can begin fetching. * Become a follower of an elected leader so that we can begin fetching.
*/ */
public void transitionToFollower( public void transitionToFollower(int epoch, Node leader) {
int epoch,
int leaderId
) {
int currentEpoch = state.epoch(); int currentEpoch = state.epoch();
if (localId.isPresent() && leaderId == localId.getAsInt()) { if (localId.isPresent() && leader.id() == localId.getAsInt()) {
throw new IllegalStateException("Cannot transition to Follower with leaderId=" + leaderId + throw new IllegalStateException("Cannot transition to Follower with leader " + leader +
" and epoch=" + epoch + " since it matches the local broker.id=" + localId); " and epoch " + epoch + " since it matches the local broker.id " + localId);
} else if (epoch < currentEpoch) { } else if (epoch < currentEpoch) {
throw new IllegalStateException("Cannot transition to Follower with leaderId=" + leaderId + throw new IllegalStateException("Cannot transition to Follower with leader " + leader +
" and epoch=" + epoch + " since the current epoch " + currentEpoch + " is larger"); " and epoch " + epoch + " since the current epoch " + currentEpoch + " is larger");
} else if (epoch == currentEpoch } else if (epoch == currentEpoch && (isFollower() || isLeader())) {
&& (isFollower() || isLeader())) { throw new IllegalStateException("Cannot transition to Follower with leader " + leader +
throw new IllegalStateException("Cannot transition to Follower with leaderId=" + leaderId + " and epoch " + epoch + " from state " + state);
" and epoch=" + epoch + " from state " + state);
} }
durableTransitionTo( durableTransitionTo(
new FollowerState( new FollowerState(
time, time,
epoch, epoch,
leaderId, leader,
latestVoterSet.get().voterIds(), latestVoterSet.get().voterIds(),
state.highWatermark(), state.highWatermark(),
fetchTimeoutMs, fetchTimeoutMs,

View File

@ -17,13 +17,14 @@
package org.apache.kafka.raft; package org.apache.kafka.raft;
import org.apache.kafka.common.protocol.ApiMessage; import org.apache.kafka.common.protocol.ApiMessage;
import org.apache.kafka.common.Node;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
public abstract class RaftRequest implements RaftMessage { public abstract class RaftRequest implements RaftMessage {
protected final int correlationId; private final int correlationId;
protected final ApiMessage data; private final ApiMessage data;
protected final long createdTimeMs; private final long createdTimeMs;
public RaftRequest(int correlationId, ApiMessage data, long createdTimeMs) { public RaftRequest(int correlationId, ApiMessage data, long createdTimeMs) {
this.correlationId = correlationId; this.correlationId = correlationId;
@ -45,7 +46,7 @@ public abstract class RaftRequest implements RaftMessage {
return createdTimeMs; return createdTimeMs;
} }
public static class Inbound extends RaftRequest { public final static class Inbound extends RaftRequest {
public final CompletableFuture<RaftResponse.Outbound> completion = new CompletableFuture<>(); public final CompletableFuture<RaftResponse.Outbound> completion = new CompletableFuture<>();
public Inbound(int correlationId, ApiMessage data, long createdTimeMs) { public Inbound(int correlationId, ApiMessage data, long createdTimeMs) {
@ -54,35 +55,37 @@ public abstract class RaftRequest implements RaftMessage {
@Override @Override
public String toString() { public String toString() {
return "InboundRequest(" + return String.format(
"correlationId=" + correlationId + "InboundRequest(correlationId=%d, data=%s, createdTimeMs=%d)",
", data=" + data + correlationId(),
", createdTimeMs=" + createdTimeMs + data(),
')'; createdTimeMs()
);
} }
} }
public static class Outbound extends RaftRequest { public final static class Outbound extends RaftRequest {
private final int destinationId; private final Node destination;
public final CompletableFuture<RaftResponse.Inbound> completion = new CompletableFuture<>(); public final CompletableFuture<RaftResponse.Inbound> completion = new CompletableFuture<>();
public Outbound(int correlationId, ApiMessage data, int destinationId, long createdTimeMs) { public Outbound(int correlationId, ApiMessage data, Node destination, long createdTimeMs) {
super(correlationId, data, createdTimeMs); super(correlationId, data, createdTimeMs);
this.destinationId = destinationId; this.destination = destination;
} }
public int destinationId() { public Node destination() {
return destinationId; return destination;
} }
@Override @Override
public String toString() { public String toString() {
return "OutboundRequest(" + return String.format(
"correlationId=" + correlationId + "OutboundRequest(correlationId=%d, data=%s, createdTimeMs=%d, destination=%s)",
", data=" + data + correlationId(),
", createdTimeMs=" + createdTimeMs + data(),
", destinationId=" + destinationId + createdTimeMs(),
')'; destination
);
} }
} }
} }

View File

@ -16,11 +16,12 @@
*/ */
package org.apache.kafka.raft; package org.apache.kafka.raft;
import org.apache.kafka.common.Node;
import org.apache.kafka.common.protocol.ApiMessage; import org.apache.kafka.common.protocol.ApiMessage;
public abstract class RaftResponse implements RaftMessage { public abstract class RaftResponse implements RaftMessage {
protected final int correlationId; private final int correlationId;
protected final ApiMessage data; private final ApiMessage data;
protected RaftResponse(int correlationId, ApiMessage data) { protected RaftResponse(int correlationId, ApiMessage data) {
this.correlationId = correlationId; this.correlationId = correlationId;
@ -37,39 +38,41 @@ public abstract class RaftResponse implements RaftMessage {
return data; return data;
} }
public static class Inbound extends RaftResponse { public final static class Inbound extends RaftResponse {
private final int sourceId; private final Node source;
public Inbound(int correlationId, ApiMessage data, int sourceId) { public Inbound(int correlationId, ApiMessage data, Node source) {
super(correlationId, data); super(correlationId, data);
this.sourceId = sourceId; this.source = source;
} }
public int sourceId() { public Node source() {
return sourceId; return source;
} }
@Override @Override
public String toString() { public String toString() {
return "InboundResponse(" + return String.format(
"correlationId=" + correlationId + "InboundResponse(correlationId=%d, data=%s, source=%s)",
", data=" + data + correlationId(),
", sourceId=" + sourceId + data(),
')'; source
);
} }
} }
public static class Outbound extends RaftResponse { public final static class Outbound extends RaftResponse {
public Outbound(int requestId, ApiMessage data) { public Outbound(int requestId, ApiMessage data) {
super(requestId, data); super(requestId, data);
} }
@Override @Override
public String toString() { public String toString() {
return "OutboundResponse(" + return String.format(
"correlationId=" + correlationId + "OutboundResponse(correlationId=%d, data=%s)",
", data=" + data + correlationId(),
')'; data()
);
} }
} }
} }

View File

@ -25,6 +25,7 @@ import org.apache.kafka.common.message.EndQuorumEpochRequestData;
import org.apache.kafka.common.message.EndQuorumEpochResponseData; import org.apache.kafka.common.message.EndQuorumEpochResponseData;
import org.apache.kafka.common.message.FetchRequestData; import org.apache.kafka.common.message.FetchRequestData;
import org.apache.kafka.common.message.FetchResponseData; import org.apache.kafka.common.message.FetchResponseData;
import org.apache.kafka.common.message.FetchSnapshotResponseData;
import org.apache.kafka.common.message.VoteRequestData; import org.apache.kafka.common.message.VoteRequestData;
import org.apache.kafka.common.message.VoteResponseData; import org.apache.kafka.common.message.VoteResponseData;
import org.apache.kafka.common.protocol.ApiKeys; import org.apache.kafka.common.protocol.ApiKeys;
@ -48,6 +49,8 @@ public class RaftUtil {
return new EndQuorumEpochResponseData().setErrorCode(error.code()); return new EndQuorumEpochResponseData().setErrorCode(error.code());
case FETCH: case FETCH:
return new FetchResponseData().setErrorCode(error.code()); return new FetchResponseData().setErrorCode(error.code());
case FETCH_SNAPSHOT:
return new FetchSnapshotResponseData().setErrorCode(error.code());
default: default:
throw new IllegalArgumentException("Received response for unexpected request type: " + apiKey); throw new IllegalArgumentException("Received response for unexpected request type: " + apiKey);
} }

View File

@ -17,96 +17,288 @@
package org.apache.kafka.raft; package org.apache.kafka.raft;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.Iterator;
import java.util.Map; import java.util.Map;
import java.util.OptionalInt; import java.util.Optional;
import java.util.OptionalLong; import java.util.OptionalLong;
import java.util.Random; import java.util.Random;
import java.util.Set; import org.apache.kafka.common.Node;
/**
* The request manager keeps tracks of the connection with remote replicas.
*
* When sending a request update this type by calling {@code onRequestSent(Node, long, long)}. When
* the RPC returns a response, update this manager with {@code onResponseResult(Node, long, boolean, long)}.
*
* Connections start in the ready state ({@code isReady(Node, long)} returns true).
*
* When a request times out or completes successfully the collection will transition back to the
* ready state.
*
* When a request completes with an error it still transition to the backoff state until
* {@code retryBackoffMs}.
*/
public class RequestManager { public class RequestManager {
private final Map<Integer, ConnectionState> connections = new HashMap<>(); private final Map<String, ConnectionState> connections = new HashMap<>();
private final List<Integer> voters = new ArrayList<>(); private final ArrayList<Node> bootstrapServers;
private final int retryBackoffMs; private final int retryBackoffMs;
private final int requestTimeoutMs; private final int requestTimeoutMs;
private final Random random; private final Random random;
public RequestManager(Set<Integer> voterIds, public RequestManager(
int retryBackoffMs, Collection<Node> bootstrapServers,
int requestTimeoutMs, int retryBackoffMs,
Random random) { int requestTimeoutMs,
Random random
) {
this.bootstrapServers = new ArrayList<>(bootstrapServers);
this.retryBackoffMs = retryBackoffMs; this.retryBackoffMs = retryBackoffMs;
this.requestTimeoutMs = requestTimeoutMs; this.requestTimeoutMs = requestTimeoutMs;
this.voters.addAll(voterIds);
this.random = random; this.random = random;
for (Integer voterId: voterIds) {
ConnectionState connection = new ConnectionState(voterId);
connections.put(voterId, connection);
}
} }
public ConnectionState getOrCreate(int id) { /**
return connections.computeIfAbsent(id, key -> new ConnectionState(id)); * Returns true if there any connection with pending requests.
} *
* This is useful for satisfying the invariant that there is only one pending Fetch request.
* If there are more than one pending fetch request, it is possible for the follower to write
* the same offset twice.
*
* @param currentTimeMs the current time
* @return true if the request manager is tracking at least one request
*/
public boolean hasAnyInflightRequest(long currentTimeMs) {
boolean result = false;
public OptionalInt findReadyVoter(long currentTimeMs) { Iterator<ConnectionState> iterator = connections.values().iterator();
int startIndex = random.nextInt(voters.size()); while (iterator.hasNext()) {
OptionalInt res = OptionalInt.empty(); ConnectionState connection = iterator.next();
for (int i = 0; i < voters.size(); i++) { if (connection.hasRequestTimedOut(currentTimeMs)) {
int index = (startIndex + i) % voters.size(); // Mark the node as ready after request timeout
Integer voterId = voters.get(index); iterator.remove();
ConnectionState connection = connections.get(voterId); } else if (connection.isBackoffComplete(currentTimeMs)) {
boolean isReady = connection.isReady(currentTimeMs); // Mark the node as ready after completed backoff
iterator.remove();
if (isReady) { } else if (connection.hasInflightRequest(currentTimeMs)) {
res = OptionalInt.of(voterId); // If there is at least one inflight request, it is enough
} else if (connection.inFlightCorrelationId.isPresent()) { // to stop checking the rest of the connections
res = OptionalInt.empty(); result = true;
break; break;
} }
} }
return res;
return result;
} }
public long backoffBeforeAvailableVoter(long currentTimeMs) { /**
long minBackoffMs = Long.MAX_VALUE; * Returns a random bootstrap node that is ready to receive a request.
for (Integer voterId : voters) { *
ConnectionState connection = connections.get(voterId); * This method doesn't return a node if there is at least one request pending. In general this
if (connection.isReady(currentTimeMs)) { * method is used to send Fetch requests. Fetch requests have the invariant that there can
return 0L; * only be one pending Fetch request for the LEO.
} else if (connection.isBackingOff(currentTimeMs)) { *
minBackoffMs = Math.min(minBackoffMs, connection.remainingBackoffMs(currentTimeMs)); * @param currentTimeMs the current time
} else { * @return a random ready bootstrap node
minBackoffMs = Math.min(minBackoffMs, connection.remainingRequestTimeMs(currentTimeMs)); */
public Optional<Node> findReadyBootstrapServer(long currentTimeMs) {
// Check that there are no infilght requests accross any of the known nodes not just
// the bootstrap servers
if (hasAnyInflightRequest(currentTimeMs)) {
return Optional.empty();
}
int startIndex = random.nextInt(bootstrapServers.size());
Optional<Node> result = Optional.empty();
for (int i = 0; i < bootstrapServers.size(); i++) {
int index = (startIndex + i) % bootstrapServers.size();
Node node = bootstrapServers.get(index);
if (isReady(node, currentTimeMs)) {
result = Optional.of(node);
break;
} }
} }
return result;
}
/**
* Computes the amount of time needed to wait before a bootstrap server is ready for a Fetch
* request.
*
* If there is a connection with a pending request it returns the amount of time to wait until
* the request times out.
*
* Returns zero, if there are no pending request and at least one of the boorstrap servers is
* ready.
*
* If all of the bootstrap servers are backing off and there are no pending requests, return
* the minimum amount of time until a bootstrap server becomes ready.
*
* @param currentTimeMs the current time
* @return the amount of time to wait until bootstrap server can accept a Fetch request
*/
public long backoffBeforeAvailableBootstrapServer(long currentTimeMs) {
long minBackoffMs = retryBackoffMs;
Iterator<ConnectionState> iterator = connections.values().iterator();
while (iterator.hasNext()) {
ConnectionState connection = iterator.next();
if (connection.hasRequestTimedOut(currentTimeMs)) {
// Mark the node as ready after request timeout
iterator.remove();
} else if (connection.isBackoffComplete(currentTimeMs)) {
// Mark the node as ready after completed backoff
iterator.remove();
} else if (connection.hasInflightRequest(currentTimeMs)) {
// There can be at most one inflight fetch request
return connection.remainingRequestTimeMs(currentTimeMs);
} else if (connection.isBackingOff(currentTimeMs)) {
minBackoffMs = Math.min(minBackoffMs, connection.remainingBackoffMs(currentTimeMs));
}
}
// There are no inflight fetch requests so check if there is a ready bootstrap server
for (Node node : bootstrapServers) {
if (isReady(node, currentTimeMs)) {
return 0L;
}
}
// There are no ready bootstrap servers and inflight fetch requests, return the backoff
return minBackoffMs; return minBackoffMs;
} }
public boolean hasRequestTimedOut(Node node, long timeMs) {
ConnectionState state = connections.get(node.idString());
if (state == null) {
return false;
}
return state.hasRequestTimedOut(timeMs);
}
public boolean isReady(Node node, long timeMs) {
ConnectionState state = connections.get(node.idString());
if (state == null) {
return true;
}
boolean ready = state.isReady(timeMs);
if (ready) {
reset(node);
}
return ready;
}
public boolean isBackingOff(Node node, long timeMs) {
ConnectionState state = connections.get(node.idString());
if (state == null) {
return false;
}
return state.isBackingOff(timeMs);
}
public long remainingRequestTimeMs(Node node, long timeMs) {
ConnectionState state = connections.get(node.idString());
if (state == null) {
return 0;
}
return state.remainingRequestTimeMs(timeMs);
}
public long remainingBackoffMs(Node node, long timeMs) {
ConnectionState state = connections.get(node.idString());
if (state == null) {
return 0;
}
return state.remainingBackoffMs(timeMs);
}
public boolean isResponseExpected(Node node, long correlationId) {
ConnectionState state = connections.get(node.idString());
if (state == null) {
return false;
}
return state.isResponseExpected(correlationId);
}
/**
* Updates the manager when a response is received.
*
* @param node the source of the response
* @param correlationId the correlation id of the response
* @param success true if the request was successful, false otherwise
* @param timeMs the current time
*/
public void onResponseResult(Node node, long correlationId, boolean success, long timeMs) {
if (isResponseExpected(node, correlationId)) {
if (success) {
// Mark the connection as ready by reseting it
reset(node);
} else {
// Backoff the connection
connections.get(node.idString()).onResponseError(correlationId, timeMs);
}
}
}
/**
* Updates the manager when a request is sent.
*
* @param node the destination of the request
* @param correlationId the correlation id of the request
* @param timeMs the current time
*/
public void onRequestSent(Node node, long correlationId, long timeMs) {
ConnectionState state = connections.computeIfAbsent(
node.idString(),
key -> new ConnectionState(node, retryBackoffMs, requestTimeoutMs)
);
state.onRequestSent(correlationId, timeMs);
}
public void reset(Node node) {
connections.remove(node.idString());
}
public void resetAll() { public void resetAll() {
for (ConnectionState connectionState : connections.values()) connections.clear();
connectionState.reset();
} }
private enum State { private enum State {
AWAITING_REQUEST, AWAITING_RESPONSE,
BACKING_OFF, BACKING_OFF,
READY READY
} }
public class ConnectionState { private final static class ConnectionState {
private final long id; private final Node node;
private final int retryBackoffMs;
private final int requestTimeoutMs;
private State state = State.READY; private State state = State.READY;
private long lastSendTimeMs = 0L; private long lastSendTimeMs = 0L;
private long lastFailTimeMs = 0L; private long lastFailTimeMs = 0L;
private OptionalLong inFlightCorrelationId = OptionalLong.empty(); private OptionalLong inFlightCorrelationId = OptionalLong.empty();
public ConnectionState(long id) { private ConnectionState(
this.id = id; Node node,
int retryBackoffMs,
int requestTimeoutMs
) {
this.node = node;
this.retryBackoffMs = retryBackoffMs;
this.requestTimeoutMs = requestTimeoutMs;
} }
private boolean isBackoffComplete(long timeMs) { private boolean isBackoffComplete(long timeMs) {
@ -114,11 +306,7 @@ public class RequestManager {
} }
boolean hasRequestTimedOut(long timeMs) { boolean hasRequestTimedOut(long timeMs) {
return state == State.AWAITING_REQUEST && timeMs >= lastSendTimeMs + requestTimeoutMs; return state == State.AWAITING_RESPONSE && timeMs >= lastSendTimeMs + requestTimeoutMs;
}
public long id() {
return id;
} }
boolean isReady(long timeMs) { boolean isReady(long timeMs) {
@ -136,8 +324,8 @@ public class RequestManager {
} }
} }
boolean hasInflightRequest(long timeMs) { private boolean hasInflightRequest(long timeMs) {
if (state != State.AWAITING_REQUEST) { if (state != State.AWAITING_RESPONSE) {
return false; return false;
} else { } else {
return !hasRequestTimedOut(timeMs); return !hasRequestTimedOut(timeMs);
@ -174,41 +362,22 @@ public class RequestManager {
}); });
} }
void onResponseReceived(long correlationId) {
inFlightCorrelationId.ifPresent(inflightRequestId -> {
if (inflightRequestId == correlationId) {
state = State.READY;
inFlightCorrelationId = OptionalLong.empty();
}
});
}
void onRequestSent(long correlationId, long timeMs) { void onRequestSent(long correlationId, long timeMs) {
lastSendTimeMs = timeMs; lastSendTimeMs = timeMs;
inFlightCorrelationId = OptionalLong.of(correlationId); inFlightCorrelationId = OptionalLong.of(correlationId);
state = State.AWAITING_REQUEST; state = State.AWAITING_RESPONSE;
}
/**
* Ignore in-flight requests or backoff and become available immediately. This is used
* when there is a state change which usually means in-flight requests are obsolete
* and we need to send new requests.
*/
void reset() {
state = State.READY;
inFlightCorrelationId = OptionalLong.empty();
} }
@Override @Override
public String toString() { public String toString() {
return "ConnectionState(" + return String.format(
"id=" + id + "ConnectionState(node=%s, state=%s, lastSendTimeMs=%d, lastFailTimeMs=%d, inFlightCorrelationId=%d)",
", state=" + state + node,
", lastSendTimeMs=" + lastSendTimeMs + state,
", lastFailTimeMs=" + lastFailTimeMs + lastSendTimeMs,
", inFlightCorrelationId=" + inFlightCorrelationId + lastFailTimeMs,
')'; inFlightCorrelationId
);
} }
} }
} }

View File

@ -28,11 +28,13 @@ import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.function.Function; import java.util.function.Function;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.kafka.common.Node; import org.apache.kafka.common.Node;
import org.apache.kafka.common.Uuid; import org.apache.kafka.common.Uuid;
import org.apache.kafka.common.feature.SupportedVersionRange; import org.apache.kafka.common.feature.SupportedVersionRange;
import org.apache.kafka.common.message.VotersRecord; import org.apache.kafka.common.message.VotersRecord;
import org.apache.kafka.common.network.ListenerName;
import org.apache.kafka.common.utils.Utils; import org.apache.kafka.common.utils.Utils;
/** /**
@ -55,15 +57,41 @@ final public class VoterSet {
} }
/** /**
* Returns the socket address for a given voter at a given listener. * Returns the node information for all the given voter ids and listener.
* *
* @param voter the id of the voter * @param voterIds the ids of the voters
* @param listener the name of the listener * @param listenerName the name of the listener
* @return the socket address if it exists, otherwise {@code Optional.empty()} * @return the node information for all of the voter ids
* @throws IllegalArgumentException if there are missing endpoints
*/ */
public Optional<InetSocketAddress> voterAddress(int voter, String listener) { public Set<Node> voterNodes(Stream<Integer> voterIds, ListenerName listenerName) {
return Optional.ofNullable(voters.get(voter)) return voterIds
.flatMap(voterNode -> voterNode.address(listener)); .map(voterId ->
voterNode(voterId, listenerName).orElseThrow(() ->
new IllegalArgumentException(
String.format(
"Unable to find endpoint for voter %d and listener %s in %s",
voterId,
listenerName,
voters
)
)
)
)
.collect(Collectors.toSet());
}
/**
* Returns the node information for a given voter id and listener.
*
* @param voterId the id of the voter
* @param listenerName the name of the listener
* @return the node information if it exists, otherwise {@code Optional.empty()}
*/
public Optional<Node> voterNode(int voterId, ListenerName listenerName) {
return Optional.ofNullable(voters.get(voterId))
.flatMap(voterNode -> voterNode.address(listenerName))
.map(address -> new Node(voterId, address.getHostString(), address.getPort()));
} }
/** /**
@ -166,7 +194,7 @@ final public class VoterSet {
.stream() .stream()
.map(entry -> .map(entry ->
new VotersRecord.Endpoint() new VotersRecord.Endpoint()
.setName(entry.getKey()) .setName(entry.getKey().value())
.setHost(entry.getValue().getHostString()) .setHost(entry.getValue().getHostString())
.setPort(entry.getValue().getPort()) .setPort(entry.getValue().getPort())
) )
@ -247,12 +275,12 @@ final public class VoterSet {
public final static class VoterNode { public final static class VoterNode {
private final ReplicaKey voterKey; private final ReplicaKey voterKey;
private final Map<String, InetSocketAddress> listeners; private final Map<ListenerName, InetSocketAddress> listeners;
private final SupportedVersionRange supportedKRaftVersion; private final SupportedVersionRange supportedKRaftVersion;
VoterNode( VoterNode(
ReplicaKey voterKey, ReplicaKey voterKey,
Map<String, InetSocketAddress> listeners, Map<ListenerName, InetSocketAddress> listeners,
SupportedVersionRange supportedKRaftVersion SupportedVersionRange supportedKRaftVersion
) { ) {
this.voterKey = voterKey; this.voterKey = voterKey;
@ -264,7 +292,7 @@ final public class VoterSet {
return voterKey; return voterKey;
} }
Map<String, InetSocketAddress> listeners() { Map<ListenerName, InetSocketAddress> listeners() {
return listeners; return listeners;
} }
@ -273,7 +301,7 @@ final public class VoterSet {
} }
Optional<InetSocketAddress> address(String listener) { Optional<InetSocketAddress> address(ListenerName listener) {
return Optional.ofNullable(listeners.get(listener)); return Optional.ofNullable(listeners.get(listener));
} }
@ -323,9 +351,12 @@ final public class VoterSet {
directoryId = Optional.empty(); directoryId = Optional.empty();
} }
Map<String, InetSocketAddress> listeners = new HashMap<>(voter.endpoints().size()); Map<ListenerName, InetSocketAddress> listeners = new HashMap<>(voter.endpoints().size());
for (VotersRecord.Endpoint endpoint : voter.endpoints()) { for (VotersRecord.Endpoint endpoint : voter.endpoints()) {
listeners.put(endpoint.name(), InetSocketAddress.createUnresolved(endpoint.host(), endpoint.port())); listeners.put(
ListenerName.normalised(endpoint.name()),
InetSocketAddress.createUnresolved(endpoint.host(), endpoint.port())
);
} }
voterNodes.put( voterNodes.put(
@ -351,7 +382,7 @@ final public class VoterSet {
* @param voters the socket addresses by voter id * @param voters the socket addresses by voter id
* @return the voter set * @return the voter set
*/ */
public static VoterSet fromInetSocketAddresses(String listener, Map<Integer, InetSocketAddress> voters) { public static VoterSet fromInetSocketAddresses(ListenerName listener, Map<Integer, InetSocketAddress> voters) {
Map<Integer, VoterNode> voterNodes = voters Map<Integer, VoterNode> voterNodes = voters
.entrySet() .entrySet()
.stream() .stream()
@ -368,16 +399,4 @@ final public class VoterSet {
return new VoterSet(voterNodes); return new VoterSet(voterNodes);
} }
public Optional<Node> voterNode(int id, String listener) {
VoterNode voterNode = voters.get(id);
if (voterNode == null) {
return Optional.empty();
}
InetSocketAddress address = voterNode.listeners.get(listener);
if (address == null) {
return Optional.empty();
}
return Optional.of(new Node(id, address.getHostString(), address.getPort()));
}
} }

View File

@ -26,11 +26,10 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource; import org.junit.jupiter.params.provider.ValueSource;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.stream.IntStream;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
@ -60,7 +59,7 @@ public class CandidateStateTest {
@Test @Test
public void testSingleNodeQuorum() { public void testSingleNodeQuorum() {
CandidateState state = newCandidateState(voterSetWithLocal(Collections.emptyList())); CandidateState state = newCandidateState(voterSetWithLocal(IntStream.empty()));
assertTrue(state.isVoteGranted()); assertTrue(state.isVoteGranted());
assertFalse(state.isVoteRejected()); assertFalse(state.isVoteRejected());
assertEquals(Collections.emptySet(), state.unrecordedVoters()); assertEquals(Collections.emptySet(), state.unrecordedVoters());
@ -70,7 +69,7 @@ public class CandidateStateTest {
public void testTwoNodeQuorumVoteRejected() { public void testTwoNodeQuorumVoteRejected() {
int otherNodeId = 1; int otherNodeId = 1;
CandidateState state = newCandidateState( CandidateState state = newCandidateState(
voterSetWithLocal(Collections.singletonList(otherNodeId)) voterSetWithLocal(IntStream.of(otherNodeId))
); );
assertFalse(state.isVoteGranted()); assertFalse(state.isVoteGranted());
assertFalse(state.isVoteRejected()); assertFalse(state.isVoteRejected());
@ -84,7 +83,7 @@ public class CandidateStateTest {
public void testTwoNodeQuorumVoteGranted() { public void testTwoNodeQuorumVoteGranted() {
int otherNodeId = 1; int otherNodeId = 1;
CandidateState state = newCandidateState( CandidateState state = newCandidateState(
voterSetWithLocal(Collections.singletonList(otherNodeId)) voterSetWithLocal(IntStream.of(otherNodeId))
); );
assertFalse(state.isVoteGranted()); assertFalse(state.isVoteGranted());
assertFalse(state.isVoteRejected()); assertFalse(state.isVoteRejected());
@ -100,7 +99,7 @@ public class CandidateStateTest {
int node1 = 1; int node1 = 1;
int node2 = 2; int node2 = 2;
CandidateState state = newCandidateState( CandidateState state = newCandidateState(
voterSetWithLocal(Arrays.asList(node1, node2)) voterSetWithLocal(IntStream.of(node1, node2))
); );
assertFalse(state.isVoteGranted()); assertFalse(state.isVoteGranted());
assertFalse(state.isVoteRejected()); assertFalse(state.isVoteRejected());
@ -120,7 +119,7 @@ public class CandidateStateTest {
int node1 = 1; int node1 = 1;
int node2 = 2; int node2 = 2;
CandidateState state = newCandidateState( CandidateState state = newCandidateState(
voterSetWithLocal(Arrays.asList(node1, node2)) voterSetWithLocal(IntStream.of(node1, node2))
); );
assertFalse(state.isVoteGranted()); assertFalse(state.isVoteGranted());
assertFalse(state.isVoteRejected()); assertFalse(state.isVoteRejected());
@ -139,7 +138,7 @@ public class CandidateStateTest {
public void testCannotRejectVoteFromLocalId() { public void testCannotRejectVoteFromLocalId() {
int otherNodeId = 1; int otherNodeId = 1;
CandidateState state = newCandidateState( CandidateState state = newCandidateState(
voterSetWithLocal(Collections.singletonList(otherNodeId)) voterSetWithLocal(IntStream.of(otherNodeId))
); );
assertThrows( assertThrows(
IllegalArgumentException.class, IllegalArgumentException.class,
@ -151,7 +150,7 @@ public class CandidateStateTest {
public void testCannotChangeVoteGrantedToRejected() { public void testCannotChangeVoteGrantedToRejected() {
int otherNodeId = 1; int otherNodeId = 1;
CandidateState state = newCandidateState( CandidateState state = newCandidateState(
voterSetWithLocal(Collections.singletonList(otherNodeId)) voterSetWithLocal(IntStream.of(otherNodeId))
); );
assertTrue(state.recordGrantedVote(otherNodeId)); assertTrue(state.recordGrantedVote(otherNodeId));
assertThrows(IllegalArgumentException.class, () -> state.recordRejectedVote(otherNodeId)); assertThrows(IllegalArgumentException.class, () -> state.recordRejectedVote(otherNodeId));
@ -162,7 +161,7 @@ public class CandidateStateTest {
public void testCannotChangeVoteRejectedToGranted() { public void testCannotChangeVoteRejectedToGranted() {
int otherNodeId = 1; int otherNodeId = 1;
CandidateState state = newCandidateState( CandidateState state = newCandidateState(
voterSetWithLocal(Collections.singletonList(otherNodeId)) voterSetWithLocal(IntStream.of(otherNodeId))
); );
assertTrue(state.recordRejectedVote(otherNodeId)); assertTrue(state.recordRejectedVote(otherNodeId));
assertThrows(IllegalArgumentException.class, () -> state.recordGrantedVote(otherNodeId)); assertThrows(IllegalArgumentException.class, () -> state.recordGrantedVote(otherNodeId));
@ -172,7 +171,7 @@ public class CandidateStateTest {
@Test @Test
public void testCannotGrantOrRejectNonVoters() { public void testCannotGrantOrRejectNonVoters() {
int nonVoterId = 1; int nonVoterId = 1;
CandidateState state = newCandidateState(voterSetWithLocal(Collections.emptyList())); CandidateState state = newCandidateState(voterSetWithLocal(IntStream.empty()));
assertThrows(IllegalArgumentException.class, () -> state.recordGrantedVote(nonVoterId)); assertThrows(IllegalArgumentException.class, () -> state.recordGrantedVote(nonVoterId));
assertThrows(IllegalArgumentException.class, () -> state.recordRejectedVote(nonVoterId)); assertThrows(IllegalArgumentException.class, () -> state.recordRejectedVote(nonVoterId));
} }
@ -181,7 +180,7 @@ public class CandidateStateTest {
public void testIdempotentGrant() { public void testIdempotentGrant() {
int otherNodeId = 1; int otherNodeId = 1;
CandidateState state = newCandidateState( CandidateState state = newCandidateState(
voterSetWithLocal(Collections.singletonList(otherNodeId)) voterSetWithLocal(IntStream.of(otherNodeId))
); );
assertTrue(state.recordGrantedVote(otherNodeId)); assertTrue(state.recordGrantedVote(otherNodeId));
assertFalse(state.recordGrantedVote(otherNodeId)); assertFalse(state.recordGrantedVote(otherNodeId));
@ -191,7 +190,7 @@ public class CandidateStateTest {
public void testIdempotentReject() { public void testIdempotentReject() {
int otherNodeId = 1; int otherNodeId = 1;
CandidateState state = newCandidateState( CandidateState state = newCandidateState(
voterSetWithLocal(Collections.singletonList(otherNodeId)) voterSetWithLocal(IntStream.of(otherNodeId))
); );
assertTrue(state.recordRejectedVote(otherNodeId)); assertTrue(state.recordRejectedVote(otherNodeId));
assertFalse(state.recordRejectedVote(otherNodeId)); assertFalse(state.recordRejectedVote(otherNodeId));
@ -201,7 +200,7 @@ public class CandidateStateTest {
@ValueSource(booleans = {true, false}) @ValueSource(booleans = {true, false})
public void testGrantVote(boolean isLogUpToDate) { public void testGrantVote(boolean isLogUpToDate) {
CandidateState state = newCandidateState( CandidateState state = newCandidateState(
voterSetWithLocal(Arrays.asList(1, 2, 3)) voterSetWithLocal(IntStream.of(1, 2, 3))
); );
assertFalse(state.canGrantVote(ReplicaKey.of(0, Optional.empty()), isLogUpToDate)); assertFalse(state.canGrantVote(ReplicaKey.of(0, Optional.empty()), isLogUpToDate));
@ -212,7 +211,7 @@ public class CandidateStateTest {
@Test @Test
public void testElectionState() { public void testElectionState() {
VoterSet voters = voterSetWithLocal(Arrays.asList(1, 2, 3)); VoterSet voters = voterSetWithLocal(IntStream.of(1, 2, 3));
CandidateState state = newCandidateState(voters); CandidateState state = newCandidateState(voters);
assertEquals( assertEquals(
ElectionState.withVotedCandidate( ElectionState.withVotedCandidate(
@ -228,11 +227,11 @@ public class CandidateStateTest {
public void testInvalidVoterSet() { public void testInvalidVoterSet() {
assertThrows( assertThrows(
IllegalArgumentException.class, IllegalArgumentException.class,
() -> newCandidateState(VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true))) () -> newCandidateState(VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(1, 2, 3), true)))
); );
} }
private VoterSet voterSetWithLocal(Collection<Integer> remoteVoters) { private VoterSet voterSetWithLocal(IntStream remoteVoters) {
Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(remoteVoters, true); Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(remoteVoters, true);
voterMap.put(localNode.voterKey().id(), localNode); voterMap.put(localNode.voterKey().id(), localNode);

View File

@ -16,6 +16,7 @@
*/ */
package org.apache.kafka.raft; package org.apache.kafka.raft;
import org.apache.kafka.common.Node;
import org.apache.kafka.common.utils.LogContext; import org.apache.kafka.common.utils.LogContext;
import org.apache.kafka.common.utils.MockTime; import org.apache.kafka.common.utils.MockTime;
import org.apache.kafka.common.utils.Utils; import org.apache.kafka.common.utils.Utils;
@ -38,7 +39,7 @@ public class FollowerStateTest {
private final LogContext logContext = new LogContext(); private final LogContext logContext = new LogContext();
private final int epoch = 5; private final int epoch = 5;
private final int fetchTimeoutMs = 15000; private final int fetchTimeoutMs = 15000;
int leaderId = 3; private final Node leader = new Node(3, "mock-host-3", 1234);
private FollowerState newFollowerState( private FollowerState newFollowerState(
Set<Integer> voters, Set<Integer> voters,
@ -47,7 +48,7 @@ public class FollowerStateTest {
return new FollowerState( return new FollowerState(
time, time,
epoch, epoch,
leaderId, leader,
voters, voters,
highWatermark, highWatermark,
fetchTimeoutMs, fetchTimeoutMs,
@ -96,4 +97,10 @@ public class FollowerStateTest {
assertFalse(state.canGrantVote(ReplicaKey.of(3, Optional.empty()), isLogUpToDate)); assertFalse(state.canGrantVote(ReplicaKey.of(3, Optional.empty()), isLogUpToDate));
} }
@Test
public void testLeaderNode() {
FollowerState state = newFollowerState(Utils.mkSet(0, 1, 2), Optional.empty());
assertEquals(leader, state.leader());
}
} }

View File

@ -26,7 +26,10 @@ import org.apache.kafka.common.message.BeginQuorumEpochResponseData;
import org.apache.kafka.common.message.EndQuorumEpochResponseData; import org.apache.kafka.common.message.EndQuorumEpochResponseData;
import org.apache.kafka.common.message.FetchRequestData; import org.apache.kafka.common.message.FetchRequestData;
import org.apache.kafka.common.message.FetchResponseData; import org.apache.kafka.common.message.FetchResponseData;
import org.apache.kafka.common.message.FetchSnapshotRequestData;
import org.apache.kafka.common.message.FetchSnapshotResponseData;
import org.apache.kafka.common.message.VoteResponseData; import org.apache.kafka.common.message.VoteResponseData;
import org.apache.kafka.common.network.ListenerName;
import org.apache.kafka.common.protocol.ApiKeys; import org.apache.kafka.common.protocol.ApiKeys;
import org.apache.kafka.common.protocol.ApiMessage; import org.apache.kafka.common.protocol.ApiMessage;
import org.apache.kafka.common.protocol.Errors; import org.apache.kafka.common.protocol.Errors;
@ -39,6 +42,8 @@ import org.apache.kafka.common.requests.EndQuorumEpochRequest;
import org.apache.kafka.common.requests.EndQuorumEpochResponse; import org.apache.kafka.common.requests.EndQuorumEpochResponse;
import org.apache.kafka.common.requests.FetchRequest; import org.apache.kafka.common.requests.FetchRequest;
import org.apache.kafka.common.requests.FetchResponse; import org.apache.kafka.common.requests.FetchResponse;
import org.apache.kafka.common.requests.FetchSnapshotRequest;
import org.apache.kafka.common.requests.FetchSnapshotResponse;
import org.apache.kafka.common.requests.VoteRequest; import org.apache.kafka.common.requests.VoteRequest;
import org.apache.kafka.common.requests.VoteResponse; import org.apache.kafka.common.requests.VoteResponse;
import org.apache.kafka.common.utils.MockTime; import org.apache.kafka.common.utils.MockTime;
@ -47,8 +52,8 @@ import org.apache.kafka.common.utils.annotation.ApiKeyVersionsSource;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import java.net.InetSocketAddress;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
@ -80,7 +85,8 @@ public class KafkaNetworkChannelTest {
ApiKeys.VOTE, ApiKeys.VOTE,
ApiKeys.BEGIN_QUORUM_EPOCH, ApiKeys.BEGIN_QUORUM_EPOCH,
ApiKeys.END_QUORUM_EPOCH, ApiKeys.END_QUORUM_EPOCH,
ApiKeys.FETCH ApiKeys.FETCH,
ApiKeys.FETCH_SNAPSHOT
); );
private final int requestTimeoutMs = 30000; private final int requestTimeoutMs = 30000;
@ -88,35 +94,40 @@ public class KafkaNetworkChannelTest {
private final MockClient client = new MockClient(time, new StubMetadataUpdater()); private final MockClient client = new MockClient(time, new StubMetadataUpdater());
private final TopicPartition topicPartition = new TopicPartition("topic", 0); private final TopicPartition topicPartition = new TopicPartition("topic", 0);
private final Uuid topicId = Uuid.randomUuid(); private final Uuid topicId = Uuid.randomUuid();
private final KafkaNetworkChannel channel = new KafkaNetworkChannel(time, client, requestTimeoutMs, "test-raft"); private final KafkaNetworkChannel channel = new KafkaNetworkChannel(
time,
ListenerName.normalised("NAME"),
client,
requestTimeoutMs,
"test-raft"
);
private Node nodeWithId(boolean withId) {
int id = withId ? 2 : -2;
return new Node(id, "127.0.0.1", 9092);
}
@BeforeEach @BeforeEach
public void setupSupportedApis() { public void setupSupportedApis() {
List<ApiVersionsResponseData.ApiVersion> supportedApis = RAFT_APIS.stream().map( List<ApiVersionsResponseData.ApiVersion> supportedApis = RAFT_APIS
ApiVersionsResponse::toApiVersion).collect(Collectors.toList()); .stream()
.map(ApiVersionsResponse::toApiVersion)
.collect(Collectors.toList());
client.setNodeApiVersions(NodeApiVersions.create(supportedApis)); client.setNodeApiVersions(NodeApiVersions.create(supportedApis));
} }
@Test @ParameterizedTest
public void testSendToUnknownDestination() throws ExecutionException, InterruptedException { @ValueSource(booleans = {true, false})
int destinationId = 2; public void testSendToBlackedOutDestination(boolean withDestinationId) throws ExecutionException, InterruptedException {
assertBrokerNotAvailable(destinationId); Node destination = nodeWithId(withDestinationId);
} client.backoff(destination, 500);
assertBrokerNotAvailable(destination);
@Test
public void testSendToBlackedOutDestination() throws ExecutionException, InterruptedException {
int destinationId = 2;
Node destinationNode = new Node(destinationId, "127.0.0.1", 9092);
channel.updateEndpoint(destinationId, new InetSocketAddress(destinationNode.host(), destinationNode.port()));
client.backoff(destinationNode, 500);
assertBrokerNotAvailable(destinationId);
} }
@Test @Test
public void testWakeupClientOnSend() throws InterruptedException, ExecutionException { public void testWakeupClientOnSend() throws InterruptedException, ExecutionException {
int destinationId = 2; int destinationId = 2;
Node destinationNode = new Node(destinationId, "127.0.0.1", 9092); Node destinationNode = new Node(destinationId, "127.0.0.1", 9092);
channel.updateEndpoint(destinationId, new InetSocketAddress(destinationNode.host(), destinationNode.port()));
client.enableBlockingUntilWakeup(1); client.enableBlockingUntilWakeup(1);
@ -132,7 +143,7 @@ public class KafkaNetworkChannelTest {
client.prepareResponseFrom(response, destinationNode, false); client.prepareResponseFrom(response, destinationNode, false);
ioThread.start(); ioThread.start();
RaftRequest.Outbound request = sendTestRequest(ApiKeys.FETCH, destinationId); RaftRequest.Outbound request = sendTestRequest(ApiKeys.FETCH, destinationNode);
ioThread.join(); ioThread.join();
assertResponseCompleted(request, Errors.INVALID_REQUEST); assertResponseCompleted(request, Errors.INVALID_REQUEST);
@ -142,12 +153,11 @@ public class KafkaNetworkChannelTest {
public void testSendAndDisconnect() throws ExecutionException, InterruptedException { public void testSendAndDisconnect() throws ExecutionException, InterruptedException {
int destinationId = 2; int destinationId = 2;
Node destinationNode = new Node(destinationId, "127.0.0.1", 9092); Node destinationNode = new Node(destinationId, "127.0.0.1", 9092);
channel.updateEndpoint(destinationId, new InetSocketAddress(destinationNode.host(), destinationNode.port()));
for (ApiKeys apiKey : RAFT_APIS) { for (ApiKeys apiKey : RAFT_APIS) {
AbstractResponse response = buildResponse(buildTestErrorResponse(apiKey, Errors.INVALID_REQUEST)); AbstractResponse response = buildResponse(buildTestErrorResponse(apiKey, Errors.INVALID_REQUEST));
client.prepareResponseFrom(response, destinationNode, true); client.prepareResponseFrom(response, destinationNode, true);
sendAndAssertErrorResponse(apiKey, destinationId, Errors.BROKER_NOT_AVAILABLE); sendAndAssertErrorResponse(apiKey, destinationNode, Errors.BROKER_NOT_AVAILABLE);
} }
} }
@ -155,35 +165,33 @@ public class KafkaNetworkChannelTest {
public void testSendAndFailAuthentication() throws ExecutionException, InterruptedException { public void testSendAndFailAuthentication() throws ExecutionException, InterruptedException {
int destinationId = 2; int destinationId = 2;
Node destinationNode = new Node(destinationId, "127.0.0.1", 9092); Node destinationNode = new Node(destinationId, "127.0.0.1", 9092);
channel.updateEndpoint(destinationId, new InetSocketAddress(destinationNode.host(), destinationNode.port()));
for (ApiKeys apiKey : RAFT_APIS) { for (ApiKeys apiKey : RAFT_APIS) {
client.createPendingAuthenticationError(destinationNode, 100); client.createPendingAuthenticationError(destinationNode, 100);
sendAndAssertErrorResponse(apiKey, destinationId, Errors.NETWORK_EXCEPTION); sendAndAssertErrorResponse(apiKey, destinationNode, Errors.NETWORK_EXCEPTION);
// reset to clear backoff time // reset to clear backoff time
client.reset(); client.reset();
} }
} }
private void assertBrokerNotAvailable(int destinationId) throws ExecutionException, InterruptedException { private void assertBrokerNotAvailable(Node destination) throws ExecutionException, InterruptedException {
for (ApiKeys apiKey : RAFT_APIS) { for (ApiKeys apiKey : RAFT_APIS) {
sendAndAssertErrorResponse(apiKey, destinationId, Errors.BROKER_NOT_AVAILABLE); sendAndAssertErrorResponse(apiKey, destination, Errors.BROKER_NOT_AVAILABLE);
} }
} }
@Test @ParameterizedTest
public void testSendAndReceiveOutboundRequest() throws ExecutionException, InterruptedException { @ValueSource(booleans = {true, false})
int destinationId = 2; public void testSendAndReceiveOutboundRequest(boolean withDestinationId) throws ExecutionException, InterruptedException {
Node destinationNode = new Node(destinationId, "127.0.0.1", 9092); Node destination = nodeWithId(withDestinationId);
channel.updateEndpoint(destinationId, new InetSocketAddress(destinationNode.host(), destinationNode.port()));
for (ApiKeys apiKey : RAFT_APIS) { for (ApiKeys apiKey : RAFT_APIS) {
Errors expectedError = Errors.INVALID_REQUEST; Errors expectedError = Errors.INVALID_REQUEST;
AbstractResponse response = buildResponse(buildTestErrorResponse(apiKey, expectedError)); AbstractResponse response = buildResponse(buildTestErrorResponse(apiKey, expectedError));
client.prepareResponseFrom(response, destinationNode); client.prepareResponseFrom(response, destination);
System.out.println("api key " + apiKey + ", response " + response); System.out.println("api key " + apiKey + ", response " + response);
sendAndAssertErrorResponse(apiKey, destinationId, expectedError); sendAndAssertErrorResponse(apiKey, destination, expectedError);
} }
} }
@ -191,11 +199,10 @@ public class KafkaNetworkChannelTest {
public void testUnsupportedVersionError() throws ExecutionException, InterruptedException { public void testUnsupportedVersionError() throws ExecutionException, InterruptedException {
int destinationId = 2; int destinationId = 2;
Node destinationNode = new Node(destinationId, "127.0.0.1", 9092); Node destinationNode = new Node(destinationId, "127.0.0.1", 9092);
channel.updateEndpoint(destinationId, new InetSocketAddress(destinationNode.host(), destinationNode.port()));
for (ApiKeys apiKey : RAFT_APIS) { for (ApiKeys apiKey : RAFT_APIS) {
client.prepareUnsupportedVersionResponse(request -> request.apiKey() == apiKey); client.prepareUnsupportedVersionResponse(request -> request.apiKey() == apiKey);
sendAndAssertErrorResponse(apiKey, destinationId, Errors.UNSUPPORTED_VERSION); sendAndAssertErrorResponse(apiKey, destinationNode, Errors.UNSUPPORTED_VERSION);
} }
} }
@ -204,8 +211,7 @@ public class KafkaNetworkChannelTest {
public void testFetchRequestDowngrade(short version) { public void testFetchRequestDowngrade(short version) {
int destinationId = 2; int destinationId = 2;
Node destinationNode = new Node(destinationId, "127.0.0.1", 9092); Node destinationNode = new Node(destinationId, "127.0.0.1", 9092);
channel.updateEndpoint(destinationId, new InetSocketAddress(destinationNode.host(), destinationNode.port())); sendTestRequest(ApiKeys.FETCH, destinationNode);
sendTestRequest(ApiKeys.FETCH, destinationId);
channel.pollOnce(); channel.pollOnce();
assertEquals(1, client.requests().size()); assertEquals(1, client.requests().size());
@ -220,27 +226,39 @@ public class KafkaNetworkChannelTest {
} }
} }
private RaftRequest.Outbound sendTestRequest(ApiKeys apiKey, int destinationId) { private RaftRequest.Outbound sendTestRequest(ApiKeys apiKey, Node destination) {
int correlationId = channel.newCorrelationId(); int correlationId = channel.newCorrelationId();
long createdTimeMs = time.milliseconds(); long createdTimeMs = time.milliseconds();
ApiMessage apiRequest = buildTestRequest(apiKey); ApiMessage apiRequest = buildTestRequest(apiKey);
RaftRequest.Outbound request = new RaftRequest.Outbound(correlationId, apiRequest, destinationId, createdTimeMs); RaftRequest.Outbound request = new RaftRequest.Outbound(
correlationId,
apiRequest,
destination,
createdTimeMs
);
channel.send(request); channel.send(request);
return request; return request;
} }
private void assertResponseCompleted(RaftRequest.Outbound request, Errors expectedError) throws ExecutionException, InterruptedException { private void assertResponseCompleted(
RaftRequest.Outbound request,
Errors expectedError
) throws ExecutionException, InterruptedException {
assertTrue(request.completion.isDone()); assertTrue(request.completion.isDone());
RaftResponse.Inbound response = request.completion.get(); RaftResponse.Inbound response = request.completion.get();
assertEquals(request.destinationId(), response.sourceId()); assertEquals(request.destination(), response.source());
assertEquals(request.correlationId, response.correlationId); assertEquals(request.correlationId(), response.correlationId());
assertEquals(request.data.apiKey(), response.data.apiKey()); assertEquals(request.data().apiKey(), response.data().apiKey());
assertEquals(expectedError, extractError(response.data)); assertEquals(expectedError, extractError(response.data()));
} }
private void sendAndAssertErrorResponse(ApiKeys apiKey, int destinationId, Errors error) throws ExecutionException, InterruptedException { private void sendAndAssertErrorResponse(
RaftRequest.Outbound request = sendTestRequest(apiKey, destinationId); ApiKeys apiKey,
Node destination,
Errors error
) throws ExecutionException, InterruptedException {
RaftRequest.Outbound request = sendTestRequest(apiKey, destination);
channel.pollOnce(); channel.pollOnce();
assertResponseCompleted(request, error); assertResponseCompleted(request, error);
} }
@ -252,12 +270,20 @@ public class KafkaNetworkChannelTest {
switch (key) { switch (key) {
case BEGIN_QUORUM_EPOCH: case BEGIN_QUORUM_EPOCH:
return BeginQuorumEpochRequest.singletonRequest(topicPartition, clusterId, leaderEpoch, leaderId); return BeginQuorumEpochRequest.singletonRequest(topicPartition, clusterId, leaderEpoch, leaderId);
case END_QUORUM_EPOCH: case END_QUORUM_EPOCH:
return EndQuorumEpochRequest.singletonRequest(topicPartition, clusterId, leaderId, leaderEpoch, return EndQuorumEpochRequest.singletonRequest(
Collections.singletonList(2)); topicPartition,
clusterId,
leaderId,
leaderEpoch,
Collections.singletonList(2)
);
case VOTE: case VOTE:
int lastEpoch = 4; int lastEpoch = 4;
return VoteRequest.singletonRequest(topicPartition, clusterId, leaderEpoch, leaderId, lastEpoch, 329); return VoteRequest.singletonRequest(topicPartition, clusterId, leaderEpoch, leaderId, lastEpoch, 329);
case FETCH: case FETCH:
FetchRequestData request = RaftUtil.singletonFetchRequest(topicPartition, topicId, fetchPartition -> { FetchRequestData request = RaftUtil.singletonFetchRequest(topicPartition, topicId, fetchPartition -> {
fetchPartition fetchPartition
@ -267,6 +293,21 @@ public class KafkaNetworkChannelTest {
}); });
request.setReplicaState(new FetchRequestData.ReplicaState().setReplicaId(1)); request.setReplicaState(new FetchRequestData.ReplicaState().setReplicaId(1));
return request; return request;
case FETCH_SNAPSHOT:
return FetchSnapshotRequest.singleton(
clusterId,
1,
topicPartition,
snapshotPartition -> snapshotPartition
.setCurrentLeaderEpoch(5)
.setSnapshotId(new FetchSnapshotRequestData.SnapshotId()
.setEpoch(4)
.setEndOffset(323)
)
.setPosition(10)
);
default: default:
throw new AssertionError("Unexpected api " + key); throw new AssertionError("Unexpected api " + key);
} }
@ -282,6 +323,8 @@ public class KafkaNetworkChannelTest {
return VoteResponse.singletonResponse(error, topicPartition, Errors.NONE, 1, 5, false); return VoteResponse.singletonResponse(error, topicPartition, Errors.NONE, 1, 5, false);
case FETCH: case FETCH:
return new FetchResponseData().setErrorCode(error.code()); return new FetchResponseData().setErrorCode(error.code());
case FETCH_SNAPSHOT:
return new FetchSnapshotResponseData().setErrorCode(error.code());
default: default:
throw new AssertionError("Unexpected api " + key); throw new AssertionError("Unexpected api " + key);
} }
@ -289,28 +332,36 @@ public class KafkaNetworkChannelTest {
private Errors extractError(ApiMessage response) { private Errors extractError(ApiMessage response) {
short code; short code;
if (response instanceof BeginQuorumEpochResponseData) if (response instanceof BeginQuorumEpochResponseData) {
code = ((BeginQuorumEpochResponseData) response).errorCode(); code = ((BeginQuorumEpochResponseData) response).errorCode();
else if (response instanceof EndQuorumEpochResponseData) } else if (response instanceof EndQuorumEpochResponseData) {
code = ((EndQuorumEpochResponseData) response).errorCode(); code = ((EndQuorumEpochResponseData) response).errorCode();
else if (response instanceof FetchResponseData) } else if (response instanceof FetchResponseData) {
code = ((FetchResponseData) response).errorCode(); code = ((FetchResponseData) response).errorCode();
else if (response instanceof VoteResponseData) } else if (response instanceof VoteResponseData) {
code = ((VoteResponseData) response).errorCode(); code = ((VoteResponseData) response).errorCode();
else } else if (response instanceof FetchSnapshotResponseData) {
code = ((FetchSnapshotResponseData) response).errorCode();
} else {
throw new IllegalArgumentException("Unexpected type for responseData: " + response); throw new IllegalArgumentException("Unexpected type for responseData: " + response);
}
return Errors.forCode(code); return Errors.forCode(code);
} }
private AbstractResponse buildResponse(ApiMessage responseData) { private AbstractResponse buildResponse(ApiMessage responseData) {
if (responseData instanceof VoteResponseData) if (responseData instanceof VoteResponseData) {
return new VoteResponse((VoteResponseData) responseData); return new VoteResponse((VoteResponseData) responseData);
if (responseData instanceof BeginQuorumEpochResponseData) } else if (responseData instanceof BeginQuorumEpochResponseData) {
return new BeginQuorumEpochResponse((BeginQuorumEpochResponseData) responseData); return new BeginQuorumEpochResponse((BeginQuorumEpochResponseData) responseData);
if (responseData instanceof EndQuorumEpochResponseData) } else if (responseData instanceof EndQuorumEpochResponseData) {
return new EndQuorumEpochResponse((EndQuorumEpochResponseData) responseData); return new EndQuorumEpochResponse((EndQuorumEpochResponseData) responseData);
if (responseData instanceof FetchResponseData) } else if (responseData instanceof FetchResponseData) {
return new FetchResponse((FetchResponseData) responseData); return new FetchResponse((FetchResponseData) responseData);
throw new IllegalArgumentException("Unexpected type for responseData: " + responseData); } else if (responseData instanceof FetchSnapshotResponseData) {
return new FetchSnapshotResponse((FetchSnapshotResponseData) responseData);
} else {
throw new IllegalArgumentException("Unexpected type for responseData: " + responseData);
}
} }
} }

View File

@ -153,8 +153,8 @@ final public class KafkaRaftClientSnapshotTest {
RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(); RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest();
context.assertFetchRequestData(fetchRequest, epoch, localLogEndOffset, snapshotId.epoch()); context.assertFetchRequestData(fetchRequest, epoch, localLogEndOffset, snapshotId.epoch());
context.deliverResponse( context.deliverResponse(
fetchRequest.correlationId, fetchRequest.correlationId(),
fetchRequest.destinationId(), fetchRequest.destination(),
context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, localLogEndOffset, Errors.NONE) context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, localLogEndOffset, Errors.NONE)
); );
@ -195,8 +195,8 @@ final public class KafkaRaftClientSnapshotTest {
RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(); RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest();
context.assertFetchRequestData(fetchRequest, epoch, localLogEndOffset, snapshotId.epoch()); context.assertFetchRequestData(fetchRequest, epoch, localLogEndOffset, snapshotId.epoch());
context.deliverResponse( context.deliverResponse(
fetchRequest.correlationId, fetchRequest.correlationId(),
fetchRequest.destinationId(), fetchRequest.destination(),
context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, localLogEndOffset, Errors.NONE) context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, localLogEndOffset, Errors.NONE)
); );
@ -1032,8 +1032,8 @@ final public class KafkaRaftClientSnapshotTest {
context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); context.assertFetchRequestData(fetchRequest, epoch, 0L, 0);
context.deliverResponse( context.deliverResponse(
fetchRequest.correlationId, fetchRequest.correlationId(),
fetchRequest.destinationId(), fetchRequest.destination(),
snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, invalidEpoch, 200L) snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, invalidEpoch, 200L)
); );
@ -1049,8 +1049,8 @@ final public class KafkaRaftClientSnapshotTest {
context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); context.assertFetchRequestData(fetchRequest, epoch, 0L, 0);
context.deliverResponse( context.deliverResponse(
fetchRequest.correlationId, fetchRequest.correlationId(),
fetchRequest.destinationId(), fetchRequest.destination(),
snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, invalidEndOffset, 200L) snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, invalidEndOffset, 200L)
); );
@ -1091,8 +1091,8 @@ final public class KafkaRaftClientSnapshotTest {
context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); context.assertFetchRequestData(fetchRequest, epoch, 0L, 0);
context.deliverResponse( context.deliverResponse(
fetchRequest.correlationId, fetchRequest.correlationId(),
fetchRequest.destinationId(), fetchRequest.destination(),
snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, snapshotId, 200L) snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, snapshotId, 200L)
); );
@ -1116,8 +1116,8 @@ final public class KafkaRaftClientSnapshotTest {
} }
context.deliverResponse( context.deliverResponse(
snapshotRequest.correlationId, snapshotRequest.correlationId(),
snapshotRequest.destinationId(), snapshotRequest.destination(),
fetchSnapshotResponse( fetchSnapshotResponse(
context.metadataPartition, context.metadataPartition,
epoch, epoch,
@ -1162,8 +1162,8 @@ final public class KafkaRaftClientSnapshotTest {
context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); context.assertFetchRequestData(fetchRequest, epoch, 0L, 0);
context.deliverResponse( context.deliverResponse(
fetchRequest.correlationId, fetchRequest.correlationId(),
fetchRequest.destinationId(), fetchRequest.destination(),
snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, snapshotId, 200L) snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, snapshotId, 200L)
); );
@ -1190,8 +1190,8 @@ final public class KafkaRaftClientSnapshotTest {
sendingBuffer.limit(sendingBuffer.limit() / 2); sendingBuffer.limit(sendingBuffer.limit() / 2);
context.deliverResponse( context.deliverResponse(
snapshotRequest.correlationId, snapshotRequest.correlationId(),
snapshotRequest.destinationId(), snapshotRequest.destination(),
fetchSnapshotResponse( fetchSnapshotResponse(
context.metadataPartition, context.metadataPartition,
epoch, epoch,
@ -1219,8 +1219,8 @@ final public class KafkaRaftClientSnapshotTest {
sendingBuffer.position(Math.toIntExact(request.position())); sendingBuffer.position(Math.toIntExact(request.position()));
context.deliverResponse( context.deliverResponse(
snapshotRequest.correlationId, snapshotRequest.correlationId(),
snapshotRequest.destinationId(), snapshotRequest.destination(),
fetchSnapshotResponse( fetchSnapshotResponse(
context.metadataPartition, context.metadataPartition,
epoch, epoch,
@ -1265,8 +1265,8 @@ final public class KafkaRaftClientSnapshotTest {
context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); context.assertFetchRequestData(fetchRequest, epoch, 0L, 0);
context.deliverResponse( context.deliverResponse(
fetchRequest.correlationId, fetchRequest.correlationId(),
fetchRequest.destinationId(), fetchRequest.destination(),
snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, snapshotId, 200L) snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, snapshotId, 200L)
); );
@ -1284,8 +1284,8 @@ final public class KafkaRaftClientSnapshotTest {
// Reply with a snapshot not found error // Reply with a snapshot not found error
context.deliverResponse( context.deliverResponse(
snapshotRequest.correlationId, snapshotRequest.correlationId(),
snapshotRequest.destinationId(), snapshotRequest.destination(),
FetchSnapshotResponse.singleton( FetchSnapshotResponse.singleton(
context.metadataPartition, context.metadataPartition,
responsePartitionSnapshot -> { responsePartitionSnapshot -> {
@ -1323,8 +1323,8 @@ final public class KafkaRaftClientSnapshotTest {
context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); context.assertFetchRequestData(fetchRequest, epoch, 0L, 0);
context.deliverResponse( context.deliverResponse(
fetchRequest.correlationId, fetchRequest.correlationId(),
fetchRequest.destinationId(), fetchRequest.destination(),
snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, firstLeaderId, snapshotId, 200L) snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, firstLeaderId, snapshotId, 200L)
); );
@ -1342,8 +1342,8 @@ final public class KafkaRaftClientSnapshotTest {
// Reply with new leader response // Reply with new leader response
context.deliverResponse( context.deliverResponse(
snapshotRequest.correlationId, snapshotRequest.correlationId(),
snapshotRequest.destinationId(), snapshotRequest.destination(),
FetchSnapshotResponse.singleton( FetchSnapshotResponse.singleton(
context.metadataPartition, context.metadataPartition,
responsePartitionSnapshot -> { responsePartitionSnapshot -> {
@ -1380,8 +1380,8 @@ final public class KafkaRaftClientSnapshotTest {
context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); context.assertFetchRequestData(fetchRequest, epoch, 0L, 0);
context.deliverResponse( context.deliverResponse(
fetchRequest.correlationId, fetchRequest.correlationId(),
fetchRequest.destinationId(), fetchRequest.destination(),
snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, snapshotId, 200L) snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, snapshotId, 200L)
); );
@ -1399,8 +1399,8 @@ final public class KafkaRaftClientSnapshotTest {
// Reply with new leader epoch // Reply with new leader epoch
context.deliverResponse( context.deliverResponse(
snapshotRequest.correlationId, snapshotRequest.correlationId(),
snapshotRequest.destinationId(), snapshotRequest.destination(),
FetchSnapshotResponse.singleton( FetchSnapshotResponse.singleton(
context.metadataPartition, context.metadataPartition,
responsePartitionSnapshot -> { responsePartitionSnapshot -> {
@ -1437,8 +1437,8 @@ final public class KafkaRaftClientSnapshotTest {
context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); context.assertFetchRequestData(fetchRequest, epoch, 0L, 0);
context.deliverResponse( context.deliverResponse(
fetchRequest.correlationId, fetchRequest.correlationId(),
fetchRequest.destinationId(), fetchRequest.destination(),
snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, snapshotId, 200L) snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, snapshotId, 200L)
); );
@ -1456,8 +1456,8 @@ final public class KafkaRaftClientSnapshotTest {
// Reply with unknown leader epoch // Reply with unknown leader epoch
context.deliverResponse( context.deliverResponse(
snapshotRequest.correlationId, snapshotRequest.correlationId(),
snapshotRequest.destinationId(), snapshotRequest.destination(),
FetchSnapshotResponse.singleton( FetchSnapshotResponse.singleton(
context.metadataPartition, context.metadataPartition,
responsePartitionSnapshot -> { responsePartitionSnapshot -> {
@ -1504,8 +1504,8 @@ final public class KafkaRaftClientSnapshotTest {
context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); context.assertFetchRequestData(fetchRequest, epoch, 0L, 0);
context.deliverResponse( context.deliverResponse(
fetchRequest.correlationId, fetchRequest.correlationId(),
fetchRequest.destinationId(), fetchRequest.destination(),
snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, snapshotId, 200L) snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, snapshotId, 200L)
); );
@ -1523,8 +1523,8 @@ final public class KafkaRaftClientSnapshotTest {
// Reply with an invalid snapshot id endOffset // Reply with an invalid snapshot id endOffset
context.deliverResponse( context.deliverResponse(
snapshotRequest.correlationId, snapshotRequest.correlationId(),
snapshotRequest.destinationId(), snapshotRequest.destination(),
FetchSnapshotResponse.singleton( FetchSnapshotResponse.singleton(
context.metadataPartition, context.metadataPartition,
responsePartitionSnapshot -> { responsePartitionSnapshot -> {
@ -1550,8 +1550,8 @@ final public class KafkaRaftClientSnapshotTest {
context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); context.assertFetchRequestData(fetchRequest, epoch, 0L, 0);
context.deliverResponse( context.deliverResponse(
fetchRequest.correlationId, fetchRequest.correlationId(),
fetchRequest.destinationId(), fetchRequest.destination(),
snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, snapshotId, 200L) snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, snapshotId, 200L)
); );
@ -1570,8 +1570,8 @@ final public class KafkaRaftClientSnapshotTest {
// Reply with an invalid snapshot id epoch // Reply with an invalid snapshot id epoch
context.deliverResponse( context.deliverResponse(
snapshotRequest.correlationId, snapshotRequest.correlationId(),
snapshotRequest.destinationId(), snapshotRequest.destination(),
FetchSnapshotResponse.singleton( FetchSnapshotResponse.singleton(
context.metadataPartition, context.metadataPartition,
responsePartitionSnapshot -> { responsePartitionSnapshot -> {
@ -1614,8 +1614,8 @@ final public class KafkaRaftClientSnapshotTest {
context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); context.assertFetchRequestData(fetchRequest, epoch, 0L, 0);
context.deliverResponse( context.deliverResponse(
fetchRequest.correlationId, fetchRequest.correlationId(),
fetchRequest.destinationId(), fetchRequest.destination(),
snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, snapshotId, 200L) snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, snapshotId, 200L)
); );
@ -1642,8 +1642,8 @@ final public class KafkaRaftClientSnapshotTest {
// Send the response late // Send the response late
context.deliverResponse( context.deliverResponse(
snapshotRequest.correlationId, snapshotRequest.correlationId(),
snapshotRequest.destinationId(), snapshotRequest.destination(),
FetchSnapshotResponse.singleton( FetchSnapshotResponse.singleton(
context.metadataPartition, context.metadataPartition,
responsePartitionSnapshot -> { responsePartitionSnapshot -> {
@ -1805,14 +1805,17 @@ final public class KafkaRaftClientSnapshotTest {
// Poll for our first fetch request // Poll for our first fetch request
context.pollUntilRequest(); context.pollUntilRequest();
RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(); RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest();
assertTrue(voters.contains(fetchRequest.destinationId())); assertTrue(voters.contains(fetchRequest.destination().id()));
context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); context.assertFetchRequestData(fetchRequest, epoch, 0L, 0);
// The response does not advance the high watermark // The response does not advance the high watermark
List<String> records1 = Arrays.asList("a", "b", "c"); List<String> records1 = Arrays.asList("a", "b", "c");
MemoryRecords batch1 = context.buildBatch(0L, 3, records1); MemoryRecords batch1 = context.buildBatch(0L, 3, records1);
context.deliverResponse(fetchRequest.correlationId, fetchRequest.destinationId(), context.deliverResponse(
context.fetchResponse(epoch, leaderId, batch1, 0L, Errors.NONE)); fetchRequest.correlationId(),
fetchRequest.destination(),
context.fetchResponse(epoch, leaderId, batch1, 0L, Errors.NONE)
);
context.client.poll(); context.client.poll();
// 2) The high watermark must be larger than or equal to the snapshotId's endOffset // 2) The high watermark must be larger than or equal to the snapshotId's endOffset
@ -1827,13 +1830,16 @@ final public class KafkaRaftClientSnapshotTest {
// The high watermark advances to be larger than log.endOffsetForEpoch(3), to test the case 3 // The high watermark advances to be larger than log.endOffsetForEpoch(3), to test the case 3
context.pollUntilRequest(); context.pollUntilRequest();
fetchRequest = context.assertSentFetchRequest(); fetchRequest = context.assertSentFetchRequest();
assertTrue(voters.contains(fetchRequest.destinationId())); assertTrue(voters.contains(fetchRequest.destination().id()));
context.assertFetchRequestData(fetchRequest, epoch, 3L, 3); context.assertFetchRequestData(fetchRequest, epoch, 3L, 3);
List<String> records2 = Arrays.asList("d", "e", "f"); List<String> records2 = Arrays.asList("d", "e", "f");
MemoryRecords batch2 = context.buildBatch(3L, 4, records2); MemoryRecords batch2 = context.buildBatch(3L, 4, records2);
context.deliverResponse(fetchRequest.correlationId, fetchRequest.destinationId(), context.deliverResponse(
context.fetchResponse(epoch, leaderId, batch2, 6L, Errors.NONE)); fetchRequest.correlationId(),
fetchRequest.destination(),
context.fetchResponse(epoch, leaderId, batch2, 6L, Errors.NONE)
);
context.client.poll(); context.client.poll();
assertEquals(6L, context.client.highWatermark().getAsLong()); assertEquals(6L, context.client.highWatermark().getAsLong());

View File

@ -51,6 +51,7 @@ import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.Mockito; import org.mockito.Mockito;
import java.io.IOException; import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
@ -62,6 +63,7 @@ import java.util.OptionalLong;
import java.util.Set; import java.util.Set;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeoutException; import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;
import static java.util.Collections.singletonList; import static java.util.Collections.singletonList;
import static org.apache.kafka.raft.RaftClientTestContext.Builder.DEFAULT_ELECTION_TIMEOUT_MS; import static org.apache.kafka.raft.RaftClientTestContext.Builder.DEFAULT_ELECTION_TIMEOUT_MS;
@ -274,8 +276,12 @@ public class KafkaRaftClientTest {
assertThrows(NotLeaderException.class, () -> context.client.scheduleAppend(epoch, Arrays.asList("a", "b"))); assertThrows(NotLeaderException.class, () -> context.client.scheduleAppend(epoch, Arrays.asList("a", "b")));
context.pollUntilRequest(); context.pollUntilRequest();
int correlationId = context.assertSentEndQuorumEpochRequest(epoch, 1); RaftRequest.Outbound request = context.assertSentEndQuorumEpochRequest(epoch, 1);
context.deliverResponse(correlationId, 1, context.endEpochResponse(epoch, OptionalInt.of(localId))); context.deliverResponse(
request.correlationId(),
request.destination(),
context.endEpochResponse(epoch, OptionalInt.of(localId))
);
context.client.poll(); context.client.poll();
context.time.sleep(context.electionTimeoutMs()); context.time.sleep(context.electionTimeoutMs());
@ -389,14 +395,17 @@ public class KafkaRaftClientTest {
// Respond to one of the requests so that we can verify that no additional // Respond to one of the requests so that we can verify that no additional
// request to this node is sent. // request to this node is sent.
RaftRequest.Outbound endEpochOutbound = requests.get(0); RaftRequest.Outbound endEpochOutbound = requests.get(0);
context.deliverResponse(endEpochOutbound.correlationId, endEpochOutbound.destinationId(), context.deliverResponse(
context.endEpochResponse(epoch, OptionalInt.of(localId))); endEpochOutbound.correlationId(),
endEpochOutbound.destination(),
context.endEpochResponse(epoch, OptionalInt.of(localId))
);
context.client.poll(); context.client.poll();
assertEquals(Collections.emptyList(), context.channel.drainSendQueue()); assertEquals(Collections.emptyList(), context.channel.drainSendQueue());
// Now sleep for the request timeout and verify that we get only one // Now sleep for the request timeout and verify that we get only one
// retried request from the voter that hasn't responded yet. // retried request from the voter that hasn't responded yet.
int nonRespondedId = requests.get(1).destinationId(); int nonRespondedId = requests.get(1).destination().id();
context.time.sleep(6000); context.time.sleep(6000);
context.pollUntilRequest(); context.pollUntilRequest();
List<RaftRequest.Outbound> retries = context.collectEndQuorumRequests( List<RaftRequest.Outbound> retries = context.collectEndQuorumRequests(
@ -573,7 +582,7 @@ public class KafkaRaftClientTest {
context.pollUntil(context.client.quorum()::isResigned); context.pollUntil(context.client.quorum()::isResigned);
context.pollUntilRequest(); context.pollUntilRequest();
int correlationId = context.assertSentEndQuorumEpochRequest(resignedEpoch, otherNodeId); RaftRequest.Outbound request = context.assertSentEndQuorumEpochRequest(resignedEpoch, otherNodeId);
EndQuorumEpochResponseData response = EndQuorumEpochResponse.singletonResponse( EndQuorumEpochResponseData response = EndQuorumEpochResponse.singletonResponse(
Errors.NONE, Errors.NONE,
@ -583,7 +592,7 @@ public class KafkaRaftClientTest {
localId localId
); );
context.deliverResponse(correlationId, otherNodeId, response); context.deliverResponse(request.correlationId(), request.destination(), response);
context.client.poll(); context.client.poll();
// We do not resend `EndQuorumRequest` once the other voter has acknowledged it. // We do not resend `EndQuorumRequest` once the other voter has acknowledged it.
@ -644,11 +653,14 @@ public class KafkaRaftClientTest {
context.pollUntilRequest(); context.pollUntilRequest();
RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(); RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest();
assertTrue(voters.contains(fetchRequest.destinationId())); assertTrue(voters.contains(fetchRequest.destination().id()));
context.assertFetchRequestData(fetchRequest, 0, 0L, 0); context.assertFetchRequestData(fetchRequest, 0, 0L, 0);
context.deliverResponse(fetchRequest.correlationId, fetchRequest.destinationId(), context.deliverResponse(
context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH)); fetchRequest.correlationId(),
fetchRequest.destination(),
context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH)
);
context.client.poll(); context.client.poll();
context.assertElectedLeader(epoch, leaderId); context.assertElectedLeader(epoch, leaderId);
@ -686,8 +698,12 @@ public class KafkaRaftClientTest {
context.pollUntilRequest(); context.pollUntilRequest();
context.assertVotedCandidate(1, localId); context.assertVotedCandidate(1, localId);
int correlationId = context.assertSentVoteRequest(1, 0, 0L, 1); RaftRequest.Outbound request = context.assertSentVoteRequest(1, 0, 0L, 1);
context.deliverResponse(correlationId, otherNodeId, context.voteResponse(true, Optional.empty(), 1)); context.deliverResponse(
request.correlationId(),
request.destination(),
context.voteResponse(true, Optional.empty(), 1)
);
// Become leader after receiving the vote // Become leader after receiving the vote
context.pollUntil(() -> context.log.endOffset().offset == 1L); context.pollUntil(() -> context.log.endOffset().offset == 1L);
@ -726,8 +742,12 @@ public class KafkaRaftClientTest {
context.pollUntilRequest(); context.pollUntilRequest();
context.assertVotedCandidate(1, localId); context.assertVotedCandidate(1, localId);
int correlationId = context.assertSentVoteRequest(1, 0, 0L, 2); RaftRequest.Outbound request = context.assertSentVoteRequest(1, 0, 0L, 2);
context.deliverResponse(correlationId, firstNodeId, context.voteResponse(true, Optional.empty(), 1)); context.deliverResponse(
request.correlationId(),
request.destination(),
context.voteResponse(true, Optional.empty(), 1)
);
// Become leader after receiving the vote // Become leader after receiving the vote
context.pollUntil(() -> context.log.endOffset().offset == 1L); context.pollUntil(() -> context.log.endOffset().offset == 1L);
@ -1102,19 +1122,27 @@ public class KafkaRaftClientTest {
context.pollUntilRequest(); context.pollUntilRequest();
context.assertVotedCandidate(epoch, localId); context.assertVotedCandidate(epoch, localId);
int correlationId = context.assertSentVoteRequest(epoch, 0, 0L, 1); RaftRequest.Outbound request = context.assertSentVoteRequest(epoch, 0, 0L, 1);
context.time.sleep(context.requestTimeoutMs()); context.time.sleep(context.requestTimeoutMs());
context.client.poll(); context.client.poll();
int retryCorrelationId = context.assertSentVoteRequest(epoch, 0, 0L, 1); RaftRequest.Outbound retryRequest = context.assertSentVoteRequest(epoch, 0, 0L, 1);
// We will ignore the timed out response if it arrives late // We will ignore the timed out response if it arrives late
context.deliverResponse(correlationId, otherNodeId, context.voteResponse(true, Optional.empty(), 1)); context.deliverResponse(
request.correlationId(),
request.destination(),
context.voteResponse(true, Optional.empty(), 1)
);
context.client.poll(); context.client.poll();
context.assertVotedCandidate(epoch, localId); context.assertVotedCandidate(epoch, localId);
// Become leader after receiving the retry response // Become leader after receiving the retry response
context.deliverResponse(retryCorrelationId, otherNodeId, context.voteResponse(true, Optional.empty(), 1)); context.deliverResponse(
retryRequest.correlationId(),
retryRequest.destination(),
context.voteResponse(true, Optional.empty(), 1)
);
context.client.poll(); context.client.poll();
context.assertElectedLeader(epoch, localId); context.assertElectedLeader(epoch, localId);
} }
@ -1338,8 +1366,12 @@ public class KafkaRaftClientTest {
context.assertVotedCandidate(epoch, localId); context.assertVotedCandidate(epoch, localId);
// Quorum size is two. If the other member rejects, then we need to schedule a revote. // Quorum size is two. If the other member rejects, then we need to schedule a revote.
int correlationId = context.assertSentVoteRequest(epoch, 0, 0L, 1); RaftRequest.Outbound request = context.assertSentVoteRequest(epoch, 0, 0L, 1);
context.deliverResponse(correlationId, otherNodeId, context.voteResponse(false, Optional.empty(), 1)); context.deliverResponse(
request.correlationId(),
request.destination(),
context.voteResponse(false, Optional.empty(), 1)
);
context.client.poll(); context.client.poll();
@ -1434,11 +1466,14 @@ public class KafkaRaftClientTest {
context.pollUntilRequest(); context.pollUntilRequest();
RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(); RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest();
assertTrue(voters.contains(fetchRequest.destinationId())); assertTrue(voters.contains(fetchRequest.destination().id()));
context.assertFetchRequestData(fetchRequest, 0, 0L, 0); context.assertFetchRequestData(fetchRequest, 0, 0L, 0);
context.deliverResponse(fetchRequest.correlationId, fetchRequest.destinationId(), context.deliverResponse(
context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH)); fetchRequest.correlationId(),
fetchRequest.destination(),
context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH)
);
context.client.poll(); context.client.poll();
context.assertElectedLeader(epoch, leaderId); context.assertElectedLeader(epoch, leaderId);
@ -1450,27 +1485,39 @@ public class KafkaRaftClientTest {
int leaderId = 1; int leaderId = 1;
int epoch = 5; int epoch = 5;
Set<Integer> voters = Utils.mkSet(leaderId); Set<Integer> voters = Utils.mkSet(leaderId);
List<InetSocketAddress> bootstrapServers = voters
.stream()
.map(RaftClientTestContext::mockAddress)
.collect(Collectors.toList());
RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters).build(); RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters)
.withBootstrapServers(bootstrapServers)
.build();
context.pollUntilRequest(); context.pollUntilRequest();
RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(); RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest();
assertTrue(voters.contains(fetchRequest.destinationId())); assertTrue(context.bootstrapIds.contains(fetchRequest.destination().id()));
context.assertFetchRequestData(fetchRequest, 0, 0L, 0); context.assertFetchRequestData(fetchRequest, 0, 0L, 0);
context.deliverResponse(fetchRequest.correlationId, fetchRequest.destinationId(), context.deliverResponse(
context.fetchResponse(-1, -1, MemoryRecords.EMPTY, -1, Errors.UNKNOWN_SERVER_ERROR)); fetchRequest.correlationId(),
fetchRequest.destination(),
context.fetchResponse(-1, -1, MemoryRecords.EMPTY, -1, Errors.UNKNOWN_SERVER_ERROR)
);
context.client.poll(); context.client.poll();
context.time.sleep(context.retryBackoffMs); context.time.sleep(context.retryBackoffMs);
context.pollUntilRequest(); context.pollUntilRequest();
fetchRequest = context.assertSentFetchRequest(); fetchRequest = context.assertSentFetchRequest();
assertTrue(voters.contains(fetchRequest.destinationId())); assertTrue(context.bootstrapIds.contains(fetchRequest.destination().id()));
context.assertFetchRequestData(fetchRequest, 0, 0L, 0); context.assertFetchRequestData(fetchRequest, 0, 0L, 0);
context.deliverResponse(fetchRequest.correlationId, fetchRequest.destinationId(), context.deliverResponse(
context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH)); fetchRequest.correlationId(),
fetchRequest.destination(),
context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH)
);
context.client.poll(); context.client.poll();
context.assertElectedLeader(epoch, leaderId); context.assertElectedLeader(epoch, leaderId);
@ -1483,27 +1530,169 @@ public class KafkaRaftClientTest {
int otherNodeId = 2; int otherNodeId = 2;
int epoch = 5; int epoch = 5;
Set<Integer> voters = Utils.mkSet(leaderId, otherNodeId); Set<Integer> voters = Utils.mkSet(leaderId, otherNodeId);
List<InetSocketAddress> bootstrapServers = voters
.stream()
.map(RaftClientTestContext::mockAddress)
.collect(Collectors.toList());
RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters).build(); RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters)
.withBootstrapServers(bootstrapServers)
.build();
context.pollUntilRequest(); context.pollUntilRequest();
RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(); RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest();
assertTrue(voters.contains(fetchRequest.destinationId())); assertTrue(context.bootstrapIds.contains(fetchRequest.destination().id()));
context.assertFetchRequestData(fetchRequest, 0, 0L, 0); context.assertFetchRequestData(fetchRequest, 0, 0L, 0);
context.deliverResponse(fetchRequest.correlationId, fetchRequest.destinationId(), context.deliverResponse(
context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH)); fetchRequest.correlationId(),
context.client.poll(); fetchRequest.destination(),
context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH)
);
context.client.poll();
context.assertElectedLeader(epoch, leaderId); context.assertElectedLeader(epoch, leaderId);
context.time.sleep(context.fetchTimeoutMs); context.time.sleep(context.fetchTimeoutMs);
context.pollUntilRequest(); context.pollUntilRequest();
fetchRequest = context.assertSentFetchRequest(); fetchRequest = context.assertSentFetchRequest();
assertTrue(voters.contains(fetchRequest.destinationId())); assertNotEquals(leaderId, fetchRequest.destination().id());
assertTrue(context.bootstrapIds.contains(fetchRequest.destination().id()));
context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); context.assertFetchRequestData(fetchRequest, epoch, 0L, 0);
} }
@Test
public void testObserverHandleRetryFetchtToBootstrapServer() throws Exception {
// This test tries to check that KRaft is able to handle a retrying Fetch request to
// a boostrap server after a Fetch request to the leader.
int localId = 0;
int leaderId = 1;
int otherNodeId = 2;
int epoch = 5;
Set<Integer> voters = Utils.mkSet(leaderId, otherNodeId);
List<InetSocketAddress> bootstrapServers = voters
.stream()
.map(RaftClientTestContext::mockAddress)
.collect(Collectors.toList());
RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters)
.withBootstrapServers(bootstrapServers)
.build();
// Expect a fetch request to one of the bootstrap servers
context.pollUntilRequest();
RaftRequest.Outbound discoveryFetchRequest = context.assertSentFetchRequest();
assertFalse(voters.contains(discoveryFetchRequest.destination().id()));
assertTrue(context.bootstrapIds.contains(discoveryFetchRequest.destination().id()));
context.assertFetchRequestData(discoveryFetchRequest, 0, 0L, 0);
// Send a response with the leader and epoch
context.deliverResponse(
discoveryFetchRequest.correlationId(),
discoveryFetchRequest.destination(),
context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH)
);
context.client.poll();
context.assertElectedLeader(epoch, leaderId);
// Expect a fetch request to the leader
context.pollUntilRequest();
RaftRequest.Outbound toLeaderFetchRequest = context.assertSentFetchRequest();
assertEquals(leaderId, toLeaderFetchRequest.destination().id());
context.assertFetchRequestData(toLeaderFetchRequest, epoch, 0L, 0);
context.time.sleep(context.requestTimeoutMs());
// After the fetch timeout expect a request to a bootstrap server
context.pollUntilRequest();
RaftRequest.Outbound retryToBootstrapServerFetchRequest = context.assertSentFetchRequest();
assertFalse(voters.contains(retryToBootstrapServerFetchRequest.destination().id()));
assertTrue(context.bootstrapIds.contains(retryToBootstrapServerFetchRequest.destination().id()));
context.assertFetchRequestData(retryToBootstrapServerFetchRequest, epoch, 0L, 0);
// Deliver the delayed responses from the leader
Records records = context.buildBatch(0L, 3, Arrays.asList("a", "b"));
context.deliverResponse(
toLeaderFetchRequest.correlationId(),
toLeaderFetchRequest.destination(),
context.fetchResponse(epoch, leaderId, records, 0L, Errors.NONE)
);
context.client.poll();
// Deliver the same delayed responses from the bootstrap server and assume that it is the leader
records = context.buildBatch(0L, 3, Arrays.asList("a", "b"));
context.deliverResponse(
retryToBootstrapServerFetchRequest.correlationId(),
retryToBootstrapServerFetchRequest.destination(),
context.fetchResponse(epoch, leaderId, records, 0L, Errors.NONE)
);
// This poll should not fail when handling the duplicate response from the bootstrap server
context.client.poll();
}
@Test
public void testObserverHandleRetryFetchToLeader() throws Exception {
// This test tries to check that KRaft is able to handle a retrying Fetch request to
// the leader after a Fetch request to the bootstrap server.
int localId = 0;
int leaderId = 1;
int otherNodeId = 2;
int epoch = 5;
Set<Integer> voters = Utils.mkSet(leaderId, otherNodeId);
List<InetSocketAddress> bootstrapServers = voters
.stream()
.map(RaftClientTestContext::mockAddress)
.collect(Collectors.toList());
RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters)
.withBootstrapServers(bootstrapServers)
.build();
// Expect a fetch request to one of the bootstrap servers
context.pollUntilRequest();
RaftRequest.Outbound discoveryFetchRequest = context.assertSentFetchRequest();
assertFalse(voters.contains(discoveryFetchRequest.destination().id()));
assertTrue(context.bootstrapIds.contains(discoveryFetchRequest.destination().id()));
context.assertFetchRequestData(discoveryFetchRequest, 0, 0L, 0);
// Send a response with the leader and epoch
context.deliverResponse(
discoveryFetchRequest.correlationId(),
discoveryFetchRequest.destination(),
context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH)
);
context.client.poll();
context.assertElectedLeader(epoch, leaderId);
// Expect a fetch request to the leader
context.pollUntilRequest();
RaftRequest.Outbound toLeaderFetchRequest = context.assertSentFetchRequest();
assertEquals(leaderId, toLeaderFetchRequest.destination().id());
context.assertFetchRequestData(toLeaderFetchRequest, epoch, 0L, 0);
context.time.sleep(context.requestTimeoutMs());
// After the fetch timeout expect a request to a bootstrap server
context.pollUntilRequest();
RaftRequest.Outbound retryToBootstrapServerFetchRequest = context.assertSentFetchRequest();
assertFalse(voters.contains(retryToBootstrapServerFetchRequest.destination().id()));
assertTrue(context.bootstrapIds.contains(retryToBootstrapServerFetchRequest.destination().id()));
context.assertFetchRequestData(retryToBootstrapServerFetchRequest, epoch, 0L, 0);
// At this point toLeaderFetchRequest has timed out but retryToBootstrapServerFetchRequest
// is still waiting for a response.
// Confirm that no new fetch request has been sent
context.client.poll();
assertFalse(context.channel.hasSentRequests());
}
@Test @Test
public void testInvalidFetchRequest() throws Exception { public void testInvalidFetchRequest() throws Exception {
int localId = 0; int localId = 0;
@ -1828,7 +2017,7 @@ public class KafkaRaftClientTest {
// Wait until we have a Fetch inflight to the leader // Wait until we have a Fetch inflight to the leader
context.pollUntilRequest(); context.pollUntilRequest();
int fetchCorrelationId = context.assertSentFetchRequest(epoch, 0L, 0); RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(epoch, 0L, 0);
// Now await the fetch timeout and become a candidate // Now await the fetch timeout and become a candidate
context.time.sleep(context.fetchTimeoutMs); context.time.sleep(context.fetchTimeoutMs);
@ -1837,8 +2026,11 @@ public class KafkaRaftClientTest {
// The fetch response from the old leader returns, but it should be ignored // The fetch response from the old leader returns, but it should be ignored
Records records = context.buildBatch(0L, 3, Arrays.asList("a", "b")); Records records = context.buildBatch(0L, 3, Arrays.asList("a", "b"));
context.deliverResponse(fetchCorrelationId, otherNodeId, context.deliverResponse(
context.fetchResponse(epoch, otherNodeId, records, 0L, Errors.NONE)); fetchRequest.correlationId(),
fetchRequest.destination(),
context.fetchResponse(epoch, otherNodeId, records, 0L, Errors.NONE)
);
context.client.poll(); context.client.poll();
assertEquals(0, context.log.endOffset().offset); assertEquals(0, context.log.endOffset().offset);
@ -1862,7 +2054,7 @@ public class KafkaRaftClientTest {
// Wait until we have a Fetch inflight to the leader // Wait until we have a Fetch inflight to the leader
context.pollUntilRequest(); context.pollUntilRequest();
int fetchCorrelationId = context.assertSentFetchRequest(epoch, 0L, 0); RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(epoch, 0L, 0);
// Now receive a BeginEpoch from `voter3` // Now receive a BeginEpoch from `voter3`
context.deliverRequest(context.beginEpochRequest(epoch + 1, voter3)); context.deliverRequest(context.beginEpochRequest(epoch + 1, voter3));
@ -1872,7 +2064,11 @@ public class KafkaRaftClientTest {
// The fetch response from the old leader returns, but it should be ignored // The fetch response from the old leader returns, but it should be ignored
Records records = context.buildBatch(0L, 3, Arrays.asList("a", "b")); Records records = context.buildBatch(0L, 3, Arrays.asList("a", "b"));
FetchResponseData response = context.fetchResponse(epoch, voter2, records, 0L, Errors.NONE); FetchResponseData response = context.fetchResponse(epoch, voter2, records, 0L, Errors.NONE);
context.deliverResponse(fetchCorrelationId, voter2, response); context.deliverResponse(
fetchRequest.correlationId(),
fetchRequest.destination(),
response
);
context.client.poll(); context.client.poll();
assertEquals(0, context.log.endOffset().offset); assertEquals(0, context.log.endOffset().offset);
@ -1909,10 +2105,18 @@ public class KafkaRaftClientTest {
// The vote requests now return and should be ignored // The vote requests now return and should be ignored
VoteResponseData voteResponse1 = context.voteResponse(false, Optional.empty(), epoch); VoteResponseData voteResponse1 = context.voteResponse(false, Optional.empty(), epoch);
context.deliverResponse(voteRequests.get(0).correlationId, voter2, voteResponse1); context.deliverResponse(
voteRequests.get(0).correlationId(),
voteRequests.get(0).destination(),
voteResponse1
);
VoteResponseData voteResponse2 = context.voteResponse(false, Optional.of(voter3), epoch); VoteResponseData voteResponse2 = context.voteResponse(false, Optional.of(voter3), epoch);
context.deliverResponse(voteRequests.get(1).correlationId, voter3, voteResponse2); context.deliverResponse(
voteRequests.get(1).correlationId(),
voteRequests.get(1).destination(),
voteResponse2
);
context.client.poll(); context.client.poll();
context.assertElectedLeader(epoch, voter3); context.assertElectedLeader(epoch, voter3);
@ -1925,31 +2129,43 @@ public class KafkaRaftClientTest {
int otherNodeId = 2; int otherNodeId = 2;
int epoch = 5; int epoch = 5;
Set<Integer> voters = Utils.mkSet(leaderId, otherNodeId); Set<Integer> voters = Utils.mkSet(leaderId, otherNodeId);
List<InetSocketAddress> bootstrapServers = voters
.stream()
.map(RaftClientTestContext::mockAddress)
.collect(Collectors.toList());
RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters).build(); RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters)
.withBootstrapServers(bootstrapServers)
.build();
context.discoverLeaderAsObserver(leaderId, epoch); context.discoverLeaderAsObserver(leaderId, epoch);
context.pollUntilRequest(); context.pollUntilRequest();
RaftRequest.Outbound fetchRequest1 = context.assertSentFetchRequest(); RaftRequest.Outbound fetchRequest1 = context.assertSentFetchRequest();
assertEquals(leaderId, fetchRequest1.destinationId()); assertEquals(leaderId, fetchRequest1.destination().id());
context.assertFetchRequestData(fetchRequest1, epoch, 0L, 0); context.assertFetchRequestData(fetchRequest1, epoch, 0L, 0);
context.deliverResponse(fetchRequest1.correlationId, fetchRequest1.destinationId(), context.deliverResponse(
context.fetchResponse(epoch, -1, MemoryRecords.EMPTY, -1, Errors.BROKER_NOT_AVAILABLE)); fetchRequest1.correlationId(),
fetchRequest1.destination(),
context.fetchResponse(epoch, -1, MemoryRecords.EMPTY, -1, Errors.BROKER_NOT_AVAILABLE)
);
context.pollUntilRequest(); context.pollUntilRequest();
// We should retry the Fetch against the other voter since the original // We should retry the Fetch against the other voter since the original
// voter connection will be backing off. // voter connection will be backing off.
RaftRequest.Outbound fetchRequest2 = context.assertSentFetchRequest(); RaftRequest.Outbound fetchRequest2 = context.assertSentFetchRequest();
assertNotEquals(leaderId, fetchRequest2.destinationId()); assertNotEquals(leaderId, fetchRequest2.destination().id());
assertTrue(voters.contains(fetchRequest2.destinationId())); assertTrue(context.bootstrapIds.contains(fetchRequest2.destination().id()));
context.assertFetchRequestData(fetchRequest2, epoch, 0L, 0); context.assertFetchRequestData(fetchRequest2, epoch, 0L, 0);
Errors error = fetchRequest2.destinationId() == leaderId ? Errors error = fetchRequest2.destination().id() == leaderId ?
Errors.NONE : Errors.NOT_LEADER_OR_FOLLOWER; Errors.NONE : Errors.NOT_LEADER_OR_FOLLOWER;
context.deliverResponse(fetchRequest2.correlationId, fetchRequest2.destinationId(), context.deliverResponse(
context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, error)); fetchRequest2.correlationId(),
fetchRequest2.destination(),
context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, error)
);
context.client.poll(); context.client.poll();
context.assertElectedLeader(epoch, leaderId); context.assertElectedLeader(epoch, leaderId);
@ -1962,14 +2178,20 @@ public class KafkaRaftClientTest {
int otherNodeId = 2; int otherNodeId = 2;
int epoch = 5; int epoch = 5;
Set<Integer> voters = Utils.mkSet(leaderId, otherNodeId); Set<Integer> voters = Utils.mkSet(leaderId, otherNodeId);
List<InetSocketAddress> bootstrapServers = voters
.stream()
.map(RaftClientTestContext::mockAddress)
.collect(Collectors.toList());
RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters).build(); RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters)
.withBootstrapServers(bootstrapServers)
.build();
context.discoverLeaderAsObserver(leaderId, epoch); context.discoverLeaderAsObserver(leaderId, epoch);
context.pollUntilRequest(); context.pollUntilRequest();
RaftRequest.Outbound fetchRequest1 = context.assertSentFetchRequest(); RaftRequest.Outbound fetchRequest1 = context.assertSentFetchRequest();
assertEquals(leaderId, fetchRequest1.destinationId()); assertEquals(leaderId, fetchRequest1.destination().id());
context.assertFetchRequestData(fetchRequest1, epoch, 0L, 0); context.assertFetchRequestData(fetchRequest1, epoch, 0L, 0);
context.time.sleep(context.requestTimeoutMs()); context.time.sleep(context.requestTimeoutMs());
@ -1978,12 +2200,15 @@ public class KafkaRaftClientTest {
// We should retry the Fetch against the other voter since the original // We should retry the Fetch against the other voter since the original
// voter connection will be backing off. // voter connection will be backing off.
RaftRequest.Outbound fetchRequest2 = context.assertSentFetchRequest(); RaftRequest.Outbound fetchRequest2 = context.assertSentFetchRequest();
assertNotEquals(leaderId, fetchRequest2.destinationId()); assertNotEquals(leaderId, fetchRequest2.destination().id());
assertTrue(voters.contains(fetchRequest2.destinationId())); assertTrue(context.bootstrapIds.contains(fetchRequest2.destination().id()));
context.assertFetchRequestData(fetchRequest2, epoch, 0L, 0); context.assertFetchRequestData(fetchRequest2, epoch, 0L, 0);
context.deliverResponse(fetchRequest2.correlationId, fetchRequest2.destinationId(), context.deliverResponse(
context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH)); fetchRequest2.correlationId(),
fetchRequest2.destination(),
context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH)
);
context.client.poll(); context.client.poll();
context.assertElectedLeader(epoch, leaderId); context.assertElectedLeader(epoch, leaderId);
@ -2273,10 +2498,14 @@ public class KafkaRaftClientTest {
context.pollUntilRequest(); context.pollUntilRequest();
int fetchQuorumCorrelationId = context.assertSentFetchRequest(epoch, 0L, 0); RaftRequest.Outbound fetchQuorumRequest = context.assertSentFetchRequest(epoch, 0L, 0);
Records records = context.buildBatch(0L, 3, Arrays.asList("a", "b")); Records records = context.buildBatch(0L, 3, Arrays.asList("a", "b"));
FetchResponseData response = context.fetchResponse(epoch, otherNodeId, records, 0L, Errors.NONE); FetchResponseData response = context.fetchResponse(epoch, otherNodeId, records, 0L, Errors.NONE);
context.deliverResponse(fetchQuorumCorrelationId, otherNodeId, response); context.deliverResponse(
fetchQuorumRequest.correlationId(),
fetchQuorumRequest.destination(),
response
);
context.client.poll(); context.client.poll();
assertEquals(2L, context.log.endOffset().offset); assertEquals(2L, context.log.endOffset().offset);
@ -2297,10 +2526,19 @@ public class KafkaRaftClientTest {
// Receive an empty fetch response // Receive an empty fetch response
context.pollUntilRequest(); context.pollUntilRequest();
int fetchQuorumCorrelationId = context.assertSentFetchRequest(epoch, 0L, 0); RaftRequest.Outbound fetchQuorumRequest = context.assertSentFetchRequest(epoch, 0L, 0);
FetchResponseData fetchResponse = context.fetchResponse(epoch, otherNodeId, FetchResponseData fetchResponse = context.fetchResponse(
MemoryRecords.EMPTY, 0L, Errors.NONE); epoch,
context.deliverResponse(fetchQuorumCorrelationId, otherNodeId, fetchResponse); otherNodeId,
MemoryRecords.EMPTY,
0L,
Errors.NONE
);
context.deliverResponse(
fetchQuorumRequest.correlationId(),
fetchQuorumRequest.destination(),
fetchResponse
);
context.client.poll(); context.client.poll();
assertEquals(0L, context.log.endOffset().offset); assertEquals(0L, context.log.endOffset().offset);
assertEquals(OptionalLong.of(0L), context.client.highWatermark()); assertEquals(OptionalLong.of(0L), context.client.highWatermark());
@ -2308,20 +2546,32 @@ public class KafkaRaftClientTest {
// Receive some records in the next poll, but do not advance high watermark // Receive some records in the next poll, but do not advance high watermark
context.pollUntilRequest(); context.pollUntilRequest();
Records records = context.buildBatch(0L, epoch, Arrays.asList("a", "b")); Records records = context.buildBatch(0L, epoch, Arrays.asList("a", "b"));
fetchQuorumCorrelationId = context.assertSentFetchRequest(epoch, 0L, 0); fetchQuorumRequest = context.assertSentFetchRequest(epoch, 0L, 0);
fetchResponse = context.fetchResponse(epoch, otherNodeId, fetchResponse = context.fetchResponse(epoch, otherNodeId, records, 0L, Errors.NONE);
records, 0L, Errors.NONE); context.deliverResponse(
context.deliverResponse(fetchQuorumCorrelationId, otherNodeId, fetchResponse); fetchQuorumRequest.correlationId(),
fetchQuorumRequest.destination(),
fetchResponse
);
context.client.poll(); context.client.poll();
assertEquals(2L, context.log.endOffset().offset); assertEquals(2L, context.log.endOffset().offset);
assertEquals(OptionalLong.of(0L), context.client.highWatermark()); assertEquals(OptionalLong.of(0L), context.client.highWatermark());
// The next fetch response is empty, but should still advance the high watermark // The next fetch response is empty, but should still advance the high watermark
context.pollUntilRequest(); context.pollUntilRequest();
fetchQuorumCorrelationId = context.assertSentFetchRequest(epoch, 2L, epoch); fetchQuorumRequest = context.assertSentFetchRequest(epoch, 2L, epoch);
fetchResponse = context.fetchResponse(epoch, otherNodeId, fetchResponse = context.fetchResponse(
MemoryRecords.EMPTY, 2L, Errors.NONE); epoch,
context.deliverResponse(fetchQuorumCorrelationId, otherNodeId, fetchResponse); otherNodeId,
MemoryRecords.EMPTY,
2L,
Errors.NONE
);
context.deliverResponse(
fetchQuorumRequest.correlationId(),
fetchQuorumRequest.destination(),
fetchResponse
);
context.client.poll(); context.client.poll();
assertEquals(2L, context.log.endOffset().offset); assertEquals(2L, context.log.endOffset().offset);
assertEquals(OptionalLong.of(2L), context.client.highWatermark()); assertEquals(OptionalLong.of(2L), context.client.highWatermark());
@ -2454,11 +2704,11 @@ public class KafkaRaftClientTest {
context.pollUntilRequest(); context.pollUntilRequest();
int correlationId = context.assertSentFetchRequest(epoch, 3L, lastEpoch); RaftRequest.Outbound request = context.assertSentFetchRequest(epoch, 3L, lastEpoch);
FetchResponseData response = context.divergingFetchResponse(epoch, otherNodeId, 2L, FetchResponseData response = context.divergingFetchResponse(epoch, otherNodeId, 2L,
lastEpoch, 1L); lastEpoch, 1L);
context.deliverResponse(correlationId, otherNodeId, response); context.deliverResponse(request.correlationId(), request.destination(), response);
// Poll again to complete truncation // Poll again to complete truncation
context.client.poll(); context.client.poll();
@ -2530,10 +2780,14 @@ public class KafkaRaftClientTest {
context.pollUntilRequest(); context.pollUntilRequest();
int correlationId = context.assertSentFetchRequest(epoch, 0, 0); RaftRequest.Outbound request = context.assertSentFetchRequest(epoch, 0, 0);
FetchResponseData response = new FetchResponseData() FetchResponseData response = new FetchResponseData()
.setErrorCode(Errors.CLUSTER_AUTHORIZATION_FAILED.code()); .setErrorCode(Errors.CLUSTER_AUTHORIZATION_FAILED.code());
context.deliverResponse(correlationId, otherNodeId, response); context.deliverResponse(
request.correlationId(),
request.destination(),
response
);
assertThrows(ClusterAuthorizationException.class, context.client::poll); assertThrows(ClusterAuthorizationException.class, context.client::poll);
} }
@ -2553,11 +2807,11 @@ public class KafkaRaftClientTest {
context.expectAndGrantVotes(epoch); context.expectAndGrantVotes(epoch);
context.pollUntilRequest(); context.pollUntilRequest();
int correlationId = context.assertSentBeginQuorumEpochRequest(epoch, 1); RaftRequest.Outbound request = context.assertSentBeginQuorumEpochRequest(epoch, 1);
BeginQuorumEpochResponseData response = new BeginQuorumEpochResponseData() BeginQuorumEpochResponseData response = new BeginQuorumEpochResponseData()
.setErrorCode(Errors.CLUSTER_AUTHORIZATION_FAILED.code()); .setErrorCode(Errors.CLUSTER_AUTHORIZATION_FAILED.code());
context.deliverResponse(correlationId, otherNodeId, response); context.deliverResponse(request.correlationId(), request.destination(), response);
assertThrows(ClusterAuthorizationException.class, context.client::poll); assertThrows(ClusterAuthorizationException.class, context.client::poll);
} }
@ -2577,11 +2831,11 @@ public class KafkaRaftClientTest {
context.pollUntilRequest(); context.pollUntilRequest();
context.assertVotedCandidate(epoch, localId); context.assertVotedCandidate(epoch, localId);
int correlationId = context.assertSentVoteRequest(epoch, 0, 0L, 1); RaftRequest.Outbound request = context.assertSentVoteRequest(epoch, 0, 0L, 1);
VoteResponseData response = new VoteResponseData() VoteResponseData response = new VoteResponseData()
.setErrorCode(Errors.CLUSTER_AUTHORIZATION_FAILED.code()); .setErrorCode(Errors.CLUSTER_AUTHORIZATION_FAILED.code());
context.deliverResponse(correlationId, otherNodeId, response); context.deliverResponse(request.correlationId(), request.destination(), response);
assertThrows(ClusterAuthorizationException.class, context.client::poll); assertThrows(ClusterAuthorizationException.class, context.client::poll);
} }
@ -2597,11 +2851,11 @@ public class KafkaRaftClientTest {
context.client.shutdown(5000); context.client.shutdown(5000);
context.pollUntilRequest(); context.pollUntilRequest();
int correlationId = context.assertSentEndQuorumEpochRequest(epoch, otherNodeId); RaftRequest.Outbound request = context.assertSentEndQuorumEpochRequest(epoch, otherNodeId);
EndQuorumEpochResponseData response = new EndQuorumEpochResponseData() EndQuorumEpochResponseData response = new EndQuorumEpochResponseData()
.setErrorCode(Errors.CLUSTER_AUTHORIZATION_FAILED.code()); .setErrorCode(Errors.CLUSTER_AUTHORIZATION_FAILED.code());
context.deliverResponse(correlationId, otherNodeId, response); context.deliverResponse(request.correlationId(), request.destination(), response);
assertThrows(ClusterAuthorizationException.class, context.client::poll); assertThrows(ClusterAuthorizationException.class, context.client::poll);
} }
@ -2810,14 +3064,17 @@ public class KafkaRaftClientTest {
// Poll for our first fetch request // Poll for our first fetch request
context.pollUntilRequest(); context.pollUntilRequest();
RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(); RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest();
assertTrue(voters.contains(fetchRequest.destinationId())); assertTrue(voters.contains(fetchRequest.destination().id()));
context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); context.assertFetchRequestData(fetchRequest, epoch, 0L, 0);
// The response does not advance the high watermark // The response does not advance the high watermark
List<String> records1 = Arrays.asList("a", "b", "c"); List<String> records1 = Arrays.asList("a", "b", "c");
MemoryRecords batch1 = context.buildBatch(0L, 3, records1); MemoryRecords batch1 = context.buildBatch(0L, 3, records1);
context.deliverResponse(fetchRequest.correlationId, fetchRequest.destinationId(), context.deliverResponse(
context.fetchResponse(epoch, otherNodeId, batch1, 0L, Errors.NONE)); fetchRequest.correlationId(),
fetchRequest.destination(),
context.fetchResponse(epoch, otherNodeId, batch1, 0L, Errors.NONE)
);
context.client.poll(); context.client.poll();
// The listener should not have seen any data // The listener should not have seen any data
@ -2828,14 +3085,17 @@ public class KafkaRaftClientTest {
// Now look for the next fetch request // Now look for the next fetch request
context.pollUntilRequest(); context.pollUntilRequest();
fetchRequest = context.assertSentFetchRequest(); fetchRequest = context.assertSentFetchRequest();
assertTrue(voters.contains(fetchRequest.destinationId())); assertTrue(voters.contains(fetchRequest.destination().id()));
context.assertFetchRequestData(fetchRequest, epoch, 3L, 3); context.assertFetchRequestData(fetchRequest, epoch, 3L, 3);
// The high watermark advances to include the first batch we fetched // The high watermark advances to include the first batch we fetched
List<String> records2 = Arrays.asList("d", "e", "f"); List<String> records2 = Arrays.asList("d", "e", "f");
MemoryRecords batch2 = context.buildBatch(3L, 3, records2); MemoryRecords batch2 = context.buildBatch(3L, 3, records2);
context.deliverResponse(fetchRequest.correlationId, fetchRequest.destinationId(), context.deliverResponse(
context.fetchResponse(epoch, otherNodeId, batch2, 3L, Errors.NONE)); fetchRequest.correlationId(),
fetchRequest.destination(),
context.fetchResponse(epoch, otherNodeId, batch2, 3L, Errors.NONE)
);
context.client.poll(); context.client.poll();
// The listener should have seen only the data from the first batch // The listener should have seen only the data from the first batch
@ -3012,21 +3272,30 @@ public class KafkaRaftClientTest {
// This is designed for tooling/debugging use cases. // This is designed for tooling/debugging use cases.
Set<Integer> voters = Utils.mkSet(1, 2); Set<Integer> voters = Utils.mkSet(1, 2);
List<InetSocketAddress> bootstrapServers = voters
.stream()
.map(RaftClientTestContext::mockAddress)
.collect(Collectors.toList());
RaftClientTestContext context = new RaftClientTestContext.Builder(OptionalInt.empty(), voters) RaftClientTestContext context = new RaftClientTestContext.Builder(OptionalInt.empty(), voters)
.withBootstrapServers(bootstrapServers)
.build(); .build();
// First fetch discovers the current leader and epoch // First fetch discovers the current leader and epoch
context.pollUntilRequest(); context.pollUntilRequest();
RaftRequest.Outbound fetchRequest1 = context.assertSentFetchRequest(); RaftRequest.Outbound fetchRequest1 = context.assertSentFetchRequest();
assertTrue(voters.contains(fetchRequest1.destinationId())); assertTrue(context.bootstrapIds.contains(fetchRequest1.destination().id()));
context.assertFetchRequestData(fetchRequest1, 0, 0L, 0); context.assertFetchRequestData(fetchRequest1, 0, 0L, 0);
int leaderEpoch = 5; int leaderEpoch = 5;
int leaderId = 1; int leaderId = 1;
context.deliverResponse(fetchRequest1.correlationId, fetchRequest1.destinationId(), context.deliverResponse(
context.fetchResponse(5, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH)); fetchRequest1.correlationId(),
fetchRequest1.destination(),
context.fetchResponse(5, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH)
);
context.client.poll(); context.client.poll();
context.assertElectedLeader(leaderEpoch, leaderId); context.assertElectedLeader(leaderEpoch, leaderId);
@ -3034,13 +3303,16 @@ public class KafkaRaftClientTest {
context.pollUntilRequest(); context.pollUntilRequest();
RaftRequest.Outbound fetchRequest2 = context.assertSentFetchRequest(); RaftRequest.Outbound fetchRequest2 = context.assertSentFetchRequest();
assertEquals(leaderId, fetchRequest2.destinationId()); assertEquals(leaderId, fetchRequest2.destination().id());
context.assertFetchRequestData(fetchRequest2, leaderEpoch, 0L, 0); context.assertFetchRequestData(fetchRequest2, leaderEpoch, 0L, 0);
List<String> records = Arrays.asList("a", "b", "c"); List<String> records = Arrays.asList("a", "b", "c");
MemoryRecords batch1 = context.buildBatch(0L, 3, records); MemoryRecords batch1 = context.buildBatch(0L, 3, records);
context.deliverResponse(fetchRequest2.correlationId, fetchRequest2.destinationId(), context.deliverResponse(
context.fetchResponse(leaderEpoch, leaderId, batch1, 0L, Errors.NONE)); fetchRequest2.correlationId(),
fetchRequest2.destination(),
context.fetchResponse(leaderEpoch, leaderId, batch1, 0L, Errors.NONE)
);
context.client.poll(); context.client.poll();
assertEquals(3L, context.log.endOffset().offset); assertEquals(3L, context.log.endOffset().offset);
assertEquals(3, context.log.lastFetchedEpoch()); assertEquals(3, context.log.lastFetchedEpoch());

View File

@ -16,31 +16,29 @@
*/ */
package org.apache.kafka.raft; package org.apache.kafka.raft;
import org.apache.kafka.common.network.ListenerName;
import org.apache.kafka.common.protocol.ApiKeys; import org.apache.kafka.common.protocol.ApiKeys;
import java.net.InetSocketAddress;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
import java.util.Iterator; import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
public class MockNetworkChannel implements NetworkChannel { public class MockNetworkChannel implements NetworkChannel {
private final AtomicInteger correlationIdCounter; private final AtomicInteger correlationIdCounter;
private final Set<Integer> nodeCache;
private final List<RaftRequest.Outbound> sendQueue = new ArrayList<>(); private final List<RaftRequest.Outbound> sendQueue = new ArrayList<>();
private final Map<Integer, RaftRequest.Outbound> awaitingResponse = new HashMap<>(); private final Map<Integer, RaftRequest.Outbound> awaitingResponse = new HashMap<>();
private final ListenerName listenerName = ListenerName.normalised("CONTROLLER");
public MockNetworkChannel(AtomicInteger correlationIdCounter, Set<Integer> destinationIds) { public MockNetworkChannel(AtomicInteger correlationIdCounter) {
this.correlationIdCounter = correlationIdCounter; this.correlationIdCounter = correlationIdCounter;
this.nodeCache = destinationIds;
} }
public MockNetworkChannel(Set<Integer> destinationIds) { public MockNetworkChannel() {
this(new AtomicInteger(0), destinationIds); this(new AtomicInteger(0));
} }
@Override @Override
@ -50,16 +48,12 @@ public class MockNetworkChannel implements NetworkChannel {
@Override @Override
public void send(RaftRequest.Outbound request) { public void send(RaftRequest.Outbound request) {
if (!nodeCache.contains(request.destinationId())) {
throw new IllegalArgumentException("Attempted to send to destination " +
request.destinationId() + ", but its address is not yet known");
}
sendQueue.add(request); sendQueue.add(request);
} }
@Override @Override
public void updateEndpoint(int id, InetSocketAddress address) { public ListenerName listenerName() {
// empty return listenerName;
} }
public List<RaftRequest.Outbound> drainSendQueue() { public List<RaftRequest.Outbound> drainSendQueue() {
@ -72,7 +66,7 @@ public class MockNetworkChannel implements NetworkChannel {
while (iterator.hasNext()) { while (iterator.hasNext()) {
RaftRequest.Outbound request = iterator.next(); RaftRequest.Outbound request = iterator.next();
if (!apiKeyFilter.isPresent() || request.data().apiKey() == apiKeyFilter.get().id) { if (!apiKeyFilter.isPresent() || request.data().apiKey() == apiKeyFilter.get().id) {
awaitingResponse.put(request.correlationId, request); awaitingResponse.put(request.correlationId(), request);
requests.add(request); requests.add(request);
iterator.remove(); iterator.remove();
} }
@ -80,17 +74,15 @@ public class MockNetworkChannel implements NetworkChannel {
return requests; return requests;
} }
public boolean hasSentRequests() { public boolean hasSentRequests() {
return !sendQueue.isEmpty(); return !sendQueue.isEmpty();
} }
public void mockReceive(RaftResponse.Inbound response) { public void mockReceive(RaftResponse.Inbound response) {
RaftRequest.Outbound request = awaitingResponse.get(response.correlationId); RaftRequest.Outbound request = awaitingResponse.get(response.correlationId());
if (request == null) { if (request == null) {
throw new IllegalStateException("Received response for a request which is not being awaited"); throw new IllegalStateException("Received response for a request which is not being awaited");
} }
request.completion.complete(response); request.completion.complete(response);
} }
} }

File diff suppressed because it is too large Load Diff

View File

@ -16,6 +16,7 @@
*/ */
package org.apache.kafka.raft; package org.apache.kafka.raft;
import org.apache.kafka.common.Node;
import org.apache.kafka.common.TopicPartition; import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.Uuid; import org.apache.kafka.common.Uuid;
import org.apache.kafka.common.compress.Compression; import org.apache.kafka.common.compress.Compression;
@ -79,6 +80,7 @@ import java.util.Set;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.function.Function; import java.util.function.Function;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static org.apache.kafka.raft.LeaderState.CHECK_QUORUM_TIMEOUT_FACTOR; import static org.apache.kafka.raft.LeaderState.CHECK_QUORUM_TIMEOUT_FACTOR;
import static org.apache.kafka.raft.RaftUtil.hasValidTopicPartition; import static org.apache.kafka.raft.RaftUtil.hasValidTopicPartition;
@ -114,6 +116,7 @@ public final class RaftClientTestContext {
final MockTime time; final MockTime time;
final MockListener listener; final MockListener listener;
final Set<Integer> voters; final Set<Integer> voters;
final Set<Integer> bootstrapIds;
private final List<RaftResponse.Outbound> sentResponses = new ArrayList<>(); private final List<RaftResponse.Outbound> sentResponses = new ArrayList<>();
@ -146,6 +149,7 @@ public final class RaftClientTestContext {
private int electionTimeoutMs = DEFAULT_ELECTION_TIMEOUT_MS; private int electionTimeoutMs = DEFAULT_ELECTION_TIMEOUT_MS;
private int appendLingerMs = DEFAULT_APPEND_LINGER_MS; private int appendLingerMs = DEFAULT_APPEND_LINGER_MS;
private MemoryPool memoryPool = MemoryPool.NONE; private MemoryPool memoryPool = MemoryPool.NONE;
private List<InetSocketAddress> bootstrapServers = Collections.emptyList();
public Builder(int localId, Set<Integer> voters) { public Builder(int localId, Set<Integer> voters) {
this(OptionalInt.of(localId), voters); this(OptionalInt.of(localId), voters);
@ -240,9 +244,14 @@ public final class RaftClientTestContext {
return this; return this;
} }
Builder withBootstrapServers(List<InetSocketAddress> bootstrapServers) {
this.bootstrapServers = bootstrapServers;
return this;
}
public RaftClientTestContext build() throws IOException { public RaftClientTestContext build() throws IOException {
Metrics metrics = new Metrics(time); Metrics metrics = new Metrics(time);
MockNetworkChannel channel = new MockNetworkChannel(voters); MockNetworkChannel channel = new MockNetworkChannel();
MockListener listener = new MockListener(localId); MockListener listener = new MockListener(localId);
Map<Integer, InetSocketAddress> voterAddressMap = voters Map<Integer, InetSocketAddress> voterAddressMap = voters
.stream() .stream()
@ -269,6 +278,7 @@ public final class RaftClientTestContext {
new MockExpirationService(time), new MockExpirationService(time),
FETCH_MAX_WAIT_MS, FETCH_MAX_WAIT_MS,
clusterId.toString(), clusterId.toString(),
bootstrapServers,
logContext, logContext,
random, random,
quorumConfig quorumConfig
@ -277,7 +287,6 @@ public final class RaftClientTestContext {
client.register(listener); client.register(listener);
client.initialize( client.initialize(
voterAddressMap, voterAddressMap,
"CONTROLLER",
quorumStateStore, quorumStateStore,
metrics metrics
); );
@ -292,6 +301,11 @@ public final class RaftClientTestContext {
time, time,
quorumStateStore, quorumStateStore,
voters, voters,
IntStream
.iterate(-2, id -> id - 1)
.limit(bootstrapServers.size())
.boxed()
.collect(Collectors.toSet()),
metrics, metrics,
listener listener
); );
@ -314,6 +328,7 @@ public final class RaftClientTestContext {
MockTime time, MockTime time,
QuorumStateStore quorumStateStore, QuorumStateStore quorumStateStore,
Set<Integer> voters, Set<Integer> voters,
Set<Integer> bootstrapIds,
Metrics metrics, Metrics metrics,
MockListener listener MockListener listener
) { ) {
@ -326,6 +341,7 @@ public final class RaftClientTestContext {
this.time = time; this.time = time;
this.quorumStateStore = quorumStateStore; this.quorumStateStore = quorumStateStore;
this.voters = voters; this.voters = voters;
this.bootstrapIds = bootstrapIds;
this.metrics = metrics; this.metrics = metrics;
this.listener = listener; this.listener = listener;
} }
@ -417,7 +433,7 @@ public final class RaftClientTestContext {
for (RaftRequest.Outbound request : voteRequests) { for (RaftRequest.Outbound request : voteRequests) {
VoteResponseData voteResponse = voteResponse(true, Optional.empty(), epoch); VoteResponseData voteResponse = voteResponse(true, Optional.empty(), epoch);
deliverResponse(request.correlationId, request.destinationId(), voteResponse); deliverResponse(request.correlationId(), request.destination(), voteResponse);
} }
client.poll(); client.poll();
@ -432,7 +448,7 @@ public final class RaftClientTestContext {
pollUntilRequest(); pollUntilRequest();
for (RaftRequest.Outbound request : collectBeginEpochRequests(epoch)) { for (RaftRequest.Outbound request : collectBeginEpochRequests(epoch)) {
BeginQuorumEpochResponseData beginEpochResponse = beginEpochResponse(epoch, localIdOrThrow()); BeginQuorumEpochResponseData beginEpochResponse = beginEpochResponse(epoch, localIdOrThrow());
deliverResponse(request.correlationId, request.destinationId(), beginEpochResponse); deliverResponse(request.correlationId(), request.destination(), beginEpochResponse);
} }
client.poll(); client.poll();
} }
@ -519,10 +535,10 @@ public final class RaftClientTestContext {
assertEquals(expectedResponse, response); assertEquals(expectedResponse, response);
} }
int assertSentVoteRequest(int epoch, int lastEpoch, long lastEpochOffset, int numVoteReceivers) { RaftRequest.Outbound assertSentVoteRequest(int epoch, int lastEpoch, long lastEpochOffset, int numVoteReceivers) {
List<RaftRequest.Outbound> voteRequests = collectVoteRequests(epoch, lastEpoch, lastEpochOffset); List<RaftRequest.Outbound> voteRequests = collectVoteRequests(epoch, lastEpoch, lastEpochOffset);
assertEquals(numVoteReceivers, voteRequests.size()); assertEquals(numVoteReceivers, voteRequests.size());
return voteRequests.iterator().next().correlationId(); return voteRequests.iterator().next();
} }
void assertSentVoteResponse(Errors error) { void assertSentVoteResponse(Errors error) {
@ -590,14 +606,14 @@ public final class RaftClientTestContext {
client.handle(inboundRequest); client.handle(inboundRequest);
} }
void deliverResponse(int correlationId, int sourceId, ApiMessage response) { void deliverResponse(int correlationId, Node source, ApiMessage response) {
channel.mockReceive(new RaftResponse.Inbound(correlationId, response, sourceId)); channel.mockReceive(new RaftResponse.Inbound(correlationId, response, source));
} }
int assertSentBeginQuorumEpochRequest(int epoch, int numBeginEpochRequests) { RaftRequest.Outbound assertSentBeginQuorumEpochRequest(int epoch, int numBeginEpochRequests) {
List<RaftRequest.Outbound> requests = collectBeginEpochRequests(epoch); List<RaftRequest.Outbound> requests = collectBeginEpochRequests(epoch);
assertEquals(numBeginEpochRequests, requests.size()); assertEquals(numBeginEpochRequests, requests.size());
return requests.get(0).correlationId; return requests.get(0);
} }
private List<RaftResponse.Outbound> drainSentResponses( private List<RaftResponse.Outbound> drainSentResponses(
@ -607,7 +623,7 @@ public final class RaftClientTestContext {
Iterator<RaftResponse.Outbound> iterator = sentResponses.iterator(); Iterator<RaftResponse.Outbound> iterator = sentResponses.iterator();
while (iterator.hasNext()) { while (iterator.hasNext()) {
RaftResponse.Outbound response = iterator.next(); RaftResponse.Outbound response = iterator.next();
if (response.data.apiKey() == apiKey.id) { if (response.data().apiKey() == apiKey.id) {
res.add(response); res.add(response);
iterator.remove(); iterator.remove();
} }
@ -646,11 +662,14 @@ public final class RaftClientTestContext {
assertEquals(partitionError, Errors.forCode(partitionResponse.errorCode())); assertEquals(partitionError, Errors.forCode(partitionResponse.errorCode()));
} }
int assertSentEndQuorumEpochRequest(int epoch, int destinationId) { RaftRequest.Outbound assertSentEndQuorumEpochRequest(int epoch, int destinationId) {
List<RaftRequest.Outbound> endQuorumRequests = collectEndQuorumRequests( List<RaftRequest.Outbound> endQuorumRequests = collectEndQuorumRequests(
epoch, Collections.singleton(destinationId), Optional.empty()); epoch,
Collections.singleton(destinationId),
Optional.empty()
);
assertEquals(1, endQuorumRequests.size()); assertEquals(1, endQuorumRequests.size());
return endQuorumRequests.get(0).correlationId(); return endQuorumRequests.get(0);
} }
void assertSentEndQuorumEpochResponse( void assertSentEndQuorumEpochResponse(
@ -690,7 +709,7 @@ public final class RaftClientTestContext {
return sentRequests.get(0); return sentRequests.get(0);
} }
int assertSentFetchRequest( RaftRequest.Outbound assertSentFetchRequest(
int epoch, int epoch,
long fetchOffset, long fetchOffset,
int lastFetchedEpoch int lastFetchedEpoch
@ -700,7 +719,7 @@ public final class RaftClientTestContext {
RaftRequest.Outbound raftRequest = sentMessages.get(0); RaftRequest.Outbound raftRequest = sentMessages.get(0);
assertFetchRequestData(raftRequest, epoch, fetchOffset, lastFetchedEpoch); assertFetchRequestData(raftRequest, epoch, fetchOffset, lastFetchedEpoch);
return raftRequest.correlationId(); return raftRequest;
} }
FetchResponseData.PartitionData assertSentFetchPartitionResponse() { FetchResponseData.PartitionData assertSentFetchPartitionResponse() {
@ -708,7 +727,7 @@ public final class RaftClientTestContext {
assertEquals( assertEquals(
1, sentMessages.size(), "Found unexpected sent messages " + sentMessages); 1, sentMessages.size(), "Found unexpected sent messages " + sentMessages);
RaftResponse.Outbound raftMessage = sentMessages.get(0); RaftResponse.Outbound raftMessage = sentMessages.get(0);
assertEquals(ApiKeys.FETCH.id, raftMessage.data.apiKey()); assertEquals(ApiKeys.FETCH.id, raftMessage.data().apiKey());
FetchResponseData response = (FetchResponseData) raftMessage.data(); FetchResponseData response = (FetchResponseData) raftMessage.data();
assertEquals(Errors.NONE, Errors.forCode(response.errorCode())); assertEquals(Errors.NONE, Errors.forCode(response.errorCode()));
@ -723,7 +742,7 @@ public final class RaftClientTestContext {
assertEquals( assertEquals(
1, sentMessages.size(), "Found unexpected sent messages " + sentMessages); 1, sentMessages.size(), "Found unexpected sent messages " + sentMessages);
RaftResponse.Outbound raftMessage = sentMessages.get(0); RaftResponse.Outbound raftMessage = sentMessages.get(0);
assertEquals(ApiKeys.FETCH.id, raftMessage.data.apiKey()); assertEquals(ApiKeys.FETCH.id, raftMessage.data().apiKey());
FetchResponseData response = (FetchResponseData) raftMessage.data(); FetchResponseData response = (FetchResponseData) raftMessage.data();
assertEquals(topLevelError, Errors.forCode(response.errorCode())); assertEquals(topLevelError, Errors.forCode(response.errorCode()));
} }
@ -811,7 +830,7 @@ public final class RaftClientTestContext {
assertEquals(preferredSuccessors, partitionRequest.preferredSuccessors()); assertEquals(preferredSuccessors, partitionRequest.preferredSuccessors());
}); });
collectedDestinationIdSet.add(raftMessage.destinationId()); collectedDestinationIdSet.add(raftMessage.destination().id());
endQuorumRequests.add(raftMessage); endQuorumRequests.add(raftMessage);
} }
} }
@ -825,11 +844,18 @@ public final class RaftClientTestContext {
) throws Exception { ) throws Exception {
pollUntilRequest(); pollUntilRequest();
RaftRequest.Outbound fetchRequest = assertSentFetchRequest(); RaftRequest.Outbound fetchRequest = assertSentFetchRequest();
assertTrue(voters.contains(fetchRequest.destinationId())); int destinationId = fetchRequest.destination().id();
assertTrue(
voters.contains(destinationId) || bootstrapIds.contains(destinationId),
String.format("id %d is not in sets %s or %s", destinationId, voters, bootstrapIds)
);
assertFetchRequestData(fetchRequest, 0, 0L, 0); assertFetchRequestData(fetchRequest, 0, 0L, 0);
deliverResponse(fetchRequest.correlationId, fetchRequest.destinationId(), deliverResponse(
fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.NONE)); fetchRequest.correlationId(),
fetchRequest.destination(),
fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.NONE)
);
client.poll(); client.poll();
assertElectedLeader(epoch, leaderId); assertElectedLeader(epoch, leaderId);
} }
@ -850,7 +876,7 @@ public final class RaftClientTestContext {
return requests; return requests;
} }
private static InetSocketAddress mockAddress(int id) { public static InetSocketAddress mockAddress(int id) {
return new InetSocketAddress("localhost", 9990 + id); return new InetSocketAddress("localhost", 9990 + id);
} }

View File

@ -21,6 +21,7 @@ import net.jqwik.api.ForAll;
import net.jqwik.api.Property; import net.jqwik.api.Property;
import net.jqwik.api.Tag; import net.jqwik.api.Tag;
import net.jqwik.api.constraints.IntRange; import net.jqwik.api.constraints.IntRange;
import org.apache.kafka.common.Node;
import org.apache.kafka.common.TopicPartition; import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.Uuid; import org.apache.kafka.common.Uuid;
import org.apache.kafka.common.memory.MemoryPool; import org.apache.kafka.common.memory.MemoryPool;
@ -45,6 +46,7 @@ import java.net.InetSocketAddress;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
import java.util.Iterator; import java.util.Iterator;
@ -59,7 +61,6 @@ import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier; import java.util.function.Supplier;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -189,7 +190,7 @@ public class RaftEventSimulationTest {
// they are able to elect a leader and continue making progress // they are able to elect a leader and continue making progress
cluster.killAll(); cluster.killAll();
Iterator<Integer> nodeIdsIterator = cluster.nodes().iterator(); Iterator<Integer> nodeIdsIterator = cluster.nodeIds().iterator();
for (int i = 0; i < cluster.majoritySize(); i++) { for (int i = 0; i < cluster.majoritySize(); i++) {
Integer nodeId = nodeIdsIterator.next(); Integer nodeId = nodeIdsIterator.next();
cluster.start(nodeId); cluster.start(nodeId);
@ -224,7 +225,7 @@ public class RaftEventSimulationTest {
); );
router.filter(leaderId, new DropAllTraffic()); router.filter(leaderId, new DropAllTraffic());
Set<Integer> nonPartitionedNodes = new HashSet<>(cluster.nodes()); Set<Integer> nonPartitionedNodes = new HashSet<>(cluster.nodeIds());
nonPartitionedNodes.remove(leaderId); nonPartitionedNodes.remove(leaderId);
scheduler.runUntil(() -> cluster.allReachedHighWatermark(20, nonPartitionedNodes)); scheduler.runUntil(() -> cluster.allReachedHighWatermark(20, nonPartitionedNodes));
@ -252,11 +253,17 @@ public class RaftEventSimulationTest {
// Partition the nodes into two sets. Nodes are reachable within each set, // Partition the nodes into two sets. Nodes are reachable within each set,
// but the two sets cannot communicate with each other. We should be able // but the two sets cannot communicate with each other. We should be able
// to make progress even if an election is needed in the larger set. // to make progress even if an election is needed in the larger set.
router.filter(0, new DropOutboundRequestsFrom(Utils.mkSet(2, 3, 4))); router.filter(
router.filter(1, new DropOutboundRequestsFrom(Utils.mkSet(2, 3, 4))); 0,
router.filter(2, new DropOutboundRequestsFrom(Utils.mkSet(0, 1))); new DropOutboundRequestsTo(cluster.endpointsFromIds(Utils.mkSet(2, 3, 4)))
router.filter(3, new DropOutboundRequestsFrom(Utils.mkSet(0, 1))); );
router.filter(4, new DropOutboundRequestsFrom(Utils.mkSet(0, 1))); router.filter(
1,
new DropOutboundRequestsTo(cluster.endpointsFromIds(Utils.mkSet(2, 3, 4)))
);
router.filter(2, new DropOutboundRequestsTo(cluster.endpointsFromIds(Utils.mkSet(0, 1))));
router.filter(3, new DropOutboundRequestsTo(cluster.endpointsFromIds(Utils.mkSet(0, 1))));
router.filter(4, new DropOutboundRequestsTo(cluster.endpointsFromIds(Utils.mkSet(0, 1))));
long partitionLogEndOffset = cluster.maxLogEndOffset(); long partitionLogEndOffset = cluster.maxLogEndOffset();
scheduler.runUntil(() -> cluster.anyReachedHighWatermark(2 * partitionLogEndOffset)); scheduler.runUntil(() -> cluster.anyReachedHighWatermark(2 * partitionLogEndOffset));
@ -374,7 +381,7 @@ public class RaftEventSimulationTest {
int pollIntervalMs, int pollIntervalMs,
int pollJitterMs) { int pollJitterMs) {
int delayMs = 0; int delayMs = 0;
for (int nodeId : cluster.nodes()) { for (int nodeId : cluster.nodeIds()) {
scheduler.schedule(() -> cluster.pollIfRunning(nodeId), delayMs, pollIntervalMs, pollJitterMs); scheduler.schedule(() -> cluster.pollIfRunning(nodeId), delayMs, pollIntervalMs, pollJitterMs);
delayMs++; delayMs++;
} }
@ -527,25 +534,37 @@ public class RaftEventSimulationTest {
final AtomicInteger correlationIdCounter = new AtomicInteger(); final AtomicInteger correlationIdCounter = new AtomicInteger();
final MockTime time = new MockTime(); final MockTime time = new MockTime();
final Uuid clusterId = Uuid.randomUuid(); final Uuid clusterId = Uuid.randomUuid();
final Set<Integer> voters = new HashSet<>(); final Map<Integer, Node> voters = new HashMap<>();
final Map<Integer, PersistentState> nodes = new HashMap<>(); final Map<Integer, PersistentState> nodes = new HashMap<>();
final Map<Integer, RaftNode> running = new HashMap<>(); final Map<Integer, RaftNode> running = new HashMap<>();
private Cluster(int numVoters, int numObservers, Random random) { private Cluster(int numVoters, int numObservers, Random random) {
this.random = random; this.random = random;
int nodeId = 0; for (int nodeId = 0; nodeId < numVoters; nodeId++) {
for (; nodeId < numVoters; nodeId++) { voters.put(
voters.add(nodeId); nodeId,
new Node(nodeId, String.format("host-node-%d", nodeId), 1234)
);
nodes.put(nodeId, new PersistentState(nodeId)); nodes.put(nodeId, new PersistentState(nodeId));
} }
for (; nodeId < numVoters + numObservers; nodeId++) { for (int nodeIdDelta = 0; nodeIdDelta < numObservers; nodeIdDelta++) {
int nodeId = numVoters + nodeIdDelta;
nodes.put(nodeId, new PersistentState(nodeId)); nodes.put(nodeId, new PersistentState(nodeId));
} }
} }
Set<Integer> nodes() { Set<InetSocketAddress> endpointsFromIds(Set<Integer> nodeIds) {
return voters
.values()
.stream()
.filter(node -> nodeIds.contains(node.id()))
.map(Cluster::nodeAddress)
.collect(Collectors.toSet());
}
Set<Integer> nodeIds() {
return nodes.keySet(); return nodes.keySet();
} }
@ -710,18 +729,19 @@ public class RaftEventSimulationTest {
nodes.put(nodeId, new PersistentState(nodeId)); nodes.put(nodeId, new PersistentState(nodeId));
} }
private static InetSocketAddress nodeAddress(int id) { private static InetSocketAddress nodeAddress(Node node) {
return new InetSocketAddress("localhost", 9990 + id); return InetSocketAddress.createUnresolved(node.host(), node.port());
} }
void start(int nodeId) { void start(int nodeId) {
LogContext logContext = new LogContext("[Node " + nodeId + "] "); LogContext logContext = new LogContext("[Node " + nodeId + "] ");
PersistentState persistentState = nodes.get(nodeId); PersistentState persistentState = nodes.get(nodeId);
MockNetworkChannel channel = new MockNetworkChannel(correlationIdCounter, voters); MockNetworkChannel channel = new MockNetworkChannel(correlationIdCounter);
MockMessageQueue messageQueue = new MockMessageQueue(); MockMessageQueue messageQueue = new MockMessageQueue();
Map<Integer, InetSocketAddress> voterAddressMap = voters Map<Integer, InetSocketAddress> voterAddressMap = voters
.values()
.stream() .stream()
.collect(Collectors.toMap(Function.identity(), Cluster::nodeAddress)); .collect(Collectors.toMap(Node::id, Cluster::nodeAddress));
QuorumConfig quorumConfig = new QuorumConfig( QuorumConfig quorumConfig = new QuorumConfig(
REQUEST_TIMEOUT_MS, REQUEST_TIMEOUT_MS,
@ -750,6 +770,7 @@ public class RaftEventSimulationTest {
new MockExpirationService(time), new MockExpirationService(time),
FETCH_MAX_WAIT_MS, FETCH_MAX_WAIT_MS,
clusterId.toString(), clusterId.toString(),
Collections.emptyList(),
logContext, logContext,
random, random,
quorumConfig quorumConfig
@ -808,7 +829,6 @@ public class RaftEventSimulationTest {
client.register(counter); client.register(counter);
client.initialize( client.initialize(
voterAddresses, voterAddresses,
"CONTROLLER",
store, store,
metrics metrics
); );
@ -847,9 +867,11 @@ public class RaftEventSimulationTest {
private static class InflightRequest { private static class InflightRequest {
final int sourceId; final int sourceId;
final Node destination;
private InflightRequest(int sourceId) { private InflightRequest(int sourceId, Node destination) {
this.sourceId = sourceId; this.sourceId = sourceId;
this.destination = destination;
} }
} }
@ -884,11 +906,15 @@ public class RaftEventSimulationTest {
} }
} }
private static class DropOutboundRequestsFrom implements NetworkFilter { private static class DropOutboundRequestsTo implements NetworkFilter {
private final Set<InetSocketAddress> unreachable;
private final Set<Integer> unreachable; /**
* This network filter drops any outbound message sent to the {@code unreachable} nodes.
private DropOutboundRequestsFrom(Set<Integer> unreachable) { *
* @param unreachable the set of destination address which are not reachable
*/
private DropOutboundRequestsTo(Set<InetSocketAddress> unreachable) {
this.unreachable = unreachable; this.unreachable = unreachable;
} }
@ -897,11 +923,25 @@ public class RaftEventSimulationTest {
return true; return true;
} }
/**
* Returns if the message should be sent to the destination.
*
* Returns false when outbound request messages contains a destination {@code Node} that
* matches the set of unreaable {@code InetSocketAddress}. Note that the {@code Node.id()}
* and {@code Node.rack()} are not compared.
*
* @param message the raft message
* @return true if the message should be delivered, otherwise false
*/
@Override @Override
public boolean acceptOutbound(RaftMessage message) { public boolean acceptOutbound(RaftMessage message) {
if (message instanceof RaftRequest.Outbound) { if (message instanceof RaftRequest.Outbound) {
RaftRequest.Outbound request = (RaftRequest.Outbound) message; RaftRequest.Outbound request = (RaftRequest.Outbound) message;
return !unreachable.contains(request.destinationId()); InetSocketAddress destination = InetSocketAddress.createUnresolved(
request.destination().host(),
request.destination().port()
);
return !unreachable.contains(destination);
} }
return true; return true;
} }
@ -955,7 +995,7 @@ public class RaftEventSimulationTest {
public void verify() { public void verify() {
cluster.leaderHighWatermark().ifPresent(highWatermark -> { cluster.leaderHighWatermark().ifPresent(highWatermark -> {
long numReachedHighWatermark = cluster.nodes.entrySet().stream() long numReachedHighWatermark = cluster.nodes.entrySet().stream()
.filter(entry -> cluster.voters.contains(entry.getKey())) .filter(entry -> cluster.voters.containsKey(entry.getKey()))
.filter(entry -> entry.getValue().log.endOffset().offset >= highWatermark) .filter(entry -> entry.getValue().log.endOffset().offset >= highWatermark)
.count(); .count();
assertTrue( assertTrue(
@ -1194,19 +1234,19 @@ public class RaftEventSimulationTest {
return; return;
int correlationId = outbound.correlationId(); int correlationId = outbound.correlationId();
int destinationId = outbound.destinationId(); Node destination = outbound.destination();
RaftRequest.Inbound inbound = new RaftRequest.Inbound(correlationId, outbound.data(), RaftRequest.Inbound inbound = new RaftRequest.Inbound(correlationId, outbound.data(),
cluster.time.milliseconds()); cluster.time.milliseconds());
if (!filters.get(destinationId).acceptInbound(inbound)) if (!filters.get(destination.id()).acceptInbound(inbound))
return; return;
cluster.nodeIfRunning(destinationId).ifPresent(node -> { cluster.nodeIfRunning(destination.id()).ifPresent(node -> {
inflight.put(correlationId, new InflightRequest(senderId)); inflight.put(correlationId, new InflightRequest(senderId, destination));
inbound.completion.whenComplete((response, exception) -> { inbound.completion.whenComplete((response, exception) -> {
if (response != null && filters.get(destinationId).acceptOutbound(response)) { if (response != null && filters.get(destination.id()).acceptOutbound(response)) {
deliver(destinationId, response); deliver(response);
} }
}); });
@ -1214,11 +1254,17 @@ public class RaftEventSimulationTest {
}); });
} }
void deliver(int senderId, RaftResponse.Outbound outbound) { void deliver(RaftResponse.Outbound outbound) {
int correlationId = outbound.correlationId(); int correlationId = outbound.correlationId();
RaftResponse.Inbound inbound = new RaftResponse.Inbound(correlationId, outbound.data(), senderId);
InflightRequest inflightRequest = inflight.remove(correlationId); InflightRequest inflightRequest = inflight.remove(correlationId);
RaftResponse.Inbound inbound = new RaftResponse.Inbound(
correlationId,
outbound.data(),
// The source of the response is the destination of the request
inflightRequest.destination
);
if (!filters.get(inflightRequest.sourceId).acceptInbound(inbound)) if (!filters.get(inflightRequest.sourceId).acceptInbound(inbound))
return; return;

View File

@ -16,14 +16,20 @@
*/ */
package org.apache.kafka.raft; package org.apache.kafka.raft;
import org.apache.kafka.common.Node;
import org.apache.kafka.common.utils.MockTime; import org.apache.kafka.common.utils.MockTime;
import org.apache.kafka.common.utils.Utils;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import java.util.List;
import java.util.Optional;
import java.util.Random; import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
public class RequestManagerTest { public class RequestManagerTest {
private final MockTime time = new MockTime(); private final MockTime time = new MockTime();
@ -33,105 +39,247 @@ public class RequestManagerTest {
@Test @Test
public void testResetAllConnections() { public void testResetAllConnections() {
Node node1 = new Node(1, "mock-host-1", 4321);
Node node2 = new Node(2, "mock-host-2", 4321);
RequestManager cache = new RequestManager( RequestManager cache = new RequestManager(
Utils.mkSet(1, 2, 3), makeBootstrapList(3),
retryBackoffMs, retryBackoffMs,
requestTimeoutMs, requestTimeoutMs,
random); random
);
// One host has an inflight request // One host has an inflight request
RequestManager.ConnectionState connectionState1 = cache.getOrCreate(1); cache.onRequestSent(node1, 1, time.milliseconds());
connectionState1.onRequestSent(1, time.milliseconds()); assertFalse(cache.isReady(node1, time.milliseconds()));
assertFalse(connectionState1.isReady(time.milliseconds()));
// Another is backing off // Another is backing off
RequestManager.ConnectionState connectionState2 = cache.getOrCreate(2); cache.onRequestSent(node2, 2, time.milliseconds());
connectionState2.onRequestSent(2, time.milliseconds()); cache.onResponseResult(node2, 2, false, time.milliseconds());
connectionState2.onResponseError(2, time.milliseconds()); assertFalse(cache.isReady(node2, time.milliseconds()));
assertFalse(connectionState2.isReady(time.milliseconds()));
cache.resetAll(); cache.resetAll();
// Now both should be ready // Now both should be ready
assertTrue(connectionState1.isReady(time.milliseconds())); assertTrue(cache.isReady(node1, time.milliseconds()));
assertTrue(connectionState2.isReady(time.milliseconds())); assertTrue(cache.isReady(node2, time.milliseconds()));
} }
@Test @Test
public void testBackoffAfterFailure() { public void testBackoffAfterFailure() {
Node node = new Node(1, "mock-host-1", 4321);
RequestManager cache = new RequestManager( RequestManager cache = new RequestManager(
Utils.mkSet(1, 2, 3), makeBootstrapList(3),
retryBackoffMs, retryBackoffMs,
requestTimeoutMs, requestTimeoutMs,
random); random
);
RequestManager.ConnectionState connectionState = cache.getOrCreate(1); assertTrue(cache.isReady(node, time.milliseconds()));
assertTrue(connectionState.isReady(time.milliseconds()));
long correlationId = 1; long correlationId = 1;
connectionState.onRequestSent(correlationId, time.milliseconds()); cache.onRequestSent(node, correlationId, time.milliseconds());
assertFalse(connectionState.isReady(time.milliseconds())); assertFalse(cache.isReady(node, time.milliseconds()));
connectionState.onResponseError(correlationId, time.milliseconds()); cache.onResponseResult(node, correlationId, false, time.milliseconds());
assertFalse(connectionState.isReady(time.milliseconds())); assertFalse(cache.isReady(node, time.milliseconds()));
time.sleep(retryBackoffMs); time.sleep(retryBackoffMs);
assertTrue(connectionState.isReady(time.milliseconds())); assertTrue(cache.isReady(node, time.milliseconds()));
} }
@Test @Test
public void testSuccessfulResponse() { public void testSuccessfulResponse() {
Node node = new Node(1, "mock-host-1", 4321);
RequestManager cache = new RequestManager( RequestManager cache = new RequestManager(
Utils.mkSet(1, 2, 3), makeBootstrapList(3),
retryBackoffMs, retryBackoffMs,
requestTimeoutMs, requestTimeoutMs,
random); random
);
RequestManager.ConnectionState connectionState = cache.getOrCreate(1);
long correlationId = 1; long correlationId = 1;
connectionState.onRequestSent(correlationId, time.milliseconds()); cache.onRequestSent(node, correlationId, time.milliseconds());
assertFalse(connectionState.isReady(time.milliseconds())); assertFalse(cache.isReady(node, time.milliseconds()));
connectionState.onResponseReceived(correlationId); cache.onResponseResult(node, correlationId, true, time.milliseconds());
assertTrue(connectionState.isReady(time.milliseconds())); assertTrue(cache.isReady(node, time.milliseconds()));
} }
@Test @Test
public void testIgnoreUnexpectedResponse() { public void testIgnoreUnexpectedResponse() {
Node node = new Node(1, "mock-host-1", 4321);
RequestManager cache = new RequestManager( RequestManager cache = new RequestManager(
Utils.mkSet(1, 2, 3), makeBootstrapList(3),
retryBackoffMs, retryBackoffMs,
requestTimeoutMs, requestTimeoutMs,
random); random
);
RequestManager.ConnectionState connectionState = cache.getOrCreate(1);
long correlationId = 1; long correlationId = 1;
connectionState.onRequestSent(correlationId, time.milliseconds()); cache.onRequestSent(node, correlationId, time.milliseconds());
assertFalse(connectionState.isReady(time.milliseconds())); assertFalse(cache.isReady(node, time.milliseconds()));
connectionState.onResponseReceived(correlationId + 1); cache.onResponseResult(node, correlationId + 1, true, time.milliseconds());
assertFalse(connectionState.isReady(time.milliseconds())); assertFalse(cache.isReady(node, time.milliseconds()));
} }
@Test @Test
public void testRequestTimeout() { public void testRequestTimeout() {
Node node = new Node(1, "mock-host-1", 4321);
RequestManager cache = new RequestManager( RequestManager cache = new RequestManager(
Utils.mkSet(1, 2, 3), makeBootstrapList(3),
retryBackoffMs, retryBackoffMs,
requestTimeoutMs, requestTimeoutMs,
random); random
);
RequestManager.ConnectionState connectionState = cache.getOrCreate(1);
long correlationId = 1; long correlationId = 1;
connectionState.onRequestSent(correlationId, time.milliseconds()); cache.onRequestSent(node, correlationId, time.milliseconds());
assertFalse(connectionState.isReady(time.milliseconds())); assertFalse(cache.isReady(node, time.milliseconds()));
time.sleep(requestTimeoutMs - 1); time.sleep(requestTimeoutMs - 1);
assertFalse(connectionState.isReady(time.milliseconds())); assertFalse(cache.isReady(node, time.milliseconds()));
time.sleep(1); time.sleep(1);
assertTrue(connectionState.isReady(time.milliseconds())); assertTrue(cache.isReady(node, time.milliseconds()));
} }
@Test
public void testRequestToBootstrapList() {
List<Node> bootstrapList = makeBootstrapList(2);
RequestManager cache = new RequestManager(
bootstrapList,
retryBackoffMs,
requestTimeoutMs,
random
);
// Find a ready node with the starting state
Node bootstrapNode1 = cache.findReadyBootstrapServer(time.milliseconds()).get();
assertTrue(
bootstrapList.contains(bootstrapNode1),
String.format("%s is not in %s", bootstrapNode1, bootstrapList)
);
assertEquals(0, cache.backoffBeforeAvailableBootstrapServer(time.milliseconds()));
// Send a request and check the cache state
cache.onRequestSent(bootstrapNode1, 1, time.milliseconds());
assertEquals(
Optional.empty(),
cache.findReadyBootstrapServer(time.milliseconds())
);
assertEquals(requestTimeoutMs, cache.backoffBeforeAvailableBootstrapServer(time.milliseconds()));
// Fail the request
time.sleep(100);
cache.onResponseResult(bootstrapNode1, 1, false, time.milliseconds());
Node bootstrapNode2 = cache.findReadyBootstrapServer(time.milliseconds()).get();
assertNotEquals(bootstrapNode1, bootstrapNode2);
assertEquals(0, cache.backoffBeforeAvailableBootstrapServer(time.milliseconds()));
// Send a request to the second node and check the state
cache.onRequestSent(bootstrapNode2, 2, time.milliseconds());
assertEquals(
Optional.empty(),
cache.findReadyBootstrapServer(time.milliseconds())
);
assertEquals(requestTimeoutMs, cache.backoffBeforeAvailableBootstrapServer(time.milliseconds()));
// Fail the second request before the request timeout
time.sleep(retryBackoffMs - 1);
cache.onResponseResult(bootstrapNode2, 2, false, time.milliseconds());
assertEquals(
Optional.empty(),
cache.findReadyBootstrapServer(time.milliseconds())
);
assertEquals(1, cache.backoffBeforeAvailableBootstrapServer(time.milliseconds()));
// Timeout the first backoff and show that that node is ready
time.sleep(1);
Node bootstrapNode3 = cache.findReadyBootstrapServer(time.milliseconds()).get();
assertEquals(bootstrapNode1, bootstrapNode3);
assertEquals(0, cache.backoffBeforeAvailableBootstrapServer(time.milliseconds()));
}
@Test
public void testFindReadyWithInflightRequest() {
Node otherNode = new Node(1, "other-node", 1234);
List<Node> bootstrapList = makeBootstrapList(3);
RequestManager cache = new RequestManager(
bootstrapList,
retryBackoffMs,
requestTimeoutMs,
random
);
// Send request to a node that is not in the bootstrap list
cache.onRequestSent(otherNode, 1, time.milliseconds());
assertEquals(Optional.empty(), cache.findReadyBootstrapServer(time.milliseconds()));
}
@Test
public void testFindReadyWithRequestTimedout() {
Node otherNode = new Node(1, "other-node", 1234);
List<Node> bootstrapList = makeBootstrapList(3);
RequestManager cache = new RequestManager(
bootstrapList,
retryBackoffMs,
requestTimeoutMs,
random
);
// Send request to a node that is not in the bootstrap list
cache.onRequestSent(otherNode, 1, time.milliseconds());
assertTrue(cache.isResponseExpected(otherNode, 1));
assertEquals(Optional.empty(), cache.findReadyBootstrapServer(time.milliseconds()));
// Timeout the request
time.sleep(requestTimeoutMs);
Node bootstrapNode = cache.findReadyBootstrapServer(time.milliseconds()).get();
assertTrue(bootstrapList.contains(bootstrapNode));
assertFalse(cache.isResponseExpected(otherNode, 1));
}
@Test
public void testAnyInflightRequestWithAnyRequest() {
Node otherNode = new Node(1, "other-node", 1234);
List<Node> bootstrapList = makeBootstrapList(3);
RequestManager cache = new RequestManager(
bootstrapList,
retryBackoffMs,
requestTimeoutMs,
random
);
assertFalse(cache.hasAnyInflightRequest(time.milliseconds()));
// Send a request and check state
cache.onRequestSent(otherNode, 11, time.milliseconds());
assertTrue(cache.hasAnyInflightRequest(time.milliseconds()));
// Wait until the request times out
time.sleep(requestTimeoutMs);
assertFalse(cache.hasAnyInflightRequest(time.milliseconds()));
// Send another request and fail it
cache.onRequestSent(otherNode, 12, time.milliseconds());
cache.onResponseResult(otherNode, 12, false, time.milliseconds());
assertFalse(cache.hasAnyInflightRequest(time.milliseconds()));
// Send another request and mark it successful
cache.onRequestSent(otherNode, 12, time.milliseconds());
cache.onResponseResult(otherNode, 12, true, time.milliseconds());
assertFalse(cache.hasAnyInflightRequest(time.milliseconds()));
}
private List<Node> makeBootstrapList(int numberOfNodes) {
return IntStream.iterate(-2, id -> id - 1)
.limit(numberOfNodes)
.mapToObj(id -> new Node(id, String.format("mock-boot-host%d", id), 1234))
.collect(Collectors.toList());
}
} }

View File

@ -16,8 +16,8 @@
*/ */
package org.apache.kafka.raft.internals; package org.apache.kafka.raft.internals;
import java.util.Arrays;
import java.util.Optional; import java.util.Optional;
import java.util.stream.IntStream;
import org.apache.kafka.common.TopicPartition; import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.Uuid; import org.apache.kafka.common.Uuid;
import org.apache.kafka.common.message.KRaftVersionRecord; import org.apache.kafka.common.message.KRaftVersionRecord;
@ -52,7 +52,7 @@ final class KRaftControlRecordStateMachineTest {
@Test @Test
void testEmptyPartition() { void testEmptyPartition() {
MockLog log = buildLog(); MockLog log = buildLog();
VoterSet voterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true)); VoterSet voterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(1, 2, 3), true));
KRaftControlRecordStateMachine partitionState = buildPartitionListener(log, Optional.of(voterSet)); KRaftControlRecordStateMachine partitionState = buildPartitionListener(log, Optional.of(voterSet));
@ -65,7 +65,7 @@ final class KRaftControlRecordStateMachineTest {
@Test @Test
void testUpdateWithoutSnapshot() { void testUpdateWithoutSnapshot() {
MockLog log = buildLog(); MockLog log = buildLog();
VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true)); VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(1, 2, 3), true));
BufferSupplier bufferSupplier = BufferSupplier.NO_CACHING; BufferSupplier bufferSupplier = BufferSupplier.NO_CACHING;
int epoch = 1; int epoch = 1;
@ -85,7 +85,7 @@ final class KRaftControlRecordStateMachineTest {
); );
// Append the voter set control record // Append the voter set control record
VoterSet voterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(4, 5, 6), true)); VoterSet voterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(4, 5, 6), true));
log.appendAsLeader( log.appendAsLeader(
MemoryRecords.withVotersRecord( MemoryRecords.withVotersRecord(
log.endOffset().offset, log.endOffset().offset,
@ -108,7 +108,7 @@ final class KRaftControlRecordStateMachineTest {
@Test @Test
void testUpdateWithEmptySnapshot() { void testUpdateWithEmptySnapshot() {
MockLog log = buildLog(); MockLog log = buildLog();
VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true)); VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(1, 2, 3), true));
BufferSupplier bufferSupplier = BufferSupplier.NO_CACHING; BufferSupplier bufferSupplier = BufferSupplier.NO_CACHING;
int epoch = 1; int epoch = 1;
@ -136,7 +136,7 @@ final class KRaftControlRecordStateMachineTest {
); );
// Append the voter set control record // Append the voter set control record
VoterSet voterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(4, 5, 6), true)); VoterSet voterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(4, 5, 6), true));
log.appendAsLeader( log.appendAsLeader(
MemoryRecords.withVotersRecord( MemoryRecords.withVotersRecord(
log.endOffset().offset, log.endOffset().offset,
@ -159,14 +159,14 @@ final class KRaftControlRecordStateMachineTest {
@Test @Test
void testUpdateWithSnapshot() { void testUpdateWithSnapshot() {
MockLog log = buildLog(); MockLog log = buildLog();
VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true)); VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(1, 2, 3), true));
int epoch = 1; int epoch = 1;
KRaftControlRecordStateMachine partitionState = buildPartitionListener(log, Optional.of(staticVoterSet)); KRaftControlRecordStateMachine partitionState = buildPartitionListener(log, Optional.of(staticVoterSet));
// Create a snapshot that has kraft.version and voter set control records // Create a snapshot that has kraft.version and voter set control records
short kraftVersion = 1; short kraftVersion = 1;
VoterSet voterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(4, 5, 6), true)); VoterSet voterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(4, 5, 6), true));
RecordsSnapshotWriter.Builder builder = new RecordsSnapshotWriter.Builder() RecordsSnapshotWriter.Builder builder = new RecordsSnapshotWriter.Builder()
.setRawSnapshotWriter(log.createNewSnapshotUnchecked(new OffsetAndEpoch(10, epoch)).get()) .setRawSnapshotWriter(log.createNewSnapshotUnchecked(new OffsetAndEpoch(10, epoch)).get())
@ -188,7 +188,7 @@ final class KRaftControlRecordStateMachineTest {
@Test @Test
void testUpdateWithSnapshotAndLogOverride() { void testUpdateWithSnapshotAndLogOverride() {
MockLog log = buildLog(); MockLog log = buildLog();
VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true)); VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(1, 2, 3), true));
BufferSupplier bufferSupplier = BufferSupplier.NO_CACHING; BufferSupplier bufferSupplier = BufferSupplier.NO_CACHING;
int epoch = 1; int epoch = 1;
@ -196,7 +196,7 @@ final class KRaftControlRecordStateMachineTest {
// Create a snapshot that has kraft.version and voter set control records // Create a snapshot that has kraft.version and voter set control records
short kraftVersion = 1; short kraftVersion = 1;
VoterSet snapshotVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(4, 5, 6), true)); VoterSet snapshotVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(4, 5, 6), true));
OffsetAndEpoch snapshotId = new OffsetAndEpoch(10, epoch); OffsetAndEpoch snapshotId = new OffsetAndEpoch(10, epoch);
RecordsSnapshotWriter.Builder builder = new RecordsSnapshotWriter.Builder() RecordsSnapshotWriter.Builder builder = new RecordsSnapshotWriter.Builder()
@ -235,7 +235,7 @@ final class KRaftControlRecordStateMachineTest {
@Test @Test
void testTruncateTo() { void testTruncateTo() {
MockLog log = buildLog(); MockLog log = buildLog();
VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true)); VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(1, 2, 3), true));
BufferSupplier bufferSupplier = BufferSupplier.NO_CACHING; BufferSupplier bufferSupplier = BufferSupplier.NO_CACHING;
int epoch = 1; int epoch = 1;
@ -256,7 +256,7 @@ final class KRaftControlRecordStateMachineTest {
// Append the voter set control record // Append the voter set control record
long firstVoterSetOffset = log.endOffset().offset; long firstVoterSetOffset = log.endOffset().offset;
VoterSet firstVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(4, 5, 6), true)); VoterSet firstVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(4, 5, 6), true));
log.appendAsLeader( log.appendAsLeader(
MemoryRecords.withVotersRecord( MemoryRecords.withVotersRecord(
firstVoterSetOffset, firstVoterSetOffset,
@ -303,7 +303,7 @@ final class KRaftControlRecordStateMachineTest {
@Test @Test
void testTrimPrefixTo() { void testTrimPrefixTo() {
MockLog log = buildLog(); MockLog log = buildLog();
VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true)); VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(1, 2, 3), true));
BufferSupplier bufferSupplier = BufferSupplier.NO_CACHING; BufferSupplier bufferSupplier = BufferSupplier.NO_CACHING;
int epoch = 1; int epoch = 1;
@ -325,7 +325,7 @@ final class KRaftControlRecordStateMachineTest {
// Append the voter set control record // Append the voter set control record
long firstVoterSetOffset = log.endOffset().offset; long firstVoterSetOffset = log.endOffset().offset;
VoterSet firstVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(4, 5, 6), true)); VoterSet firstVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(4, 5, 6), true));
log.appendAsLeader( log.appendAsLeader(
MemoryRecords.withVotersRecord( MemoryRecords.withVotersRecord(
firstVoterSetOffset, firstVoterSetOffset,

View File

@ -22,7 +22,6 @@ import org.apache.kafka.common.metrics.Metrics;
import org.apache.kafka.common.utils.LogContext; import org.apache.kafka.common.utils.LogContext;
import org.apache.kafka.common.utils.MockTime; import org.apache.kafka.common.utils.MockTime;
import org.apache.kafka.common.utils.Time; import org.apache.kafka.common.utils.Time;
import org.apache.kafka.common.utils.Utils;
import org.apache.kafka.raft.LogOffsetMetadata; import org.apache.kafka.raft.LogOffsetMetadata;
import org.apache.kafka.raft.MockQuorumStateStore; import org.apache.kafka.raft.MockQuorumStateStore;
import org.apache.kafka.raft.OffsetAndEpoch; import org.apache.kafka.raft.OffsetAndEpoch;
@ -32,13 +31,13 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource; import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.Mockito; import org.mockito.Mockito;
import java.util.Map;
import java.util.Collections; import java.util.Collections;
import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.OptionalInt; import java.util.OptionalInt;
import java.util.OptionalLong; import java.util.OptionalLong;
import java.util.Random; import java.util.Random;
import java.util.Set; import java.util.stream.IntStream;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@ -64,19 +63,11 @@ public class KafkaRaftMetricsTest {
metrics.close(); metrics.close();
} }
private QuorumState buildQuorumState(Set<Integer> voters, short kraftVersion) {
boolean withDirectoryId = kraftVersion > 0;
return buildQuorumState(
VoterSetTest.voterSet(VoterSetTest.voterMap(voters, withDirectoryId)),
kraftVersion
);
}
private QuorumState buildQuorumState(VoterSet voterSet, short kraftVersion) { private QuorumState buildQuorumState(VoterSet voterSet, short kraftVersion) {
return new QuorumState( return new QuorumState(
OptionalInt.of(localId), OptionalInt.of(localId),
localDirectoryId, localDirectoryId,
VoterSetTest.DEFAULT_LISTENER_NAME,
() -> voterSet, () -> voterSet,
() -> kraftVersion, () -> kraftVersion,
electionTimeoutMs, electionTimeoutMs,
@ -88,11 +79,26 @@ public class KafkaRaftMetricsTest {
); );
} }
private VoterSet localStandaloneVoterSet(short kraftVersion) {
boolean withDirectoryId = kraftVersion > 0;
return VoterSetTest.voterSet(
Collections.singletonMap(
localId,
VoterSetTest.voterNode(
ReplicaKey.of(
localId,
withDirectoryId ? Optional.of(localDirectoryId) : Optional.empty()
)
)
)
);
}
@ParameterizedTest @ParameterizedTest
@ValueSource(shorts = {0, 1}) @ValueSource(shorts = {0, 1})
public void shouldRecordVoterQuorumState(short kraftVersion) { public void shouldRecordVoterQuorumState(short kraftVersion) {
boolean withDirectoryId = kraftVersion > 0; boolean withDirectoryId = kraftVersion > 0;
Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(Utils.mkSet(1, 2), withDirectoryId); Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(IntStream.of(1, 2), withDirectoryId);
voterMap.put( voterMap.put(
localId, localId,
VoterSetTest.voterNode( VoterSetTest.voterNode(
@ -102,7 +108,8 @@ public class KafkaRaftMetricsTest {
) )
) )
); );
QuorumState state = buildQuorumState(VoterSetTest.voterSet(voterMap), kraftVersion); VoterSet voters = VoterSetTest.voterSet(voterMap);
QuorumState state = buildQuorumState(voters, kraftVersion);
state.initialize(new OffsetAndEpoch(0L, 0)); state.initialize(new OffsetAndEpoch(0L, 0));
raftMetrics = new KafkaRaftMetrics(metrics, "raft", state); raftMetrics = new KafkaRaftMetrics(metrics, "raft", state);
@ -144,7 +151,7 @@ public class KafkaRaftMetricsTest {
state.leaderStateOrThrow().updateReplicaState(1, 0, new LogOffsetMetadata(5L)); state.leaderStateOrThrow().updateReplicaState(1, 0, new LogOffsetMetadata(5L));
assertEquals((double) 5L, getMetric(metrics, "high-watermark").metricValue()); assertEquals((double) 5L, getMetric(metrics, "high-watermark").metricValue());
state.transitionToFollower(2, 1); state.transitionToFollower(2, voters.voterNode(1, VoterSetTest.DEFAULT_LISTENER_NAME).get());
assertEquals("follower", getMetric(metrics, "current-state").metricValue()); assertEquals("follower", getMetric(metrics, "current-state").metricValue());
assertEquals((double) 1, getMetric(metrics, "current-leader").metricValue()); assertEquals((double) 1, getMetric(metrics, "current-leader").metricValue());
assertEquals((double) -1, getMetric(metrics, "current-vote").metricValue()); assertEquals((double) -1, getMetric(metrics, "current-vote").metricValue());
@ -184,7 +191,11 @@ public class KafkaRaftMetricsTest {
@ParameterizedTest @ParameterizedTest
@ValueSource(shorts = {0, 1}) @ValueSource(shorts = {0, 1})
public void shouldRecordNonVoterQuorumState(short kraftVersion) { public void shouldRecordNonVoterQuorumState(short kraftVersion) {
QuorumState state = buildQuorumState(Utils.mkSet(1, 2, 3), kraftVersion); boolean withDirectoryId = kraftVersion > 0;
VoterSet voters = VoterSetTest.voterSet(
VoterSetTest.voterMap(IntStream.of(1, 2, 3), withDirectoryId)
);
QuorumState state = buildQuorumState(voters, kraftVersion);
state.initialize(new OffsetAndEpoch(0L, 0)); state.initialize(new OffsetAndEpoch(0L, 0));
raftMetrics = new KafkaRaftMetrics(metrics, "raft", state); raftMetrics = new KafkaRaftMetrics(metrics, "raft", state);
@ -198,7 +209,7 @@ public class KafkaRaftMetricsTest {
assertEquals((double) 0, getMetric(metrics, "current-epoch").metricValue()); assertEquals((double) 0, getMetric(metrics, "current-epoch").metricValue());
assertEquals((double) -1L, getMetric(metrics, "high-watermark").metricValue()); assertEquals((double) -1L, getMetric(metrics, "high-watermark").metricValue());
state.transitionToFollower(2, 1); state.transitionToFollower(2, voters.voterNode(1, VoterSetTest.DEFAULT_LISTENER_NAME).get());
assertEquals("observer", getMetric(metrics, "current-state").metricValue()); assertEquals("observer", getMetric(metrics, "current-state").metricValue());
assertEquals((double) 1, getMetric(metrics, "current-leader").metricValue()); assertEquals((double) 1, getMetric(metrics, "current-leader").metricValue());
assertEquals((double) -1, getMetric(metrics, "current-vote").metricValue()); assertEquals((double) -1, getMetric(metrics, "current-vote").metricValue());
@ -227,7 +238,7 @@ public class KafkaRaftMetricsTest {
@ParameterizedTest @ParameterizedTest
@ValueSource(shorts = {0, 1}) @ValueSource(shorts = {0, 1})
public void shouldRecordLogEnd(short kraftVersion) { public void shouldRecordLogEnd(short kraftVersion) {
QuorumState state = buildQuorumState(Collections.singleton(localId), kraftVersion); QuorumState state = buildQuorumState(localStandaloneVoterSet(kraftVersion), kraftVersion);
state.initialize(new OffsetAndEpoch(0L, 0)); state.initialize(new OffsetAndEpoch(0L, 0));
raftMetrics = new KafkaRaftMetrics(metrics, "raft", state); raftMetrics = new KafkaRaftMetrics(metrics, "raft", state);
@ -243,7 +254,7 @@ public class KafkaRaftMetricsTest {
@ParameterizedTest @ParameterizedTest
@ValueSource(shorts = {0, 1}) @ValueSource(shorts = {0, 1})
public void shouldRecordNumUnknownVoterConnections(short kraftVersion) { public void shouldRecordNumUnknownVoterConnections(short kraftVersion) {
QuorumState state = buildQuorumState(Collections.singleton(localId), kraftVersion); QuorumState state = buildQuorumState(localStandaloneVoterSet(kraftVersion), kraftVersion);
state.initialize(new OffsetAndEpoch(0L, 0)); state.initialize(new OffsetAndEpoch(0L, 0));
raftMetrics = new KafkaRaftMetrics(metrics, "raft", state); raftMetrics = new KafkaRaftMetrics(metrics, "raft", state);
@ -257,7 +268,7 @@ public class KafkaRaftMetricsTest {
@ParameterizedTest @ParameterizedTest
@ValueSource(shorts = {0, 1}) @ValueSource(shorts = {0, 1})
public void shouldRecordPollIdleRatio(short kraftVersion) { public void shouldRecordPollIdleRatio(short kraftVersion) {
QuorumState state = buildQuorumState(Collections.singleton(localId), kraftVersion); QuorumState state = buildQuorumState(localStandaloneVoterSet(kraftVersion), kraftVersion);
state.initialize(new OffsetAndEpoch(0L, 0)); state.initialize(new OffsetAndEpoch(0L, 0));
raftMetrics = new KafkaRaftMetrics(metrics, "raft", state); raftMetrics = new KafkaRaftMetrics(metrics, "raft", state);
@ -330,7 +341,7 @@ public class KafkaRaftMetricsTest {
@ParameterizedTest @ParameterizedTest
@ValueSource(shorts = {0, 1}) @ValueSource(shorts = {0, 1})
public void shouldRecordLatency(short kraftVersion) { public void shouldRecordLatency(short kraftVersion) {
QuorumState state = buildQuorumState(Collections.singleton(localId), kraftVersion); QuorumState state = buildQuorumState(localStandaloneVoterSet(kraftVersion), kraftVersion);
state.initialize(new OffsetAndEpoch(0L, 0)); state.initialize(new OffsetAndEpoch(0L, 0));
raftMetrics = new KafkaRaftMetrics(metrics, "raft", state); raftMetrics = new KafkaRaftMetrics(metrics, "raft", state);
@ -362,7 +373,7 @@ public class KafkaRaftMetricsTest {
@ParameterizedTest @ParameterizedTest
@ValueSource(shorts = {0, 1}) @ValueSource(shorts = {0, 1})
public void shouldRecordRate(short kraftVersion) { public void shouldRecordRate(short kraftVersion) {
QuorumState state = buildQuorumState(Collections.singleton(localId), kraftVersion); QuorumState state = buildQuorumState(localStandaloneVoterSet(kraftVersion), kraftVersion);
state.initialize(new OffsetAndEpoch(0L, 0)); state.initialize(new OffsetAndEpoch(0L, 0));
raftMetrics = new KafkaRaftMetrics(metrics, "raft", state); raftMetrics = new KafkaRaftMetrics(metrics, "raft", state);

View File

@ -21,7 +21,6 @@ import java.nio.ByteBuffer;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap;
import java.util.IdentityHashMap; import java.util.IdentityHashMap;
import java.util.List; import java.util.List;
import java.util.NoSuchElementException; import java.util.NoSuchElementException;
@ -31,6 +30,7 @@ import java.util.Random;
import java.util.Set; import java.util.Set;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream; import java.util.stream.Stream;
import net.jqwik.api.ForAll; import net.jqwik.api.ForAll;
import net.jqwik.api.Property; import net.jqwik.api.Property;
@ -204,7 +204,7 @@ public final class RecordsIteratorTest {
public void testControlRecordIterationWithKraftVersion1() { public void testControlRecordIterationWithKraftVersion1() {
AtomicReference<ByteBuffer> buffer = new AtomicReference<>(null); AtomicReference<ByteBuffer> buffer = new AtomicReference<>(null);
VoterSet voterSet = new VoterSet( VoterSet voterSet = new VoterSet(
new HashMap<>(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true)) VoterSetTest.voterMap(IntStream.of(1, 2, 3), true)
); );
RecordsSnapshotWriter.Builder builder = new RecordsSnapshotWriter.Builder() RecordsSnapshotWriter.Builder builder = new RecordsSnapshotWriter.Builder()
.setTime(new MockTime()) .setTime(new MockTime())

View File

@ -16,10 +16,10 @@
*/ */
package org.apache.kafka.raft.internals; package org.apache.kafka.raft.internals;
import java.util.Arrays;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.stream.IntStream;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
@ -27,7 +27,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows;
final public class VoterSetHistoryTest { final public class VoterSetHistoryTest {
@Test @Test
void testStaticVoterSet() { void testStaticVoterSet() {
VoterSet staticVoterSet = new VoterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true)); VoterSet staticVoterSet = new VoterSet(VoterSetTest.voterMap(IntStream.of(1, 2, 3), true));
VoterSetHistory votersHistory = new VoterSetHistory(Optional.of(staticVoterSet)); VoterSetHistory votersHistory = new VoterSetHistory(Optional.of(staticVoterSet));
assertEquals(Optional.empty(), votersHistory.valueAtOrBefore(0)); assertEquals(Optional.empty(), votersHistory.valueAtOrBefore(0));
@ -58,13 +58,13 @@ final public class VoterSetHistoryTest {
@Test @Test
void testAddAt() { void testAddAt() {
Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true); Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(IntStream.of(1, 2, 3), true);
VoterSet staticVoterSet = new VoterSet(new HashMap<>(voterMap)); VoterSet staticVoterSet = new VoterSet(new HashMap<>(voterMap));
VoterSetHistory votersHistory = new VoterSetHistory(Optional.of(staticVoterSet)); VoterSetHistory votersHistory = new VoterSetHistory(Optional.of(staticVoterSet));
assertThrows( assertThrows(
IllegalArgumentException.class, IllegalArgumentException.class,
() -> votersHistory.addAt(-1, new VoterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true))) () -> votersHistory.addAt(-1, new VoterSet(VoterSetTest.voterMap(IntStream.of(1, 2, 3), true)))
); );
assertEquals(staticVoterSet, votersHistory.lastValue()); assertEquals(staticVoterSet, votersHistory.lastValue());
@ -90,7 +90,7 @@ final public class VoterSetHistoryTest {
void testAddAtNonOverlapping() { void testAddAtNonOverlapping() {
VoterSetHistory votersHistory = new VoterSetHistory(Optional.empty()); VoterSetHistory votersHistory = new VoterSetHistory(Optional.empty());
Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true); Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(IntStream.of(1, 2, 3), true);
VoterSet voterSet = new VoterSet(new HashMap<>(voterMap)); VoterSet voterSet = new VoterSet(new HashMap<>(voterMap));
// Add a starting voter to the history // Add a starting voter to the history
@ -122,7 +122,7 @@ final public class VoterSetHistoryTest {
@Test @Test
void testNonoverlappingFromStaticVoterSet() { void testNonoverlappingFromStaticVoterSet() {
Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true); Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(IntStream.of(1, 2, 3), true);
VoterSet staticVoterSet = new VoterSet(new HashMap<>(voterMap)); VoterSet staticVoterSet = new VoterSet(new HashMap<>(voterMap));
VoterSetHistory votersHistory = new VoterSetHistory(Optional.empty()); VoterSetHistory votersHistory = new VoterSetHistory(Optional.empty());
@ -137,7 +137,7 @@ final public class VoterSetHistoryTest {
@Test @Test
void testTruncateTo() { void testTruncateTo() {
Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true); Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(IntStream.of(1, 2, 3), true);
VoterSet staticVoterSet = new VoterSet(new HashMap<>(voterMap)); VoterSet staticVoterSet = new VoterSet(new HashMap<>(voterMap));
VoterSetHistory votersHistory = new VoterSetHistory(Optional.of(staticVoterSet)); VoterSetHistory votersHistory = new VoterSetHistory(Optional.of(staticVoterSet));
@ -163,7 +163,7 @@ final public class VoterSetHistoryTest {
@Test @Test
void testTrimPrefixTo() { void testTrimPrefixTo() {
Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true); Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(IntStream.of(1, 2, 3), true);
VoterSet staticVoterSet = new VoterSet(new HashMap<>(voterMap)); VoterSet staticVoterSet = new VoterSet(new HashMap<>(voterMap));
VoterSetHistory votersHistory = new VoterSetHistory(Optional.of(staticVoterSet)); VoterSetHistory votersHistory = new VoterSetHistory(Optional.of(staticVoterSet));
@ -196,7 +196,7 @@ final public class VoterSetHistoryTest {
@Test @Test
void testClear() { void testClear() {
Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true); Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(IntStream.of(1, 2, 3), true);
VoterSet staticVoterSet = new VoterSet(new HashMap<>(voterMap)); VoterSet staticVoterSet = new VoterSet(new HashMap<>(voterMap));
VoterSetHistory votersHistory = new VoterSetHistory(Optional.of(staticVoterSet)); VoterSetHistory votersHistory = new VoterSetHistory(Optional.of(staticVoterSet));

View File

@ -18,7 +18,6 @@ package org.apache.kafka.raft.internals;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
@ -26,8 +25,13 @@ import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.function.Function; import java.util.function.Function;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.kafka.common.Node;
import org.apache.kafka.common.Uuid; import org.apache.kafka.common.Uuid;
import org.apache.kafka.common.feature.SupportedVersionRange; import org.apache.kafka.common.feature.SupportedVersionRange;
import org.apache.kafka.common.network.ListenerName;
import org.apache.kafka.common.utils.Utils;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
@ -41,22 +45,45 @@ final public class VoterSetTest {
} }
@Test @Test
void testVoterAddress() { void testVoterNode() {
VoterSet voterSet = new VoterSet(voterMap(Arrays.asList(1, 2, 3), true)); VoterSet voterSet = new VoterSet(voterMap(IntStream.of(1, 2, 3), true));
assertEquals(Optional.of(new InetSocketAddress("replica-1", 1234)), voterSet.voterAddress(1, "LISTENER")); assertEquals(
assertEquals(Optional.empty(), voterSet.voterAddress(1, "MISSING")); Optional.of(new Node(1, "replica-1", 1234)),
assertEquals(Optional.empty(), voterSet.voterAddress(4, "LISTENER")); voterSet.voterNode(1, DEFAULT_LISTENER_NAME)
);
assertEquals(Optional.empty(), voterSet.voterNode(1, ListenerName.normalised("MISSING")));
assertEquals(Optional.empty(), voterSet.voterNode(4, DEFAULT_LISTENER_NAME));
}
@Test
void testVoterNodes() {
VoterSet voterSet = new VoterSet(voterMap(IntStream.of(1, 2, 3), true));
assertEquals(
Utils.mkSet(new Node(1, "replica-1", 1234), new Node(2, "replica-2", 1234)),
voterSet.voterNodes(IntStream.of(1, 2).boxed(), DEFAULT_LISTENER_NAME)
);
assertThrows(
IllegalArgumentException.class,
() -> voterSet.voterNodes(IntStream.of(1, 2).boxed(), ListenerName.normalised("MISSING"))
);
assertThrows(
IllegalArgumentException.class,
() -> voterSet.voterNodes(IntStream.of(1, 4).boxed(), DEFAULT_LISTENER_NAME)
);
} }
@Test @Test
void testVoterIds() { void testVoterIds() {
VoterSet voterSet = new VoterSet(voterMap(Arrays.asList(1, 2, 3), true)); VoterSet voterSet = new VoterSet(voterMap(IntStream.of(1, 2, 3), true));
assertEquals(new HashSet<>(Arrays.asList(1, 2, 3)), voterSet.voterIds()); assertEquals(new HashSet<>(Arrays.asList(1, 2, 3)), voterSet.voterIds());
} }
@Test @Test
void testAddVoter() { void testAddVoter() {
Map<Integer, VoterSet.VoterNode> aVoterMap = voterMap(Arrays.asList(1, 2, 3), true); Map<Integer, VoterSet.VoterNode> aVoterMap = voterMap(IntStream.of(1, 2, 3), true);
VoterSet voterSet = new VoterSet(new HashMap<>(aVoterMap)); VoterSet voterSet = new VoterSet(new HashMap<>(aVoterMap));
assertEquals(Optional.empty(), voterSet.addVoter(voterNode(1, true))); assertEquals(Optional.empty(), voterSet.addVoter(voterNode(1, true)));
@ -68,7 +95,7 @@ final public class VoterSetTest {
@Test @Test
void testRemoveVoter() { void testRemoveVoter() {
Map<Integer, VoterSet.VoterNode> aVoterMap = voterMap(Arrays.asList(1, 2, 3), true); Map<Integer, VoterSet.VoterNode> aVoterMap = voterMap(IntStream.of(1, 2, 3), true);
VoterSet voterSet = new VoterSet(new HashMap<>(aVoterMap)); VoterSet voterSet = new VoterSet(new HashMap<>(aVoterMap));
assertEquals(Optional.empty(), voterSet.removeVoter(ReplicaKey.of(4, Optional.empty()))); assertEquals(Optional.empty(), voterSet.removeVoter(ReplicaKey.of(4, Optional.empty())));
@ -83,7 +110,7 @@ final public class VoterSetTest {
@Test @Test
void testIsVoterWithDirectoryId() { void testIsVoterWithDirectoryId() {
Map<Integer, VoterSet.VoterNode> aVoterMap = voterMap(Arrays.asList(1, 2, 3), true); Map<Integer, VoterSet.VoterNode> aVoterMap = voterMap(IntStream.of(1, 2, 3), true);
VoterSet voterSet = new VoterSet(new HashMap<>(aVoterMap)); VoterSet voterSet = new VoterSet(new HashMap<>(aVoterMap));
assertTrue(voterSet.isVoter(aVoterMap.get(1).voterKey())); assertTrue(voterSet.isVoter(aVoterMap.get(1).voterKey()));
@ -100,7 +127,7 @@ final public class VoterSetTest {
@Test @Test
void testIsVoterWithoutDirectoryId() { void testIsVoterWithoutDirectoryId() {
Map<Integer, VoterSet.VoterNode> aVoterMap = voterMap(Arrays.asList(1, 2, 3), false); Map<Integer, VoterSet.VoterNode> aVoterMap = voterMap(IntStream.of(1, 2, 3), false);
VoterSet voterSet = new VoterSet(new HashMap<>(aVoterMap)); VoterSet voterSet = new VoterSet(new HashMap<>(aVoterMap));
assertTrue(voterSet.isVoter(ReplicaKey.of(1, Optional.empty()))); assertTrue(voterSet.isVoter(ReplicaKey.of(1, Optional.empty())));
@ -111,7 +138,7 @@ final public class VoterSetTest {
@Test @Test
void testIsOnlyVoterInStandalone() { void testIsOnlyVoterInStandalone() {
Map<Integer, VoterSet.VoterNode> aVoterMap = voterMap(Arrays.asList(1), true); Map<Integer, VoterSet.VoterNode> aVoterMap = voterMap(IntStream.of(1), true);
VoterSet voterSet = new VoterSet(new HashMap<>(aVoterMap)); VoterSet voterSet = new VoterSet(new HashMap<>(aVoterMap));
assertTrue(voterSet.isOnlyVoter(aVoterMap.get(1).voterKey())); assertTrue(voterSet.isOnlyVoter(aVoterMap.get(1).voterKey()));
@ -125,7 +152,7 @@ final public class VoterSetTest {
@Test @Test
void testIsOnlyVoterInNotStandalone() { void testIsOnlyVoterInNotStandalone() {
Map<Integer, VoterSet.VoterNode> aVoterMap = voterMap(Arrays.asList(1, 2), true); Map<Integer, VoterSet.VoterNode> aVoterMap = voterMap(IntStream.of(1, 2), true);
VoterSet voterSet = new VoterSet(new HashMap<>(aVoterMap)); VoterSet voterSet = new VoterSet(new HashMap<>(aVoterMap));
assertFalse(voterSet.isOnlyVoter(aVoterMap.get(1).voterKey())); assertFalse(voterSet.isOnlyVoter(aVoterMap.get(1).voterKey()));
@ -142,14 +169,14 @@ final public class VoterSetTest {
@Test @Test
void testRecordRoundTrip() { void testRecordRoundTrip() {
VoterSet voterSet = new VoterSet(voterMap(Arrays.asList(1, 2, 3), true)); VoterSet voterSet = new VoterSet(voterMap(IntStream.of(1, 2, 3), true));
assertEquals(voterSet, VoterSet.fromVotersRecord(voterSet.toVotersRecord((short) 0))); assertEquals(voterSet, VoterSet.fromVotersRecord(voterSet.toVotersRecord((short) 0)));
} }
@Test @Test
void testOverlappingMajority() { void testOverlappingMajority() {
Map<Integer, VoterSet.VoterNode> startingVoterMap = voterMap(Arrays.asList(1, 2, 3), true); Map<Integer, VoterSet.VoterNode> startingVoterMap = voterMap(IntStream.of(1, 2, 3), true);
VoterSet startingVoterSet = voterSet(startingVoterMap); VoterSet startingVoterSet = voterSet(startingVoterMap);
VoterSet biggerVoterSet = startingVoterSet VoterSet biggerVoterSet = startingVoterSet
@ -172,7 +199,7 @@ final public class VoterSetTest {
@Test @Test
void testNonoverlappingMajority() { void testNonoverlappingMajority() {
Map<Integer, VoterSet.VoterNode> startingVoterMap = voterMap(Arrays.asList(1, 2, 3, 4, 5), true); Map<Integer, VoterSet.VoterNode> startingVoterMap = voterMap(IntStream.of(1, 2, 3, 4, 5), true);
VoterSet startingVoterSet = voterSet(startingVoterMap); VoterSet startingVoterSet = voterSet(startingVoterMap);
// Two additions don't have an overlapping majority // Two additions don't have an overlapping majority
@ -217,20 +244,27 @@ final public class VoterSetTest {
); );
} }
public static final ListenerName DEFAULT_LISTENER_NAME = ListenerName.normalised("LISTENER");
public static Map<Integer, VoterSet.VoterNode> voterMap( public static Map<Integer, VoterSet.VoterNode> voterMap(
Collection<Integer> replicas, IntStream replicas,
boolean withDirectoryId boolean withDirectoryId
) { ) {
return replicas return replicas
.stream() .boxed()
.collect( .collect(
Collectors.toMap( Collectors.toMap(
Function.identity(), Function.identity(),
id -> VoterSetTest.voterNode(id, withDirectoryId) id -> voterNode(id, withDirectoryId)
) )
); );
} }
public static Map<Integer, VoterSet.VoterNode> voterMap(Stream<ReplicaKey> replicas) {
return replicas
.collect(Collectors.toMap(ReplicaKey::id, VoterSetTest::voterNode));
}
public static VoterSet.VoterNode voterNode(int id, boolean withDirectoryId) { public static VoterSet.VoterNode voterNode(int id, boolean withDirectoryId) {
return voterNode( return voterNode(
ReplicaKey.of( ReplicaKey.of(
@ -244,7 +278,7 @@ final public class VoterSetTest {
return new VoterSet.VoterNode( return new VoterSet.VoterNode(
replicaKey, replicaKey,
Collections.singletonMap( Collections.singletonMap(
"LISTENER", DEFAULT_LISTENER_NAME,
InetSocketAddress.createUnresolved( InetSocketAddress.createUnresolved(
String.format("replica-%d", replicaKey.id()), String.format("replica-%d", replicaKey.id()),
1234 1234
@ -257,4 +291,8 @@ final public class VoterSetTest {
public static VoterSet voterSet(Map<Integer, VoterSet.VoterNode> voters) { public static VoterSet voterSet(Map<Integer, VoterSet.VoterNode> voters) {
return new VoterSet(voters); return new VoterSet(voters);
} }
public static VoterSet voterSet(Stream<ReplicaKey> voterKeys) {
return voterSet(voterMap(voterKeys));
}
} }

View File

@ -17,12 +17,11 @@
package org.apache.kafka.snapshot; package org.apache.kafka.snapshot;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.HashMap; import java.util.HashMap;
import java.util.Optional; import java.util.Optional;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.IntStream;
import org.apache.kafka.common.message.KRaftVersionRecord; import org.apache.kafka.common.message.KRaftVersionRecord;
import org.apache.kafka.common.message.SnapshotFooterRecord; import org.apache.kafka.common.message.SnapshotFooterRecord;
import org.apache.kafka.common.message.SnapshotHeaderRecord; import org.apache.kafka.common.message.SnapshotHeaderRecord;
@ -97,7 +96,7 @@ final class RecordsSnapshotWriterTest {
OffsetAndEpoch snapshotId = new OffsetAndEpoch(100, 10); OffsetAndEpoch snapshotId = new OffsetAndEpoch(100, 10);
int maxBatchSize = 1024; int maxBatchSize = 1024;
VoterSet voterSet = VoterSetTest.voterSet( VoterSet voterSet = VoterSetTest.voterSet(
new HashMap<>(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true)) new HashMap<>(VoterSetTest.voterMap(IntStream.of(1, 2, 3), true))
); );
AtomicReference<ByteBuffer> buffer = new AtomicReference<>(null); AtomicReference<ByteBuffer> buffer = new AtomicReference<>(null);
RecordsSnapshotWriter.Builder builder = new RecordsSnapshotWriter.Builder() RecordsSnapshotWriter.Builder builder = new RecordsSnapshotWriter.Builder()
@ -117,7 +116,7 @@ final class RecordsSnapshotWriterTest {
OffsetAndEpoch snapshotId = new OffsetAndEpoch(100, 10); OffsetAndEpoch snapshotId = new OffsetAndEpoch(100, 10);
int maxBatchSize = 1024; int maxBatchSize = 1024;
VoterSet voterSet = VoterSetTest.voterSet( VoterSet voterSet = VoterSetTest.voterSet(
new HashMap<>(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true)) new HashMap<>(VoterSetTest.voterMap(IntStream.of(1, 2, 3), true))
); );
AtomicReference<ByteBuffer> buffer = new AtomicReference<>(null); AtomicReference<ByteBuffer> buffer = new AtomicReference<>(null);
RecordsSnapshotWriter.Builder builder = new RecordsSnapshotWriter.Builder() RecordsSnapshotWriter.Builder builder = new RecordsSnapshotWriter.Builder()