diff --git a/checkstyle/import-control.xml b/checkstyle/import-control.xml index ab6177961f5..2f90548ffa9 100644 --- a/checkstyle/import-control.xml +++ b/checkstyle/import-control.xml @@ -441,6 +441,7 @@ + diff --git a/core/src/main/scala/kafka/raft/RaftManager.scala b/core/src/main/scala/kafka/raft/RaftManager.scala index 6bf8bd893ba..0c45734593b 100644 --- a/core/src/main/scala/kafka/raft/RaftManager.scala +++ b/core/src/main/scala/kafka/raft/RaftManager.scala @@ -23,6 +23,7 @@ import java.nio.file.Paths import java.util.OptionalInt import java.util.concurrent.CompletableFuture import java.util.{Map => JMap} +import java.util.{Collection => JCollection} import kafka.log.LogManager import kafka.log.UnifiedLog import kafka.server.KafkaConfig @@ -133,7 +134,7 @@ trait RaftManager[T] { def replicatedLog: ReplicatedLog - def voterNode(id: Int, listener: String): Option[Node] + def voterNode(id: Int, listener: ListenerName): Option[Node] } class KafkaRaftManager[T]( @@ -147,6 +148,7 @@ class KafkaRaftManager[T]( metrics: Metrics, threadNamePrefixOpt: Option[String], val controllerQuorumVotersFuture: CompletableFuture[JMap[Integer, InetSocketAddress]], + bootstrapServers: JCollection[InetSocketAddress], fatalFaultHandler: FaultHandler ) extends RaftManager[T] with Logging { @@ -185,7 +187,6 @@ class KafkaRaftManager[T]( def startup(): Unit = { client.initialize( controllerQuorumVotersFuture.get(), - config.controllerListenerNames.head, new FileQuorumStateStore(new File(dataDir, FileQuorumStateStore.DEFAULT_FILE_NAME)), metrics ) @@ -228,14 +229,15 @@ class KafkaRaftManager[T]( expirationService, logContext, clusterId, + bootstrapServers, raftConfig ) client } private def buildNetworkChannel(): KafkaNetworkChannel = { - val netClient = buildNetworkClient() - new KafkaNetworkChannel(time, netClient, config.quorumRequestTimeoutMs, threadNamePrefix) + val (listenerName, netClient) = buildNetworkClient() + new KafkaNetworkChannel(time, listenerName, netClient, config.quorumRequestTimeoutMs, threadNamePrefix) } private def createDataDir(): File = { @@ -254,7 +256,7 @@ class KafkaRaftManager[T]( ) } - private def buildNetworkClient(): NetworkClient = { + private def buildNetworkClient(): (ListenerName, NetworkClient) = { val controllerListenerName = new ListenerName(config.controllerListenerNames.head) val controllerSecurityProtocol = config.effectiveListenerSecurityProtocolMap.getOrElse( controllerListenerName, @@ -292,7 +294,7 @@ class KafkaRaftManager[T]( val reconnectBackoffMsMs = 500 val discoverBrokerVersions = true - new NetworkClient( + val networkClient = new NetworkClient( selector, new ManualMetadataUpdater(), clientId, @@ -309,13 +311,15 @@ class KafkaRaftManager[T]( apiVersions, logContext ) + + (controllerListenerName, networkClient) } override def leaderAndEpoch: LeaderAndEpoch = { client.leaderAndEpoch } - override def voterNode(id: Int, listener: String): Option[Node] = { + override def voterNode(id: Int, listener: ListenerName): Option[Node] = { client.voterNode(id, listener).toScala } } diff --git a/core/src/main/scala/kafka/server/KafkaConfig.scala b/core/src/main/scala/kafka/server/KafkaConfig.scala index 94a7b349af9..45ec15b1008 100755 --- a/core/src/main/scala/kafka/server/KafkaConfig.scala +++ b/core/src/main/scala/kafka/server/KafkaConfig.scala @@ -439,6 +439,7 @@ object KafkaConfig { /** ********* Raft Quorum Configuration *********/ .define(QuorumConfig.QUORUM_VOTERS_CONFIG, LIST, QuorumConfig.DEFAULT_QUORUM_VOTERS, new QuorumConfig.ControllerQuorumVotersValidator(), HIGH, QuorumConfig.QUORUM_VOTERS_DOC) + .define(QuorumConfig.QUORUM_BOOTSTRAP_SERVERS_CONFIG, LIST, QuorumConfig.DEFAULT_QUORUM_BOOTSTRAP_SERVERS, new QuorumConfig.ControllerQuorumBootstrapServersValidator(), HIGH, QuorumConfig.QUORUM_BOOTSTRAP_SERVERS_DOC) .define(QuorumConfig.QUORUM_ELECTION_TIMEOUT_MS_CONFIG, INT, QuorumConfig.DEFAULT_QUORUM_ELECTION_TIMEOUT_MS, null, HIGH, QuorumConfig.QUORUM_ELECTION_TIMEOUT_MS_DOC) .define(QuorumConfig.QUORUM_FETCH_TIMEOUT_MS_CONFIG, INT, QuorumConfig.DEFAULT_QUORUM_FETCH_TIMEOUT_MS, null, HIGH, QuorumConfig.QUORUM_FETCH_TIMEOUT_MS_DOC) .define(QuorumConfig.QUORUM_ELECTION_BACKOFF_MAX_MS_CONFIG, INT, QuorumConfig.DEFAULT_QUORUM_ELECTION_BACKOFF_MAX_MS, null, HIGH, QuorumConfig.QUORUM_ELECTION_BACKOFF_MAX_MS_DOC) @@ -1055,6 +1056,7 @@ class KafkaConfig private(doLog: Boolean, val props: java.util.Map[_, _], dynami /** ********* Raft Quorum Configuration *********/ val quorumVoters = getList(QuorumConfig.QUORUM_VOTERS_CONFIG) + val quorumBootstrapServers = getList(QuorumConfig.QUORUM_BOOTSTRAP_SERVERS_CONFIG) val quorumElectionTimeoutMs = getInt(QuorumConfig.QUORUM_ELECTION_TIMEOUT_MS_CONFIG) val quorumFetchTimeoutMs = getInt(QuorumConfig.QUORUM_FETCH_TIMEOUT_MS_CONFIG) val quorumElectionBackoffMs = getInt(QuorumConfig.QUORUM_ELECTION_BACKOFF_MAX_MS_CONFIG) diff --git a/core/src/main/scala/kafka/server/KafkaRaftServer.scala b/core/src/main/scala/kafka/server/KafkaRaftServer.scala index d3200149f7a..ecb757c1a89 100644 --- a/core/src/main/scala/kafka/server/KafkaRaftServer.scala +++ b/core/src/main/scala/kafka/server/KafkaRaftServer.scala @@ -71,6 +71,7 @@ class KafkaRaftServer( time, metrics, CompletableFuture.completedFuture(QuorumConfig.parseVoterConnections(config.quorumVoters)), + QuorumConfig.parseBootstrapServers(config.quorumBootstrapServers), new StandardFaultHandlerFactory(), ) diff --git a/core/src/main/scala/kafka/server/KafkaServer.scala b/core/src/main/scala/kafka/server/KafkaServer.scala index 738adab0fb0..9c807c79da5 100755 --- a/core/src/main/scala/kafka/server/KafkaServer.scala +++ b/core/src/main/scala/kafka/server/KafkaServer.scala @@ -70,9 +70,9 @@ import java.net.{InetAddress, SocketTimeoutException} import java.nio.file.{Files, Paths} import java.time.Duration import java.util -import java.util.{Optional, OptionalInt, OptionalLong} import java.util.concurrent._ import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger} +import java.util.{Optional, OptionalInt, OptionalLong} import scala.collection.{Map, Seq} import scala.compat.java8.OptionConverters.RichOptionForJava8 import scala.jdk.CollectionConverters._ @@ -439,6 +439,7 @@ class KafkaServer( metrics, threadNamePrefix, CompletableFuture.completedFuture(quorumVoters), + QuorumConfig.parseBootstrapServers(config.quorumBootstrapServers), fatalFaultHandler = new LoggingFaultHandler("raftManager", () => shutdown()) ) quorumControllerNodeProvider = RaftControllerNodeProvider(raftManager, config) diff --git a/core/src/main/scala/kafka/server/NodeToControllerChannelManager.scala b/core/src/main/scala/kafka/server/NodeToControllerChannelManager.scala index 0017a5876af..a0e4bbbc463 100644 --- a/core/src/main/scala/kafka/server/NodeToControllerChannelManager.scala +++ b/core/src/main/scala/kafka/server/NodeToControllerChannelManager.scala @@ -112,7 +112,7 @@ class RaftControllerNodeProvider( val saslMechanism: String ) extends ControllerNodeProvider with Logging { - private def idToNode(id: Int): Option[Node] = raftManager.voterNode(id, listenerName.value()) + private def idToNode(id: Int): Option[Node] = raftManager.voterNode(id, listenerName) override def getControllerInfo(): ControllerInformation = ControllerInformation(raftManager.leaderAndEpoch.leaderId.asScala.flatMap(idToNode), diff --git a/core/src/main/scala/kafka/server/SharedServer.scala b/core/src/main/scala/kafka/server/SharedServer.scala index 215208f9f63..ea92dd61f5f 100644 --- a/core/src/main/scala/kafka/server/SharedServer.scala +++ b/core/src/main/scala/kafka/server/SharedServer.scala @@ -41,6 +41,7 @@ import java.util.Arrays import java.util.Optional import java.util.concurrent.atomic.AtomicReference import java.util.concurrent.{CompletableFuture, TimeUnit} +import java.util.{Collection => JCollection} import java.util.{Map => JMap} @@ -94,6 +95,7 @@ class SharedServer( val time: Time, private val _metrics: Metrics, val controllerQuorumVotersFuture: CompletableFuture[JMap[Integer, InetSocketAddress]], + val bootstrapServers: JCollection[InetSocketAddress], val faultHandlerFactory: FaultHandlerFactory ) extends Logging { private val logContext: LogContext = new LogContext(s"[SharedServer id=${sharedServerConfig.nodeId}] ") @@ -265,6 +267,7 @@ class SharedServer( metrics, Some(s"kafka-${sharedServerConfig.nodeId}-raft"), // No dash expected at the end controllerQuorumVotersFuture, + bootstrapServers, raftManagerFaultHandler ) raftManager = _raftManager diff --git a/core/src/main/scala/kafka/tools/StorageTool.scala b/core/src/main/scala/kafka/tools/StorageTool.scala index c79548761d0..8481f8468b9 100644 --- a/core/src/main/scala/kafka/tools/StorageTool.scala +++ b/core/src/main/scala/kafka/tools/StorageTool.scala @@ -502,7 +502,7 @@ object StorageTool extends Logging { metaPropertiesEnsemble.verify(metaProperties.clusterId(), metaProperties.nodeId(), util.EnumSet.noneOf(classOf[VerificationFlag])) - System.out.println(s"metaPropertiesEnsemble=$metaPropertiesEnsemble") + stream.println(s"metaPropertiesEnsemble=$metaPropertiesEnsemble") val copier = new MetaPropertiesEnsemble.Copier(metaPropertiesEnsemble) if (!(ignoreFormatted || copier.logDirProps().isEmpty)) { val firstLogDir = copier.logDirProps().keySet().iterator().next() diff --git a/core/src/main/scala/kafka/tools/TestRaftServer.scala b/core/src/main/scala/kafka/tools/TestRaftServer.scala index d357ad0bd56..0acae6c5dc3 100644 --- a/core/src/main/scala/kafka/tools/TestRaftServer.scala +++ b/core/src/main/scala/kafka/tools/TestRaftServer.scala @@ -95,6 +95,7 @@ class TestRaftServer( metrics, Some(threadNamePrefix), CompletableFuture.completedFuture(QuorumConfig.parseVoterConnections(config.quorumVoters)), + QuorumConfig.parseBootstrapServers(config.quorumBootstrapServers), new ProcessTerminatingFaultHandler.Builder().build() ) diff --git a/core/src/test/java/kafka/testkit/KafkaClusterTestKit.java b/core/src/test/java/kafka/testkit/KafkaClusterTestKit.java index 5365652a5fc..94d94dc7173 100644 --- a/core/src/test/java/kafka/testkit/KafkaClusterTestKit.java +++ b/core/src/test/java/kafka/testkit/KafkaClusterTestKit.java @@ -239,12 +239,15 @@ public class KafkaClusterTestKit implements AutoCloseable { ThreadUtils.createThreadFactory("kafka-cluster-test-kit-executor-%d", false)); for (ControllerNode node : nodes.controllerNodes().values()) { setupNodeDirectories(baseDirectory, node.metadataDirectory(), Collections.emptyList()); - SharedServer sharedServer = new SharedServer(createNodeConfig(node), - node.initialMetaPropertiesEnsemble(), - Time.SYSTEM, - new Metrics(), - connectFutureManager.future, - faultHandlerFactory); + SharedServer sharedServer = new SharedServer( + createNodeConfig(node), + node.initialMetaPropertiesEnsemble(), + Time.SYSTEM, + new Metrics(), + connectFutureManager.future, + Collections.emptyList(), + faultHandlerFactory + ); ControllerServer controller = null; try { controller = new ControllerServer( @@ -267,13 +270,18 @@ public class KafkaClusterTestKit implements AutoCloseable { jointServers.put(node.id(), sharedServer); } for (BrokerNode node : nodes.brokerNodes().values()) { - SharedServer sharedServer = jointServers.computeIfAbsent(node.id(), - id -> new SharedServer(createNodeConfig(node), + SharedServer sharedServer = jointServers.computeIfAbsent( + node.id(), + id -> new SharedServer( + createNodeConfig(node), node.initialMetaPropertiesEnsemble(), Time.SYSTEM, new Metrics(), connectFutureManager.future, - faultHandlerFactory)); + Collections.emptyList(), + faultHandlerFactory + ) + ); BrokerServer broker = null; try { broker = new BrokerServer(sharedServer); diff --git a/core/src/test/resources/log4j.properties b/core/src/test/resources/log4j.properties index f7fb7364a3c..b265ee9cdaa 100644 --- a/core/src/test/resources/log4j.properties +++ b/core/src/test/resources/log4j.properties @@ -21,6 +21,5 @@ log4j.appender.stdout.layout.ConversionPattern=[%d] %p %m (%c:%L)%n log4j.logger.kafka=WARN log4j.logger.org.apache.kafka=WARN - # zkclient can be verbose, during debugging it is common to adjust it separately log4j.logger.org.apache.zookeeper=WARN diff --git a/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala b/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala index 7d32d1bd3e0..b4617af1503 100644 --- a/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala +++ b/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala @@ -21,12 +21,12 @@ import kafka.utils.{TestInfoUtils, TestUtils} import org.apache.kafka.clients.admin.{NewPartitions, NewTopic} import org.apache.kafka.clients.consumer._ import org.apache.kafka.clients.producer.{ProducerConfig, ProducerRecord} -import org.apache.kafka.common.{KafkaException, MetricName, TopicPartition} import org.apache.kafka.common.config.TopicConfig import org.apache.kafka.common.errors.{InvalidGroupIdException, InvalidTopicException, TimeoutException, WakeupException} import org.apache.kafka.common.header.Headers import org.apache.kafka.common.record.{CompressionType, TimestampType} import org.apache.kafka.common.serialization._ +import org.apache.kafka.common.{KafkaException, MetricName, TopicPartition} import org.apache.kafka.test.{MockConsumerInterceptor, MockProducerInterceptor} import org.junit.jupiter.api.Assertions._ import org.junit.jupiter.api.Timeout diff --git a/core/src/test/scala/integration/kafka/server/QuorumTestHarness.scala b/core/src/test/scala/integration/kafka/server/QuorumTestHarness.scala index 9f787a1b168..82b5b4cfd1e 100755 --- a/core/src/test/scala/integration/kafka/server/QuorumTestHarness.scala +++ b/core/src/test/scala/integration/kafka/server/QuorumTestHarness.scala @@ -124,12 +124,15 @@ class KRaftQuorumImplementation( metaPropertiesEnsemble.verify(Optional.of(clusterId), OptionalInt.of(config.nodeId), util.EnumSet.of(REQUIRE_AT_LEAST_ONE_VALID, REQUIRE_METADATA_LOG_DIR)) - val sharedServer = new SharedServer(config, + val sharedServer = new SharedServer( + config, metaPropertiesEnsemble, time, new Metrics(), controllerQuorumVotersFuture, - faultHandlerFactory) + controllerQuorumVotersFuture.get().values(), + faultHandlerFactory + ) var broker: BrokerServer = null try { broker = new BrokerServer(sharedServer) @@ -371,12 +374,15 @@ abstract class QuorumTestHarness extends Logging { metaPropertiesEnsemble.verify(Optional.of(metaProperties.clusterId().get()), OptionalInt.of(nodeId), util.EnumSet.of(REQUIRE_AT_LEAST_ONE_VALID, REQUIRE_METADATA_LOG_DIR)) - val sharedServer = new SharedServer(config, + val sharedServer = new SharedServer( + config, metaPropertiesEnsemble, Time.SYSTEM, new Metrics(), controllerQuorumVotersFuture, - faultHandlerFactory) + Collections.emptyList(), + faultHandlerFactory + ) var controllerServer: ControllerServer = null try { controllerServer = new ControllerServer( diff --git a/core/src/test/scala/unit/kafka/KafkaConfigTest.scala b/core/src/test/scala/unit/kafka/KafkaConfigTest.scala index 457326cd19a..3a1fc2e4bda 100644 --- a/core/src/test/scala/unit/kafka/KafkaConfigTest.scala +++ b/core/src/test/scala/unit/kafka/KafkaConfigTest.scala @@ -86,7 +86,7 @@ class KafkaConfigTest { @Test def testBrokerRoleNodeIdValidation(): Unit = { - // Ensure that validation is happening at startup to check that brokers do not use their node.id as a voter in controller.quorum.voters + // Ensure that validation is happening at startup to check that brokers do not use their node.id as a voter in controller.quorum.voters val propertiesFile = new Properties propertiesFile.setProperty(KRaftConfigs.PROCESS_ROLES_CONFIG, "broker") propertiesFile.setProperty(KRaftConfigs.NODE_ID_CONFIG, "1") @@ -102,7 +102,7 @@ class KafkaConfigTest { @Test def testControllerRoleNodeIdValidation(): Unit = { - // Ensure that validation is happening at startup to check that controllers use their node.id as a voter in controller.quorum.voters + // Ensure that validation is happening at startup to check that controllers use their node.id as a voter in controller.quorum.voters val propertiesFile = new Properties propertiesFile.setProperty(KRaftConfigs.PROCESS_ROLES_CONFIG, "controller") propertiesFile.setProperty(KRaftConfigs.NODE_ID_CONFIG, "1") diff --git a/core/src/test/scala/unit/kafka/raft/RaftManagerTest.scala b/core/src/test/scala/unit/kafka/raft/RaftManagerTest.scala index 3416ffe65b6..da9d29304e5 100644 --- a/core/src/test/scala/unit/kafka/raft/RaftManagerTest.scala +++ b/core/src/test/scala/unit/kafka/raft/RaftManagerTest.scala @@ -118,6 +118,7 @@ class RaftManagerTest { new Metrics(Time.SYSTEM), Option.empty, CompletableFuture.completedFuture(QuorumConfig.parseVoterConnections(config.quorumVoters)), + QuorumConfig.parseBootstrapServers(config.quorumBootstrapServers), mock(classOf[FaultHandler]) ) } diff --git a/core/src/test/scala/unit/kafka/server/KafkaConfigTest.scala b/core/src/test/scala/unit/kafka/server/KafkaConfigTest.scala index 933f514df59..266b64560fb 100755 --- a/core/src/test/scala/unit/kafka/server/KafkaConfigTest.scala +++ b/core/src/test/scala/unit/kafka/server/KafkaConfigTest.scala @@ -19,7 +19,7 @@ package kafka.server import java.net.InetSocketAddress import java.util -import java.util.{Collections, Properties} +import java.util.{Arrays, Collections, Properties} import kafka.cluster.EndPoint import kafka.security.authorizer.AclAuthorizer import kafka.utils.TestUtils.assertBadConfigContainingMessage @@ -1032,6 +1032,7 @@ class KafkaConfigTest { // Raft Quorum Configs case QuorumConfig.QUORUM_VOTERS_CONFIG => // ignore string + case QuorumConfig.QUORUM_BOOTSTRAP_SERVERS_CONFIG => // ignore string case QuorumConfig.QUORUM_ELECTION_TIMEOUT_MS_CONFIG => assertPropertyInvalid(baseProperties, name, "not_a_number") case QuorumConfig.QUORUM_FETCH_TIMEOUT_MS_CONFIG => assertPropertyInvalid(baseProperties, name, "not_a_number") case QuorumConfig.QUORUM_ELECTION_BACKOFF_MAX_MS_CONFIG => assertPropertyInvalid(baseProperties, name, "not_a_number") @@ -1402,6 +1403,23 @@ class KafkaConfigTest { assertEquals(expectedVoters, addresses) } + @Test + def testParseQuorumBootstrapServers(): Unit = { + val expected = Arrays.asList( + InetSocketAddress.createUnresolved("kafka1", 9092), + InetSocketAddress.createUnresolved("kafka2", 9092) + ) + + val props = TestUtils.createBrokerConfig(0, null) + props.setProperty(QuorumConfig.QUORUM_BOOTSTRAP_SERVERS_CONFIG, "kafka1:9092,kafka2:9092") + + val addresses = QuorumConfig.parseBootstrapServers( + KafkaConfig.fromProps(props).quorumBootstrapServers + ) + + assertEquals(expected, addresses) + } + @Test def testAcceptsLargeNodeIdForRaftBasedCase(): Unit = { // Generation of Broker IDs is not supported when using Raft-based controller quorums, diff --git a/core/src/test/scala/unit/kafka/tools/DumpLogSegmentsTest.scala b/core/src/test/scala/unit/kafka/tools/DumpLogSegmentsTest.scala index b8764f5fae3..c625dc6e968 100644 --- a/core/src/test/scala/unit/kafka/tools/DumpLogSegmentsTest.scala +++ b/core/src/test/scala/unit/kafka/tools/DumpLogSegmentsTest.scala @@ -22,8 +22,8 @@ import java.nio.ByteBuffer import java.util import java.util.Collections import java.util.Optional -import java.util.Arrays import java.util.Properties +import java.util.stream.IntStream import kafka.log.{LogTestUtils, UnifiedLog} import kafka.raft.{KafkaMetadataLog, MetadataLogConfig} import kafka.server.{BrokerTopicStats, KafkaRaftServer} @@ -338,7 +338,7 @@ class DumpLogSegmentsTest { .setLastContainedLogTimestamp(lastContainedLogTimestamp) .setRawSnapshotWriter(metadataLog.createNewSnapshot(new OffsetAndEpoch(0, 0)).get) .setKraftVersion(1) - .setVoterSet(Optional.of(VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true)))) + .setVoterSet(Optional.of(VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(1, 2, 3), true)))) .build(MetadataRecordSerde.INSTANCE) ) { snapshotWriter => snapshotWriter.append(metadataRecords.asJava) diff --git a/raft/src/main/java/org/apache/kafka/raft/ElectionState.java b/raft/src/main/java/org/apache/kafka/raft/ElectionState.java index 005ff23a4f9..825acf7df69 100644 --- a/raft/src/main/java/org/apache/kafka/raft/ElectionState.java +++ b/raft/src/main/java/org/apache/kafka/raft/ElectionState.java @@ -30,9 +30,9 @@ import org.apache.kafka.raft.internals.ReplicaKey; * Encapsulate election state stored on disk after every state change. */ final public class ElectionState { - private static int unknownLeaderId = -1; - private static int notVoted = -1; - private static Uuid noVotedDirectoryId = Uuid.ZERO_UUID; + private static final int UNKNOWN_LEADER_ID = -1; + private static final int NOT_VOTED = -1; + private static final Uuid NO_VOTED_DIRECTORY_ID = Uuid.ZERO_UUID; private final int epoch; private final OptionalInt leaderId; @@ -95,7 +95,7 @@ final public class ElectionState { } public int leaderIdOrSentinel() { - return leaderId.orElse(unknownLeaderId); + return leaderId.orElse(UNKNOWN_LEADER_ID); } public OptionalInt optionalLeaderId() { @@ -126,7 +126,7 @@ final public class ElectionState { QuorumStateData data = new QuorumStateData() .setLeaderEpoch(epoch) .setLeaderId(leaderIdOrSentinel()) - .setVotedId(votedKey.map(ReplicaKey::id).orElse(notVoted)); + .setVotedId(votedKey.map(ReplicaKey::id).orElse(NOT_VOTED)); if (version == 0) { List dataVoters = voters @@ -135,7 +135,7 @@ final public class ElectionState { .collect(Collectors.toList()); data.setCurrentVoters(dataVoters); } else if (version == 1) { - data.setVotedDirectoryId(votedKey.flatMap(ReplicaKey::directoryId).orElse(noVotedDirectoryId)); + data.setVotedDirectoryId(votedKey.flatMap(ReplicaKey::directoryId).orElse(NO_VOTED_DIRECTORY_ID)); } else { throw new IllegalStateException( String.format( @@ -198,17 +198,17 @@ final public class ElectionState { } public static ElectionState fromQuorumStateData(QuorumStateData data) { - Optional votedDirectoryId = data.votedDirectoryId().equals(noVotedDirectoryId) ? + Optional votedDirectoryId = data.votedDirectoryId().equals(NO_VOTED_DIRECTORY_ID) ? Optional.empty() : Optional.of(data.votedDirectoryId()); - Optional votedKey = data.votedId() == notVoted ? + Optional votedKey = data.votedId() == NOT_VOTED ? Optional.empty() : Optional.of(ReplicaKey.of(data.votedId(), votedDirectoryId)); return new ElectionState( data.leaderEpoch(), - data.leaderId() == unknownLeaderId ? OptionalInt.empty() : OptionalInt.of(data.leaderId()), + data.leaderId() == UNKNOWN_LEADER_ID ? OptionalInt.empty() : OptionalInt.of(data.leaderId()), votedKey, data.currentVoters().stream().map(QuorumStateData.Voter::voterId).collect(Collectors.toSet()) ); diff --git a/raft/src/main/java/org/apache/kafka/raft/FollowerState.java b/raft/src/main/java/org/apache/kafka/raft/FollowerState.java index 49bfaff181e..0491689505e 100644 --- a/raft/src/main/java/org/apache/kafka/raft/FollowerState.java +++ b/raft/src/main/java/org/apache/kafka/raft/FollowerState.java @@ -19,6 +19,7 @@ package org.apache.kafka.raft; import java.util.Optional; import java.util.OptionalLong; import java.util.Set; +import org.apache.kafka.common.Node; import org.apache.kafka.common.utils.LogContext; import org.apache.kafka.common.utils.Time; import org.apache.kafka.common.utils.Timer; @@ -29,7 +30,7 @@ import org.slf4j.Logger; public class FollowerState implements EpochState { private final int fetchTimeoutMs; private final int epoch; - private final int leaderId; + private final Node leader; private final Set voters; // Used for tracking the expiration of both the Fetch and FetchSnapshot requests private final Timer fetchTimer; @@ -37,14 +38,14 @@ public class FollowerState implements EpochState { /* Used to track the currently fetching snapshot. When fetching snapshot regular * Fetch request are paused */ - private Optional fetchingSnapshot; + private Optional fetchingSnapshot = Optional.empty(); private final Logger log; public FollowerState( Time time, int epoch, - int leaderId, + Node leader, Set voters, Optional highWatermark, int fetchTimeoutMs, @@ -52,17 +53,16 @@ public class FollowerState implements EpochState { ) { this.fetchTimeoutMs = fetchTimeoutMs; this.epoch = epoch; - this.leaderId = leaderId; + this.leader = leader; this.voters = voters; this.fetchTimer = time.timer(fetchTimeoutMs); this.highWatermark = highWatermark; - this.fetchingSnapshot = Optional.empty(); this.log = logContext.logger(FollowerState.class); } @Override public ElectionState election() { - return ElectionState.withElectedLeader(epoch, leaderId, voters); + return ElectionState.withElectedLeader(epoch, leader.id(), voters); } @Override @@ -80,8 +80,8 @@ public class FollowerState implements EpochState { return fetchTimer.remainingMs(); } - public int leaderId() { - return leaderId; + public Node leader() { + return leader; } public boolean hasFetchTimeoutExpired(long currentTimeMs) { @@ -156,7 +156,7 @@ public class FollowerState implements EpochState { log.debug( "Rejecting vote request from candidate ({}) since we already have a leader {} in epoch {}", candidateKey, - leaderId(), + leader, epoch ); return false; @@ -164,14 +164,16 @@ public class FollowerState implements EpochState { @Override public String toString() { - return "FollowerState(" + - "fetchTimeoutMs=" + fetchTimeoutMs + - ", epoch=" + epoch + - ", leaderId=" + leaderId + - ", voters=" + voters + - ", highWatermark=" + highWatermark + - ", fetchingSnapshot=" + fetchingSnapshot + - ')'; + return String.format( + "FollowerState(fetchTimeoutMs=%d, epoch=%d, leader=%s voters=%s, highWatermark=%s, " + + "fetchingSnapshot=%s)", + fetchTimeoutMs, + epoch, + leader, + voters, + highWatermark, + fetchingSnapshot + ); } @Override diff --git a/raft/src/main/java/org/apache/kafka/raft/KafkaNetworkChannel.java b/raft/src/main/java/org/apache/kafka/raft/KafkaNetworkChannel.java index 5ec91752cbe..f6341b76e7e 100644 --- a/raft/src/main/java/org/apache/kafka/raft/KafkaNetworkChannel.java +++ b/raft/src/main/java/org/apache/kafka/raft/KafkaNetworkChannel.java @@ -24,6 +24,7 @@ import org.apache.kafka.common.message.EndQuorumEpochRequestData; import org.apache.kafka.common.message.FetchRequestData; import org.apache.kafka.common.message.FetchSnapshotRequestData; import org.apache.kafka.common.message.VoteRequestData; +import org.apache.kafka.common.network.ListenerName; import org.apache.kafka.common.protocol.ApiKeys; import org.apache.kafka.common.protocol.ApiMessage; import org.apache.kafka.common.protocol.Errors; @@ -39,12 +40,9 @@ import org.apache.kafka.server.util.RequestAndCompletionHandler; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.net.InetSocketAddress; import java.util.ArrayList; import java.util.Collection; -import java.util.HashMap; import java.util.List; -import java.util.Map; import java.util.Queue; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicInteger; @@ -83,9 +81,17 @@ public class KafkaNetworkChannel implements NetworkChannel { private final SendThread requestThread; private final AtomicInteger correlationIdCounter = new AtomicInteger(0); - private final Map endpoints = new HashMap<>(); - public KafkaNetworkChannel(Time time, KafkaClient client, int requestTimeoutMs, String threadNamePrefix) { + private final ListenerName listenerName; + + public KafkaNetworkChannel( + Time time, + ListenerName listenerName, + KafkaClient client, + int requestTimeoutMs, + String threadNamePrefix + ) { + this.listenerName = listenerName; this.requestThread = new SendThread( threadNamePrefix + "-outbound-request-thread", client, @@ -102,23 +108,23 @@ public class KafkaNetworkChannel implements NetworkChannel { @Override public void send(RaftRequest.Outbound request) { - Node node = endpoints.get(request.destinationId()); + Node node = request.destination(); if (node != null) { requestThread.sendRequest(new RequestAndCompletionHandler( - request.createdTimeMs, + request.createdTimeMs(), node, - buildRequest(request.data), + buildRequest(request.data()), response -> sendOnComplete(request, response) )); } else - sendCompleteFuture(request, errorResponse(request.data, Errors.BROKER_NOT_AVAILABLE)); + sendCompleteFuture(request, errorResponse(request.data(), Errors.BROKER_NOT_AVAILABLE)); } private void sendCompleteFuture(RaftRequest.Outbound request, ApiMessage message) { RaftResponse.Inbound response = new RaftResponse.Inbound( - request.correlationId, + request.correlationId(), message, - request.destinationId() + request.destination() ); request.completion.complete(response); } @@ -127,16 +133,16 @@ public class KafkaNetworkChannel implements NetworkChannel { ApiMessage response; if (clientResponse.versionMismatch() != null) { log.error("Request {} failed due to unsupported version error", request, clientResponse.versionMismatch()); - response = errorResponse(request.data, Errors.UNSUPPORTED_VERSION); + response = errorResponse(request.data(), Errors.UNSUPPORTED_VERSION); } else if (clientResponse.authenticationException() != null) { // For now we treat authentication errors as retriable. We use the // `NETWORK_EXCEPTION` error code for lack of a good alternative. // Note that `NodeToControllerChannelManager` will still log the // authentication errors so that users have a chance to fix the problem. log.error("Request {} failed due to authentication error", request, clientResponse.authenticationException()); - response = errorResponse(request.data, Errors.NETWORK_EXCEPTION); + response = errorResponse(request.data(), Errors.NETWORK_EXCEPTION); } else if (clientResponse.wasDisconnected()) { - response = errorResponse(request.data, Errors.BROKER_NOT_AVAILABLE); + response = errorResponse(request.data(), Errors.BROKER_NOT_AVAILABLE); } else { response = clientResponse.responseBody().data(); } @@ -149,9 +155,8 @@ public class KafkaNetworkChannel implements NetworkChannel { } @Override - public void updateEndpoint(int id, InetSocketAddress address) { - Node node = new Node(id, address.getHostString(), address.getPort()); - endpoints.put(id, node); + public ListenerName listenerName() { + return listenerName; } public void start() { diff --git a/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java b/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java index 10910c3db79..6a14b37cfff 100644 --- a/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java +++ b/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java @@ -37,6 +37,7 @@ import org.apache.kafka.common.message.FetchSnapshotResponseData; import org.apache.kafka.common.message.VoteRequestData; import org.apache.kafka.common.message.VoteResponseData; import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.network.ListenerName; import org.apache.kafka.common.protocol.ApiKeys; import org.apache.kafka.common.protocol.ApiMessage; import org.apache.kafka.common.protocol.Errors; @@ -60,7 +61,6 @@ import org.apache.kafka.common.utils.BufferSupplier; import org.apache.kafka.common.utils.LogContext; import org.apache.kafka.common.utils.Time; import org.apache.kafka.common.utils.Timer; -import org.apache.kafka.raft.RequestManager.ConnectionState; import org.apache.kafka.raft.errors.NotLeaderException; import org.apache.kafka.raft.internals.BatchAccumulator; import org.apache.kafka.raft.internals.BatchMemoryPool; @@ -85,6 +85,7 @@ import org.apache.kafka.snapshot.SnapshotWriter; import org.slf4j.Logger; import java.net.InetSocketAddress; +import java.util.Collection; import java.util.Collections; import java.util.IdentityHashMap; import java.util.Iterator; @@ -100,8 +101,10 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; +import java.util.stream.Collectors; import static java.util.concurrent.CompletableFuture.completedFuture; import static org.apache.kafka.raft.RaftUtil.hasValidTopicPartition; @@ -209,6 +212,7 @@ final public class KafkaRaftClient implements RaftClient { ExpirationService expirationService, LogContext logContext, String clusterId, + Collection bootstrapServers, QuorumConfig quorumConfig ) { this( @@ -223,6 +227,7 @@ final public class KafkaRaftClient implements RaftClient { expirationService, MAX_FETCH_WAIT_MS, clusterId, + bootstrapServers, logContext, new Random(), quorumConfig @@ -241,6 +246,7 @@ final public class KafkaRaftClient implements RaftClient { ExpirationService expirationService, int fetchMaxWaitMs, String clusterId, + Collection bootstrapServers, LogContext logContext, Random random, QuorumConfig quorumConfig @@ -262,6 +268,30 @@ final public class KafkaRaftClient implements RaftClient { this.random = random; this.quorumConfig = quorumConfig; this.snapshotCleaner = new RaftMetadataLogCleanerManager(logger, time, 60000, log::maybeClean); + + if (!bootstrapServers.isEmpty()) { + // generate Node objects from network addresses by using decreasing negative ids + AtomicInteger id = new AtomicInteger(-2); + List bootstrapNodes = bootstrapServers + .stream() + .map(address -> + new Node( + id.getAndDecrement(), + address.getHostString(), + address.getPort() + ) + ) + .collect(Collectors.toList()); + + logger.info("Starting request manager with bootstrap servers: {}", bootstrapNodes); + + requestManager = new RequestManager( + bootstrapNodes, + quorumConfig.retryBackoffMs(), + quorumConfig.requestTimeoutMs(), + random + ); + } } private void updateFollowerHighWatermark( @@ -378,12 +408,11 @@ final public class KafkaRaftClient implements RaftClient { public void initialize( Map voterAddresses, - String listenerName, QuorumStateStore quorumStateStore, Metrics metrics ) { partitionState = new KRaftControlRecordStateMachine( - Optional.of(VoterSet.fromInetSocketAddresses(listenerName, voterAddresses)), + Optional.of(VoterSet.fromInetSocketAddresses(channel.listenerName(), voterAddresses)), log, serde, BufferSupplier.create(), @@ -394,17 +423,35 @@ final public class KafkaRaftClient implements RaftClient { logger.info("Reading KRaft snapshot and log as part of the initialization"); partitionState.updateState(); - VoterSet lastVoterSet = partitionState.lastVoterSet(); - requestManager = new RequestManager( - lastVoterSet.voterIds(), - quorumConfig.retryBackoffMs(), - quorumConfig.requestTimeoutMs(), - random - ); + if (requestManager == null) { + // The request manager wasn't created using the bootstrap servers + // create it using the voters static configuration + List bootstrapNodes = voterAddresses + .entrySet() + .stream() + .map(entry -> + new Node( + entry.getKey(), + entry.getValue().getHostString(), + entry.getValue().getPort() + ) + ) + .collect(Collectors.toList()); + + logger.info("Starting request manager with static voters: {}", bootstrapNodes); + + requestManager = new RequestManager( + bootstrapNodes, + quorumConfig.retryBackoffMs(), + quorumConfig.requestTimeoutMs(), + random + ); + } quorum = new QuorumState( nodeId, nodeDirectoryId, + channel.listenerName(), partitionState::lastVoterSet, partitionState::lastKraftVersion, quorumConfig.electionTimeoutMs(), @@ -420,10 +467,6 @@ final public class KafkaRaftClient implements RaftClient { // so there are no unknown voter connections. Report this metric as 0. kafkaRaftMetrics.updateNumUnknownVoterConnections(0); - for (Integer voterId : lastVoterSet.voterIds()) { - channel.updateEndpoint(voterId, lastVoterSet.voterAddress(voterId, listenerName).get()); - } - quorum.initialize(new OffsetAndEpoch(log.endOffset().offset, log.lastFetchedEpoch())); long currentTimeMs = time.milliseconds(); @@ -569,10 +612,10 @@ final public class KafkaRaftClient implements RaftClient { private void transitionToFollower( int epoch, - int leaderId, + Node leader, long currentTimeMs ) { - quorum.transitionToFollower(epoch, leaderId); + quorum.transitionToFollower(epoch, leader); maybeFireLeaderChange(); onBecomeFollower(currentTimeMs); } @@ -601,7 +644,7 @@ final public class KafkaRaftClient implements RaftClient { private VoteResponseData handleVoteRequest( RaftRequest.Inbound requestMetadata ) { - VoteRequestData request = (VoteRequestData) requestMetadata.data; + VoteRequestData request = (VoteRequestData) requestMetadata.data(); if (!hasValidClusterId(request.clusterId())) { return new VoteResponseData().setErrorCode(Errors.INCONSISTENT_CLUSTER_ID.code()); @@ -652,8 +695,8 @@ final public class KafkaRaftClient implements RaftClient { RaftResponse.Inbound responseMetadata, long currentTimeMs ) { - int remoteNodeId = responseMetadata.sourceId(); - VoteResponseData response = (VoteResponseData) responseMetadata.data; + int remoteNodeId = responseMetadata.source().id(); + VoteResponseData response = (VoteResponseData) responseMetadata.data(); Errors topLevelError = Errors.forCode(response.errorCode()); if (topLevelError != Errors.NONE) { return handleTopLevelError(topLevelError, responseMetadata); @@ -751,7 +794,7 @@ final public class KafkaRaftClient implements RaftClient { RaftRequest.Inbound requestMetadata, long currentTimeMs ) { - BeginQuorumEpochRequestData request = (BeginQuorumEpochRequestData) requestMetadata.data; + BeginQuorumEpochRequestData request = (BeginQuorumEpochRequestData) requestMetadata.data(); if (!hasValidClusterId(request.clusterId())) { return new BeginQuorumEpochResponseData().setErrorCode(Errors.INCONSISTENT_CLUSTER_ID.code()); @@ -773,7 +816,11 @@ final public class KafkaRaftClient implements RaftClient { return buildBeginQuorumEpochResponse(errorOpt.get()); } - maybeTransition(OptionalInt.of(requestLeaderId), requestEpoch, currentTimeMs); + maybeTransition( + partitionState.lastVoterSet().voterNode(requestLeaderId, channel.listenerName()), + requestEpoch, + currentTimeMs + ); return buildBeginQuorumEpochResponse(Errors.NONE); } @@ -781,8 +828,8 @@ final public class KafkaRaftClient implements RaftClient { RaftResponse.Inbound responseMetadata, long currentTimeMs ) { - int remoteNodeId = responseMetadata.sourceId(); - BeginQuorumEpochResponseData response = (BeginQuorumEpochResponseData) responseMetadata.data; + int remoteNodeId = responseMetadata.source().id(); + BeginQuorumEpochResponseData response = (BeginQuorumEpochResponseData) responseMetadata.data(); Errors topLevelError = Errors.forCode(response.errorCode()); if (topLevelError != Errors.NONE) { return handleTopLevelError(topLevelError, responseMetadata); @@ -840,7 +887,7 @@ final public class KafkaRaftClient implements RaftClient { RaftRequest.Inbound requestMetadata, long currentTimeMs ) { - EndQuorumEpochRequestData request = (EndQuorumEpochRequestData) requestMetadata.data; + EndQuorumEpochRequestData request = (EndQuorumEpochRequestData) requestMetadata.data(); if (!hasValidClusterId(request.clusterId())) { return new EndQuorumEpochResponseData().setErrorCode(Errors.INCONSISTENT_CLUSTER_ID.code()); @@ -861,11 +908,15 @@ final public class KafkaRaftClient implements RaftClient { if (errorOpt.isPresent()) { return buildEndQuorumEpochResponse(errorOpt.get()); } - maybeTransition(OptionalInt.of(requestLeaderId), requestEpoch, currentTimeMs); + maybeTransition( + partitionState.lastVoterSet().voterNode(requestLeaderId, channel.listenerName()), + requestEpoch, + currentTimeMs + ); if (quorum.isFollower()) { FollowerState state = quorum.followerStateOrThrow(); - if (state.leaderId() == requestLeaderId) { + if (state.leader().id() == requestLeaderId) { List preferredSuccessors = partitionRequest.preferredSuccessors(); long electionBackoffMs = endEpochElectionBackoff(preferredSuccessors); logger.debug("Overriding follower fetch timeout to {} after receiving " + @@ -894,7 +945,7 @@ final public class KafkaRaftClient implements RaftClient { RaftResponse.Inbound responseMetadata, long currentTimeMs ) { - EndQuorumEpochResponseData response = (EndQuorumEpochResponseData) responseMetadata.data; + EndQuorumEpochResponseData response = (EndQuorumEpochResponseData) responseMetadata.data(); Errors topLevelError = Errors.forCode(response.errorCode()); if (topLevelError != Errors.NONE) { return handleTopLevelError(topLevelError, responseMetadata); @@ -917,7 +968,7 @@ final public class KafkaRaftClient implements RaftClient { return handled.get(); } else if (partitionError == Errors.NONE) { ResignedState resignedState = quorum.resignedStateOrThrow(); - resignedState.acknowledgeResignation(responseMetadata.sourceId()); + resignedState.acknowledgeResignation(responseMetadata.source().id()); return true; } else { return handleUnexpectedError(partitionError, responseMetadata); @@ -998,7 +1049,7 @@ final public class KafkaRaftClient implements RaftClient { RaftRequest.Inbound requestMetadata, long currentTimeMs ) { - FetchRequestData request = (FetchRequestData) requestMetadata.data; + FetchRequestData request = (FetchRequestData) requestMetadata.data(); if (!hasValidClusterId(request.clusterId())) { return completedFuture(new FetchResponseData().setErrorCode(Errors.INCONSISTENT_CLUSTER_ID.code())); @@ -1147,13 +1198,13 @@ final public class KafkaRaftClient implements RaftClient { RaftResponse.Inbound responseMetadata, long currentTimeMs ) { - FetchResponseData response = (FetchResponseData) responseMetadata.data; + FetchResponseData response = (FetchResponseData) responseMetadata.data(); Errors topLevelError = Errors.forCode(response.errorCode()); if (topLevelError != Errors.NONE) { return handleTopLevelError(topLevelError, responseMetadata); } - if (!RaftUtil.hasValidTopicPartition(response, log.topicPartition(), log.topicId())) { + if (!hasValidTopicPartition(response, log.topicPartition(), log.topicId())) { return false; } // If the ID is valid, we can set the topic name. @@ -1286,7 +1337,7 @@ final public class KafkaRaftClient implements RaftClient { RaftRequest.Inbound requestMetadata, long currentTimeMs ) { - DescribeQuorumRequestData describeQuorumRequestData = (DescribeQuorumRequestData) requestMetadata.data; + DescribeQuorumRequestData describeQuorumRequestData = (DescribeQuorumRequestData) requestMetadata.data(); if (!hasValidTopicPartition(describeQuorumRequestData, log.topicPartition())) { return DescribeQuorumRequest.getPartitionLevelErrorResponse( describeQuorumRequestData, Errors.UNKNOWN_TOPIC_OR_PARTITION); @@ -1325,7 +1376,7 @@ final public class KafkaRaftClient implements RaftClient { RaftRequest.Inbound requestMetadata, long currentTimeMs ) { - FetchSnapshotRequestData data = (FetchSnapshotRequestData) requestMetadata.data; + FetchSnapshotRequestData data = (FetchSnapshotRequestData) requestMetadata.data(); if (!hasValidClusterId(data.clusterId())) { return new FetchSnapshotResponseData().setErrorCode(Errors.INCONSISTENT_CLUSTER_ID.code()); @@ -1429,7 +1480,7 @@ final public class KafkaRaftClient implements RaftClient { RaftResponse.Inbound responseMetadata, long currentTimeMs ) { - FetchSnapshotResponseData data = (FetchSnapshotResponseData) responseMetadata.data; + FetchSnapshotResponseData data = (FetchSnapshotResponseData) responseMetadata.data(); Errors topLevelError = Errors.forCode(data.errorCode()); if (topLevelError != Errors.NONE) { return handleTopLevelError(topLevelError, responseMetadata); @@ -1593,6 +1644,12 @@ final public class KafkaRaftClient implements RaftClient { int epoch, long currentTimeMs ) { + Optional leader = leaderId.isPresent() ? + partitionState + .lastVoterSet() + .voterNode(leaderId.getAsInt(), channel.listenerName()) : + Optional.empty(); + if (epoch < quorum.epoch() || error == Errors.UNKNOWN_LEADER_EPOCH) { // We have a larger epoch, so the response is no longer relevant return Optional.of(true); @@ -1602,10 +1659,10 @@ final public class KafkaRaftClient implements RaftClient { // The response indicates that the request had a stale epoch, but we need // to validate the epoch from the response against our current state. - maybeTransition(leaderId, epoch, currentTimeMs); + maybeTransition(leader, epoch, currentTimeMs); return Optional.of(true); } else if (epoch == quorum.epoch() - && leaderId.isPresent() + && leader.isPresent() && !quorum.hasLeader()) { // Since we are transitioning to Follower, we will only forward the @@ -1613,7 +1670,7 @@ final public class KafkaRaftClient implements RaftClient { // the request be retried immediately (if needed) after the transition. // This handling allows an observer to discover the leader and append // to the log in the same Fetch request. - transitionToFollower(epoch, leaderId.getAsInt(), currentTimeMs); + transitionToFollower(epoch, leader.get(), currentTimeMs); if (error == Errors.NONE) { return Optional.empty(); } else { @@ -1635,24 +1692,28 @@ final public class KafkaRaftClient implements RaftClient { } private void maybeTransition( - OptionalInt leaderId, + Optional leader, int epoch, long currentTimeMs ) { + OptionalInt leaderId = leader.isPresent() ? + OptionalInt.of(leader.get().id()) : + OptionalInt.empty(); + if (!hasConsistentLeader(epoch, leaderId)) { - throw new IllegalStateException("Received request or response with leader " + leaderId + + throw new IllegalStateException("Received request or response with leader " + leader + " and epoch " + epoch + " which is inconsistent with current leader " + quorum.leaderId() + " and epoch " + quorum.epoch()); } else if (epoch > quorum.epoch()) { - if (leaderId.isPresent()) { - transitionToFollower(epoch, leaderId.getAsInt(), currentTimeMs); + if (leader.isPresent()) { + transitionToFollower(epoch, leader.get(), currentTimeMs); } else { transitionToUnattached(epoch); } - } else if (leaderId.isPresent() && !quorum.hasLeader()) { + } else if (leader.isPresent() && !quorum.hasLeader()) { // The request or response indicates the leader of the current epoch, // which is currently unknown - transitionToFollower(epoch, leaderId.getAsInt(), currentTimeMs); + transitionToFollower(epoch, leader.get(), currentTimeMs); } } @@ -1668,13 +1729,13 @@ final public class KafkaRaftClient implements RaftClient { private boolean handleUnexpectedError(Errors error, RaftResponse.Inbound response) { logger.error("Unexpected error {} in {} response: {}", - error, ApiKeys.forId(response.data.apiKey()), response); + error, ApiKeys.forId(response.data().apiKey()), response); return false; } private void handleResponse(RaftResponse.Inbound response, long currentTimeMs) { // The response epoch matches the local epoch, so we can handle the response - ApiKeys apiKey = ApiKeys.forId(response.data.apiKey()); + ApiKeys apiKey = ApiKeys.forId(response.data().apiKey()); final boolean handledSuccessfully; switch (apiKey) { @@ -1702,12 +1763,12 @@ final public class KafkaRaftClient implements RaftClient { throw new IllegalArgumentException("Received unexpected response type: " + apiKey); } - ConnectionState connection = requestManager.getOrCreate(response.sourceId()); - if (handledSuccessfully) { - connection.onResponseReceived(response.correlationId); - } else { - connection.onResponseError(response.correlationId, currentTimeMs); - } + requestManager.onResponseResult( + response.source(), + response.correlationId(), + handledSuccessfully, + currentTimeMs + ); } /** @@ -1749,7 +1810,7 @@ final public class KafkaRaftClient implements RaftClient { } private void handleRequest(RaftRequest.Inbound request, long currentTimeMs) { - ApiKeys apiKey = ApiKeys.forId(request.data.apiKey()); + ApiKeys apiKey = ApiKeys.forId(request.data().apiKey()); final CompletableFuture responseFuture; switch (apiKey) { @@ -1803,8 +1864,7 @@ final public class KafkaRaftClient implements RaftClient { handleRequest(request, currentTimeMs); } else if (message instanceof RaftResponse.Inbound) { RaftResponse.Inbound response = (RaftResponse.Inbound) message; - ConnectionState connection = requestManager.getOrCreate(response.sourceId()); - if (connection.isResponseExpected(response.correlationId)) { + if (requestManager.isResponseExpected(response.source(), response.correlationId())) { handleResponse(response, currentTimeMs); } else { logger.debug("Ignoring response {} since it is no longer needed", response); @@ -1819,25 +1879,23 @@ final public class KafkaRaftClient implements RaftClient { */ private long maybeSendRequest( long currentTimeMs, - int destinationId, + Node destination, Supplier requestSupplier ) { - ConnectionState connection = requestManager.getOrCreate(destinationId); - - if (connection.isBackingOff(currentTimeMs)) { - long remainingBackoffMs = connection.remainingBackoffMs(currentTimeMs); - logger.debug("Connection for {} is backing off for {} ms", destinationId, remainingBackoffMs); + if (requestManager.isBackingOff(destination, currentTimeMs)) { + long remainingBackoffMs = requestManager.remainingBackoffMs(destination, currentTimeMs); + logger.debug("Connection for {} is backing off for {} ms", destination, remainingBackoffMs); return remainingBackoffMs; } - if (connection.isReady(currentTimeMs)) { + if (requestManager.isReady(destination, currentTimeMs)) { int correlationId = channel.newCorrelationId(); ApiMessage request = requestSupplier.get(); RaftRequest.Outbound requestMessage = new RaftRequest.Outbound( correlationId, request, - destinationId, + destination, currentTimeMs ); @@ -1850,20 +1908,19 @@ final public class KafkaRaftClient implements RaftClient { response = new RaftResponse.Inbound( correlationId, errorResponse, - destinationId + destination ); } messageQueue.add(response); }); + requestManager.onRequestSent(destination, correlationId, currentTimeMs); channel.send(requestMessage); logger.trace("Sent outbound request: {}", requestMessage); - connection.onRequestSent(correlationId, currentTimeMs); - return Long.MAX_VALUE; } - return connection.remainingRequestTimeMs(currentTimeMs); + return requestManager.remainingRequestTimeMs(destination, currentTimeMs); } private EndQuorumEpochRequestData buildEndQuorumEpochRequest( @@ -1880,12 +1937,12 @@ final public class KafkaRaftClient implements RaftClient { private long maybeSendRequests( long currentTimeMs, - Set destinationIds, + Set destinations, Supplier requestSupplier ) { long minBackoffMs = Long.MAX_VALUE; - for (Integer destinationId : destinationIds) { - long backoffMs = maybeSendRequest(currentTimeMs, destinationId, requestSupplier); + for (Node destination : destinations) { + long backoffMs = maybeSendRequest(currentTimeMs, destination, requestSupplier); if (backoffMs < minBackoffMs) { minBackoffMs = backoffMs; } @@ -1929,15 +1986,15 @@ final public class KafkaRaftClient implements RaftClient { } private long maybeSendAnyVoterFetch(long currentTimeMs) { - OptionalInt readyVoterIdOpt = requestManager.findReadyVoter(currentTimeMs); - if (readyVoterIdOpt.isPresent()) { + Optional readyNode = requestManager.findReadyBootstrapServer(currentTimeMs); + if (readyNode.isPresent()) { return maybeSendRequest( currentTimeMs, - readyVoterIdOpt.getAsInt(), + readyNode.get(), this::buildFetchRequest ); } else { - return requestManager.backoffBeforeAvailableVoter(currentTimeMs); + return requestManager.backoffBeforeAvailableBootstrapServer(currentTimeMs); } } @@ -2038,7 +2095,9 @@ final public class KafkaRaftClient implements RaftClient { ResignedState state = quorum.resignedStateOrThrow(); long endQuorumBackoffMs = maybeSendRequests( currentTimeMs, - state.unackedVoters(), + partitionState + .lastVoterSet() + .voterNodes(state.unackedVoters().stream(), channel.listenerName()), () -> buildEndQuorumEpochRequest(state) ); @@ -2075,7 +2134,9 @@ final public class KafkaRaftClient implements RaftClient { long timeUntilSend = maybeSendRequests( currentTimeMs, - state.nonAcknowledgingVoters(), + partitionState + .lastVoterSet() + .voterNodes(state.nonAcknowledgingVoters().stream(), channel.listenerName()), this::buildBeginQuorumEpochRequest ); @@ -2090,7 +2151,9 @@ final public class KafkaRaftClient implements RaftClient { if (!state.isVoteRejected()) { return maybeSendRequests( currentTimeMs, - state.unrecordedVoters(), + partitionState + .lastVoterSet() + .voterNodes(state.unrecordedVoters().stream(), channel.listenerName()), this::buildVoteRequest ); } @@ -2163,14 +2226,16 @@ final public class KafkaRaftClient implements RaftClient { // If the current leader is backing off due to some failure or if the // request has timed out, then we attempt to send the Fetch to another // voter in order to discover if there has been a leader change. - ConnectionState connection = requestManager.getOrCreate(state.leaderId()); - if (connection.hasRequestTimedOut(currentTimeMs)) { + if (requestManager.hasRequestTimedOut(state.leader(), currentTimeMs)) { + // Once the request has timed out backoff the connection + requestManager.reset(state.leader()); backoffMs = maybeSendAnyVoterFetch(currentTimeMs); - connection.reset(); - } else if (connection.isBackingOff(currentTimeMs)) { + } else if (requestManager.isBackingOff(state.leader(), currentTimeMs)) { backoffMs = maybeSendAnyVoterFetch(currentTimeMs); - } else { + } else if (!requestManager.hasAnyInflightRequest(currentTimeMs)) { backoffMs = maybeSendFetchOrFetchSnapshot(state, currentTimeMs); + } else { + backoffMs = requestManager.backoffBeforeAvailableBootstrapServer(currentTimeMs); } return Math.min(backoffMs, state.remainingFetchTimeMs(currentTimeMs)); @@ -2189,7 +2254,7 @@ final public class KafkaRaftClient implements RaftClient { requestSupplier = this::buildFetchRequest; } - return maybeSendRequest(currentTimeMs, state.leaderId(), requestSupplier); + return maybeSendRequest(currentTimeMs, state.leader(), requestSupplier); } private long pollVoted(long currentTimeMs) { @@ -2549,8 +2614,8 @@ final public class KafkaRaftClient implements RaftClient { } } - public Optional voterNode(int id, String listener) { - return partitionState.lastVoterSet().voterNode(id, listener); + public Optional voterNode(int id, ListenerName listenerName) { + return partitionState.lastVoterSet().voterNode(id, listenerName); } // Visible only for test diff --git a/raft/src/main/java/org/apache/kafka/raft/NetworkChannel.java b/raft/src/main/java/org/apache/kafka/raft/NetworkChannel.java index 6c715b37887..20043ac4f97 100644 --- a/raft/src/main/java/org/apache/kafka/raft/NetworkChannel.java +++ b/raft/src/main/java/org/apache/kafka/raft/NetworkChannel.java @@ -16,7 +16,7 @@ */ package org.apache.kafka.raft; -import java.net.InetSocketAddress; +import org.apache.kafka.common.network.ListenerName; /** * A simple network interface with few assumptions. We do not assume ordering @@ -37,10 +37,11 @@ public interface NetworkChannel extends AutoCloseable { void send(RaftRequest.Outbound request); /** - * Update connection information for the given id. + * The name of listener used when sending requests. + * + * @return the name of the listener */ - void updateEndpoint(int id, InetSocketAddress address); + ListenerName listenerName(); default void close() throws InterruptedException {} - } diff --git a/raft/src/main/java/org/apache/kafka/raft/QuorumConfig.java b/raft/src/main/java/org/apache/kafka/raft/QuorumConfig.java index 5c9c20b763b..d7b18ba0840 100644 --- a/raft/src/main/java/org/apache/kafka/raft/QuorumConfig.java +++ b/raft/src/main/java/org/apache/kafka/raft/QuorumConfig.java @@ -54,6 +54,13 @@ public class QuorumConfig { "For example: 1@localhost:9092,2@localhost:9093,3@localhost:9094"; public static final List DEFAULT_QUORUM_VOTERS = Collections.emptyList(); + public static final String QUORUM_BOOTSTRAP_SERVERS_CONFIG = QUORUM_PREFIX + "bootstrap.servers"; + public static final String QUORUM_BOOTSTRAP_SERVERS_DOC = "List of endpoints to use for " + + "bootstrapping the cluster metadata. The endpoints are specified in comma-separated list " + + "of {host}:{port} entries. For example: " + + "localhost:9092,localhost:9093,localhost:9094."; + public static final List DEFAULT_QUORUM_BOOTSTRAP_SERVERS = Collections.emptyList(); + public static final String QUORUM_ELECTION_TIMEOUT_MS_CONFIG = QUORUM_PREFIX + "election.timeout.ms"; public static final String QUORUM_ELECTION_TIMEOUT_MS_DOC = "Maximum time in milliseconds to wait " + "without being able to fetch from the leader before triggering a new election"; @@ -163,7 +170,7 @@ public class QuorumConfig { List voterEntries, boolean requireRoutableAddresses ) { - Map voterMap = new HashMap<>(); + Map voterMap = new HashMap<>(voterEntries.size()); for (String voterMapEntry : voterEntries) { String[] idAndAddress = voterMapEntry.split("@"); if (idAndAddress.length != 2) { @@ -173,7 +180,7 @@ public class QuorumConfig { Integer voterId = parseVoterId(idAndAddress[0]); String host = Utils.getHost(idAndAddress[1]); - if (host == null) { + if (host == null || !Utils.validHostPattern(host)) { throw new ConfigException("Failed to parse host name from entry " + voterMapEntry + " for the configuration " + QUORUM_VOTERS_CONFIG + ". Each entry should be in the form `{id}@{host}:{port}`."); @@ -199,6 +206,41 @@ public class QuorumConfig { return voterMap; } + public static List parseBootstrapServers(List bootstrapServers) { + return bootstrapServers + .stream() + .map(QuorumConfig::parseBootstrapServer) + .collect(Collectors.toList()); + } + + private static InetSocketAddress parseBootstrapServer(String bootstrapServer) { + String host = Utils.getHost(bootstrapServer); + if (host == null || !Utils.validHostPattern(host)) { + throw new ConfigException( + String.format( + "Failed to parse host name from {} for the configuration {}. Each " + + "entry should be in the form \"{host}:{port}\"", + bootstrapServer, + QUORUM_BOOTSTRAP_SERVERS_CONFIG + ) + ); + } + + Integer port = Utils.getPort(bootstrapServer); + if (port == null) { + throw new ConfigException( + String.format( + "Failed to parse host port from {} for the configuration {}. Each " + + "entry should be in the form \"{host}:{port}\"", + bootstrapServer, + QUORUM_BOOTSTRAP_SERVERS_CONFIG + ) + ); + } + + return InetSocketAddress.createUnresolved(host, port); + } + public static List quorumVoterStringsToNodes(List voters) { return voterConnectionsToNodes(parseVoterConnections(voters)); } @@ -231,4 +273,26 @@ public class QuorumConfig { return "non-empty list"; } } + + public static class ControllerQuorumBootstrapServersValidator implements ConfigDef.Validator { + @Override + public void ensureValid(String name, Object value) { + if (value == null) { + throw new ConfigException(name, null); + } + + @SuppressWarnings("unchecked") + List entries = (List) value; + + // Attempt to parse the connect strings + for (String entry : entries) { + parseBootstrapServer(entry); + } + } + + @Override + public String toString() { + return "non-empty list"; + } + } } diff --git a/raft/src/main/java/org/apache/kafka/raft/QuorumState.java b/raft/src/main/java/org/apache/kafka/raft/QuorumState.java index 522b7080504..b9b17c5f99b 100644 --- a/raft/src/main/java/org/apache/kafka/raft/QuorumState.java +++ b/raft/src/main/java/org/apache/kafka/raft/QuorumState.java @@ -25,7 +25,9 @@ import java.util.OptionalInt; import java.util.Random; import java.util.function.Supplier; +import org.apache.kafka.common.Node; import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.network.ListenerName; import org.apache.kafka.common.utils.LogContext; import org.apache.kafka.common.utils.Time; import org.apache.kafka.raft.internals.BatchAccumulator; @@ -81,6 +83,7 @@ public class QuorumState { private final Time time; private final Logger log; private final QuorumStateStore store; + private final ListenerName listenerName; private final Supplier latestVoterSet; private final Supplier latestKraftVersion; private final Random random; @@ -93,6 +96,7 @@ public class QuorumState { public QuorumState( OptionalInt localId, Uuid localDirectoryId, + ListenerName listenerName, Supplier latestVoterSet, Supplier latestKraftVersion, int electionTimeoutMs, @@ -104,6 +108,7 @@ public class QuorumState { ) { this.localId = localId; this.localDirectoryId = localDirectoryId; + this.listenerName = listenerName; this.latestVoterSet = latestVoterSet; this.latestKraftVersion = latestKraftVersion; this.electionTimeoutMs = electionTimeoutMs; @@ -115,16 +120,21 @@ public class QuorumState { this.logContext = logContext; } - public void initialize(OffsetAndEpoch logEndOffsetAndEpoch) throws IllegalStateException { - // We initialize in whatever state we were in on shutdown. If we were a leader - // or candidate, probably an election was held, but we will find out about it - // when we send Vote or BeginEpoch requests. - + private ElectionState readElectionState() { ElectionState election; election = store .readElectionState() .orElseGet(() -> ElectionState.withUnknownLeader(0, latestVoterSet.get().voterIds())); + return election; + } + + public void initialize(OffsetAndEpoch logEndOffsetAndEpoch) throws IllegalStateException { + // We initialize in whatever state we were in on shutdown. If we were a leader + // or candidate, probably an election was held, but we will find out about it + // when we send Vote or BeginEpoch requests. + ElectionState election = readElectionState(); + final EpochState initialState; if (election.hasVoted() && !localId.isPresent()) { throw new IllegalStateException( @@ -191,10 +201,26 @@ public class QuorumState { logContext ); } else if (election.hasLeader()) { + /* KAFKA-16529 is going to change this so that the leader is not required to be in the set + * of voters. In other words, don't throw an IllegalStateException if the leader is not in + * the set of voters. + */ + Node leader = latestVoterSet + .get() + .voterNode(election.leaderId(), listenerName) + .orElseThrow(() -> + new IllegalStateException( + String.format( + "Leader %s must be in the voter set %s", + election.leaderId(), + latestVoterSet.get() + ) + ) + ); initialState = new FollowerState( time, election.epoch(), - election.leaderId(), + leader, latestVoterSet.get().voterIds(), Optional.empty(), fetchTimeoutMs, @@ -400,28 +426,24 @@ public class QuorumState { /** * Become a follower of an elected leader so that we can begin fetching. */ - public void transitionToFollower( - int epoch, - int leaderId - ) { + public void transitionToFollower(int epoch, Node leader) { int currentEpoch = state.epoch(); - if (localId.isPresent() && leaderId == localId.getAsInt()) { - throw new IllegalStateException("Cannot transition to Follower with leaderId=" + leaderId + - " and epoch=" + epoch + " since it matches the local broker.id=" + localId); + if (localId.isPresent() && leader.id() == localId.getAsInt()) { + throw new IllegalStateException("Cannot transition to Follower with leader " + leader + + " and epoch " + epoch + " since it matches the local broker.id " + localId); } else if (epoch < currentEpoch) { - throw new IllegalStateException("Cannot transition to Follower with leaderId=" + leaderId + - " and epoch=" + epoch + " since the current epoch " + currentEpoch + " is larger"); - } else if (epoch == currentEpoch - && (isFollower() || isLeader())) { - throw new IllegalStateException("Cannot transition to Follower with leaderId=" + leaderId + - " and epoch=" + epoch + " from state " + state); + throw new IllegalStateException("Cannot transition to Follower with leader " + leader + + " and epoch " + epoch + " since the current epoch " + currentEpoch + " is larger"); + } else if (epoch == currentEpoch && (isFollower() || isLeader())) { + throw new IllegalStateException("Cannot transition to Follower with leader " + leader + + " and epoch " + epoch + " from state " + state); } durableTransitionTo( new FollowerState( time, epoch, - leaderId, + leader, latestVoterSet.get().voterIds(), state.highWatermark(), fetchTimeoutMs, diff --git a/raft/src/main/java/org/apache/kafka/raft/RaftRequest.java b/raft/src/main/java/org/apache/kafka/raft/RaftRequest.java index 28e63c14ce6..bf590f56ab1 100644 --- a/raft/src/main/java/org/apache/kafka/raft/RaftRequest.java +++ b/raft/src/main/java/org/apache/kafka/raft/RaftRequest.java @@ -17,13 +17,14 @@ package org.apache.kafka.raft; import org.apache.kafka.common.protocol.ApiMessage; +import org.apache.kafka.common.Node; import java.util.concurrent.CompletableFuture; public abstract class RaftRequest implements RaftMessage { - protected final int correlationId; - protected final ApiMessage data; - protected final long createdTimeMs; + private final int correlationId; + private final ApiMessage data; + private final long createdTimeMs; public RaftRequest(int correlationId, ApiMessage data, long createdTimeMs) { this.correlationId = correlationId; @@ -45,7 +46,7 @@ public abstract class RaftRequest implements RaftMessage { return createdTimeMs; } - public static class Inbound extends RaftRequest { + public final static class Inbound extends RaftRequest { public final CompletableFuture completion = new CompletableFuture<>(); public Inbound(int correlationId, ApiMessage data, long createdTimeMs) { @@ -54,35 +55,37 @@ public abstract class RaftRequest implements RaftMessage { @Override public String toString() { - return "InboundRequest(" + - "correlationId=" + correlationId + - ", data=" + data + - ", createdTimeMs=" + createdTimeMs + - ')'; + return String.format( + "InboundRequest(correlationId=%d, data=%s, createdTimeMs=%d)", + correlationId(), + data(), + createdTimeMs() + ); } } - public static class Outbound extends RaftRequest { - private final int destinationId; + public final static class Outbound extends RaftRequest { + private final Node destination; public final CompletableFuture completion = new CompletableFuture<>(); - public Outbound(int correlationId, ApiMessage data, int destinationId, long createdTimeMs) { + public Outbound(int correlationId, ApiMessage data, Node destination, long createdTimeMs) { super(correlationId, data, createdTimeMs); - this.destinationId = destinationId; + this.destination = destination; } - public int destinationId() { - return destinationId; + public Node destination() { + return destination; } @Override public String toString() { - return "OutboundRequest(" + - "correlationId=" + correlationId + - ", data=" + data + - ", createdTimeMs=" + createdTimeMs + - ", destinationId=" + destinationId + - ')'; + return String.format( + "OutboundRequest(correlationId=%d, data=%s, createdTimeMs=%d, destination=%s)", + correlationId(), + data(), + createdTimeMs(), + destination + ); } } } diff --git a/raft/src/main/java/org/apache/kafka/raft/RaftResponse.java b/raft/src/main/java/org/apache/kafka/raft/RaftResponse.java index 71101a63bf2..9c5047ca92d 100644 --- a/raft/src/main/java/org/apache/kafka/raft/RaftResponse.java +++ b/raft/src/main/java/org/apache/kafka/raft/RaftResponse.java @@ -16,11 +16,12 @@ */ package org.apache.kafka.raft; +import org.apache.kafka.common.Node; import org.apache.kafka.common.protocol.ApiMessage; public abstract class RaftResponse implements RaftMessage { - protected final int correlationId; - protected final ApiMessage data; + private final int correlationId; + private final ApiMessage data; protected RaftResponse(int correlationId, ApiMessage data) { this.correlationId = correlationId; @@ -37,39 +38,41 @@ public abstract class RaftResponse implements RaftMessage { return data; } - public static class Inbound extends RaftResponse { - private final int sourceId; + public final static class Inbound extends RaftResponse { + private final Node source; - public Inbound(int correlationId, ApiMessage data, int sourceId) { + public Inbound(int correlationId, ApiMessage data, Node source) { super(correlationId, data); - this.sourceId = sourceId; + this.source = source; } - public int sourceId() { - return sourceId; + public Node source() { + return source; } @Override public String toString() { - return "InboundResponse(" + - "correlationId=" + correlationId + - ", data=" + data + - ", sourceId=" + sourceId + - ')'; + return String.format( + "InboundResponse(correlationId=%d, data=%s, source=%s)", + correlationId(), + data(), + source + ); } } - public static class Outbound extends RaftResponse { + public final static class Outbound extends RaftResponse { public Outbound(int requestId, ApiMessage data) { super(requestId, data); } @Override public String toString() { - return "OutboundResponse(" + - "correlationId=" + correlationId + - ", data=" + data + - ')'; + return String.format( + "OutboundResponse(correlationId=%d, data=%s)", + correlationId(), + data() + ); } } } diff --git a/raft/src/main/java/org/apache/kafka/raft/RaftUtil.java b/raft/src/main/java/org/apache/kafka/raft/RaftUtil.java index 9ff03617e63..86a47eff1c8 100644 --- a/raft/src/main/java/org/apache/kafka/raft/RaftUtil.java +++ b/raft/src/main/java/org/apache/kafka/raft/RaftUtil.java @@ -25,6 +25,7 @@ import org.apache.kafka.common.message.EndQuorumEpochRequestData; import org.apache.kafka.common.message.EndQuorumEpochResponseData; import org.apache.kafka.common.message.FetchRequestData; import org.apache.kafka.common.message.FetchResponseData; +import org.apache.kafka.common.message.FetchSnapshotResponseData; import org.apache.kafka.common.message.VoteRequestData; import org.apache.kafka.common.message.VoteResponseData; import org.apache.kafka.common.protocol.ApiKeys; @@ -48,6 +49,8 @@ public class RaftUtil { return new EndQuorumEpochResponseData().setErrorCode(error.code()); case FETCH: return new FetchResponseData().setErrorCode(error.code()); + case FETCH_SNAPSHOT: + return new FetchSnapshotResponseData().setErrorCode(error.code()); default: throw new IllegalArgumentException("Received response for unexpected request type: " + apiKey); } diff --git a/raft/src/main/java/org/apache/kafka/raft/RequestManager.java b/raft/src/main/java/org/apache/kafka/raft/RequestManager.java index 5a5cb003c25..dfdaf9d1935 100644 --- a/raft/src/main/java/org/apache/kafka/raft/RequestManager.java +++ b/raft/src/main/java/org/apache/kafka/raft/RequestManager.java @@ -17,96 +17,288 @@ package org.apache.kafka.raft; import java.util.ArrayList; +import java.util.Collection; import java.util.HashMap; -import java.util.List; +import java.util.Iterator; import java.util.Map; -import java.util.OptionalInt; +import java.util.Optional; import java.util.OptionalLong; import java.util.Random; -import java.util.Set; +import org.apache.kafka.common.Node; +/** + * The request manager keeps tracks of the connection with remote replicas. + * + * When sending a request update this type by calling {@code onRequestSent(Node, long, long)}. When + * the RPC returns a response, update this manager with {@code onResponseResult(Node, long, boolean, long)}. + * + * Connections start in the ready state ({@code isReady(Node, long)} returns true). + * + * When a request times out or completes successfully the collection will transition back to the + * ready state. + * + * When a request completes with an error it still transition to the backoff state until + * {@code retryBackoffMs}. + */ public class RequestManager { - private final Map connections = new HashMap<>(); - private final List voters = new ArrayList<>(); + private final Map connections = new HashMap<>(); + private final ArrayList bootstrapServers; private final int retryBackoffMs; private final int requestTimeoutMs; private final Random random; - public RequestManager(Set voterIds, - int retryBackoffMs, - int requestTimeoutMs, - Random random) { - + public RequestManager( + Collection bootstrapServers, + int retryBackoffMs, + int requestTimeoutMs, + Random random + ) { + this.bootstrapServers = new ArrayList<>(bootstrapServers); this.retryBackoffMs = retryBackoffMs; this.requestTimeoutMs = requestTimeoutMs; - this.voters.addAll(voterIds); this.random = random; - - for (Integer voterId: voterIds) { - ConnectionState connection = new ConnectionState(voterId); - connections.put(voterId, connection); - } } - public ConnectionState getOrCreate(int id) { - return connections.computeIfAbsent(id, key -> new ConnectionState(id)); - } + /** + * Returns true if there any connection with pending requests. + * + * This is useful for satisfying the invariant that there is only one pending Fetch request. + * If there are more than one pending fetch request, it is possible for the follower to write + * the same offset twice. + * + * @param currentTimeMs the current time + * @return true if the request manager is tracking at least one request + */ + public boolean hasAnyInflightRequest(long currentTimeMs) { + boolean result = false; - public OptionalInt findReadyVoter(long currentTimeMs) { - int startIndex = random.nextInt(voters.size()); - OptionalInt res = OptionalInt.empty(); - for (int i = 0; i < voters.size(); i++) { - int index = (startIndex + i) % voters.size(); - Integer voterId = voters.get(index); - ConnectionState connection = connections.get(voterId); - boolean isReady = connection.isReady(currentTimeMs); - - if (isReady) { - res = OptionalInt.of(voterId); - } else if (connection.inFlightCorrelationId.isPresent()) { - res = OptionalInt.empty(); + Iterator iterator = connections.values().iterator(); + while (iterator.hasNext()) { + ConnectionState connection = iterator.next(); + if (connection.hasRequestTimedOut(currentTimeMs)) { + // Mark the node as ready after request timeout + iterator.remove(); + } else if (connection.isBackoffComplete(currentTimeMs)) { + // Mark the node as ready after completed backoff + iterator.remove(); + } else if (connection.hasInflightRequest(currentTimeMs)) { + // If there is at least one inflight request, it is enough + // to stop checking the rest of the connections + result = true; break; } } - return res; + + return result; } - public long backoffBeforeAvailableVoter(long currentTimeMs) { - long minBackoffMs = Long.MAX_VALUE; - for (Integer voterId : voters) { - ConnectionState connection = connections.get(voterId); - if (connection.isReady(currentTimeMs)) { - return 0L; - } else if (connection.isBackingOff(currentTimeMs)) { - minBackoffMs = Math.min(minBackoffMs, connection.remainingBackoffMs(currentTimeMs)); - } else { - minBackoffMs = Math.min(minBackoffMs, connection.remainingRequestTimeMs(currentTimeMs)); + /** + * Returns a random bootstrap node that is ready to receive a request. + * + * This method doesn't return a node if there is at least one request pending. In general this + * method is used to send Fetch requests. Fetch requests have the invariant that there can + * only be one pending Fetch request for the LEO. + * + * @param currentTimeMs the current time + * @return a random ready bootstrap node + */ + public Optional findReadyBootstrapServer(long currentTimeMs) { + // Check that there are no infilght requests accross any of the known nodes not just + // the bootstrap servers + if (hasAnyInflightRequest(currentTimeMs)) { + return Optional.empty(); + } + + int startIndex = random.nextInt(bootstrapServers.size()); + Optional result = Optional.empty(); + for (int i = 0; i < bootstrapServers.size(); i++) { + int index = (startIndex + i) % bootstrapServers.size(); + Node node = bootstrapServers.get(index); + + if (isReady(node, currentTimeMs)) { + result = Optional.of(node); + break; } } + + return result; + } + + /** + * Computes the amount of time needed to wait before a bootstrap server is ready for a Fetch + * request. + * + * If there is a connection with a pending request it returns the amount of time to wait until + * the request times out. + * + * Returns zero, if there are no pending request and at least one of the boorstrap servers is + * ready. + * + * If all of the bootstrap servers are backing off and there are no pending requests, return + * the minimum amount of time until a bootstrap server becomes ready. + * + * @param currentTimeMs the current time + * @return the amount of time to wait until bootstrap server can accept a Fetch request + */ + public long backoffBeforeAvailableBootstrapServer(long currentTimeMs) { + long minBackoffMs = retryBackoffMs; + + Iterator iterator = connections.values().iterator(); + while (iterator.hasNext()) { + ConnectionState connection = iterator.next(); + if (connection.hasRequestTimedOut(currentTimeMs)) { + // Mark the node as ready after request timeout + iterator.remove(); + } else if (connection.isBackoffComplete(currentTimeMs)) { + // Mark the node as ready after completed backoff + iterator.remove(); + } else if (connection.hasInflightRequest(currentTimeMs)) { + // There can be at most one inflight fetch request + return connection.remainingRequestTimeMs(currentTimeMs); + } else if (connection.isBackingOff(currentTimeMs)) { + minBackoffMs = Math.min(minBackoffMs, connection.remainingBackoffMs(currentTimeMs)); + } + } + + // There are no inflight fetch requests so check if there is a ready bootstrap server + for (Node node : bootstrapServers) { + if (isReady(node, currentTimeMs)) { + return 0L; + } + } + + // There are no ready bootstrap servers and inflight fetch requests, return the backoff return minBackoffMs; } + public boolean hasRequestTimedOut(Node node, long timeMs) { + ConnectionState state = connections.get(node.idString()); + if (state == null) { + return false; + } + + return state.hasRequestTimedOut(timeMs); + } + + public boolean isReady(Node node, long timeMs) { + ConnectionState state = connections.get(node.idString()); + if (state == null) { + return true; + } + + boolean ready = state.isReady(timeMs); + if (ready) { + reset(node); + } + + return ready; + } + + public boolean isBackingOff(Node node, long timeMs) { + ConnectionState state = connections.get(node.idString()); + if (state == null) { + return false; + } + + return state.isBackingOff(timeMs); + } + + public long remainingRequestTimeMs(Node node, long timeMs) { + ConnectionState state = connections.get(node.idString()); + if (state == null) { + return 0; + } + + return state.remainingRequestTimeMs(timeMs); + } + + public long remainingBackoffMs(Node node, long timeMs) { + ConnectionState state = connections.get(node.idString()); + if (state == null) { + return 0; + } + + return state.remainingBackoffMs(timeMs); + } + + public boolean isResponseExpected(Node node, long correlationId) { + ConnectionState state = connections.get(node.idString()); + if (state == null) { + return false; + } + + return state.isResponseExpected(correlationId); + } + + /** + * Updates the manager when a response is received. + * + * @param node the source of the response + * @param correlationId the correlation id of the response + * @param success true if the request was successful, false otherwise + * @param timeMs the current time + */ + public void onResponseResult(Node node, long correlationId, boolean success, long timeMs) { + if (isResponseExpected(node, correlationId)) { + if (success) { + // Mark the connection as ready by reseting it + reset(node); + } else { + // Backoff the connection + connections.get(node.idString()).onResponseError(correlationId, timeMs); + } + } + } + + /** + * Updates the manager when a request is sent. + * + * @param node the destination of the request + * @param correlationId the correlation id of the request + * @param timeMs the current time + */ + public void onRequestSent(Node node, long correlationId, long timeMs) { + ConnectionState state = connections.computeIfAbsent( + node.idString(), + key -> new ConnectionState(node, retryBackoffMs, requestTimeoutMs) + ); + + state.onRequestSent(correlationId, timeMs); + } + + public void reset(Node node) { + connections.remove(node.idString()); + } + public void resetAll() { - for (ConnectionState connectionState : connections.values()) - connectionState.reset(); + connections.clear(); } private enum State { - AWAITING_REQUEST, + AWAITING_RESPONSE, BACKING_OFF, READY } - public class ConnectionState { - private final long id; + private final static class ConnectionState { + private final Node node; + private final int retryBackoffMs; + private final int requestTimeoutMs; + private State state = State.READY; private long lastSendTimeMs = 0L; private long lastFailTimeMs = 0L; private OptionalLong inFlightCorrelationId = OptionalLong.empty(); - public ConnectionState(long id) { - this.id = id; + private ConnectionState( + Node node, + int retryBackoffMs, + int requestTimeoutMs + ) { + this.node = node; + this.retryBackoffMs = retryBackoffMs; + this.requestTimeoutMs = requestTimeoutMs; } private boolean isBackoffComplete(long timeMs) { @@ -114,11 +306,7 @@ public class RequestManager { } boolean hasRequestTimedOut(long timeMs) { - return state == State.AWAITING_REQUEST && timeMs >= lastSendTimeMs + requestTimeoutMs; - } - - public long id() { - return id; + return state == State.AWAITING_RESPONSE && timeMs >= lastSendTimeMs + requestTimeoutMs; } boolean isReady(long timeMs) { @@ -136,8 +324,8 @@ public class RequestManager { } } - boolean hasInflightRequest(long timeMs) { - if (state != State.AWAITING_REQUEST) { + private boolean hasInflightRequest(long timeMs) { + if (state != State.AWAITING_RESPONSE) { return false; } else { return !hasRequestTimedOut(timeMs); @@ -174,41 +362,22 @@ public class RequestManager { }); } - void onResponseReceived(long correlationId) { - inFlightCorrelationId.ifPresent(inflightRequestId -> { - if (inflightRequestId == correlationId) { - state = State.READY; - inFlightCorrelationId = OptionalLong.empty(); - } - }); - } - void onRequestSent(long correlationId, long timeMs) { lastSendTimeMs = timeMs; inFlightCorrelationId = OptionalLong.of(correlationId); - state = State.AWAITING_REQUEST; - } - - /** - * Ignore in-flight requests or backoff and become available immediately. This is used - * when there is a state change which usually means in-flight requests are obsolete - * and we need to send new requests. - */ - void reset() { - state = State.READY; - inFlightCorrelationId = OptionalLong.empty(); + state = State.AWAITING_RESPONSE; } @Override public String toString() { - return "ConnectionState(" + - "id=" + id + - ", state=" + state + - ", lastSendTimeMs=" + lastSendTimeMs + - ", lastFailTimeMs=" + lastFailTimeMs + - ", inFlightCorrelationId=" + inFlightCorrelationId + - ')'; + return String.format( + "ConnectionState(node=%s, state=%s, lastSendTimeMs=%d, lastFailTimeMs=%d, inFlightCorrelationId=%d)", + node, + state, + lastSendTimeMs, + lastFailTimeMs, + inFlightCorrelationId + ); } } - } diff --git a/raft/src/main/java/org/apache/kafka/raft/internals/VoterSet.java b/raft/src/main/java/org/apache/kafka/raft/internals/VoterSet.java index 16662e06ee3..3ab41f5788c 100644 --- a/raft/src/main/java/org/apache/kafka/raft/internals/VoterSet.java +++ b/raft/src/main/java/org/apache/kafka/raft/internals/VoterSet.java @@ -28,11 +28,13 @@ import java.util.Optional; import java.util.Set; import java.util.function.Function; import java.util.stream.Collectors; +import java.util.stream.Stream; import org.apache.kafka.common.Node; import org.apache.kafka.common.Uuid; import org.apache.kafka.common.feature.SupportedVersionRange; import org.apache.kafka.common.message.VotersRecord; +import org.apache.kafka.common.network.ListenerName; import org.apache.kafka.common.utils.Utils; /** @@ -55,15 +57,41 @@ final public class VoterSet { } /** - * Returns the socket address for a given voter at a given listener. + * Returns the node information for all the given voter ids and listener. * - * @param voter the id of the voter - * @param listener the name of the listener - * @return the socket address if it exists, otherwise {@code Optional.empty()} + * @param voterIds the ids of the voters + * @param listenerName the name of the listener + * @return the node information for all of the voter ids + * @throws IllegalArgumentException if there are missing endpoints */ - public Optional voterAddress(int voter, String listener) { - return Optional.ofNullable(voters.get(voter)) - .flatMap(voterNode -> voterNode.address(listener)); + public Set voterNodes(Stream voterIds, ListenerName listenerName) { + return voterIds + .map(voterId -> + voterNode(voterId, listenerName).orElseThrow(() -> + new IllegalArgumentException( + String.format( + "Unable to find endpoint for voter %d and listener %s in %s", + voterId, + listenerName, + voters + ) + ) + ) + ) + .collect(Collectors.toSet()); + } + + /** + * Returns the node information for a given voter id and listener. + * + * @param voterId the id of the voter + * @param listenerName the name of the listener + * @return the node information if it exists, otherwise {@code Optional.empty()} + */ + public Optional voterNode(int voterId, ListenerName listenerName) { + return Optional.ofNullable(voters.get(voterId)) + .flatMap(voterNode -> voterNode.address(listenerName)) + .map(address -> new Node(voterId, address.getHostString(), address.getPort())); } /** @@ -166,7 +194,7 @@ final public class VoterSet { .stream() .map(entry -> new VotersRecord.Endpoint() - .setName(entry.getKey()) + .setName(entry.getKey().value()) .setHost(entry.getValue().getHostString()) .setPort(entry.getValue().getPort()) ) @@ -247,12 +275,12 @@ final public class VoterSet { public final static class VoterNode { private final ReplicaKey voterKey; - private final Map listeners; + private final Map listeners; private final SupportedVersionRange supportedKRaftVersion; VoterNode( ReplicaKey voterKey, - Map listeners, + Map listeners, SupportedVersionRange supportedKRaftVersion ) { this.voterKey = voterKey; @@ -264,7 +292,7 @@ final public class VoterSet { return voterKey; } - Map listeners() { + Map listeners() { return listeners; } @@ -273,7 +301,7 @@ final public class VoterSet { } - Optional address(String listener) { + Optional address(ListenerName listener) { return Optional.ofNullable(listeners.get(listener)); } @@ -323,9 +351,12 @@ final public class VoterSet { directoryId = Optional.empty(); } - Map listeners = new HashMap<>(voter.endpoints().size()); + Map listeners = new HashMap<>(voter.endpoints().size()); for (VotersRecord.Endpoint endpoint : voter.endpoints()) { - listeners.put(endpoint.name(), InetSocketAddress.createUnresolved(endpoint.host(), endpoint.port())); + listeners.put( + ListenerName.normalised(endpoint.name()), + InetSocketAddress.createUnresolved(endpoint.host(), endpoint.port()) + ); } voterNodes.put( @@ -351,7 +382,7 @@ final public class VoterSet { * @param voters the socket addresses by voter id * @return the voter set */ - public static VoterSet fromInetSocketAddresses(String listener, Map voters) { + public static VoterSet fromInetSocketAddresses(ListenerName listener, Map voters) { Map voterNodes = voters .entrySet() .stream() @@ -368,16 +399,4 @@ final public class VoterSet { return new VoterSet(voterNodes); } - - public Optional voterNode(int id, String listener) { - VoterNode voterNode = voters.get(id); - if (voterNode == null) { - return Optional.empty(); - } - InetSocketAddress address = voterNode.listeners.get(listener); - if (address == null) { - return Optional.empty(); - } - return Optional.of(new Node(id, address.getHostString(), address.getPort())); - } } diff --git a/raft/src/test/java/org/apache/kafka/raft/CandidateStateTest.java b/raft/src/test/java/org/apache/kafka/raft/CandidateStateTest.java index 524a93fa1d7..9aa5eeab499 100644 --- a/raft/src/test/java/org/apache/kafka/raft/CandidateStateTest.java +++ b/raft/src/test/java/org/apache/kafka/raft/CandidateStateTest.java @@ -26,11 +26,10 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; -import java.util.Arrays; -import java.util.Collection; import java.util.Collections; import java.util.Map; import java.util.Optional; +import java.util.stream.IntStream; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; @@ -60,7 +59,7 @@ public class CandidateStateTest { @Test public void testSingleNodeQuorum() { - CandidateState state = newCandidateState(voterSetWithLocal(Collections.emptyList())); + CandidateState state = newCandidateState(voterSetWithLocal(IntStream.empty())); assertTrue(state.isVoteGranted()); assertFalse(state.isVoteRejected()); assertEquals(Collections.emptySet(), state.unrecordedVoters()); @@ -70,7 +69,7 @@ public class CandidateStateTest { public void testTwoNodeQuorumVoteRejected() { int otherNodeId = 1; CandidateState state = newCandidateState( - voterSetWithLocal(Collections.singletonList(otherNodeId)) + voterSetWithLocal(IntStream.of(otherNodeId)) ); assertFalse(state.isVoteGranted()); assertFalse(state.isVoteRejected()); @@ -84,7 +83,7 @@ public class CandidateStateTest { public void testTwoNodeQuorumVoteGranted() { int otherNodeId = 1; CandidateState state = newCandidateState( - voterSetWithLocal(Collections.singletonList(otherNodeId)) + voterSetWithLocal(IntStream.of(otherNodeId)) ); assertFalse(state.isVoteGranted()); assertFalse(state.isVoteRejected()); @@ -100,7 +99,7 @@ public class CandidateStateTest { int node1 = 1; int node2 = 2; CandidateState state = newCandidateState( - voterSetWithLocal(Arrays.asList(node1, node2)) + voterSetWithLocal(IntStream.of(node1, node2)) ); assertFalse(state.isVoteGranted()); assertFalse(state.isVoteRejected()); @@ -120,7 +119,7 @@ public class CandidateStateTest { int node1 = 1; int node2 = 2; CandidateState state = newCandidateState( - voterSetWithLocal(Arrays.asList(node1, node2)) + voterSetWithLocal(IntStream.of(node1, node2)) ); assertFalse(state.isVoteGranted()); assertFalse(state.isVoteRejected()); @@ -139,7 +138,7 @@ public class CandidateStateTest { public void testCannotRejectVoteFromLocalId() { int otherNodeId = 1; CandidateState state = newCandidateState( - voterSetWithLocal(Collections.singletonList(otherNodeId)) + voterSetWithLocal(IntStream.of(otherNodeId)) ); assertThrows( IllegalArgumentException.class, @@ -151,7 +150,7 @@ public class CandidateStateTest { public void testCannotChangeVoteGrantedToRejected() { int otherNodeId = 1; CandidateState state = newCandidateState( - voterSetWithLocal(Collections.singletonList(otherNodeId)) + voterSetWithLocal(IntStream.of(otherNodeId)) ); assertTrue(state.recordGrantedVote(otherNodeId)); assertThrows(IllegalArgumentException.class, () -> state.recordRejectedVote(otherNodeId)); @@ -162,7 +161,7 @@ public class CandidateStateTest { public void testCannotChangeVoteRejectedToGranted() { int otherNodeId = 1; CandidateState state = newCandidateState( - voterSetWithLocal(Collections.singletonList(otherNodeId)) + voterSetWithLocal(IntStream.of(otherNodeId)) ); assertTrue(state.recordRejectedVote(otherNodeId)); assertThrows(IllegalArgumentException.class, () -> state.recordGrantedVote(otherNodeId)); @@ -172,7 +171,7 @@ public class CandidateStateTest { @Test public void testCannotGrantOrRejectNonVoters() { int nonVoterId = 1; - CandidateState state = newCandidateState(voterSetWithLocal(Collections.emptyList())); + CandidateState state = newCandidateState(voterSetWithLocal(IntStream.empty())); assertThrows(IllegalArgumentException.class, () -> state.recordGrantedVote(nonVoterId)); assertThrows(IllegalArgumentException.class, () -> state.recordRejectedVote(nonVoterId)); } @@ -181,7 +180,7 @@ public class CandidateStateTest { public void testIdempotentGrant() { int otherNodeId = 1; CandidateState state = newCandidateState( - voterSetWithLocal(Collections.singletonList(otherNodeId)) + voterSetWithLocal(IntStream.of(otherNodeId)) ); assertTrue(state.recordGrantedVote(otherNodeId)); assertFalse(state.recordGrantedVote(otherNodeId)); @@ -191,7 +190,7 @@ public class CandidateStateTest { public void testIdempotentReject() { int otherNodeId = 1; CandidateState state = newCandidateState( - voterSetWithLocal(Collections.singletonList(otherNodeId)) + voterSetWithLocal(IntStream.of(otherNodeId)) ); assertTrue(state.recordRejectedVote(otherNodeId)); assertFalse(state.recordRejectedVote(otherNodeId)); @@ -201,7 +200,7 @@ public class CandidateStateTest { @ValueSource(booleans = {true, false}) public void testGrantVote(boolean isLogUpToDate) { CandidateState state = newCandidateState( - voterSetWithLocal(Arrays.asList(1, 2, 3)) + voterSetWithLocal(IntStream.of(1, 2, 3)) ); assertFalse(state.canGrantVote(ReplicaKey.of(0, Optional.empty()), isLogUpToDate)); @@ -212,7 +211,7 @@ public class CandidateStateTest { @Test public void testElectionState() { - VoterSet voters = voterSetWithLocal(Arrays.asList(1, 2, 3)); + VoterSet voters = voterSetWithLocal(IntStream.of(1, 2, 3)); CandidateState state = newCandidateState(voters); assertEquals( ElectionState.withVotedCandidate( @@ -228,11 +227,11 @@ public class CandidateStateTest { public void testInvalidVoterSet() { assertThrows( IllegalArgumentException.class, - () -> newCandidateState(VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true))) + () -> newCandidateState(VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(1, 2, 3), true))) ); } - private VoterSet voterSetWithLocal(Collection remoteVoters) { + private VoterSet voterSetWithLocal(IntStream remoteVoters) { Map voterMap = VoterSetTest.voterMap(remoteVoters, true); voterMap.put(localNode.voterKey().id(), localNode); diff --git a/raft/src/test/java/org/apache/kafka/raft/FollowerStateTest.java b/raft/src/test/java/org/apache/kafka/raft/FollowerStateTest.java index 1894472fa34..ab699159d0d 100644 --- a/raft/src/test/java/org/apache/kafka/raft/FollowerStateTest.java +++ b/raft/src/test/java/org/apache/kafka/raft/FollowerStateTest.java @@ -16,6 +16,7 @@ */ package org.apache.kafka.raft; +import org.apache.kafka.common.Node; import org.apache.kafka.common.utils.LogContext; import org.apache.kafka.common.utils.MockTime; import org.apache.kafka.common.utils.Utils; @@ -38,7 +39,7 @@ public class FollowerStateTest { private final LogContext logContext = new LogContext(); private final int epoch = 5; private final int fetchTimeoutMs = 15000; - int leaderId = 3; + private final Node leader = new Node(3, "mock-host-3", 1234); private FollowerState newFollowerState( Set voters, @@ -47,7 +48,7 @@ public class FollowerStateTest { return new FollowerState( time, epoch, - leaderId, + leader, voters, highWatermark, fetchTimeoutMs, @@ -96,4 +97,10 @@ public class FollowerStateTest { assertFalse(state.canGrantVote(ReplicaKey.of(3, Optional.empty()), isLogUpToDate)); } + @Test + public void testLeaderNode() { + FollowerState state = newFollowerState(Utils.mkSet(0, 1, 2), Optional.empty()); + + assertEquals(leader, state.leader()); + } } diff --git a/raft/src/test/java/org/apache/kafka/raft/KafkaNetworkChannelTest.java b/raft/src/test/java/org/apache/kafka/raft/KafkaNetworkChannelTest.java index 4a580a124bd..2455990e770 100644 --- a/raft/src/test/java/org/apache/kafka/raft/KafkaNetworkChannelTest.java +++ b/raft/src/test/java/org/apache/kafka/raft/KafkaNetworkChannelTest.java @@ -26,7 +26,10 @@ import org.apache.kafka.common.message.BeginQuorumEpochResponseData; import org.apache.kafka.common.message.EndQuorumEpochResponseData; import org.apache.kafka.common.message.FetchRequestData; import org.apache.kafka.common.message.FetchResponseData; +import org.apache.kafka.common.message.FetchSnapshotRequestData; +import org.apache.kafka.common.message.FetchSnapshotResponseData; import org.apache.kafka.common.message.VoteResponseData; +import org.apache.kafka.common.network.ListenerName; import org.apache.kafka.common.protocol.ApiKeys; import org.apache.kafka.common.protocol.ApiMessage; import org.apache.kafka.common.protocol.Errors; @@ -39,6 +42,8 @@ import org.apache.kafka.common.requests.EndQuorumEpochRequest; import org.apache.kafka.common.requests.EndQuorumEpochResponse; import org.apache.kafka.common.requests.FetchRequest; import org.apache.kafka.common.requests.FetchResponse; +import org.apache.kafka.common.requests.FetchSnapshotRequest; +import org.apache.kafka.common.requests.FetchSnapshotResponse; import org.apache.kafka.common.requests.VoteRequest; import org.apache.kafka.common.requests.VoteResponse; import org.apache.kafka.common.utils.MockTime; @@ -47,8 +52,8 @@ import org.apache.kafka.common.utils.annotation.ApiKeyVersionsSource; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; -import java.net.InetSocketAddress; import java.util.Collections; import java.util.List; import java.util.concurrent.ExecutionException; @@ -80,7 +85,8 @@ public class KafkaNetworkChannelTest { ApiKeys.VOTE, ApiKeys.BEGIN_QUORUM_EPOCH, ApiKeys.END_QUORUM_EPOCH, - ApiKeys.FETCH + ApiKeys.FETCH, + ApiKeys.FETCH_SNAPSHOT ); private final int requestTimeoutMs = 30000; @@ -88,35 +94,40 @@ public class KafkaNetworkChannelTest { private final MockClient client = new MockClient(time, new StubMetadataUpdater()); private final TopicPartition topicPartition = new TopicPartition("topic", 0); private final Uuid topicId = Uuid.randomUuid(); - private final KafkaNetworkChannel channel = new KafkaNetworkChannel(time, client, requestTimeoutMs, "test-raft"); + private final KafkaNetworkChannel channel = new KafkaNetworkChannel( + time, + ListenerName.normalised("NAME"), + client, + requestTimeoutMs, + "test-raft" + ); + + private Node nodeWithId(boolean withId) { + int id = withId ? 2 : -2; + return new Node(id, "127.0.0.1", 9092); + } @BeforeEach public void setupSupportedApis() { - List supportedApis = RAFT_APIS.stream().map( - ApiVersionsResponse::toApiVersion).collect(Collectors.toList()); + List supportedApis = RAFT_APIS + .stream() + .map(ApiVersionsResponse::toApiVersion) + .collect(Collectors.toList()); client.setNodeApiVersions(NodeApiVersions.create(supportedApis)); } - @Test - public void testSendToUnknownDestination() throws ExecutionException, InterruptedException { - int destinationId = 2; - assertBrokerNotAvailable(destinationId); - } - - @Test - public void testSendToBlackedOutDestination() throws ExecutionException, InterruptedException { - int destinationId = 2; - Node destinationNode = new Node(destinationId, "127.0.0.1", 9092); - channel.updateEndpoint(destinationId, new InetSocketAddress(destinationNode.host(), destinationNode.port())); - client.backoff(destinationNode, 500); - assertBrokerNotAvailable(destinationId); + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void testSendToBlackedOutDestination(boolean withDestinationId) throws ExecutionException, InterruptedException { + Node destination = nodeWithId(withDestinationId); + client.backoff(destination, 500); + assertBrokerNotAvailable(destination); } @Test public void testWakeupClientOnSend() throws InterruptedException, ExecutionException { int destinationId = 2; Node destinationNode = new Node(destinationId, "127.0.0.1", 9092); - channel.updateEndpoint(destinationId, new InetSocketAddress(destinationNode.host(), destinationNode.port())); client.enableBlockingUntilWakeup(1); @@ -132,7 +143,7 @@ public class KafkaNetworkChannelTest { client.prepareResponseFrom(response, destinationNode, false); ioThread.start(); - RaftRequest.Outbound request = sendTestRequest(ApiKeys.FETCH, destinationId); + RaftRequest.Outbound request = sendTestRequest(ApiKeys.FETCH, destinationNode); ioThread.join(); assertResponseCompleted(request, Errors.INVALID_REQUEST); @@ -142,12 +153,11 @@ public class KafkaNetworkChannelTest { public void testSendAndDisconnect() throws ExecutionException, InterruptedException { int destinationId = 2; Node destinationNode = new Node(destinationId, "127.0.0.1", 9092); - channel.updateEndpoint(destinationId, new InetSocketAddress(destinationNode.host(), destinationNode.port())); for (ApiKeys apiKey : RAFT_APIS) { AbstractResponse response = buildResponse(buildTestErrorResponse(apiKey, Errors.INVALID_REQUEST)); client.prepareResponseFrom(response, destinationNode, true); - sendAndAssertErrorResponse(apiKey, destinationId, Errors.BROKER_NOT_AVAILABLE); + sendAndAssertErrorResponse(apiKey, destinationNode, Errors.BROKER_NOT_AVAILABLE); } } @@ -155,35 +165,33 @@ public class KafkaNetworkChannelTest { public void testSendAndFailAuthentication() throws ExecutionException, InterruptedException { int destinationId = 2; Node destinationNode = new Node(destinationId, "127.0.0.1", 9092); - channel.updateEndpoint(destinationId, new InetSocketAddress(destinationNode.host(), destinationNode.port())); for (ApiKeys apiKey : RAFT_APIS) { client.createPendingAuthenticationError(destinationNode, 100); - sendAndAssertErrorResponse(apiKey, destinationId, Errors.NETWORK_EXCEPTION); + sendAndAssertErrorResponse(apiKey, destinationNode, Errors.NETWORK_EXCEPTION); // reset to clear backoff time client.reset(); } } - private void assertBrokerNotAvailable(int destinationId) throws ExecutionException, InterruptedException { + private void assertBrokerNotAvailable(Node destination) throws ExecutionException, InterruptedException { for (ApiKeys apiKey : RAFT_APIS) { - sendAndAssertErrorResponse(apiKey, destinationId, Errors.BROKER_NOT_AVAILABLE); + sendAndAssertErrorResponse(apiKey, destination, Errors.BROKER_NOT_AVAILABLE); } } - @Test - public void testSendAndReceiveOutboundRequest() throws ExecutionException, InterruptedException { - int destinationId = 2; - Node destinationNode = new Node(destinationId, "127.0.0.1", 9092); - channel.updateEndpoint(destinationId, new InetSocketAddress(destinationNode.host(), destinationNode.port())); + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void testSendAndReceiveOutboundRequest(boolean withDestinationId) throws ExecutionException, InterruptedException { + Node destination = nodeWithId(withDestinationId); for (ApiKeys apiKey : RAFT_APIS) { Errors expectedError = Errors.INVALID_REQUEST; AbstractResponse response = buildResponse(buildTestErrorResponse(apiKey, expectedError)); - client.prepareResponseFrom(response, destinationNode); + client.prepareResponseFrom(response, destination); System.out.println("api key " + apiKey + ", response " + response); - sendAndAssertErrorResponse(apiKey, destinationId, expectedError); + sendAndAssertErrorResponse(apiKey, destination, expectedError); } } @@ -191,11 +199,10 @@ public class KafkaNetworkChannelTest { public void testUnsupportedVersionError() throws ExecutionException, InterruptedException { int destinationId = 2; Node destinationNode = new Node(destinationId, "127.0.0.1", 9092); - channel.updateEndpoint(destinationId, new InetSocketAddress(destinationNode.host(), destinationNode.port())); for (ApiKeys apiKey : RAFT_APIS) { client.prepareUnsupportedVersionResponse(request -> request.apiKey() == apiKey); - sendAndAssertErrorResponse(apiKey, destinationId, Errors.UNSUPPORTED_VERSION); + sendAndAssertErrorResponse(apiKey, destinationNode, Errors.UNSUPPORTED_VERSION); } } @@ -204,8 +211,7 @@ public class KafkaNetworkChannelTest { public void testFetchRequestDowngrade(short version) { int destinationId = 2; Node destinationNode = new Node(destinationId, "127.0.0.1", 9092); - channel.updateEndpoint(destinationId, new InetSocketAddress(destinationNode.host(), destinationNode.port())); - sendTestRequest(ApiKeys.FETCH, destinationId); + sendTestRequest(ApiKeys.FETCH, destinationNode); channel.pollOnce(); assertEquals(1, client.requests().size()); @@ -220,27 +226,39 @@ public class KafkaNetworkChannelTest { } } - private RaftRequest.Outbound sendTestRequest(ApiKeys apiKey, int destinationId) { + private RaftRequest.Outbound sendTestRequest(ApiKeys apiKey, Node destination) { int correlationId = channel.newCorrelationId(); long createdTimeMs = time.milliseconds(); ApiMessage apiRequest = buildTestRequest(apiKey); - RaftRequest.Outbound request = new RaftRequest.Outbound(correlationId, apiRequest, destinationId, createdTimeMs); + RaftRequest.Outbound request = new RaftRequest.Outbound( + correlationId, + apiRequest, + destination, + createdTimeMs + ); channel.send(request); return request; } - private void assertResponseCompleted(RaftRequest.Outbound request, Errors expectedError) throws ExecutionException, InterruptedException { + private void assertResponseCompleted( + RaftRequest.Outbound request, + Errors expectedError + ) throws ExecutionException, InterruptedException { assertTrue(request.completion.isDone()); RaftResponse.Inbound response = request.completion.get(); - assertEquals(request.destinationId(), response.sourceId()); - assertEquals(request.correlationId, response.correlationId); - assertEquals(request.data.apiKey(), response.data.apiKey()); - assertEquals(expectedError, extractError(response.data)); + assertEquals(request.destination(), response.source()); + assertEquals(request.correlationId(), response.correlationId()); + assertEquals(request.data().apiKey(), response.data().apiKey()); + assertEquals(expectedError, extractError(response.data())); } - private void sendAndAssertErrorResponse(ApiKeys apiKey, int destinationId, Errors error) throws ExecutionException, InterruptedException { - RaftRequest.Outbound request = sendTestRequest(apiKey, destinationId); + private void sendAndAssertErrorResponse( + ApiKeys apiKey, + Node destination, + Errors error + ) throws ExecutionException, InterruptedException { + RaftRequest.Outbound request = sendTestRequest(apiKey, destination); channel.pollOnce(); assertResponseCompleted(request, error); } @@ -252,12 +270,20 @@ public class KafkaNetworkChannelTest { switch (key) { case BEGIN_QUORUM_EPOCH: return BeginQuorumEpochRequest.singletonRequest(topicPartition, clusterId, leaderEpoch, leaderId); + case END_QUORUM_EPOCH: - return EndQuorumEpochRequest.singletonRequest(topicPartition, clusterId, leaderId, leaderEpoch, - Collections.singletonList(2)); + return EndQuorumEpochRequest.singletonRequest( + topicPartition, + clusterId, + leaderId, + leaderEpoch, + Collections.singletonList(2) + ); + case VOTE: int lastEpoch = 4; return VoteRequest.singletonRequest(topicPartition, clusterId, leaderEpoch, leaderId, lastEpoch, 329); + case FETCH: FetchRequestData request = RaftUtil.singletonFetchRequest(topicPartition, topicId, fetchPartition -> { fetchPartition @@ -267,6 +293,21 @@ public class KafkaNetworkChannelTest { }); request.setReplicaState(new FetchRequestData.ReplicaState().setReplicaId(1)); return request; + + case FETCH_SNAPSHOT: + return FetchSnapshotRequest.singleton( + clusterId, + 1, + topicPartition, + snapshotPartition -> snapshotPartition + .setCurrentLeaderEpoch(5) + .setSnapshotId(new FetchSnapshotRequestData.SnapshotId() + .setEpoch(4) + .setEndOffset(323) + ) + .setPosition(10) + ); + default: throw new AssertionError("Unexpected api " + key); } @@ -282,6 +323,8 @@ public class KafkaNetworkChannelTest { return VoteResponse.singletonResponse(error, topicPartition, Errors.NONE, 1, 5, false); case FETCH: return new FetchResponseData().setErrorCode(error.code()); + case FETCH_SNAPSHOT: + return new FetchSnapshotResponseData().setErrorCode(error.code()); default: throw new AssertionError("Unexpected api " + key); } @@ -289,28 +332,36 @@ public class KafkaNetworkChannelTest { private Errors extractError(ApiMessage response) { short code; - if (response instanceof BeginQuorumEpochResponseData) + if (response instanceof BeginQuorumEpochResponseData) { code = ((BeginQuorumEpochResponseData) response).errorCode(); - else if (response instanceof EndQuorumEpochResponseData) + } else if (response instanceof EndQuorumEpochResponseData) { code = ((EndQuorumEpochResponseData) response).errorCode(); - else if (response instanceof FetchResponseData) + } else if (response instanceof FetchResponseData) { code = ((FetchResponseData) response).errorCode(); - else if (response instanceof VoteResponseData) + } else if (response instanceof VoteResponseData) { code = ((VoteResponseData) response).errorCode(); - else + } else if (response instanceof FetchSnapshotResponseData) { + code = ((FetchSnapshotResponseData) response).errorCode(); + } else { throw new IllegalArgumentException("Unexpected type for responseData: " + response); + } + return Errors.forCode(code); } private AbstractResponse buildResponse(ApiMessage responseData) { - if (responseData instanceof VoteResponseData) + if (responseData instanceof VoteResponseData) { return new VoteResponse((VoteResponseData) responseData); - if (responseData instanceof BeginQuorumEpochResponseData) + } else if (responseData instanceof BeginQuorumEpochResponseData) { return new BeginQuorumEpochResponse((BeginQuorumEpochResponseData) responseData); - if (responseData instanceof EndQuorumEpochResponseData) + } else if (responseData instanceof EndQuorumEpochResponseData) { return new EndQuorumEpochResponse((EndQuorumEpochResponseData) responseData); - if (responseData instanceof FetchResponseData) + } else if (responseData instanceof FetchResponseData) { return new FetchResponse((FetchResponseData) responseData); - throw new IllegalArgumentException("Unexpected type for responseData: " + responseData); + } else if (responseData instanceof FetchSnapshotResponseData) { + return new FetchSnapshotResponse((FetchSnapshotResponseData) responseData); + } else { + throw new IllegalArgumentException("Unexpected type for responseData: " + responseData); + } } } diff --git a/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientSnapshotTest.java b/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientSnapshotTest.java index 3fcbec4229e..299fa819d58 100644 --- a/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientSnapshotTest.java +++ b/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientSnapshotTest.java @@ -153,8 +153,8 @@ final public class KafkaRaftClientSnapshotTest { RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(); context.assertFetchRequestData(fetchRequest, epoch, localLogEndOffset, snapshotId.epoch()); context.deliverResponse( - fetchRequest.correlationId, - fetchRequest.destinationId(), + fetchRequest.correlationId(), + fetchRequest.destination(), context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, localLogEndOffset, Errors.NONE) ); @@ -195,8 +195,8 @@ final public class KafkaRaftClientSnapshotTest { RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(); context.assertFetchRequestData(fetchRequest, epoch, localLogEndOffset, snapshotId.epoch()); context.deliverResponse( - fetchRequest.correlationId, - fetchRequest.destinationId(), + fetchRequest.correlationId(), + fetchRequest.destination(), context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, localLogEndOffset, Errors.NONE) ); @@ -1032,8 +1032,8 @@ final public class KafkaRaftClientSnapshotTest { context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); context.deliverResponse( - fetchRequest.correlationId, - fetchRequest.destinationId(), + fetchRequest.correlationId(), + fetchRequest.destination(), snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, invalidEpoch, 200L) ); @@ -1049,8 +1049,8 @@ final public class KafkaRaftClientSnapshotTest { context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); context.deliverResponse( - fetchRequest.correlationId, - fetchRequest.destinationId(), + fetchRequest.correlationId(), + fetchRequest.destination(), snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, invalidEndOffset, 200L) ); @@ -1091,8 +1091,8 @@ final public class KafkaRaftClientSnapshotTest { context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); context.deliverResponse( - fetchRequest.correlationId, - fetchRequest.destinationId(), + fetchRequest.correlationId(), + fetchRequest.destination(), snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, snapshotId, 200L) ); @@ -1116,8 +1116,8 @@ final public class KafkaRaftClientSnapshotTest { } context.deliverResponse( - snapshotRequest.correlationId, - snapshotRequest.destinationId(), + snapshotRequest.correlationId(), + snapshotRequest.destination(), fetchSnapshotResponse( context.metadataPartition, epoch, @@ -1162,8 +1162,8 @@ final public class KafkaRaftClientSnapshotTest { context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); context.deliverResponse( - fetchRequest.correlationId, - fetchRequest.destinationId(), + fetchRequest.correlationId(), + fetchRequest.destination(), snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, snapshotId, 200L) ); @@ -1190,8 +1190,8 @@ final public class KafkaRaftClientSnapshotTest { sendingBuffer.limit(sendingBuffer.limit() / 2); context.deliverResponse( - snapshotRequest.correlationId, - snapshotRequest.destinationId(), + snapshotRequest.correlationId(), + snapshotRequest.destination(), fetchSnapshotResponse( context.metadataPartition, epoch, @@ -1219,8 +1219,8 @@ final public class KafkaRaftClientSnapshotTest { sendingBuffer.position(Math.toIntExact(request.position())); context.deliverResponse( - snapshotRequest.correlationId, - snapshotRequest.destinationId(), + snapshotRequest.correlationId(), + snapshotRequest.destination(), fetchSnapshotResponse( context.metadataPartition, epoch, @@ -1265,8 +1265,8 @@ final public class KafkaRaftClientSnapshotTest { context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); context.deliverResponse( - fetchRequest.correlationId, - fetchRequest.destinationId(), + fetchRequest.correlationId(), + fetchRequest.destination(), snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, snapshotId, 200L) ); @@ -1284,8 +1284,8 @@ final public class KafkaRaftClientSnapshotTest { // Reply with a snapshot not found error context.deliverResponse( - snapshotRequest.correlationId, - snapshotRequest.destinationId(), + snapshotRequest.correlationId(), + snapshotRequest.destination(), FetchSnapshotResponse.singleton( context.metadataPartition, responsePartitionSnapshot -> { @@ -1323,8 +1323,8 @@ final public class KafkaRaftClientSnapshotTest { context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); context.deliverResponse( - fetchRequest.correlationId, - fetchRequest.destinationId(), + fetchRequest.correlationId(), + fetchRequest.destination(), snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, firstLeaderId, snapshotId, 200L) ); @@ -1342,8 +1342,8 @@ final public class KafkaRaftClientSnapshotTest { // Reply with new leader response context.deliverResponse( - snapshotRequest.correlationId, - snapshotRequest.destinationId(), + snapshotRequest.correlationId(), + snapshotRequest.destination(), FetchSnapshotResponse.singleton( context.metadataPartition, responsePartitionSnapshot -> { @@ -1380,8 +1380,8 @@ final public class KafkaRaftClientSnapshotTest { context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); context.deliverResponse( - fetchRequest.correlationId, - fetchRequest.destinationId(), + fetchRequest.correlationId(), + fetchRequest.destination(), snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, snapshotId, 200L) ); @@ -1399,8 +1399,8 @@ final public class KafkaRaftClientSnapshotTest { // Reply with new leader epoch context.deliverResponse( - snapshotRequest.correlationId, - snapshotRequest.destinationId(), + snapshotRequest.correlationId(), + snapshotRequest.destination(), FetchSnapshotResponse.singleton( context.metadataPartition, responsePartitionSnapshot -> { @@ -1437,8 +1437,8 @@ final public class KafkaRaftClientSnapshotTest { context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); context.deliverResponse( - fetchRequest.correlationId, - fetchRequest.destinationId(), + fetchRequest.correlationId(), + fetchRequest.destination(), snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, snapshotId, 200L) ); @@ -1456,8 +1456,8 @@ final public class KafkaRaftClientSnapshotTest { // Reply with unknown leader epoch context.deliverResponse( - snapshotRequest.correlationId, - snapshotRequest.destinationId(), + snapshotRequest.correlationId(), + snapshotRequest.destination(), FetchSnapshotResponse.singleton( context.metadataPartition, responsePartitionSnapshot -> { @@ -1504,8 +1504,8 @@ final public class KafkaRaftClientSnapshotTest { context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); context.deliverResponse( - fetchRequest.correlationId, - fetchRequest.destinationId(), + fetchRequest.correlationId(), + fetchRequest.destination(), snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, snapshotId, 200L) ); @@ -1523,8 +1523,8 @@ final public class KafkaRaftClientSnapshotTest { // Reply with an invalid snapshot id endOffset context.deliverResponse( - snapshotRequest.correlationId, - snapshotRequest.destinationId(), + snapshotRequest.correlationId(), + snapshotRequest.destination(), FetchSnapshotResponse.singleton( context.metadataPartition, responsePartitionSnapshot -> { @@ -1550,8 +1550,8 @@ final public class KafkaRaftClientSnapshotTest { context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); context.deliverResponse( - fetchRequest.correlationId, - fetchRequest.destinationId(), + fetchRequest.correlationId(), + fetchRequest.destination(), snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, snapshotId, 200L) ); @@ -1570,8 +1570,8 @@ final public class KafkaRaftClientSnapshotTest { // Reply with an invalid snapshot id epoch context.deliverResponse( - snapshotRequest.correlationId, - snapshotRequest.destinationId(), + snapshotRequest.correlationId(), + snapshotRequest.destination(), FetchSnapshotResponse.singleton( context.metadataPartition, responsePartitionSnapshot -> { @@ -1614,8 +1614,8 @@ final public class KafkaRaftClientSnapshotTest { context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); context.deliverResponse( - fetchRequest.correlationId, - fetchRequest.destinationId(), + fetchRequest.correlationId(), + fetchRequest.destination(), snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, snapshotId, 200L) ); @@ -1642,8 +1642,8 @@ final public class KafkaRaftClientSnapshotTest { // Send the response late context.deliverResponse( - snapshotRequest.correlationId, - snapshotRequest.destinationId(), + snapshotRequest.correlationId(), + snapshotRequest.destination(), FetchSnapshotResponse.singleton( context.metadataPartition, responsePartitionSnapshot -> { @@ -1805,14 +1805,17 @@ final public class KafkaRaftClientSnapshotTest { // Poll for our first fetch request context.pollUntilRequest(); RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(); - assertTrue(voters.contains(fetchRequest.destinationId())); + assertTrue(voters.contains(fetchRequest.destination().id())); context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); // The response does not advance the high watermark List records1 = Arrays.asList("a", "b", "c"); MemoryRecords batch1 = context.buildBatch(0L, 3, records1); - context.deliverResponse(fetchRequest.correlationId, fetchRequest.destinationId(), - context.fetchResponse(epoch, leaderId, batch1, 0L, Errors.NONE)); + context.deliverResponse( + fetchRequest.correlationId(), + fetchRequest.destination(), + context.fetchResponse(epoch, leaderId, batch1, 0L, Errors.NONE) + ); context.client.poll(); // 2) The high watermark must be larger than or equal to the snapshotId's endOffset @@ -1827,13 +1830,16 @@ final public class KafkaRaftClientSnapshotTest { // The high watermark advances to be larger than log.endOffsetForEpoch(3), to test the case 3 context.pollUntilRequest(); fetchRequest = context.assertSentFetchRequest(); - assertTrue(voters.contains(fetchRequest.destinationId())); + assertTrue(voters.contains(fetchRequest.destination().id())); context.assertFetchRequestData(fetchRequest, epoch, 3L, 3); List records2 = Arrays.asList("d", "e", "f"); MemoryRecords batch2 = context.buildBatch(3L, 4, records2); - context.deliverResponse(fetchRequest.correlationId, fetchRequest.destinationId(), - context.fetchResponse(epoch, leaderId, batch2, 6L, Errors.NONE)); + context.deliverResponse( + fetchRequest.correlationId(), + fetchRequest.destination(), + context.fetchResponse(epoch, leaderId, batch2, 6L, Errors.NONE) + ); context.client.poll(); assertEquals(6L, context.client.highWatermark().getAsLong()); diff --git a/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientTest.java b/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientTest.java index c531e5860ac..049b648d881 100644 --- a/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientTest.java +++ b/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientTest.java @@ -51,6 +51,7 @@ import org.junit.jupiter.params.provider.ValueSource; import org.mockito.Mockito; import java.io.IOException; +import java.net.InetSocketAddress; import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; @@ -62,6 +63,7 @@ import java.util.OptionalLong; import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeoutException; +import java.util.stream.Collectors; import static java.util.Collections.singletonList; import static org.apache.kafka.raft.RaftClientTestContext.Builder.DEFAULT_ELECTION_TIMEOUT_MS; @@ -274,8 +276,12 @@ public class KafkaRaftClientTest { assertThrows(NotLeaderException.class, () -> context.client.scheduleAppend(epoch, Arrays.asList("a", "b"))); context.pollUntilRequest(); - int correlationId = context.assertSentEndQuorumEpochRequest(epoch, 1); - context.deliverResponse(correlationId, 1, context.endEpochResponse(epoch, OptionalInt.of(localId))); + RaftRequest.Outbound request = context.assertSentEndQuorumEpochRequest(epoch, 1); + context.deliverResponse( + request.correlationId(), + request.destination(), + context.endEpochResponse(epoch, OptionalInt.of(localId)) + ); context.client.poll(); context.time.sleep(context.electionTimeoutMs()); @@ -389,14 +395,17 @@ public class KafkaRaftClientTest { // Respond to one of the requests so that we can verify that no additional // request to this node is sent. RaftRequest.Outbound endEpochOutbound = requests.get(0); - context.deliverResponse(endEpochOutbound.correlationId, endEpochOutbound.destinationId(), - context.endEpochResponse(epoch, OptionalInt.of(localId))); + context.deliverResponse( + endEpochOutbound.correlationId(), + endEpochOutbound.destination(), + context.endEpochResponse(epoch, OptionalInt.of(localId)) + ); context.client.poll(); assertEquals(Collections.emptyList(), context.channel.drainSendQueue()); // Now sleep for the request timeout and verify that we get only one // retried request from the voter that hasn't responded yet. - int nonRespondedId = requests.get(1).destinationId(); + int nonRespondedId = requests.get(1).destination().id(); context.time.sleep(6000); context.pollUntilRequest(); List retries = context.collectEndQuorumRequests( @@ -573,7 +582,7 @@ public class KafkaRaftClientTest { context.pollUntil(context.client.quorum()::isResigned); context.pollUntilRequest(); - int correlationId = context.assertSentEndQuorumEpochRequest(resignedEpoch, otherNodeId); + RaftRequest.Outbound request = context.assertSentEndQuorumEpochRequest(resignedEpoch, otherNodeId); EndQuorumEpochResponseData response = EndQuorumEpochResponse.singletonResponse( Errors.NONE, @@ -583,7 +592,7 @@ public class KafkaRaftClientTest { localId ); - context.deliverResponse(correlationId, otherNodeId, response); + context.deliverResponse(request.correlationId(), request.destination(), response); context.client.poll(); // We do not resend `EndQuorumRequest` once the other voter has acknowledged it. @@ -644,11 +653,14 @@ public class KafkaRaftClientTest { context.pollUntilRequest(); RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(); - assertTrue(voters.contains(fetchRequest.destinationId())); + assertTrue(voters.contains(fetchRequest.destination().id())); context.assertFetchRequestData(fetchRequest, 0, 0L, 0); - context.deliverResponse(fetchRequest.correlationId, fetchRequest.destinationId(), - context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH)); + context.deliverResponse( + fetchRequest.correlationId(), + fetchRequest.destination(), + context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH) + ); context.client.poll(); context.assertElectedLeader(epoch, leaderId); @@ -686,8 +698,12 @@ public class KafkaRaftClientTest { context.pollUntilRequest(); context.assertVotedCandidate(1, localId); - int correlationId = context.assertSentVoteRequest(1, 0, 0L, 1); - context.deliverResponse(correlationId, otherNodeId, context.voteResponse(true, Optional.empty(), 1)); + RaftRequest.Outbound request = context.assertSentVoteRequest(1, 0, 0L, 1); + context.deliverResponse( + request.correlationId(), + request.destination(), + context.voteResponse(true, Optional.empty(), 1) + ); // Become leader after receiving the vote context.pollUntil(() -> context.log.endOffset().offset == 1L); @@ -726,8 +742,12 @@ public class KafkaRaftClientTest { context.pollUntilRequest(); context.assertVotedCandidate(1, localId); - int correlationId = context.assertSentVoteRequest(1, 0, 0L, 2); - context.deliverResponse(correlationId, firstNodeId, context.voteResponse(true, Optional.empty(), 1)); + RaftRequest.Outbound request = context.assertSentVoteRequest(1, 0, 0L, 2); + context.deliverResponse( + request.correlationId(), + request.destination(), + context.voteResponse(true, Optional.empty(), 1) + ); // Become leader after receiving the vote context.pollUntil(() -> context.log.endOffset().offset == 1L); @@ -1102,19 +1122,27 @@ public class KafkaRaftClientTest { context.pollUntilRequest(); context.assertVotedCandidate(epoch, localId); - int correlationId = context.assertSentVoteRequest(epoch, 0, 0L, 1); + RaftRequest.Outbound request = context.assertSentVoteRequest(epoch, 0, 0L, 1); context.time.sleep(context.requestTimeoutMs()); context.client.poll(); - int retryCorrelationId = context.assertSentVoteRequest(epoch, 0, 0L, 1); + RaftRequest.Outbound retryRequest = context.assertSentVoteRequest(epoch, 0, 0L, 1); // We will ignore the timed out response if it arrives late - context.deliverResponse(correlationId, otherNodeId, context.voteResponse(true, Optional.empty(), 1)); + context.deliverResponse( + request.correlationId(), + request.destination(), + context.voteResponse(true, Optional.empty(), 1) + ); context.client.poll(); context.assertVotedCandidate(epoch, localId); // Become leader after receiving the retry response - context.deliverResponse(retryCorrelationId, otherNodeId, context.voteResponse(true, Optional.empty(), 1)); + context.deliverResponse( + retryRequest.correlationId(), + retryRequest.destination(), + context.voteResponse(true, Optional.empty(), 1) + ); context.client.poll(); context.assertElectedLeader(epoch, localId); } @@ -1338,8 +1366,12 @@ public class KafkaRaftClientTest { context.assertVotedCandidate(epoch, localId); // Quorum size is two. If the other member rejects, then we need to schedule a revote. - int correlationId = context.assertSentVoteRequest(epoch, 0, 0L, 1); - context.deliverResponse(correlationId, otherNodeId, context.voteResponse(false, Optional.empty(), 1)); + RaftRequest.Outbound request = context.assertSentVoteRequest(epoch, 0, 0L, 1); + context.deliverResponse( + request.correlationId(), + request.destination(), + context.voteResponse(false, Optional.empty(), 1) + ); context.client.poll(); @@ -1434,11 +1466,14 @@ public class KafkaRaftClientTest { context.pollUntilRequest(); RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(); - assertTrue(voters.contains(fetchRequest.destinationId())); + assertTrue(voters.contains(fetchRequest.destination().id())); context.assertFetchRequestData(fetchRequest, 0, 0L, 0); - context.deliverResponse(fetchRequest.correlationId, fetchRequest.destinationId(), - context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH)); + context.deliverResponse( + fetchRequest.correlationId(), + fetchRequest.destination(), + context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH) + ); context.client.poll(); context.assertElectedLeader(epoch, leaderId); @@ -1450,27 +1485,39 @@ public class KafkaRaftClientTest { int leaderId = 1; int epoch = 5; Set voters = Utils.mkSet(leaderId); + List bootstrapServers = voters + .stream() + .map(RaftClientTestContext::mockAddress) + .collect(Collectors.toList()); - RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters).build(); + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withBootstrapServers(bootstrapServers) + .build(); context.pollUntilRequest(); RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(); - assertTrue(voters.contains(fetchRequest.destinationId())); + assertTrue(context.bootstrapIds.contains(fetchRequest.destination().id())); context.assertFetchRequestData(fetchRequest, 0, 0L, 0); - context.deliverResponse(fetchRequest.correlationId, fetchRequest.destinationId(), - context.fetchResponse(-1, -1, MemoryRecords.EMPTY, -1, Errors.UNKNOWN_SERVER_ERROR)); + context.deliverResponse( + fetchRequest.correlationId(), + fetchRequest.destination(), + context.fetchResponse(-1, -1, MemoryRecords.EMPTY, -1, Errors.UNKNOWN_SERVER_ERROR) + ); context.client.poll(); context.time.sleep(context.retryBackoffMs); context.pollUntilRequest(); fetchRequest = context.assertSentFetchRequest(); - assertTrue(voters.contains(fetchRequest.destinationId())); + assertTrue(context.bootstrapIds.contains(fetchRequest.destination().id())); context.assertFetchRequestData(fetchRequest, 0, 0L, 0); - context.deliverResponse(fetchRequest.correlationId, fetchRequest.destinationId(), - context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH)); + context.deliverResponse( + fetchRequest.correlationId(), + fetchRequest.destination(), + context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH) + ); context.client.poll(); context.assertElectedLeader(epoch, leaderId); @@ -1483,27 +1530,169 @@ public class KafkaRaftClientTest { int otherNodeId = 2; int epoch = 5; Set voters = Utils.mkSet(leaderId, otherNodeId); + List bootstrapServers = voters + .stream() + .map(RaftClientTestContext::mockAddress) + .collect(Collectors.toList()); - RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters).build(); + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withBootstrapServers(bootstrapServers) + .build(); context.pollUntilRequest(); RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(); - assertTrue(voters.contains(fetchRequest.destinationId())); + assertTrue(context.bootstrapIds.contains(fetchRequest.destination().id())); context.assertFetchRequestData(fetchRequest, 0, 0L, 0); - context.deliverResponse(fetchRequest.correlationId, fetchRequest.destinationId(), - context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH)); - context.client.poll(); + context.deliverResponse( + fetchRequest.correlationId(), + fetchRequest.destination(), + context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH) + ); + context.client.poll(); context.assertElectedLeader(epoch, leaderId); + context.time.sleep(context.fetchTimeoutMs); context.pollUntilRequest(); fetchRequest = context.assertSentFetchRequest(); - assertTrue(voters.contains(fetchRequest.destinationId())); + assertNotEquals(leaderId, fetchRequest.destination().id()); + assertTrue(context.bootstrapIds.contains(fetchRequest.destination().id())); context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); } + @Test + public void testObserverHandleRetryFetchtToBootstrapServer() throws Exception { + // This test tries to check that KRaft is able to handle a retrying Fetch request to + // a boostrap server after a Fetch request to the leader. + + int localId = 0; + int leaderId = 1; + int otherNodeId = 2; + int epoch = 5; + Set voters = Utils.mkSet(leaderId, otherNodeId); + List bootstrapServers = voters + .stream() + .map(RaftClientTestContext::mockAddress) + .collect(Collectors.toList()); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withBootstrapServers(bootstrapServers) + .build(); + + // Expect a fetch request to one of the bootstrap servers + context.pollUntilRequest(); + RaftRequest.Outbound discoveryFetchRequest = context.assertSentFetchRequest(); + assertFalse(voters.contains(discoveryFetchRequest.destination().id())); + assertTrue(context.bootstrapIds.contains(discoveryFetchRequest.destination().id())); + context.assertFetchRequestData(discoveryFetchRequest, 0, 0L, 0); + + // Send a response with the leader and epoch + context.deliverResponse( + discoveryFetchRequest.correlationId(), + discoveryFetchRequest.destination(), + context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH) + ); + + context.client.poll(); + context.assertElectedLeader(epoch, leaderId); + + // Expect a fetch request to the leader + context.pollUntilRequest(); + RaftRequest.Outbound toLeaderFetchRequest = context.assertSentFetchRequest(); + assertEquals(leaderId, toLeaderFetchRequest.destination().id()); + context.assertFetchRequestData(toLeaderFetchRequest, epoch, 0L, 0); + + context.time.sleep(context.requestTimeoutMs()); + + // After the fetch timeout expect a request to a bootstrap server + context.pollUntilRequest(); + RaftRequest.Outbound retryToBootstrapServerFetchRequest = context.assertSentFetchRequest(); + assertFalse(voters.contains(retryToBootstrapServerFetchRequest.destination().id())); + assertTrue(context.bootstrapIds.contains(retryToBootstrapServerFetchRequest.destination().id())); + context.assertFetchRequestData(retryToBootstrapServerFetchRequest, epoch, 0L, 0); + + // Deliver the delayed responses from the leader + Records records = context.buildBatch(0L, 3, Arrays.asList("a", "b")); + context.deliverResponse( + toLeaderFetchRequest.correlationId(), + toLeaderFetchRequest.destination(), + context.fetchResponse(epoch, leaderId, records, 0L, Errors.NONE) + ); + + context.client.poll(); + + // Deliver the same delayed responses from the bootstrap server and assume that it is the leader + records = context.buildBatch(0L, 3, Arrays.asList("a", "b")); + context.deliverResponse( + retryToBootstrapServerFetchRequest.correlationId(), + retryToBootstrapServerFetchRequest.destination(), + context.fetchResponse(epoch, leaderId, records, 0L, Errors.NONE) + ); + + // This poll should not fail when handling the duplicate response from the bootstrap server + context.client.poll(); + } + + @Test + public void testObserverHandleRetryFetchToLeader() throws Exception { + // This test tries to check that KRaft is able to handle a retrying Fetch request to + // the leader after a Fetch request to the bootstrap server. + + int localId = 0; + int leaderId = 1; + int otherNodeId = 2; + int epoch = 5; + Set voters = Utils.mkSet(leaderId, otherNodeId); + List bootstrapServers = voters + .stream() + .map(RaftClientTestContext::mockAddress) + .collect(Collectors.toList()); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withBootstrapServers(bootstrapServers) + .build(); + + // Expect a fetch request to one of the bootstrap servers + context.pollUntilRequest(); + RaftRequest.Outbound discoveryFetchRequest = context.assertSentFetchRequest(); + assertFalse(voters.contains(discoveryFetchRequest.destination().id())); + assertTrue(context.bootstrapIds.contains(discoveryFetchRequest.destination().id())); + context.assertFetchRequestData(discoveryFetchRequest, 0, 0L, 0); + + // Send a response with the leader and epoch + context.deliverResponse( + discoveryFetchRequest.correlationId(), + discoveryFetchRequest.destination(), + context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH) + ); + + context.client.poll(); + context.assertElectedLeader(epoch, leaderId); + + // Expect a fetch request to the leader + context.pollUntilRequest(); + RaftRequest.Outbound toLeaderFetchRequest = context.assertSentFetchRequest(); + assertEquals(leaderId, toLeaderFetchRequest.destination().id()); + context.assertFetchRequestData(toLeaderFetchRequest, epoch, 0L, 0); + + context.time.sleep(context.requestTimeoutMs()); + + // After the fetch timeout expect a request to a bootstrap server + context.pollUntilRequest(); + RaftRequest.Outbound retryToBootstrapServerFetchRequest = context.assertSentFetchRequest(); + assertFalse(voters.contains(retryToBootstrapServerFetchRequest.destination().id())); + assertTrue(context.bootstrapIds.contains(retryToBootstrapServerFetchRequest.destination().id())); + context.assertFetchRequestData(retryToBootstrapServerFetchRequest, epoch, 0L, 0); + + // At this point toLeaderFetchRequest has timed out but retryToBootstrapServerFetchRequest + // is still waiting for a response. + // Confirm that no new fetch request has been sent + context.client.poll(); + assertFalse(context.channel.hasSentRequests()); + } + @Test public void testInvalidFetchRequest() throws Exception { int localId = 0; @@ -1828,7 +2017,7 @@ public class KafkaRaftClientTest { // Wait until we have a Fetch inflight to the leader context.pollUntilRequest(); - int fetchCorrelationId = context.assertSentFetchRequest(epoch, 0L, 0); + RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(epoch, 0L, 0); // Now await the fetch timeout and become a candidate context.time.sleep(context.fetchTimeoutMs); @@ -1837,8 +2026,11 @@ public class KafkaRaftClientTest { // The fetch response from the old leader returns, but it should be ignored Records records = context.buildBatch(0L, 3, Arrays.asList("a", "b")); - context.deliverResponse(fetchCorrelationId, otherNodeId, - context.fetchResponse(epoch, otherNodeId, records, 0L, Errors.NONE)); + context.deliverResponse( + fetchRequest.correlationId(), + fetchRequest.destination(), + context.fetchResponse(epoch, otherNodeId, records, 0L, Errors.NONE) + ); context.client.poll(); assertEquals(0, context.log.endOffset().offset); @@ -1862,7 +2054,7 @@ public class KafkaRaftClientTest { // Wait until we have a Fetch inflight to the leader context.pollUntilRequest(); - int fetchCorrelationId = context.assertSentFetchRequest(epoch, 0L, 0); + RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(epoch, 0L, 0); // Now receive a BeginEpoch from `voter3` context.deliverRequest(context.beginEpochRequest(epoch + 1, voter3)); @@ -1872,7 +2064,11 @@ public class KafkaRaftClientTest { // The fetch response from the old leader returns, but it should be ignored Records records = context.buildBatch(0L, 3, Arrays.asList("a", "b")); FetchResponseData response = context.fetchResponse(epoch, voter2, records, 0L, Errors.NONE); - context.deliverResponse(fetchCorrelationId, voter2, response); + context.deliverResponse( + fetchRequest.correlationId(), + fetchRequest.destination(), + response + ); context.client.poll(); assertEquals(0, context.log.endOffset().offset); @@ -1909,10 +2105,18 @@ public class KafkaRaftClientTest { // The vote requests now return and should be ignored VoteResponseData voteResponse1 = context.voteResponse(false, Optional.empty(), epoch); - context.deliverResponse(voteRequests.get(0).correlationId, voter2, voteResponse1); + context.deliverResponse( + voteRequests.get(0).correlationId(), + voteRequests.get(0).destination(), + voteResponse1 + ); VoteResponseData voteResponse2 = context.voteResponse(false, Optional.of(voter3), epoch); - context.deliverResponse(voteRequests.get(1).correlationId, voter3, voteResponse2); + context.deliverResponse( + voteRequests.get(1).correlationId(), + voteRequests.get(1).destination(), + voteResponse2 + ); context.client.poll(); context.assertElectedLeader(epoch, voter3); @@ -1925,31 +2129,43 @@ public class KafkaRaftClientTest { int otherNodeId = 2; int epoch = 5; Set voters = Utils.mkSet(leaderId, otherNodeId); + List bootstrapServers = voters + .stream() + .map(RaftClientTestContext::mockAddress) + .collect(Collectors.toList()); - RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters).build(); + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withBootstrapServers(bootstrapServers) + .build(); context.discoverLeaderAsObserver(leaderId, epoch); context.pollUntilRequest(); RaftRequest.Outbound fetchRequest1 = context.assertSentFetchRequest(); - assertEquals(leaderId, fetchRequest1.destinationId()); + assertEquals(leaderId, fetchRequest1.destination().id()); context.assertFetchRequestData(fetchRequest1, epoch, 0L, 0); - context.deliverResponse(fetchRequest1.correlationId, fetchRequest1.destinationId(), - context.fetchResponse(epoch, -1, MemoryRecords.EMPTY, -1, Errors.BROKER_NOT_AVAILABLE)); + context.deliverResponse( + fetchRequest1.correlationId(), + fetchRequest1.destination(), + context.fetchResponse(epoch, -1, MemoryRecords.EMPTY, -1, Errors.BROKER_NOT_AVAILABLE) + ); context.pollUntilRequest(); // We should retry the Fetch against the other voter since the original // voter connection will be backing off. RaftRequest.Outbound fetchRequest2 = context.assertSentFetchRequest(); - assertNotEquals(leaderId, fetchRequest2.destinationId()); - assertTrue(voters.contains(fetchRequest2.destinationId())); + assertNotEquals(leaderId, fetchRequest2.destination().id()); + assertTrue(context.bootstrapIds.contains(fetchRequest2.destination().id())); context.assertFetchRequestData(fetchRequest2, epoch, 0L, 0); - Errors error = fetchRequest2.destinationId() == leaderId ? + Errors error = fetchRequest2.destination().id() == leaderId ? Errors.NONE : Errors.NOT_LEADER_OR_FOLLOWER; - context.deliverResponse(fetchRequest2.correlationId, fetchRequest2.destinationId(), - context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, error)); + context.deliverResponse( + fetchRequest2.correlationId(), + fetchRequest2.destination(), + context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, error) + ); context.client.poll(); context.assertElectedLeader(epoch, leaderId); @@ -1962,14 +2178,20 @@ public class KafkaRaftClientTest { int otherNodeId = 2; int epoch = 5; Set voters = Utils.mkSet(leaderId, otherNodeId); + List bootstrapServers = voters + .stream() + .map(RaftClientTestContext::mockAddress) + .collect(Collectors.toList()); - RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters).build(); + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withBootstrapServers(bootstrapServers) + .build(); context.discoverLeaderAsObserver(leaderId, epoch); context.pollUntilRequest(); RaftRequest.Outbound fetchRequest1 = context.assertSentFetchRequest(); - assertEquals(leaderId, fetchRequest1.destinationId()); + assertEquals(leaderId, fetchRequest1.destination().id()); context.assertFetchRequestData(fetchRequest1, epoch, 0L, 0); context.time.sleep(context.requestTimeoutMs()); @@ -1978,12 +2200,15 @@ public class KafkaRaftClientTest { // We should retry the Fetch against the other voter since the original // voter connection will be backing off. RaftRequest.Outbound fetchRequest2 = context.assertSentFetchRequest(); - assertNotEquals(leaderId, fetchRequest2.destinationId()); - assertTrue(voters.contains(fetchRequest2.destinationId())); + assertNotEquals(leaderId, fetchRequest2.destination().id()); + assertTrue(context.bootstrapIds.contains(fetchRequest2.destination().id())); context.assertFetchRequestData(fetchRequest2, epoch, 0L, 0); - context.deliverResponse(fetchRequest2.correlationId, fetchRequest2.destinationId(), - context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH)); + context.deliverResponse( + fetchRequest2.correlationId(), + fetchRequest2.destination(), + context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH) + ); context.client.poll(); context.assertElectedLeader(epoch, leaderId); @@ -2273,10 +2498,14 @@ public class KafkaRaftClientTest { context.pollUntilRequest(); - int fetchQuorumCorrelationId = context.assertSentFetchRequest(epoch, 0L, 0); + RaftRequest.Outbound fetchQuorumRequest = context.assertSentFetchRequest(epoch, 0L, 0); Records records = context.buildBatch(0L, 3, Arrays.asList("a", "b")); FetchResponseData response = context.fetchResponse(epoch, otherNodeId, records, 0L, Errors.NONE); - context.deliverResponse(fetchQuorumCorrelationId, otherNodeId, response); + context.deliverResponse( + fetchQuorumRequest.correlationId(), + fetchQuorumRequest.destination(), + response + ); context.client.poll(); assertEquals(2L, context.log.endOffset().offset); @@ -2297,10 +2526,19 @@ public class KafkaRaftClientTest { // Receive an empty fetch response context.pollUntilRequest(); - int fetchQuorumCorrelationId = context.assertSentFetchRequest(epoch, 0L, 0); - FetchResponseData fetchResponse = context.fetchResponse(epoch, otherNodeId, - MemoryRecords.EMPTY, 0L, Errors.NONE); - context.deliverResponse(fetchQuorumCorrelationId, otherNodeId, fetchResponse); + RaftRequest.Outbound fetchQuorumRequest = context.assertSentFetchRequest(epoch, 0L, 0); + FetchResponseData fetchResponse = context.fetchResponse( + epoch, + otherNodeId, + MemoryRecords.EMPTY, + 0L, + Errors.NONE + ); + context.deliverResponse( + fetchQuorumRequest.correlationId(), + fetchQuorumRequest.destination(), + fetchResponse + ); context.client.poll(); assertEquals(0L, context.log.endOffset().offset); assertEquals(OptionalLong.of(0L), context.client.highWatermark()); @@ -2308,20 +2546,32 @@ public class KafkaRaftClientTest { // Receive some records in the next poll, but do not advance high watermark context.pollUntilRequest(); Records records = context.buildBatch(0L, epoch, Arrays.asList("a", "b")); - fetchQuorumCorrelationId = context.assertSentFetchRequest(epoch, 0L, 0); - fetchResponse = context.fetchResponse(epoch, otherNodeId, - records, 0L, Errors.NONE); - context.deliverResponse(fetchQuorumCorrelationId, otherNodeId, fetchResponse); + fetchQuorumRequest = context.assertSentFetchRequest(epoch, 0L, 0); + fetchResponse = context.fetchResponse(epoch, otherNodeId, records, 0L, Errors.NONE); + context.deliverResponse( + fetchQuorumRequest.correlationId(), + fetchQuorumRequest.destination(), + fetchResponse + ); context.client.poll(); assertEquals(2L, context.log.endOffset().offset); assertEquals(OptionalLong.of(0L), context.client.highWatermark()); // The next fetch response is empty, but should still advance the high watermark context.pollUntilRequest(); - fetchQuorumCorrelationId = context.assertSentFetchRequest(epoch, 2L, epoch); - fetchResponse = context.fetchResponse(epoch, otherNodeId, - MemoryRecords.EMPTY, 2L, Errors.NONE); - context.deliverResponse(fetchQuorumCorrelationId, otherNodeId, fetchResponse); + fetchQuorumRequest = context.assertSentFetchRequest(epoch, 2L, epoch); + fetchResponse = context.fetchResponse( + epoch, + otherNodeId, + MemoryRecords.EMPTY, + 2L, + Errors.NONE + ); + context.deliverResponse( + fetchQuorumRequest.correlationId(), + fetchQuorumRequest.destination(), + fetchResponse + ); context.client.poll(); assertEquals(2L, context.log.endOffset().offset); assertEquals(OptionalLong.of(2L), context.client.highWatermark()); @@ -2454,11 +2704,11 @@ public class KafkaRaftClientTest { context.pollUntilRequest(); - int correlationId = context.assertSentFetchRequest(epoch, 3L, lastEpoch); + RaftRequest.Outbound request = context.assertSentFetchRequest(epoch, 3L, lastEpoch); FetchResponseData response = context.divergingFetchResponse(epoch, otherNodeId, 2L, lastEpoch, 1L); - context.deliverResponse(correlationId, otherNodeId, response); + context.deliverResponse(request.correlationId(), request.destination(), response); // Poll again to complete truncation context.client.poll(); @@ -2530,10 +2780,14 @@ public class KafkaRaftClientTest { context.pollUntilRequest(); - int correlationId = context.assertSentFetchRequest(epoch, 0, 0); + RaftRequest.Outbound request = context.assertSentFetchRequest(epoch, 0, 0); FetchResponseData response = new FetchResponseData() .setErrorCode(Errors.CLUSTER_AUTHORIZATION_FAILED.code()); - context.deliverResponse(correlationId, otherNodeId, response); + context.deliverResponse( + request.correlationId(), + request.destination(), + response + ); assertThrows(ClusterAuthorizationException.class, context.client::poll); } @@ -2553,11 +2807,11 @@ public class KafkaRaftClientTest { context.expectAndGrantVotes(epoch); context.pollUntilRequest(); - int correlationId = context.assertSentBeginQuorumEpochRequest(epoch, 1); + RaftRequest.Outbound request = context.assertSentBeginQuorumEpochRequest(epoch, 1); BeginQuorumEpochResponseData response = new BeginQuorumEpochResponseData() .setErrorCode(Errors.CLUSTER_AUTHORIZATION_FAILED.code()); - context.deliverResponse(correlationId, otherNodeId, response); + context.deliverResponse(request.correlationId(), request.destination(), response); assertThrows(ClusterAuthorizationException.class, context.client::poll); } @@ -2577,11 +2831,11 @@ public class KafkaRaftClientTest { context.pollUntilRequest(); context.assertVotedCandidate(epoch, localId); - int correlationId = context.assertSentVoteRequest(epoch, 0, 0L, 1); + RaftRequest.Outbound request = context.assertSentVoteRequest(epoch, 0, 0L, 1); VoteResponseData response = new VoteResponseData() .setErrorCode(Errors.CLUSTER_AUTHORIZATION_FAILED.code()); - context.deliverResponse(correlationId, otherNodeId, response); + context.deliverResponse(request.correlationId(), request.destination(), response); assertThrows(ClusterAuthorizationException.class, context.client::poll); } @@ -2597,11 +2851,11 @@ public class KafkaRaftClientTest { context.client.shutdown(5000); context.pollUntilRequest(); - int correlationId = context.assertSentEndQuorumEpochRequest(epoch, otherNodeId); + RaftRequest.Outbound request = context.assertSentEndQuorumEpochRequest(epoch, otherNodeId); EndQuorumEpochResponseData response = new EndQuorumEpochResponseData() .setErrorCode(Errors.CLUSTER_AUTHORIZATION_FAILED.code()); - context.deliverResponse(correlationId, otherNodeId, response); + context.deliverResponse(request.correlationId(), request.destination(), response); assertThrows(ClusterAuthorizationException.class, context.client::poll); } @@ -2810,14 +3064,17 @@ public class KafkaRaftClientTest { // Poll for our first fetch request context.pollUntilRequest(); RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(); - assertTrue(voters.contains(fetchRequest.destinationId())); + assertTrue(voters.contains(fetchRequest.destination().id())); context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); // The response does not advance the high watermark List records1 = Arrays.asList("a", "b", "c"); MemoryRecords batch1 = context.buildBatch(0L, 3, records1); - context.deliverResponse(fetchRequest.correlationId, fetchRequest.destinationId(), - context.fetchResponse(epoch, otherNodeId, batch1, 0L, Errors.NONE)); + context.deliverResponse( + fetchRequest.correlationId(), + fetchRequest.destination(), + context.fetchResponse(epoch, otherNodeId, batch1, 0L, Errors.NONE) + ); context.client.poll(); // The listener should not have seen any data @@ -2828,14 +3085,17 @@ public class KafkaRaftClientTest { // Now look for the next fetch request context.pollUntilRequest(); fetchRequest = context.assertSentFetchRequest(); - assertTrue(voters.contains(fetchRequest.destinationId())); + assertTrue(voters.contains(fetchRequest.destination().id())); context.assertFetchRequestData(fetchRequest, epoch, 3L, 3); // The high watermark advances to include the first batch we fetched List records2 = Arrays.asList("d", "e", "f"); MemoryRecords batch2 = context.buildBatch(3L, 3, records2); - context.deliverResponse(fetchRequest.correlationId, fetchRequest.destinationId(), - context.fetchResponse(epoch, otherNodeId, batch2, 3L, Errors.NONE)); + context.deliverResponse( + fetchRequest.correlationId(), + fetchRequest.destination(), + context.fetchResponse(epoch, otherNodeId, batch2, 3L, Errors.NONE) + ); context.client.poll(); // The listener should have seen only the data from the first batch @@ -3012,21 +3272,30 @@ public class KafkaRaftClientTest { // This is designed for tooling/debugging use cases. Set voters = Utils.mkSet(1, 2); + List bootstrapServers = voters + .stream() + .map(RaftClientTestContext::mockAddress) + .collect(Collectors.toList()); + RaftClientTestContext context = new RaftClientTestContext.Builder(OptionalInt.empty(), voters) + .withBootstrapServers(bootstrapServers) .build(); // First fetch discovers the current leader and epoch context.pollUntilRequest(); RaftRequest.Outbound fetchRequest1 = context.assertSentFetchRequest(); - assertTrue(voters.contains(fetchRequest1.destinationId())); + assertTrue(context.bootstrapIds.contains(fetchRequest1.destination().id())); context.assertFetchRequestData(fetchRequest1, 0, 0L, 0); int leaderEpoch = 5; int leaderId = 1; - context.deliverResponse(fetchRequest1.correlationId, fetchRequest1.destinationId(), - context.fetchResponse(5, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH)); + context.deliverResponse( + fetchRequest1.correlationId(), + fetchRequest1.destination(), + context.fetchResponse(5, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH) + ); context.client.poll(); context.assertElectedLeader(leaderEpoch, leaderId); @@ -3034,13 +3303,16 @@ public class KafkaRaftClientTest { context.pollUntilRequest(); RaftRequest.Outbound fetchRequest2 = context.assertSentFetchRequest(); - assertEquals(leaderId, fetchRequest2.destinationId()); + assertEquals(leaderId, fetchRequest2.destination().id()); context.assertFetchRequestData(fetchRequest2, leaderEpoch, 0L, 0); List records = Arrays.asList("a", "b", "c"); MemoryRecords batch1 = context.buildBatch(0L, 3, records); - context.deliverResponse(fetchRequest2.correlationId, fetchRequest2.destinationId(), - context.fetchResponse(leaderEpoch, leaderId, batch1, 0L, Errors.NONE)); + context.deliverResponse( + fetchRequest2.correlationId(), + fetchRequest2.destination(), + context.fetchResponse(leaderEpoch, leaderId, batch1, 0L, Errors.NONE) + ); context.client.poll(); assertEquals(3L, context.log.endOffset().offset); assertEquals(3, context.log.lastFetchedEpoch()); diff --git a/raft/src/test/java/org/apache/kafka/raft/MockNetworkChannel.java b/raft/src/test/java/org/apache/kafka/raft/MockNetworkChannel.java index 379290240e0..f9c3efee02a 100644 --- a/raft/src/test/java/org/apache/kafka/raft/MockNetworkChannel.java +++ b/raft/src/test/java/org/apache/kafka/raft/MockNetworkChannel.java @@ -16,31 +16,29 @@ */ package org.apache.kafka.raft; +import org.apache.kafka.common.network.ListenerName; import org.apache.kafka.common.protocol.ApiKeys; -import java.net.InetSocketAddress; import java.util.ArrayList; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; public class MockNetworkChannel implements NetworkChannel { private final AtomicInteger correlationIdCounter; - private final Set nodeCache; private final List sendQueue = new ArrayList<>(); private final Map awaitingResponse = new HashMap<>(); + private final ListenerName listenerName = ListenerName.normalised("CONTROLLER"); - public MockNetworkChannel(AtomicInteger correlationIdCounter, Set destinationIds) { + public MockNetworkChannel(AtomicInteger correlationIdCounter) { this.correlationIdCounter = correlationIdCounter; - this.nodeCache = destinationIds; } - public MockNetworkChannel(Set destinationIds) { - this(new AtomicInteger(0), destinationIds); + public MockNetworkChannel() { + this(new AtomicInteger(0)); } @Override @@ -50,16 +48,12 @@ public class MockNetworkChannel implements NetworkChannel { @Override public void send(RaftRequest.Outbound request) { - if (!nodeCache.contains(request.destinationId())) { - throw new IllegalArgumentException("Attempted to send to destination " + - request.destinationId() + ", but its address is not yet known"); - } sendQueue.add(request); } @Override - public void updateEndpoint(int id, InetSocketAddress address) { - // empty + public ListenerName listenerName() { + return listenerName; } public List drainSendQueue() { @@ -72,7 +66,7 @@ public class MockNetworkChannel implements NetworkChannel { while (iterator.hasNext()) { RaftRequest.Outbound request = iterator.next(); if (!apiKeyFilter.isPresent() || request.data().apiKey() == apiKeyFilter.get().id) { - awaitingResponse.put(request.correlationId, request); + awaitingResponse.put(request.correlationId(), request); requests.add(request); iterator.remove(); } @@ -80,17 +74,15 @@ public class MockNetworkChannel implements NetworkChannel { return requests; } - public boolean hasSentRequests() { return !sendQueue.isEmpty(); } public void mockReceive(RaftResponse.Inbound response) { - RaftRequest.Outbound request = awaitingResponse.get(response.correlationId); + RaftRequest.Outbound request = awaitingResponse.get(response.correlationId()); if (request == null) { throw new IllegalStateException("Received response for a request which is not being awaited"); } request.completion.complete(response); } - } diff --git a/raft/src/test/java/org/apache/kafka/raft/QuorumStateTest.java b/raft/src/test/java/org/apache/kafka/raft/QuorumStateTest.java index 08acba10478..94567dce665 100644 --- a/raft/src/test/java/org/apache/kafka/raft/QuorumStateTest.java +++ b/raft/src/test/java/org/apache/kafka/raft/QuorumStateTest.java @@ -16,12 +16,14 @@ */ package org.apache.kafka.raft; +import org.apache.kafka.common.Node; import org.apache.kafka.common.Uuid; import org.apache.kafka.common.utils.LogContext; import org.apache.kafka.common.utils.MockTime; import org.apache.kafka.common.utils.Utils; import org.apache.kafka.raft.internals.BatchAccumulator; import org.apache.kafka.raft.internals.ReplicaKey; +import org.apache.kafka.raft.internals.VoterSet; import org.apache.kafka.raft.internals.VoterSetTest; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; @@ -29,10 +31,13 @@ import org.mockito.Mockito; import java.io.UncheckedIOException; import java.util.Collections; +import java.util.Map; import java.util.Optional; import java.util.OptionalInt; import java.util.OptionalLong; import java.util.Set; +import java.util.stream.IntStream; +import java.util.stream.Stream; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; @@ -54,19 +59,16 @@ public class QuorumStateTest { private final MockableRandom random = new MockableRandom(1L); private final BatchAccumulator accumulator = Mockito.mock(BatchAccumulator.class); - private QuorumState buildQuorumState(Set voters, short kraftVersion) { - return buildQuorumState(OptionalInt.of(localId), voters, kraftVersion); - } - private QuorumState buildQuorumState( OptionalInt localId, - Set voters, + VoterSet voterSet, short kraftVersion ) { return new QuorumState( localId, localDirectoryId, - () -> VoterSetTest.voterSet(VoterSetTest.voterMap(voters, false)), + VoterSetTest.DEFAULT_LISTENER_NAME, + () -> voterSet, () -> kraftVersion, electionTimeoutMs, fetchTimeoutMs, @@ -77,10 +79,47 @@ public class QuorumStateTest { ); } + private QuorumState initializeEmptyState(VoterSet voters, short kraftVersion) { + QuorumState state = buildQuorumState(OptionalInt.of(localId), voters, kraftVersion); + store.writeElectionState(ElectionState.withUnknownLeader(0, voters.voterIds()), kraftVersion); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + return state; + } + + private Set persistedVoters(Set voters, short kraftVersion) { + if (kraftVersion == 1) { + return Collections.emptySet(); + } + + return voters; + } + + private ReplicaKey persistedVotedKey(ReplicaKey replicaKey, short kraftVersion) { + if (kraftVersion == 1) { + return replicaKey; + } + + return ReplicaKey.of(replicaKey.id(), Optional.empty()); + } + + private VoterSet localStandaloneVoterSet() { + return VoterSetTest.voterSet( + Collections.singletonMap(localId, VoterSetTest.voterNode(localVoterKey)) + ); + } + + private VoterSet localWithRemoteVoterSet(IntStream remoteIds, short kraftVersion) { + boolean withDirectoryId = kraftVersion > 0; + Map voters = VoterSetTest.voterMap(remoteIds, withDirectoryId); + voters.put(localId, VoterSetTest.voterNode(localVoterKey)); + + return VoterSetTest.voterSet(voters); + } + @ParameterizedTest @ValueSource(shorts = {0, 1}) public void testInitializePrimordialEpoch(short kraftVersion) { - Set voters = Utils.mkSet(localId); + VoterSet voters = localStandaloneVoterSet(); assertEquals(Optional.empty(), store.readElectionState()); QuorumState state = initializeEmptyState(voters, kraftVersion); @@ -98,13 +137,13 @@ public class QuorumStateTest { int node1 = 1; int node2 = 2; int epoch = 5; - Set voters = Utils.mkSet(localId, node1, node2); - store.writeElectionState(ElectionState.withUnknownLeader(epoch, voters), kraftVersion); + VoterSet voters = localWithRemoteVoterSet(IntStream.of(node1, node2), kraftVersion); + store.writeElectionState(ElectionState.withUnknownLeader(epoch, voters.voterIds()), kraftVersion); int jitterMs = 2500; random.mockNextInt(jitterMs); - QuorumState state = buildQuorumState(voters, kraftVersion); + QuorumState state = buildQuorumState(OptionalInt.of(localId), voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, 0)); assertTrue(state.isUnattached()); @@ -120,47 +159,44 @@ public class QuorumStateTest { int node1 = 1; int node2 = 2; int epoch = 5; - Set voters = Utils.mkSet(localId, node1, node2); - store.writeElectionState(ElectionState.withElectedLeader(epoch, node1, voters), kraftVersion); + VoterSet voters = localWithRemoteVoterSet(IntStream.of(node1, node2), kraftVersion); + store.writeElectionState(ElectionState.withElectedLeader(epoch, node1, voters.voterIds()), kraftVersion); - QuorumState state = buildQuorumState(voters, kraftVersion); + QuorumState state = buildQuorumState(OptionalInt.of(localId), voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); assertTrue(state.isFollower()); assertEquals(epoch, state.epoch()); FollowerState followerState = state.followerStateOrThrow(); assertEquals(epoch, followerState.epoch()); - assertEquals(node1, followerState.leaderId()); + assertEquals(node1, followerState.leader().id()); assertEquals(fetchTimeoutMs, followerState.remainingFetchTimeMs(time.milliseconds())); } @ParameterizedTest @ValueSource(shorts = {0, 1}) public void testInitializeAsVoted(short kraftVersion) { - int node1 = 1; - Optional node1DirectoryId = Optional.of(Uuid.randomUuid()); - int node2 = 2; + ReplicaKey nodeKey1 = ReplicaKey.of(1, Optional.of(Uuid.randomUuid())); + ReplicaKey nodeKey2 = ReplicaKey.of(2, Optional.of(Uuid.randomUuid())); + int epoch = 5; - Set voters = Utils.mkSet(localId, node1, node2); + VoterSet voters = VoterSetTest.voterSet(Stream.of(localVoterKey, nodeKey1, nodeKey2)); store.writeElectionState( - ElectionState.withVotedCandidate(epoch, ReplicaKey.of(node1, node1DirectoryId), voters), + ElectionState.withVotedCandidate(epoch, nodeKey1, voters.voterIds()), kraftVersion ); int jitterMs = 2500; random.mockNextInt(jitterMs); - QuorumState state = buildQuorumState(voters, kraftVersion); + QuorumState state = buildQuorumState(OptionalInt.of(localId), voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); assertTrue(state.isVoted()); assertEquals(epoch, state.epoch()); VotedState votedState = state.votedStateOrThrow(); assertEquals(epoch, votedState.epoch()); - assertEquals( - ReplicaKey.of(node1, persistedDirectoryId(node1DirectoryId, kraftVersion)), - votedState.votedKey() - ); + assertEquals(persistedVotedKey(nodeKey1, kraftVersion), votedState.votedKey()); assertEquals( electionTimeoutMs + jitterMs, @@ -174,18 +210,18 @@ public class QuorumStateTest { int node1 = 1; int node2 = 2; int epoch = 5; - Set voters = Utils.mkSet(localId, node1, node2); + VoterSet voters = localWithRemoteVoterSet(IntStream.of(node1, node2), kraftVersion); ElectionState election = ElectionState.withVotedCandidate( epoch, localVoterKey, - voters + voters.voterIds() ); store.writeElectionState(election, kraftVersion); int jitterMs = 2500; random.mockNextInt(jitterMs); - QuorumState state = buildQuorumState(voters, kraftVersion); + QuorumState state = buildQuorumState(OptionalInt.of(localId), voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); assertTrue(state.isCandidate()); assertEquals(epoch, state.epoch()); @@ -193,7 +229,7 @@ public class QuorumStateTest { CandidateState candidateState = state.candidateStateOrThrow(); assertEquals(epoch, candidateState.epoch()); assertEquals( - ElectionState.withVotedCandidate(epoch, localVoterKey, voters), + ElectionState.withVotedCandidate(epoch, localVoterKey, voters.voterIds()), candidateState.election() ); assertEquals(Utils.mkSet(node1, node2), candidateState.unrecordedVoters()); @@ -211,8 +247,8 @@ public class QuorumStateTest { int node1 = 1; int node2 = 2; int epoch = 5; - Set voters = Utils.mkSet(localId, node1, node2); - ElectionState election = ElectionState.withElectedLeader(epoch, localId, voters); + VoterSet voters = localWithRemoteVoterSet(IntStream.of(node1, node2), kraftVersion); + ElectionState election = ElectionState.withElectedLeader(epoch, localId, voters.voterIds()); store.writeElectionState(election, kraftVersion); // If we were previously a leader, we will start as resigned in order to ensure @@ -223,7 +259,7 @@ public class QuorumStateTest { int jitterMs = 2500; random.mockNextInt(jitterMs); - QuorumState state = buildQuorumState(voters, kraftVersion); + QuorumState state = buildQuorumState(OptionalInt.of(localId), voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); assertFalse(state.isLeader()); assertEquals(epoch, state.epoch()); @@ -241,7 +277,7 @@ public class QuorumStateTest { public void testCandidateToCandidate(short kraftVersion) { int node1 = 1; int node2 = 2; - Set voters = Utils.mkSet(localId, node1, node2); + VoterSet voters = localWithRemoteVoterSet(IntStream.of(node1, node2), kraftVersion); assertEquals(Optional.empty(), store.readElectionState()); QuorumState state = initializeEmptyState(voters, kraftVersion); @@ -285,7 +321,7 @@ public class QuorumStateTest { public void testCandidateToResigned(short kraftVersion) { int node1 = 1; int node2 = 2; - Set voters = Utils.mkSet(localId, node1, node2); + VoterSet voters = localWithRemoteVoterSet(IntStream.of(node1, node2), kraftVersion); assertEquals(Optional.empty(), store.readElectionState()); QuorumState state = initializeEmptyState(voters, kraftVersion); @@ -301,7 +337,7 @@ public class QuorumStateTest { @ParameterizedTest @ValueSource(shorts = {0, 1}) public void testCandidateToLeader(short kraftVersion) { - Set voters = Utils.mkSet(localId); + VoterSet voters = localStandaloneVoterSet(); assertEquals(Optional.empty(), store.readElectionState()); QuorumState state = initializeEmptyState(voters, kraftVersion); @@ -320,7 +356,7 @@ public class QuorumStateTest { @ValueSource(shorts = {0, 1}) public void testCandidateToLeaderWithoutGrantedVote(short kraftVersion) { int otherNodeId = 1; - Set voters = Utils.mkSet(localId, otherNodeId); + VoterSet voters = localWithRemoteVoterSet(IntStream.of(otherNodeId), kraftVersion); QuorumState state = initializeEmptyState(voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); state.transitionToCandidate(); @@ -336,16 +372,23 @@ public class QuorumStateTest { @ValueSource(shorts = {0, 1}) public void testCandidateToFollower(short kraftVersion) { int otherNodeId = 1; - Set voters = Utils.mkSet(localId, otherNodeId); + + VoterSet voters = localWithRemoteVoterSet(IntStream.of(otherNodeId), kraftVersion); QuorumState state = initializeEmptyState(voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); state.transitionToCandidate(); - state.transitionToFollower(5, otherNodeId); + state.transitionToFollower(5, voters.voterNode(otherNodeId, VoterSetTest.DEFAULT_LISTENER_NAME).get()); assertEquals(5, state.epoch()); assertEquals(OptionalInt.of(otherNodeId), state.leaderId()); assertEquals( - Optional.of(ElectionState.withElectedLeader(5, otherNodeId, persistedVoters(voters, kraftVersion))), + Optional.of( + ElectionState.withElectedLeader( + 5, + otherNodeId, + persistedVoters(voters.voterIds(), kraftVersion) + ) + ), store.readElectionState() ); } @@ -354,7 +397,7 @@ public class QuorumStateTest { @ValueSource(shorts = {0, 1}) public void testCandidateToUnattached(short kraftVersion) { int otherNodeId = 1; - Set voters = Utils.mkSet(localId, otherNodeId); + VoterSet voters = localWithRemoteVoterSet(IntStream.of(otherNodeId), kraftVersion); QuorumState state = initializeEmptyState(voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); state.transitionToCandidate(); @@ -363,7 +406,12 @@ public class QuorumStateTest { assertEquals(5, state.epoch()); assertEquals(OptionalInt.empty(), state.leaderId()); assertEquals( - Optional.of(ElectionState.withUnknownLeader(5, persistedVoters(voters, kraftVersion))), + Optional.of( + ElectionState.withUnknownLeader( + 5, + persistedVoters(voters.voterIds(), kraftVersion) + ) + ), store.readElectionState() ); } @@ -371,10 +419,8 @@ public class QuorumStateTest { @ParameterizedTest @ValueSource(shorts = {0, 1}) public void testCandidateToVoted(short kraftVersion) { - int otherNodeId = 1; - Optional otherNodeDirectoryId = Optional.of(Uuid.randomUuid()); - ReplicaKey otherNodeKey = ReplicaKey.of(otherNodeId, otherNodeDirectoryId); - Set voters = Utils.mkSet(localId, otherNodeId); + ReplicaKey otherNodeKey = ReplicaKey.of(1, Optional.of(Uuid.randomUuid())); + VoterSet voters = VoterSetTest.voterSet(Stream.of(localVoterKey, otherNodeKey)); QuorumState state = initializeEmptyState(voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); state.transitionToCandidate(); @@ -390,11 +436,8 @@ public class QuorumStateTest { Optional.of( ElectionState.withVotedCandidate( 5, - ReplicaKey.of( - otherNodeId, - persistedDirectoryId(otherNodeDirectoryId, kraftVersion) - ), - persistedVoters(voters, kraftVersion)) + persistedVotedKey(otherNodeKey, kraftVersion), + persistedVoters(voters.voterIds(), kraftVersion)) ), store.readElectionState() ); @@ -403,27 +446,28 @@ public class QuorumStateTest { @ParameterizedTest @ValueSource(shorts = {0, 1}) public void testCandidateToAnyStateLowerEpoch(short kraftVersion) { - int otherNodeId = 1; - Optional otherNodeDirectoryId = Optional.of(Uuid.randomUuid()); - ReplicaKey otherNodeKey = ReplicaKey.of(otherNodeId, otherNodeDirectoryId); - Set voters = Utils.mkSet(localId, otherNodeId); + ReplicaKey otherNodeKey = ReplicaKey.of(1, Optional.of(Uuid.randomUuid())); + VoterSet voters = VoterSetTest.voterSet(Stream.of(localVoterKey, otherNodeKey)); QuorumState state = initializeEmptyState(voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); state.transitionToUnattached(5); state.transitionToCandidate(); assertThrows(IllegalStateException.class, () -> state.transitionToUnattached(4)); assertThrows(IllegalStateException.class, () -> state.transitionToVoted(4, otherNodeKey)); - assertThrows(IllegalStateException.class, () -> state.transitionToFollower(4, otherNodeId)); + assertThrows( + IllegalStateException.class, + () -> state.transitionToFollower( + 4, + voters.voterNode(otherNodeKey.id(), VoterSetTest.DEFAULT_LISTENER_NAME).get() + ) + ); assertEquals(6, state.epoch()); assertEquals( Optional.of( ElectionState.withVotedCandidate( 6, - ReplicaKey.of( - localId, - persistedDirectoryId(Optional.of(localDirectoryId), kraftVersion) - ), - persistedVoters(voters, kraftVersion) + persistedVotedKey(localVoterKey, kraftVersion), + persistedVoters(voters.voterIds(), kraftVersion) ) ), store.readElectionState() @@ -433,7 +477,7 @@ public class QuorumStateTest { @ParameterizedTest @ValueSource(shorts = {0, 1}) public void testLeaderToLeader(short kraftVersion) { - Set voters = Utils.mkSet(localId); + VoterSet voters = localStandaloneVoterSet(); assertEquals(Optional.empty(), store.readElectionState()); QuorumState state = initializeEmptyState(voters, kraftVersion); @@ -451,7 +495,7 @@ public class QuorumStateTest { @ParameterizedTest @ValueSource(shorts = {0, 1}) public void testLeaderToResigned(short kraftVersion) { - Set voters = Utils.mkSet(localId); + VoterSet voters = localStandaloneVoterSet(); assertEquals(Optional.empty(), store.readElectionState()); QuorumState state = initializeEmptyState(voters, kraftVersion); @@ -464,8 +508,10 @@ public class QuorumStateTest { state.transitionToResigned(Collections.singletonList(localId)); assertTrue(state.isResigned()); ResignedState resignedState = state.resignedStateOrThrow(); - assertEquals(ElectionState.withElectedLeader(1, localId, voters), - resignedState.election()); + assertEquals( + ElectionState.withElectedLeader(1, localId, voters.voterIds()), + resignedState.election() + ); assertEquals(1, resignedState.epoch()); assertEquals(Collections.emptySet(), resignedState.unackedVoters()); } @@ -473,7 +519,7 @@ public class QuorumStateTest { @ParameterizedTest @ValueSource(shorts = {0, 1}) public void testLeaderToCandidate(short kraftVersion) { - Set voters = Utils.mkSet(localId); + VoterSet voters = localStandaloneVoterSet(); assertEquals(Optional.empty(), store.readElectionState()); QuorumState state = initializeEmptyState(voters, kraftVersion); @@ -492,19 +538,25 @@ public class QuorumStateTest { @ValueSource(shorts = {0, 1}) public void testLeaderToFollower(short kraftVersion) { int otherNodeId = 1; - Set voters = Utils.mkSet(localId, otherNodeId); + VoterSet voters = localWithRemoteVoterSet(IntStream.of(otherNodeId), kraftVersion); QuorumState state = initializeEmptyState(voters, kraftVersion); state.transitionToCandidate(); state.candidateStateOrThrow().recordGrantedVote(otherNodeId); state.transitionToLeader(0L, accumulator); - state.transitionToFollower(5, otherNodeId); + state.transitionToFollower(5, voters.voterNode(otherNodeId, VoterSetTest.DEFAULT_LISTENER_NAME).get()); assertEquals(5, state.epoch()); assertEquals(OptionalInt.of(otherNodeId), state.leaderId()); assertEquals( - Optional.of(ElectionState.withElectedLeader(5, otherNodeId, persistedVoters(voters, kraftVersion))), + Optional.of( + ElectionState.withElectedLeader( + 5, + otherNodeId, + persistedVoters(voters.voterIds(), kraftVersion) + ) + ), store.readElectionState() ); } @@ -513,7 +565,7 @@ public class QuorumStateTest { @ValueSource(shorts = {0, 1}) public void testLeaderToUnattached(short kraftVersion) { int otherNodeId = 1; - Set voters = Utils.mkSet(localId, otherNodeId); + VoterSet voters = localWithRemoteVoterSet(IntStream.of(otherNodeId), kraftVersion); QuorumState state = initializeEmptyState(voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); state.transitionToCandidate(); @@ -523,7 +575,12 @@ public class QuorumStateTest { assertEquals(5, state.epoch()); assertEquals(OptionalInt.empty(), state.leaderId()); assertEquals( - Optional.of(ElectionState.withUnknownLeader(5, persistedVoters(voters, kraftVersion))), + Optional.of( + ElectionState.withUnknownLeader( + 5, + persistedVoters(voters.voterIds(), kraftVersion) + ) + ), store.readElectionState() ); } @@ -531,14 +588,12 @@ public class QuorumStateTest { @ParameterizedTest @ValueSource(shorts = {0, 1}) public void testLeaderToVoted(short kraftVersion) { - int otherNodeId = 1; - Optional otherNodeDirectoryId = Optional.of(Uuid.randomUuid()); - ReplicaKey otherNodeKey = ReplicaKey.of(otherNodeId, otherNodeDirectoryId); - Set voters = Utils.mkSet(localId, otherNodeId); + ReplicaKey otherNodeKey = ReplicaKey.of(1, Optional.of(Uuid.randomUuid())); + VoterSet voters = VoterSetTest.voterSet(Stream.of(localVoterKey, otherNodeKey)); QuorumState state = initializeEmptyState(voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); state.transitionToCandidate(); - state.candidateStateOrThrow().recordGrantedVote(otherNodeId); + state.candidateStateOrThrow().recordGrantedVote(otherNodeKey.id()); state.transitionToLeader(0L, accumulator); state.transitionToVoted(5, otherNodeKey); @@ -552,11 +607,8 @@ public class QuorumStateTest { Optional.of( ElectionState.withVotedCandidate( 5, - ReplicaKey.of( - otherNodeId, - persistedDirectoryId(otherNodeDirectoryId, kraftVersion) - ), - persistedVoters(voters, kraftVersion) + persistedVotedKey(otherNodeKey, kraftVersion), + persistedVoters(voters.voterIds(), kraftVersion) ) ), store.readElectionState() @@ -566,22 +618,32 @@ public class QuorumStateTest { @ParameterizedTest @ValueSource(shorts = {0, 1}) public void testLeaderToAnyStateLowerEpoch(short kraftVersion) { - int otherNodeId = 1; - Optional otherNodeDirectoryId = Optional.of(Uuid.randomUuid()); - ReplicaKey otherNodeKey = ReplicaKey.of(otherNodeId, otherNodeDirectoryId); - Set voters = Utils.mkSet(localId, otherNodeId); + ReplicaKey otherNodeKey = ReplicaKey.of(1, Optional.of(Uuid.randomUuid())); + VoterSet voters = VoterSetTest.voterSet(Stream.of(localVoterKey, otherNodeKey)); QuorumState state = initializeEmptyState(voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); state.transitionToUnattached(5); state.transitionToCandidate(); - state.candidateStateOrThrow().recordGrantedVote(otherNodeId); + state.candidateStateOrThrow().recordGrantedVote(otherNodeKey.id()); state.transitionToLeader(0L, accumulator); assertThrows(IllegalStateException.class, () -> state.transitionToUnattached(4)); assertThrows(IllegalStateException.class, () -> state.transitionToVoted(4, otherNodeKey)); - assertThrows(IllegalStateException.class, () -> state.transitionToFollower(4, otherNodeId)); + assertThrows( + IllegalStateException.class, + () -> state.transitionToFollower( + 4, + voters.voterNode(otherNodeKey.id(), VoterSetTest.DEFAULT_LISTENER_NAME).get() + ) + ); assertEquals(6, state.epoch()); assertEquals( - Optional.of(ElectionState.withElectedLeader(6, localId, persistedVoters(voters, kraftVersion))), + Optional.of( + ElectionState.withElectedLeader( + 6, + localId, + persistedVoters(voters.voterIds(), kraftVersion) + ) + ), store.readElectionState() ); } @@ -589,26 +651,28 @@ public class QuorumStateTest { @ParameterizedTest @ValueSource(shorts = {0, 1}) public void testCannotFollowOrVoteForSelf(short kraftVersion) { - Set voters = Utils.mkSet(localId); + VoterSet voters = localStandaloneVoterSet(); assertEquals(Optional.empty(), store.readElectionState()); QuorumState state = initializeEmptyState(voters, kraftVersion); - assertThrows(IllegalStateException.class, () -> state.transitionToFollower(0, localId)); + assertThrows( + IllegalStateException.class, + () -> state.transitionToFollower( + 0, + voters.voterNode(localId, VoterSetTest.DEFAULT_LISTENER_NAME).get() + ) + ); assertThrows(IllegalStateException.class, () -> state.transitionToVoted(0, localVoterKey)); } @ParameterizedTest @ValueSource(shorts = {0, 1}) public void testUnattachedToLeaderOrResigned(short kraftVersion) { - int leaderId = 1; + ReplicaKey leaderKey = ReplicaKey.of(1, Optional.of(Uuid.randomUuid())); int epoch = 5; - Set voters = Utils.mkSet(localId, leaderId); + VoterSet voters = VoterSetTest.voterSet(Stream.of(localVoterKey, leaderKey)); store.writeElectionState( - ElectionState.withVotedCandidate( - epoch, - ReplicaKey.of(leaderId, Optional.empty()), - voters - ), + ElectionState.withVotedCandidate(epoch, leaderKey, voters.voterIds()), kraftVersion ); QuorumState state = initializeEmptyState(voters, kraftVersion); @@ -621,10 +685,8 @@ public class QuorumStateTest { @ParameterizedTest @ValueSource(shorts = {0, 1}) public void testUnattachedToVotedSameEpoch(short kraftVersion) { - int otherNodeId = 1; - Optional otherNodeDirectoryId = Optional.of(Uuid.randomUuid()); - ReplicaKey otherNodeKey = ReplicaKey.of(otherNodeId, otherNodeDirectoryId); - Set voters = Utils.mkSet(localId, otherNodeId); + ReplicaKey otherNodeKey = ReplicaKey.of(1, Optional.of(Uuid.randomUuid())); + VoterSet voters = VoterSetTest.voterSet(Stream.of(localVoterKey, otherNodeKey)); QuorumState state = initializeEmptyState(voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); state.transitionToUnattached(5); @@ -641,11 +703,8 @@ public class QuorumStateTest { Optional.of( ElectionState.withVotedCandidate( 5, - ReplicaKey.of( - otherNodeId, - persistedDirectoryId(otherNodeDirectoryId, kraftVersion) - ), - persistedVoters(voters, kraftVersion) + persistedVotedKey(otherNodeKey, kraftVersion), + persistedVoters(voters.voterIds(), kraftVersion) ) ), store.readElectionState() @@ -659,10 +718,8 @@ public class QuorumStateTest { @ParameterizedTest @ValueSource(shorts = {0, 1}) public void testUnattachedToVotedHigherEpoch(short kraftVersion) { - int otherNodeId = 1; - Optional otherNodeDirectoryId = Optional.of(Uuid.randomUuid()); - ReplicaKey otherNodeKey = ReplicaKey.of(otherNodeId, otherNodeDirectoryId); - Set voters = Utils.mkSet(localId, otherNodeId); + ReplicaKey otherNodeKey = ReplicaKey.of(1, Optional.of(Uuid.randomUuid())); + VoterSet voters = VoterSetTest.voterSet(Stream.of(localVoterKey, otherNodeKey)); QuorumState state = initializeEmptyState(voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); state.transitionToUnattached(5); @@ -676,11 +733,8 @@ public class QuorumStateTest { Optional.of( ElectionState.withVotedCandidate( 8, - ReplicaKey.of( - otherNodeId, - persistedDirectoryId(otherNodeDirectoryId, kraftVersion) - ), - persistedVoters(voters, kraftVersion) + persistedVotedKey(otherNodeKey, kraftVersion), + persistedVoters(voters.voterIds(), kraftVersion) ) ), store.readElectionState() @@ -690,8 +744,8 @@ public class QuorumStateTest { @ParameterizedTest @ValueSource(shorts = {0, 1}) public void testUnattachedToCandidate(short kraftVersion) { - int otherNodeId = 1; - Set voters = Utils.mkSet(localId, otherNodeId); + ReplicaKey otherNodeKey = ReplicaKey.of(1, Optional.of(Uuid.randomUuid())); + VoterSet voters = VoterSetTest.voterSet(Stream.of(localVoterKey, otherNodeKey)); QuorumState state = initializeEmptyState(voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); state.transitionToUnattached(5); @@ -710,8 +764,8 @@ public class QuorumStateTest { @ParameterizedTest @ValueSource(shorts = {0, 1}) public void testUnattachedToUnattached(short kraftVersion) { - int otherNodeId = 1; - Set voters = Utils.mkSet(localId, otherNodeId); + ReplicaKey otherNodeKey = ReplicaKey.of(1, Optional.of(Uuid.randomUuid())); + VoterSet voters = VoterSetTest.voterSet(Stream.of(localVoterKey, otherNodeKey)); QuorumState state = initializeEmptyState(voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); state.transitionToUnattached(5); @@ -731,52 +785,74 @@ public class QuorumStateTest { @ParameterizedTest @ValueSource(shorts = {0, 1}) public void testUnattachedToFollowerSameEpoch(short kraftVersion) { - int otherNodeId = 1; - Set voters = Utils.mkSet(localId, otherNodeId); + ReplicaKey otherNodeKey = ReplicaKey.of(1, Optional.of(Uuid.randomUuid())); + VoterSet voters = VoterSetTest.voterSet(Stream.of(localVoterKey, otherNodeKey)); QuorumState state = initializeEmptyState(voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); state.transitionToUnattached(5); - state.transitionToFollower(5, otherNodeId); + state.transitionToFollower( + 5, + voters.voterNode(otherNodeKey.id(), VoterSetTest.DEFAULT_LISTENER_NAME).get() + ); assertTrue(state.isFollower()); FollowerState followerState = state.followerStateOrThrow(); assertEquals(5, followerState.epoch()); - assertEquals(otherNodeId, followerState.leaderId()); + assertEquals( + voters.voterNode(otherNodeKey.id(), VoterSetTest.DEFAULT_LISTENER_NAME), + Optional.of(followerState.leader()) + ); assertEquals(fetchTimeoutMs, followerState.remainingFetchTimeMs(time.milliseconds())); } @ParameterizedTest @ValueSource(shorts = {0, 1}) public void testUnattachedToFollowerHigherEpoch(short kraftVersion) { - int otherNodeId = 1; - Set voters = Utils.mkSet(localId, otherNodeId); + ReplicaKey otherNodeKey = ReplicaKey.of(1, Optional.of(Uuid.randomUuid())); + VoterSet voters = VoterSetTest.voterSet(Stream.of(localVoterKey, otherNodeKey)); QuorumState state = initializeEmptyState(voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); state.transitionToUnattached(5); - state.transitionToFollower(8, otherNodeId); + state.transitionToFollower( + 8, + voters.voterNode(otherNodeKey.id(), VoterSetTest.DEFAULT_LISTENER_NAME).get() + ); assertTrue(state.isFollower()); FollowerState followerState = state.followerStateOrThrow(); assertEquals(8, followerState.epoch()); - assertEquals(otherNodeId, followerState.leaderId()); + assertEquals( + voters.voterNode(otherNodeKey.id(), VoterSetTest.DEFAULT_LISTENER_NAME), + Optional.of(followerState.leader()) + ); assertEquals(fetchTimeoutMs, followerState.remainingFetchTimeMs(time.milliseconds())); } @ParameterizedTest @ValueSource(shorts = {0, 1}) public void testUnattachedToAnyStateLowerEpoch(short kraftVersion) { - int otherNodeId = 1; - ReplicaKey otherNodeKey = ReplicaKey.of(otherNodeId, Optional.empty()); - Set voters = Utils.mkSet(localId, otherNodeId); + ReplicaKey otherNodeKey = ReplicaKey.of(1, Optional.of(Uuid.randomUuid())); + VoterSet voters = VoterSetTest.voterSet(Stream.of(localVoterKey, otherNodeKey)); QuorumState state = initializeEmptyState(voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); state.transitionToUnattached(5); assertThrows(IllegalStateException.class, () -> state.transitionToUnattached(4)); assertThrows(IllegalStateException.class, () -> state.transitionToVoted(4, otherNodeKey)); - assertThrows(IllegalStateException.class, () -> state.transitionToFollower(4, otherNodeId)); + assertThrows( + IllegalStateException.class, + () -> state.transitionToFollower( + 4, + voters.voterNode(otherNodeKey.id(), VoterSetTest.DEFAULT_LISTENER_NAME).get() + ) + ); assertEquals(5, state.epoch()); assertEquals( - Optional.of(ElectionState.withUnknownLeader(5, persistedVoters(voters, kraftVersion))), + Optional.of( + ElectionState.withUnknownLeader( + 5, + persistedVoters(voters.voterIds(), kraftVersion) + ) + ), store.readElectionState() ); } @@ -786,7 +862,7 @@ public class QuorumStateTest { public void testVotedToInvalidLeaderOrResigned(short kraftVersion) { int node1 = 1; int node2 = 2; - Set voters = Utils.mkSet(localId, node1, node2); + VoterSet voters = localWithRemoteVoterSet(IntStream.of(node1, node2), kraftVersion); QuorumState state = initializeEmptyState(voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); state.transitionToVoted(5, ReplicaKey.of(node1, Optional.empty())); @@ -799,7 +875,7 @@ public class QuorumStateTest { public void testVotedToCandidate(short kraftVersion) { int node1 = 1; int node2 = 2; - Set voters = Utils.mkSet(localId, node1, node2); + VoterSet voters = localWithRemoteVoterSet(IntStream.of(node1, node2), kraftVersion); QuorumState state = initializeEmptyState(voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); state.transitionToVoted(5, ReplicaKey.of(node1, Optional.empty())); @@ -819,7 +895,7 @@ public class QuorumStateTest { public void testVotedToVotedSameEpoch(short kraftVersion) { int node1 = 1; int node2 = 2; - Set voters = Utils.mkSet(localId, node1, node2); + VoterSet voters = localWithRemoteVoterSet(IntStream.of(node1, node2), kraftVersion); QuorumState state = initializeEmptyState(voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); state.transitionToUnattached(5); @@ -839,17 +915,29 @@ public class QuorumStateTest { public void testVotedToFollowerSameEpoch(short kraftVersion) { int node1 = 1; int node2 = 2; - Set voters = Utils.mkSet(localId, node1, node2); + VoterSet voters = localWithRemoteVoterSet(IntStream.of(node1, node2), kraftVersion); QuorumState state = initializeEmptyState(voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); state.transitionToVoted(5, ReplicaKey.of(node1, Optional.empty())); - state.transitionToFollower(5, node2); + state.transitionToFollower( + 5, + voters.voterNode(node2, VoterSetTest.DEFAULT_LISTENER_NAME).get() + ); FollowerState followerState = state.followerStateOrThrow(); assertEquals(5, followerState.epoch()); - assertEquals(node2, followerState.leaderId()); assertEquals( - Optional.of(ElectionState.withElectedLeader(5, node2, persistedVoters(voters, kraftVersion))), + voters.voterNode(node2, VoterSetTest.DEFAULT_LISTENER_NAME), + Optional.of(followerState.leader()) + ); + assertEquals( + Optional.of( + ElectionState.withElectedLeader( + 5, + node2, + persistedVoters(voters.voterIds(), kraftVersion) + ) + ), store.readElectionState() ); } @@ -859,17 +947,29 @@ public class QuorumStateTest { public void testVotedToFollowerHigherEpoch(short kraftVersion) { int node1 = 1; int node2 = 2; - Set voters = Utils.mkSet(localId, node1, node2); + VoterSet voters = localWithRemoteVoterSet(IntStream.of(node1, node2), kraftVersion); QuorumState state = initializeEmptyState(voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); state.transitionToVoted(5, ReplicaKey.of(node1, Optional.empty())); - state.transitionToFollower(8, node2); + state.transitionToFollower( + 8, + voters.voterNode(node2, VoterSetTest.DEFAULT_LISTENER_NAME).get() + ); FollowerState followerState = state.followerStateOrThrow(); assertEquals(8, followerState.epoch()); - assertEquals(node2, followerState.leaderId()); assertEquals( - Optional.of(ElectionState.withElectedLeader(8, node2, persistedVoters(voters, kraftVersion))), + voters.voterNode(node2, VoterSetTest.DEFAULT_LISTENER_NAME).get(), + followerState.leader() + ); + assertEquals( + Optional.of( + ElectionState.withElectedLeader( + 8, + node2, + persistedVoters(voters.voterIds(), kraftVersion) + ) + ), store.readElectionState() ); } @@ -879,7 +979,7 @@ public class QuorumStateTest { public void testVotedToUnattachedSameEpoch(short kraftVersion) { int node1 = 1; int node2 = 2; - Set voters = Utils.mkSet(localId, node1, node2); + VoterSet voters = localWithRemoteVoterSet(IntStream.of(node1, node2), kraftVersion); QuorumState state = initializeEmptyState(voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); state.transitionToVoted(5, ReplicaKey.of(node1, Optional.empty())); @@ -890,7 +990,7 @@ public class QuorumStateTest { @ValueSource(shorts = {0, 1}) public void testVotedToUnattachedHigherEpoch(short kraftVersion) { int otherNodeId = 1; - Set voters = Utils.mkSet(localId, otherNodeId); + VoterSet voters = localWithRemoteVoterSet(IntStream.of(otherNodeId), kraftVersion); QuorumState state = initializeEmptyState(voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); state.transitionToVoted(5, ReplicaKey.of(otherNodeId, Optional.empty())); @@ -910,26 +1010,27 @@ public class QuorumStateTest { @ParameterizedTest @ValueSource(shorts = {0, 1}) public void testVotedToAnyStateLowerEpoch(short kraftVersion) { - int otherNodeId = 1; - Optional otherNodeDirectoryId = Optional.of(Uuid.randomUuid()); - ReplicaKey otherNodeKey = ReplicaKey.of(otherNodeId, otherNodeDirectoryId); - Set voters = Utils.mkSet(localId, otherNodeId); + ReplicaKey otherNodeKey = ReplicaKey.of(1, Optional.of(Uuid.randomUuid())); + VoterSet voters = VoterSetTest.voterSet(Stream.of(localVoterKey, otherNodeKey)); QuorumState state = initializeEmptyState(voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); state.transitionToVoted(5, otherNodeKey); assertThrows(IllegalStateException.class, () -> state.transitionToUnattached(4)); assertThrows(IllegalStateException.class, () -> state.transitionToVoted(4, otherNodeKey)); - assertThrows(IllegalStateException.class, () -> state.transitionToFollower(4, otherNodeId)); + assertThrows( + IllegalStateException.class, + () -> state.transitionToFollower( + 4, + voters.voterNode(otherNodeKey.id(), VoterSetTest.DEFAULT_LISTENER_NAME).get() + ) + ); assertEquals(5, state.epoch()); assertEquals( Optional.of( ElectionState.withVotedCandidate( 5, - ReplicaKey.of( - otherNodeId, - persistedDirectoryId(otherNodeDirectoryId, kraftVersion) - ), - persistedVoters(voters, kraftVersion) + persistedVotedKey(otherNodeKey, kraftVersion), + persistedVoters(voters.voterIds(), kraftVersion) ) ), store.readElectionState() @@ -941,18 +1042,42 @@ public class QuorumStateTest { public void testFollowerToFollowerSameEpoch(short kraftVersion) { int node1 = 1; int node2 = 2; - Set voters = Utils.mkSet(localId, node1, node2); + VoterSet voters = localWithRemoteVoterSet(IntStream.of(node1, node2), kraftVersion); QuorumState state = initializeEmptyState(voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); - state.transitionToFollower(8, node2); - assertThrows(IllegalStateException.class, () -> state.transitionToFollower(8, node1)); - assertThrows(IllegalStateException.class, () -> state.transitionToFollower(8, node2)); + state.transitionToFollower( + 8, + voters.voterNode(node2, VoterSetTest.DEFAULT_LISTENER_NAME).get() + ); + assertThrows( + IllegalStateException.class, + () -> state.transitionToFollower( + 8, + voters.voterNode(node1, VoterSetTest.DEFAULT_LISTENER_NAME).get() + ) + ); + assertThrows( + IllegalStateException.class, + () -> state.transitionToFollower( + 8, + voters.voterNode(node2, VoterSetTest.DEFAULT_LISTENER_NAME).get() + ) + ); FollowerState followerState = state.followerStateOrThrow(); assertEquals(8, followerState.epoch()); - assertEquals(node2, followerState.leaderId()); assertEquals( - Optional.of(ElectionState.withElectedLeader(8, node2, persistedVoters(voters, kraftVersion))), + voters.voterNode(node2, VoterSetTest.DEFAULT_LISTENER_NAME), + Optional.of(followerState.leader()) + ); + assertEquals( + Optional.of( + ElectionState.withElectedLeader( + 8, + node2, + persistedVoters(voters.voterIds(), kraftVersion) + ) + ), store.readElectionState() ); } @@ -962,17 +1087,32 @@ public class QuorumStateTest { public void testFollowerToFollowerHigherEpoch(short kraftVersion) { int node1 = 1; int node2 = 2; - Set voters = Utils.mkSet(localId, node1, node2); + VoterSet voters = localWithRemoteVoterSet(IntStream.of(node1, node2), kraftVersion); QuorumState state = initializeEmptyState(voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); - state.transitionToFollower(8, node2); - state.transitionToFollower(9, node1); + state.transitionToFollower( + 8, + voters.voterNode(node2, VoterSetTest.DEFAULT_LISTENER_NAME).get() + ); + state.transitionToFollower( + 9, + voters.voterNode(node1, VoterSetTest.DEFAULT_LISTENER_NAME).get() + ); FollowerState followerState = state.followerStateOrThrow(); assertEquals(9, followerState.epoch()); - assertEquals(node1, followerState.leaderId()); assertEquals( - Optional.of(ElectionState.withElectedLeader(9, node1, persistedVoters(voters, kraftVersion))), + voters.voterNode(node1, VoterSetTest.DEFAULT_LISTENER_NAME), + Optional.of(followerState.leader()) + ); + assertEquals( + Optional.of( + ElectionState.withElectedLeader( + 9, + node1, + persistedVoters(voters.voterIds(), kraftVersion) + ) + ), store.readElectionState() ); } @@ -982,10 +1122,13 @@ public class QuorumStateTest { public void testFollowerToLeaderOrResigned(short kraftVersion) { int node1 = 1; int node2 = 2; - Set voters = Utils.mkSet(localId, node1, node2); + VoterSet voters = localWithRemoteVoterSet(IntStream.of(node1, node2), kraftVersion); QuorumState state = initializeEmptyState(voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); - state.transitionToFollower(8, node2); + state.transitionToFollower( + 8, + voters.voterNode(node2, VoterSetTest.DEFAULT_LISTENER_NAME).get() + ); assertThrows(IllegalStateException.class, () -> state.transitionToLeader(0, accumulator)); assertThrows(IllegalStateException.class, () -> state.transitionToResigned(Collections.emptyList())); } @@ -995,10 +1138,13 @@ public class QuorumStateTest { public void testFollowerToCandidate(short kraftVersion) { int node1 = 1; int node2 = 2; - Set voters = Utils.mkSet(localId, node1, node2); + VoterSet voters = localWithRemoteVoterSet(IntStream.of(node1, node2), kraftVersion); QuorumState state = initializeEmptyState(voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); - state.transitionToFollower(8, node2); + state.transitionToFollower( + 8, + voters.voterNode(node2, VoterSetTest.DEFAULT_LISTENER_NAME).get() + ); int jitterMs = 2500; random.mockNextInt(electionTimeoutMs, jitterMs); @@ -1015,10 +1161,13 @@ public class QuorumStateTest { public void testFollowerToUnattachedSameEpoch(short kraftVersion) { int node1 = 1; int node2 = 2; - Set voters = Utils.mkSet(localId, node1, node2); + VoterSet voters = localWithRemoteVoterSet(IntStream.of(node1, node2), kraftVersion); QuorumState state = initializeEmptyState(voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); - state.transitionToFollower(8, node2); + state.transitionToFollower( + 8, + voters.voterNode(node2, VoterSetTest.DEFAULT_LISTENER_NAME).get() + ); assertThrows(IllegalStateException.class, () -> state.transitionToUnattached(8)); } @@ -1027,10 +1176,13 @@ public class QuorumStateTest { public void testFollowerToUnattachedHigherEpoch(short kraftVersion) { int node1 = 1; int node2 = 2; - Set voters = Utils.mkSet(localId, node1, node2); + VoterSet voters = localWithRemoteVoterSet(IntStream.of(node1, node2), kraftVersion); QuorumState state = initializeEmptyState(voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); - state.transitionToFollower(8, node2); + state.transitionToFollower( + 8, + voters.voterNode(node2, VoterSetTest.DEFAULT_LISTENER_NAME).get() + ); int jitterMs = 2500; random.mockNextInt(electionTimeoutMs, jitterMs); @@ -1047,10 +1199,13 @@ public class QuorumStateTest { public void testFollowerToVotedSameEpoch(short kraftVersion) { int node1 = 1; int node2 = 2; - Set voters = Utils.mkSet(localId, node1, node2); + VoterSet voters = localWithRemoteVoterSet(IntStream.of(node1, node2), kraftVersion); QuorumState state = initializeEmptyState(voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); - state.transitionToFollower(8, node2); + state.transitionToFollower( + 8, + voters.voterNode(node2, VoterSetTest.DEFAULT_LISTENER_NAME).get() + ); assertThrows( IllegalStateException.class, @@ -1069,24 +1224,26 @@ public class QuorumStateTest { @ParameterizedTest @ValueSource(shorts = {0, 1}) public void testFollowerToVotedHigherEpoch(short kraftVersion) { - int node1 = 1; - Optional node1DirectoryId = Optional.of(Uuid.randomUuid()); - ReplicaKey node1Key = ReplicaKey.of(node1, node1DirectoryId); - int node2 = 2; - Set voters = Utils.mkSet(localId, node1, node2); + ReplicaKey nodeKey1 = ReplicaKey.of(1, Optional.of(Uuid.randomUuid())); + ReplicaKey nodeKey2 = ReplicaKey.of(2, Optional.of(Uuid.randomUuid())); + + VoterSet voters = VoterSetTest.voterSet(Stream.of(localVoterKey, nodeKey1, nodeKey2)); QuorumState state = initializeEmptyState(voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); - state.transitionToFollower(8, node2); + state.transitionToFollower( + 8, + voters.voterNode(nodeKey2.id(), VoterSetTest.DEFAULT_LISTENER_NAME).get() + ); int jitterMs = 2500; random.mockNextInt(electionTimeoutMs, jitterMs); - state.transitionToVoted(9, node1Key); + state.transitionToVoted(9, nodeKey1); assertTrue(state.isVoted()); VotedState votedState = state.votedStateOrThrow(); assertEquals(9, votedState.epoch()); - assertEquals(node1Key, votedState.votedKey()); + assertEquals(nodeKey1, votedState.votedKey()); assertEquals(electionTimeoutMs + jitterMs, votedState.remainingElectionTimeMs(time.milliseconds())); @@ -1096,19 +1253,34 @@ public class QuorumStateTest { @ValueSource(shorts = {0, 1}) public void testFollowerToAnyStateLowerEpoch(short kraftVersion) { int otherNodeId = 1; - Set voters = Utils.mkSet(localId, otherNodeId); + VoterSet voters = localWithRemoteVoterSet(IntStream.of(otherNodeId), kraftVersion); QuorumState state = initializeEmptyState(voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); - state.transitionToFollower(5, otherNodeId); + state.transitionToFollower( + 5, + voters.voterNode(otherNodeId, VoterSetTest.DEFAULT_LISTENER_NAME).get() + ); assertThrows(IllegalStateException.class, () -> state.transitionToUnattached(4)); assertThrows( IllegalStateException.class, () -> state.transitionToVoted(4, ReplicaKey.of(otherNodeId, Optional.empty())) ); - assertThrows(IllegalStateException.class, () -> state.transitionToFollower(4, otherNodeId)); + assertThrows( + IllegalStateException.class, + () -> state.transitionToFollower( + 4, + voters.voterNode(otherNodeId, VoterSetTest.DEFAULT_LISTENER_NAME).get() + ) + ); assertEquals(5, state.epoch()); assertEquals( - Optional.of(ElectionState.withElectedLeader(5, otherNodeId, persistedVoters(voters, kraftVersion))), + Optional.of( + ElectionState.withElectedLeader( + 5, + otherNodeId, + persistedVoters(voters.voterIds(), kraftVersion) + ) + ), store.readElectionState() ); } @@ -1117,10 +1289,8 @@ public class QuorumStateTest { @ValueSource(shorts = {0, 1}) public void testCanBecomeFollowerOfNonVoter(short kraftVersion) { int otherNodeId = 1; - int nonVoterId = 2; - Optional nonVoterDirectoryId = Optional.of(Uuid.randomUuid()); - ReplicaKey nonVoterKey = ReplicaKey.of(nonVoterId, nonVoterDirectoryId); - Set voters = Utils.mkSet(localId, otherNodeId); + ReplicaKey nonVoterKey = ReplicaKey.of(2, Optional.of(Uuid.randomUuid())); + VoterSet voters = localWithRemoteVoterSet(IntStream.of(otherNodeId), kraftVersion); QuorumState state = initializeEmptyState(voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); @@ -1133,15 +1303,22 @@ public class QuorumStateTest { assertEquals(nonVoterKey, votedState.votedKey()); // Transition to follower - state.transitionToFollower(4, nonVoterId); - assertEquals(new LeaderAndEpoch(OptionalInt.of(nonVoterId), 4), state.leaderAndEpoch()); + Node nonVoterNode = new Node(nonVoterKey.id(), "non-voter-host", 1234); + state.transitionToFollower(4, nonVoterNode); + assertEquals( + new LeaderAndEpoch(OptionalInt.of(nonVoterKey.id()), 4), + state.leaderAndEpoch() + ); } @ParameterizedTest @ValueSource(shorts = {0, 1}) public void testObserverCannotBecomeCandidateOrLeader(short kraftVersion) { + boolean withDirectoryId = kraftVersion > 0; int otherNodeId = 1; - Set voters = Utils.mkSet(otherNodeId); + VoterSet voters = VoterSetTest.voterSet( + VoterSetTest.voterMap(IntStream.of(otherNodeId), withDirectoryId) + ); QuorumState state = initializeEmptyState(voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); assertTrue(state.isObserver()); @@ -1152,10 +1329,8 @@ public class QuorumStateTest { @ParameterizedTest @ValueSource(shorts = {0, 1}) public void testObserverWithIdCanVote(short kraftVersion) { - int otherNodeId = 1; - Optional otherNodeDirectoryId = Optional.of(Uuid.randomUuid()); - ReplicaKey otherNodeKey = ReplicaKey.of(otherNodeId, otherNodeDirectoryId); - Set voters = Utils.mkSet(otherNodeId); + ReplicaKey otherNodeKey = ReplicaKey.of(1, Optional.of(Uuid.randomUuid())); + VoterSet voters = VoterSetTest.voterSet(Stream.of(otherNodeKey)); QuorumState state = initializeEmptyState(voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); @@ -1172,14 +1347,20 @@ public class QuorumStateTest { @ParameterizedTest @ValueSource(shorts = {0, 1}) public void testObserverFollowerToUnattached(short kraftVersion) { + boolean withDirectoryId = kraftVersion > 0; int node1 = 1; int node2 = 2; - Set voters = Utils.mkSet(node1, node2); + VoterSet voters = VoterSetTest.voterSet( + VoterSetTest.voterMap(IntStream.of(node1, node2), withDirectoryId) + ); QuorumState state = initializeEmptyState(voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); assertTrue(state.isObserver()); - state.transitionToFollower(2, node1); + state.transitionToFollower( + 2, + voters.voterNode(node1, VoterSetTest.DEFAULT_LISTENER_NAME).get() + ); state.transitionToUnattached(3); assertTrue(state.isUnattached()); UnattachedState unattachedState = state.unattachedStateOrThrow(); @@ -1192,19 +1373,25 @@ public class QuorumStateTest { @ParameterizedTest @ValueSource(shorts = {0, 1}) public void testObserverUnattachedToFollower(short kraftVersion) { + boolean withDirectoryId = kraftVersion > 0; int node1 = 1; int node2 = 2; - Set voters = Utils.mkSet(node1, node2); + VoterSet voters = VoterSetTest.voterSet( + VoterSetTest.voterMap(IntStream.of(node1, node2), withDirectoryId) + ); QuorumState state = initializeEmptyState(voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); assertTrue(state.isObserver()); state.transitionToUnattached(2); - state.transitionToFollower(3, node1); + state.transitionToFollower(3, voters.voterNode(node1, VoterSetTest.DEFAULT_LISTENER_NAME).get()); assertTrue(state.isFollower()); FollowerState followerState = state.followerStateOrThrow(); assertEquals(3, followerState.epoch()); - assertEquals(node1, followerState.leaderId()); + assertEquals( + voters.voterNode(node1, VoterSetTest.DEFAULT_LISTENER_NAME), + Optional.of(followerState.leader()) + ); assertEquals(fetchTimeoutMs, followerState.remainingFetchTimeMs(time.milliseconds())); } @@ -1214,7 +1401,11 @@ public class QuorumStateTest { QuorumStateStore stateStore = Mockito.mock(QuorumStateStore.class); Mockito.doThrow(UncheckedIOException.class).when(stateStore).readElectionState(); - QuorumState state = buildQuorumState(Utils.mkSet(localId), kraftVersion); + QuorumState state = buildQuorumState( + OptionalInt.of(localId), + localStandaloneVoterSet(), + kraftVersion + ); int epoch = 2; state.initialize(new OffsetAndEpoch(0L, epoch)); @@ -1226,10 +1417,8 @@ public class QuorumStateTest { @ParameterizedTest @ValueSource(shorts = {0, 1}) public void testHasRemoteLeader(short kraftVersion) { - int otherNodeId = 1; - Optional otherNodeDirectoryId = Optional.of(Uuid.randomUuid()); - ReplicaKey otherNodeKey = ReplicaKey.of(otherNodeId, otherNodeDirectoryId); - Set voters = Utils.mkSet(localId, otherNodeId); + ReplicaKey otherNodeKey = ReplicaKey.of(1, Optional.of(Uuid.randomUuid())); + VoterSet voters = VoterSetTest.voterSet(Stream.of(localVoterKey, otherNodeKey)); QuorumState state = initializeEmptyState(voters, kraftVersion); assertFalse(state.hasRemoteLeader()); @@ -1237,7 +1426,7 @@ public class QuorumStateTest { state.transitionToCandidate(); assertFalse(state.hasRemoteLeader()); - state.candidateStateOrThrow().recordGrantedVote(otherNodeId); + state.candidateStateOrThrow().recordGrantedVote(otherNodeKey.id()); state.transitionToLeader(0L, accumulator); assertFalse(state.hasRemoteLeader()); @@ -1247,20 +1436,24 @@ public class QuorumStateTest { state.transitionToVoted(state.epoch() + 1, otherNodeKey); assertFalse(state.hasRemoteLeader()); - state.transitionToFollower(state.epoch() + 1, otherNodeId); + state.transitionToFollower( + state.epoch() + 1, + voters.voterNode(otherNodeKey.id(), VoterSetTest.DEFAULT_LISTENER_NAME).get() + ); assertTrue(state.hasRemoteLeader()); } @ParameterizedTest @ValueSource(shorts = {0, 1}) public void testHighWatermarkRetained(short kraftVersion) { - int otherNodeId = 1; - Optional otherNodeDirectoryId = Optional.of(Uuid.randomUuid()); - ReplicaKey otherNodeKey = ReplicaKey.of(otherNodeId, otherNodeDirectoryId); - Set voters = Utils.mkSet(localId, otherNodeId); + ReplicaKey otherNodeKey = ReplicaKey.of(1, Optional.of(Uuid.randomUuid())); + VoterSet voters = VoterSetTest.voterSet(Stream.of(localVoterKey, otherNodeKey)); QuorumState state = initializeEmptyState(voters, kraftVersion); - state.transitionToFollower(5, otherNodeId); + state.transitionToFollower( + 5, + voters.voterNode(otherNodeKey.id(), VoterSetTest.DEFAULT_LISTENER_NAME).get() + ); FollowerState followerState = state.followerStateOrThrow(); followerState.updateHighWatermark(OptionalLong.of(10L)); @@ -1278,7 +1471,7 @@ public class QuorumStateTest { assertEquals(highWatermark, state.highWatermark()); CandidateState candidateState = state.candidateStateOrThrow(); - candidateState.recordGrantedVote(otherNodeId); + candidateState.recordGrantedVote(otherNodeKey.id()); assertTrue(candidateState.isVoteGranted()); state.transitionToLeader(10L, accumulator); @@ -1288,7 +1481,11 @@ public class QuorumStateTest { @ParameterizedTest @ValueSource(shorts = {0, 1}) public void testInitializeWithEmptyLocalId(short kraftVersion) { - QuorumState state = buildQuorumState(OptionalInt.empty(), Utils.mkSet(0, 1), kraftVersion); + boolean withDirectoryId = kraftVersion > 0; + VoterSet voters = VoterSetTest.voterSet( + VoterSetTest.voterMap(IntStream.of(0, 1), withDirectoryId) + ); + QuorumState state = buildQuorumState(OptionalInt.empty(), voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, 0)); assertTrue(state.isObserver()); @@ -1301,7 +1498,7 @@ public class QuorumStateTest { ); assertThrows(IllegalStateException.class, () -> state.transitionToLeader(0L, accumulator)); - state.transitionToFollower(1, 1); + state.transitionToFollower(1, voters.voterNode(1, VoterSetTest.DEFAULT_LISTENER_NAME).get()); assertTrue(state.isFollower()); state.transitionToUnattached(2); @@ -1311,15 +1508,18 @@ public class QuorumStateTest { @ParameterizedTest @ValueSource(shorts = {0, 1}) public void testNoLocalIdInitializationFailsIfElectionStateHasVotedCandidate(short kraftVersion) { + boolean withDirectoryId = kraftVersion > 0; int epoch = 5; int votedId = 1; - Set voters = Utils.mkSet(0, votedId); + VoterSet voters = VoterSetTest.voterSet( + VoterSetTest.voterMap(IntStream.of(0, votedId), withDirectoryId) + ); store.writeElectionState( ElectionState.withVotedCandidate( epoch, ReplicaKey.of(votedId, Optional.empty()), - voters + voters.voterIds() ), kraftVersion ); @@ -1327,27 +1527,4 @@ public class QuorumStateTest { QuorumState state2 = buildQuorumState(OptionalInt.empty(), voters, kraftVersion); assertThrows(IllegalStateException.class, () -> state2.initialize(new OffsetAndEpoch(0, 0))); } - - private QuorumState initializeEmptyState(Set voters, short kraftVersion) { - QuorumState state = buildQuorumState(voters, kraftVersion); - store.writeElectionState(ElectionState.withUnknownLeader(0, voters), kraftVersion); - state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); - return state; - } - - private Set persistedVoters(Set voters, short kraftVersion) { - if (kraftVersion == 1) { - return Collections.emptySet(); - } - - return voters; - } - - private Optional persistedDirectoryId(Optional directoryId, short kraftVersion) { - if (kraftVersion == 1) { - return directoryId; - } - - return Optional.empty(); - } } diff --git a/raft/src/test/java/org/apache/kafka/raft/RaftClientTestContext.java b/raft/src/test/java/org/apache/kafka/raft/RaftClientTestContext.java index 03ab95ffce6..8d6b9c1cad9 100644 --- a/raft/src/test/java/org/apache/kafka/raft/RaftClientTestContext.java +++ b/raft/src/test/java/org/apache/kafka/raft/RaftClientTestContext.java @@ -16,6 +16,7 @@ */ package org.apache.kafka.raft; +import org.apache.kafka.common.Node; import org.apache.kafka.common.TopicPartition; import org.apache.kafka.common.Uuid; import org.apache.kafka.common.compress.Compression; @@ -79,6 +80,7 @@ import java.util.Set; import java.util.function.Consumer; import java.util.function.Function; import java.util.stream.Collectors; +import java.util.stream.IntStream; import static org.apache.kafka.raft.LeaderState.CHECK_QUORUM_TIMEOUT_FACTOR; import static org.apache.kafka.raft.RaftUtil.hasValidTopicPartition; @@ -114,6 +116,7 @@ public final class RaftClientTestContext { final MockTime time; final MockListener listener; final Set voters; + final Set bootstrapIds; private final List sentResponses = new ArrayList<>(); @@ -146,6 +149,7 @@ public final class RaftClientTestContext { private int electionTimeoutMs = DEFAULT_ELECTION_TIMEOUT_MS; private int appendLingerMs = DEFAULT_APPEND_LINGER_MS; private MemoryPool memoryPool = MemoryPool.NONE; + private List bootstrapServers = Collections.emptyList(); public Builder(int localId, Set voters) { this(OptionalInt.of(localId), voters); @@ -240,9 +244,14 @@ public final class RaftClientTestContext { return this; } + Builder withBootstrapServers(List bootstrapServers) { + this.bootstrapServers = bootstrapServers; + return this; + } + public RaftClientTestContext build() throws IOException { Metrics metrics = new Metrics(time); - MockNetworkChannel channel = new MockNetworkChannel(voters); + MockNetworkChannel channel = new MockNetworkChannel(); MockListener listener = new MockListener(localId); Map voterAddressMap = voters .stream() @@ -269,6 +278,7 @@ public final class RaftClientTestContext { new MockExpirationService(time), FETCH_MAX_WAIT_MS, clusterId.toString(), + bootstrapServers, logContext, random, quorumConfig @@ -277,7 +287,6 @@ public final class RaftClientTestContext { client.register(listener); client.initialize( voterAddressMap, - "CONTROLLER", quorumStateStore, metrics ); @@ -292,6 +301,11 @@ public final class RaftClientTestContext { time, quorumStateStore, voters, + IntStream + .iterate(-2, id -> id - 1) + .limit(bootstrapServers.size()) + .boxed() + .collect(Collectors.toSet()), metrics, listener ); @@ -314,6 +328,7 @@ public final class RaftClientTestContext { MockTime time, QuorumStateStore quorumStateStore, Set voters, + Set bootstrapIds, Metrics metrics, MockListener listener ) { @@ -326,6 +341,7 @@ public final class RaftClientTestContext { this.time = time; this.quorumStateStore = quorumStateStore; this.voters = voters; + this.bootstrapIds = bootstrapIds; this.metrics = metrics; this.listener = listener; } @@ -417,7 +433,7 @@ public final class RaftClientTestContext { for (RaftRequest.Outbound request : voteRequests) { VoteResponseData voteResponse = voteResponse(true, Optional.empty(), epoch); - deliverResponse(request.correlationId, request.destinationId(), voteResponse); + deliverResponse(request.correlationId(), request.destination(), voteResponse); } client.poll(); @@ -432,7 +448,7 @@ public final class RaftClientTestContext { pollUntilRequest(); for (RaftRequest.Outbound request : collectBeginEpochRequests(epoch)) { BeginQuorumEpochResponseData beginEpochResponse = beginEpochResponse(epoch, localIdOrThrow()); - deliverResponse(request.correlationId, request.destinationId(), beginEpochResponse); + deliverResponse(request.correlationId(), request.destination(), beginEpochResponse); } client.poll(); } @@ -519,10 +535,10 @@ public final class RaftClientTestContext { assertEquals(expectedResponse, response); } - int assertSentVoteRequest(int epoch, int lastEpoch, long lastEpochOffset, int numVoteReceivers) { + RaftRequest.Outbound assertSentVoteRequest(int epoch, int lastEpoch, long lastEpochOffset, int numVoteReceivers) { List voteRequests = collectVoteRequests(epoch, lastEpoch, lastEpochOffset); assertEquals(numVoteReceivers, voteRequests.size()); - return voteRequests.iterator().next().correlationId(); + return voteRequests.iterator().next(); } void assertSentVoteResponse(Errors error) { @@ -590,14 +606,14 @@ public final class RaftClientTestContext { client.handle(inboundRequest); } - void deliverResponse(int correlationId, int sourceId, ApiMessage response) { - channel.mockReceive(new RaftResponse.Inbound(correlationId, response, sourceId)); + void deliverResponse(int correlationId, Node source, ApiMessage response) { + channel.mockReceive(new RaftResponse.Inbound(correlationId, response, source)); } - int assertSentBeginQuorumEpochRequest(int epoch, int numBeginEpochRequests) { + RaftRequest.Outbound assertSentBeginQuorumEpochRequest(int epoch, int numBeginEpochRequests) { List requests = collectBeginEpochRequests(epoch); assertEquals(numBeginEpochRequests, requests.size()); - return requests.get(0).correlationId; + return requests.get(0); } private List drainSentResponses( @@ -607,7 +623,7 @@ public final class RaftClientTestContext { Iterator iterator = sentResponses.iterator(); while (iterator.hasNext()) { RaftResponse.Outbound response = iterator.next(); - if (response.data.apiKey() == apiKey.id) { + if (response.data().apiKey() == apiKey.id) { res.add(response); iterator.remove(); } @@ -646,11 +662,14 @@ public final class RaftClientTestContext { assertEquals(partitionError, Errors.forCode(partitionResponse.errorCode())); } - int assertSentEndQuorumEpochRequest(int epoch, int destinationId) { + RaftRequest.Outbound assertSentEndQuorumEpochRequest(int epoch, int destinationId) { List endQuorumRequests = collectEndQuorumRequests( - epoch, Collections.singleton(destinationId), Optional.empty()); + epoch, + Collections.singleton(destinationId), + Optional.empty() + ); assertEquals(1, endQuorumRequests.size()); - return endQuorumRequests.get(0).correlationId(); + return endQuorumRequests.get(0); } void assertSentEndQuorumEpochResponse( @@ -690,7 +709,7 @@ public final class RaftClientTestContext { return sentRequests.get(0); } - int assertSentFetchRequest( + RaftRequest.Outbound assertSentFetchRequest( int epoch, long fetchOffset, int lastFetchedEpoch @@ -700,7 +719,7 @@ public final class RaftClientTestContext { RaftRequest.Outbound raftRequest = sentMessages.get(0); assertFetchRequestData(raftRequest, epoch, fetchOffset, lastFetchedEpoch); - return raftRequest.correlationId(); + return raftRequest; } FetchResponseData.PartitionData assertSentFetchPartitionResponse() { @@ -708,7 +727,7 @@ public final class RaftClientTestContext { assertEquals( 1, sentMessages.size(), "Found unexpected sent messages " + sentMessages); RaftResponse.Outbound raftMessage = sentMessages.get(0); - assertEquals(ApiKeys.FETCH.id, raftMessage.data.apiKey()); + assertEquals(ApiKeys.FETCH.id, raftMessage.data().apiKey()); FetchResponseData response = (FetchResponseData) raftMessage.data(); assertEquals(Errors.NONE, Errors.forCode(response.errorCode())); @@ -723,7 +742,7 @@ public final class RaftClientTestContext { assertEquals( 1, sentMessages.size(), "Found unexpected sent messages " + sentMessages); RaftResponse.Outbound raftMessage = sentMessages.get(0); - assertEquals(ApiKeys.FETCH.id, raftMessage.data.apiKey()); + assertEquals(ApiKeys.FETCH.id, raftMessage.data().apiKey()); FetchResponseData response = (FetchResponseData) raftMessage.data(); assertEquals(topLevelError, Errors.forCode(response.errorCode())); } @@ -811,7 +830,7 @@ public final class RaftClientTestContext { assertEquals(preferredSuccessors, partitionRequest.preferredSuccessors()); }); - collectedDestinationIdSet.add(raftMessage.destinationId()); + collectedDestinationIdSet.add(raftMessage.destination().id()); endQuorumRequests.add(raftMessage); } } @@ -825,11 +844,18 @@ public final class RaftClientTestContext { ) throws Exception { pollUntilRequest(); RaftRequest.Outbound fetchRequest = assertSentFetchRequest(); - assertTrue(voters.contains(fetchRequest.destinationId())); + int destinationId = fetchRequest.destination().id(); + assertTrue( + voters.contains(destinationId) || bootstrapIds.contains(destinationId), + String.format("id %d is not in sets %s or %s", destinationId, voters, bootstrapIds) + ); assertFetchRequestData(fetchRequest, 0, 0L, 0); - deliverResponse(fetchRequest.correlationId, fetchRequest.destinationId(), - fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.NONE)); + deliverResponse( + fetchRequest.correlationId(), + fetchRequest.destination(), + fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.NONE) + ); client.poll(); assertElectedLeader(epoch, leaderId); } @@ -850,7 +876,7 @@ public final class RaftClientTestContext { return requests; } - private static InetSocketAddress mockAddress(int id) { + public static InetSocketAddress mockAddress(int id) { return new InetSocketAddress("localhost", 9990 + id); } diff --git a/raft/src/test/java/org/apache/kafka/raft/RaftEventSimulationTest.java b/raft/src/test/java/org/apache/kafka/raft/RaftEventSimulationTest.java index f52ee371f48..4896571c22e 100644 --- a/raft/src/test/java/org/apache/kafka/raft/RaftEventSimulationTest.java +++ b/raft/src/test/java/org/apache/kafka/raft/RaftEventSimulationTest.java @@ -21,6 +21,7 @@ import net.jqwik.api.ForAll; import net.jqwik.api.Property; import net.jqwik.api.Tag; import net.jqwik.api.constraints.IntRange; +import org.apache.kafka.common.Node; import org.apache.kafka.common.TopicPartition; import org.apache.kafka.common.Uuid; import org.apache.kafka.common.memory.MemoryPool; @@ -45,6 +46,7 @@ import java.net.InetSocketAddress; import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; @@ -59,7 +61,6 @@ import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Consumer; -import java.util.function.Function; import java.util.function.Supplier; import java.util.stream.Collectors; @@ -189,7 +190,7 @@ public class RaftEventSimulationTest { // they are able to elect a leader and continue making progress cluster.killAll(); - Iterator nodeIdsIterator = cluster.nodes().iterator(); + Iterator nodeIdsIterator = cluster.nodeIds().iterator(); for (int i = 0; i < cluster.majoritySize(); i++) { Integer nodeId = nodeIdsIterator.next(); cluster.start(nodeId); @@ -224,7 +225,7 @@ public class RaftEventSimulationTest { ); router.filter(leaderId, new DropAllTraffic()); - Set nonPartitionedNodes = new HashSet<>(cluster.nodes()); + Set nonPartitionedNodes = new HashSet<>(cluster.nodeIds()); nonPartitionedNodes.remove(leaderId); scheduler.runUntil(() -> cluster.allReachedHighWatermark(20, nonPartitionedNodes)); @@ -252,11 +253,17 @@ public class RaftEventSimulationTest { // Partition the nodes into two sets. Nodes are reachable within each set, // but the two sets cannot communicate with each other. We should be able // to make progress even if an election is needed in the larger set. - router.filter(0, new DropOutboundRequestsFrom(Utils.mkSet(2, 3, 4))); - router.filter(1, new DropOutboundRequestsFrom(Utils.mkSet(2, 3, 4))); - router.filter(2, new DropOutboundRequestsFrom(Utils.mkSet(0, 1))); - router.filter(3, new DropOutboundRequestsFrom(Utils.mkSet(0, 1))); - router.filter(4, new DropOutboundRequestsFrom(Utils.mkSet(0, 1))); + router.filter( + 0, + new DropOutboundRequestsTo(cluster.endpointsFromIds(Utils.mkSet(2, 3, 4))) + ); + router.filter( + 1, + new DropOutboundRequestsTo(cluster.endpointsFromIds(Utils.mkSet(2, 3, 4))) + ); + router.filter(2, new DropOutboundRequestsTo(cluster.endpointsFromIds(Utils.mkSet(0, 1)))); + router.filter(3, new DropOutboundRequestsTo(cluster.endpointsFromIds(Utils.mkSet(0, 1)))); + router.filter(4, new DropOutboundRequestsTo(cluster.endpointsFromIds(Utils.mkSet(0, 1)))); long partitionLogEndOffset = cluster.maxLogEndOffset(); scheduler.runUntil(() -> cluster.anyReachedHighWatermark(2 * partitionLogEndOffset)); @@ -374,7 +381,7 @@ public class RaftEventSimulationTest { int pollIntervalMs, int pollJitterMs) { int delayMs = 0; - for (int nodeId : cluster.nodes()) { + for (int nodeId : cluster.nodeIds()) { scheduler.schedule(() -> cluster.pollIfRunning(nodeId), delayMs, pollIntervalMs, pollJitterMs); delayMs++; } @@ -527,25 +534,37 @@ public class RaftEventSimulationTest { final AtomicInteger correlationIdCounter = new AtomicInteger(); final MockTime time = new MockTime(); final Uuid clusterId = Uuid.randomUuid(); - final Set voters = new HashSet<>(); + final Map voters = new HashMap<>(); final Map nodes = new HashMap<>(); final Map running = new HashMap<>(); private Cluster(int numVoters, int numObservers, Random random) { this.random = random; - int nodeId = 0; - for (; nodeId < numVoters; nodeId++) { - voters.add(nodeId); + for (int nodeId = 0; nodeId < numVoters; nodeId++) { + voters.put( + nodeId, + new Node(nodeId, String.format("host-node-%d", nodeId), 1234) + ); nodes.put(nodeId, new PersistentState(nodeId)); } - for (; nodeId < numVoters + numObservers; nodeId++) { + for (int nodeIdDelta = 0; nodeIdDelta < numObservers; nodeIdDelta++) { + int nodeId = numVoters + nodeIdDelta; nodes.put(nodeId, new PersistentState(nodeId)); } } - Set nodes() { + Set endpointsFromIds(Set nodeIds) { + return voters + .values() + .stream() + .filter(node -> nodeIds.contains(node.id())) + .map(Cluster::nodeAddress) + .collect(Collectors.toSet()); + } + + Set nodeIds() { return nodes.keySet(); } @@ -710,18 +729,19 @@ public class RaftEventSimulationTest { nodes.put(nodeId, new PersistentState(nodeId)); } - private static InetSocketAddress nodeAddress(int id) { - return new InetSocketAddress("localhost", 9990 + id); + private static InetSocketAddress nodeAddress(Node node) { + return InetSocketAddress.createUnresolved(node.host(), node.port()); } void start(int nodeId) { LogContext logContext = new LogContext("[Node " + nodeId + "] "); PersistentState persistentState = nodes.get(nodeId); - MockNetworkChannel channel = new MockNetworkChannel(correlationIdCounter, voters); + MockNetworkChannel channel = new MockNetworkChannel(correlationIdCounter); MockMessageQueue messageQueue = new MockMessageQueue(); Map voterAddressMap = voters + .values() .stream() - .collect(Collectors.toMap(Function.identity(), Cluster::nodeAddress)); + .collect(Collectors.toMap(Node::id, Cluster::nodeAddress)); QuorumConfig quorumConfig = new QuorumConfig( REQUEST_TIMEOUT_MS, @@ -750,6 +770,7 @@ public class RaftEventSimulationTest { new MockExpirationService(time), FETCH_MAX_WAIT_MS, clusterId.toString(), + Collections.emptyList(), logContext, random, quorumConfig @@ -808,7 +829,6 @@ public class RaftEventSimulationTest { client.register(counter); client.initialize( voterAddresses, - "CONTROLLER", store, metrics ); @@ -847,9 +867,11 @@ public class RaftEventSimulationTest { private static class InflightRequest { final int sourceId; + final Node destination; - private InflightRequest(int sourceId) { + private InflightRequest(int sourceId, Node destination) { this.sourceId = sourceId; + this.destination = destination; } } @@ -884,11 +906,15 @@ public class RaftEventSimulationTest { } } - private static class DropOutboundRequestsFrom implements NetworkFilter { + private static class DropOutboundRequestsTo implements NetworkFilter { + private final Set unreachable; - private final Set unreachable; - - private DropOutboundRequestsFrom(Set unreachable) { + /** + * This network filter drops any outbound message sent to the {@code unreachable} nodes. + * + * @param unreachable the set of destination address which are not reachable + */ + private DropOutboundRequestsTo(Set unreachable) { this.unreachable = unreachable; } @@ -897,11 +923,25 @@ public class RaftEventSimulationTest { return true; } + /** + * Returns if the message should be sent to the destination. + * + * Returns false when outbound request messages contains a destination {@code Node} that + * matches the set of unreaable {@code InetSocketAddress}. Note that the {@code Node.id()} + * and {@code Node.rack()} are not compared. + * + * @param message the raft message + * @return true if the message should be delivered, otherwise false + */ @Override public boolean acceptOutbound(RaftMessage message) { if (message instanceof RaftRequest.Outbound) { RaftRequest.Outbound request = (RaftRequest.Outbound) message; - return !unreachable.contains(request.destinationId()); + InetSocketAddress destination = InetSocketAddress.createUnresolved( + request.destination().host(), + request.destination().port() + ); + return !unreachable.contains(destination); } return true; } @@ -955,7 +995,7 @@ public class RaftEventSimulationTest { public void verify() { cluster.leaderHighWatermark().ifPresent(highWatermark -> { long numReachedHighWatermark = cluster.nodes.entrySet().stream() - .filter(entry -> cluster.voters.contains(entry.getKey())) + .filter(entry -> cluster.voters.containsKey(entry.getKey())) .filter(entry -> entry.getValue().log.endOffset().offset >= highWatermark) .count(); assertTrue( @@ -1194,19 +1234,19 @@ public class RaftEventSimulationTest { return; int correlationId = outbound.correlationId(); - int destinationId = outbound.destinationId(); + Node destination = outbound.destination(); RaftRequest.Inbound inbound = new RaftRequest.Inbound(correlationId, outbound.data(), cluster.time.milliseconds()); - if (!filters.get(destinationId).acceptInbound(inbound)) + if (!filters.get(destination.id()).acceptInbound(inbound)) return; - cluster.nodeIfRunning(destinationId).ifPresent(node -> { - inflight.put(correlationId, new InflightRequest(senderId)); + cluster.nodeIfRunning(destination.id()).ifPresent(node -> { + inflight.put(correlationId, new InflightRequest(senderId, destination)); inbound.completion.whenComplete((response, exception) -> { - if (response != null && filters.get(destinationId).acceptOutbound(response)) { - deliver(destinationId, response); + if (response != null && filters.get(destination.id()).acceptOutbound(response)) { + deliver(response); } }); @@ -1214,11 +1254,17 @@ public class RaftEventSimulationTest { }); } - void deliver(int senderId, RaftResponse.Outbound outbound) { + void deliver(RaftResponse.Outbound outbound) { int correlationId = outbound.correlationId(); - RaftResponse.Inbound inbound = new RaftResponse.Inbound(correlationId, outbound.data(), senderId); InflightRequest inflightRequest = inflight.remove(correlationId); + RaftResponse.Inbound inbound = new RaftResponse.Inbound( + correlationId, + outbound.data(), + // The source of the response is the destination of the request + inflightRequest.destination + ); + if (!filters.get(inflightRequest.sourceId).acceptInbound(inbound)) return; diff --git a/raft/src/test/java/org/apache/kafka/raft/RequestManagerTest.java b/raft/src/test/java/org/apache/kafka/raft/RequestManagerTest.java index e6e2f7cf0a6..326222e18c4 100644 --- a/raft/src/test/java/org/apache/kafka/raft/RequestManagerTest.java +++ b/raft/src/test/java/org/apache/kafka/raft/RequestManagerTest.java @@ -16,14 +16,20 @@ */ package org.apache.kafka.raft; +import org.apache.kafka.common.Node; import org.apache.kafka.common.utils.MockTime; -import org.apache.kafka.common.utils.Utils; import org.junit.jupiter.api.Test; +import java.util.List; +import java.util.Optional; import java.util.Random; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertNotEquals; public class RequestManagerTest { private final MockTime time = new MockTime(); @@ -33,105 +39,247 @@ public class RequestManagerTest { @Test public void testResetAllConnections() { + Node node1 = new Node(1, "mock-host-1", 4321); + Node node2 = new Node(2, "mock-host-2", 4321); + RequestManager cache = new RequestManager( - Utils.mkSet(1, 2, 3), + makeBootstrapList(3), retryBackoffMs, requestTimeoutMs, - random); + random + ); // One host has an inflight request - RequestManager.ConnectionState connectionState1 = cache.getOrCreate(1); - connectionState1.onRequestSent(1, time.milliseconds()); - assertFalse(connectionState1.isReady(time.milliseconds())); + cache.onRequestSent(node1, 1, time.milliseconds()); + assertFalse(cache.isReady(node1, time.milliseconds())); // Another is backing off - RequestManager.ConnectionState connectionState2 = cache.getOrCreate(2); - connectionState2.onRequestSent(2, time.milliseconds()); - connectionState2.onResponseError(2, time.milliseconds()); - assertFalse(connectionState2.isReady(time.milliseconds())); + cache.onRequestSent(node2, 2, time.milliseconds()); + cache.onResponseResult(node2, 2, false, time.milliseconds()); + assertFalse(cache.isReady(node2, time.milliseconds())); cache.resetAll(); // Now both should be ready - assertTrue(connectionState1.isReady(time.milliseconds())); - assertTrue(connectionState2.isReady(time.milliseconds())); + assertTrue(cache.isReady(node1, time.milliseconds())); + assertTrue(cache.isReady(node2, time.milliseconds())); } @Test public void testBackoffAfterFailure() { + Node node = new Node(1, "mock-host-1", 4321); + RequestManager cache = new RequestManager( - Utils.mkSet(1, 2, 3), + makeBootstrapList(3), retryBackoffMs, requestTimeoutMs, - random); + random + ); - RequestManager.ConnectionState connectionState = cache.getOrCreate(1); - assertTrue(connectionState.isReady(time.milliseconds())); + assertTrue(cache.isReady(node, time.milliseconds())); long correlationId = 1; - connectionState.onRequestSent(correlationId, time.milliseconds()); - assertFalse(connectionState.isReady(time.milliseconds())); + cache.onRequestSent(node, correlationId, time.milliseconds()); + assertFalse(cache.isReady(node, time.milliseconds())); - connectionState.onResponseError(correlationId, time.milliseconds()); - assertFalse(connectionState.isReady(time.milliseconds())); + cache.onResponseResult(node, correlationId, false, time.milliseconds()); + assertFalse(cache.isReady(node, time.milliseconds())); time.sleep(retryBackoffMs); - assertTrue(connectionState.isReady(time.milliseconds())); + assertTrue(cache.isReady(node, time.milliseconds())); } @Test public void testSuccessfulResponse() { + Node node = new Node(1, "mock-host-1", 4321); + RequestManager cache = new RequestManager( - Utils.mkSet(1, 2, 3), + makeBootstrapList(3), retryBackoffMs, requestTimeoutMs, - random); - - RequestManager.ConnectionState connectionState = cache.getOrCreate(1); + random + ); long correlationId = 1; - connectionState.onRequestSent(correlationId, time.milliseconds()); - assertFalse(connectionState.isReady(time.milliseconds())); - connectionState.onResponseReceived(correlationId); - assertTrue(connectionState.isReady(time.milliseconds())); + cache.onRequestSent(node, correlationId, time.milliseconds()); + assertFalse(cache.isReady(node, time.milliseconds())); + cache.onResponseResult(node, correlationId, true, time.milliseconds()); + assertTrue(cache.isReady(node, time.milliseconds())); } @Test public void testIgnoreUnexpectedResponse() { + Node node = new Node(1, "mock-host-1", 4321); + RequestManager cache = new RequestManager( - Utils.mkSet(1, 2, 3), + makeBootstrapList(3), retryBackoffMs, requestTimeoutMs, - random); - - RequestManager.ConnectionState connectionState = cache.getOrCreate(1); + random + ); long correlationId = 1; - connectionState.onRequestSent(correlationId, time.milliseconds()); - assertFalse(connectionState.isReady(time.milliseconds())); - connectionState.onResponseReceived(correlationId + 1); - assertFalse(connectionState.isReady(time.milliseconds())); + cache.onRequestSent(node, correlationId, time.milliseconds()); + assertFalse(cache.isReady(node, time.milliseconds())); + cache.onResponseResult(node, correlationId + 1, true, time.milliseconds()); + assertFalse(cache.isReady(node, time.milliseconds())); } @Test public void testRequestTimeout() { + Node node = new Node(1, "mock-host-1", 4321); + RequestManager cache = new RequestManager( - Utils.mkSet(1, 2, 3), + makeBootstrapList(3), retryBackoffMs, requestTimeoutMs, - random); - - RequestManager.ConnectionState connectionState = cache.getOrCreate(1); + random + ); long correlationId = 1; - connectionState.onRequestSent(correlationId, time.milliseconds()); - assertFalse(connectionState.isReady(time.milliseconds())); + cache.onRequestSent(node, correlationId, time.milliseconds()); + assertFalse(cache.isReady(node, time.milliseconds())); time.sleep(requestTimeoutMs - 1); - assertFalse(connectionState.isReady(time.milliseconds())); + assertFalse(cache.isReady(node, time.milliseconds())); time.sleep(1); - assertTrue(connectionState.isReady(time.milliseconds())); + assertTrue(cache.isReady(node, time.milliseconds())); } + @Test + public void testRequestToBootstrapList() { + List bootstrapList = makeBootstrapList(2); + RequestManager cache = new RequestManager( + bootstrapList, + retryBackoffMs, + requestTimeoutMs, + random + ); + + // Find a ready node with the starting state + Node bootstrapNode1 = cache.findReadyBootstrapServer(time.milliseconds()).get(); + assertTrue( + bootstrapList.contains(bootstrapNode1), + String.format("%s is not in %s", bootstrapNode1, bootstrapList) + ); + assertEquals(0, cache.backoffBeforeAvailableBootstrapServer(time.milliseconds())); + + // Send a request and check the cache state + cache.onRequestSent(bootstrapNode1, 1, time.milliseconds()); + assertEquals( + Optional.empty(), + cache.findReadyBootstrapServer(time.milliseconds()) + ); + assertEquals(requestTimeoutMs, cache.backoffBeforeAvailableBootstrapServer(time.milliseconds())); + + // Fail the request + time.sleep(100); + cache.onResponseResult(bootstrapNode1, 1, false, time.milliseconds()); + Node bootstrapNode2 = cache.findReadyBootstrapServer(time.milliseconds()).get(); + assertNotEquals(bootstrapNode1, bootstrapNode2); + assertEquals(0, cache.backoffBeforeAvailableBootstrapServer(time.milliseconds())); + + // Send a request to the second node and check the state + cache.onRequestSent(bootstrapNode2, 2, time.milliseconds()); + assertEquals( + Optional.empty(), + cache.findReadyBootstrapServer(time.milliseconds()) + ); + assertEquals(requestTimeoutMs, cache.backoffBeforeAvailableBootstrapServer(time.milliseconds())); + + + // Fail the second request before the request timeout + time.sleep(retryBackoffMs - 1); + cache.onResponseResult(bootstrapNode2, 2, false, time.milliseconds()); + assertEquals( + Optional.empty(), + cache.findReadyBootstrapServer(time.milliseconds()) + ); + assertEquals(1, cache.backoffBeforeAvailableBootstrapServer(time.milliseconds())); + + // Timeout the first backoff and show that that node is ready + time.sleep(1); + Node bootstrapNode3 = cache.findReadyBootstrapServer(time.milliseconds()).get(); + assertEquals(bootstrapNode1, bootstrapNode3); + assertEquals(0, cache.backoffBeforeAvailableBootstrapServer(time.milliseconds())); + } + + @Test + public void testFindReadyWithInflightRequest() { + Node otherNode = new Node(1, "other-node", 1234); + List bootstrapList = makeBootstrapList(3); + RequestManager cache = new RequestManager( + bootstrapList, + retryBackoffMs, + requestTimeoutMs, + random + ); + + // Send request to a node that is not in the bootstrap list + cache.onRequestSent(otherNode, 1, time.milliseconds()); + assertEquals(Optional.empty(), cache.findReadyBootstrapServer(time.milliseconds())); + } + + @Test + public void testFindReadyWithRequestTimedout() { + Node otherNode = new Node(1, "other-node", 1234); + List bootstrapList = makeBootstrapList(3); + RequestManager cache = new RequestManager( + bootstrapList, + retryBackoffMs, + requestTimeoutMs, + random + ); + + // Send request to a node that is not in the bootstrap list + cache.onRequestSent(otherNode, 1, time.milliseconds()); + assertTrue(cache.isResponseExpected(otherNode, 1)); + assertEquals(Optional.empty(), cache.findReadyBootstrapServer(time.milliseconds())); + + // Timeout the request + time.sleep(requestTimeoutMs); + Node bootstrapNode = cache.findReadyBootstrapServer(time.milliseconds()).get(); + assertTrue(bootstrapList.contains(bootstrapNode)); + assertFalse(cache.isResponseExpected(otherNode, 1)); + } + + @Test + public void testAnyInflightRequestWithAnyRequest() { + Node otherNode = new Node(1, "other-node", 1234); + List bootstrapList = makeBootstrapList(3); + RequestManager cache = new RequestManager( + bootstrapList, + retryBackoffMs, + requestTimeoutMs, + random + ); + + assertFalse(cache.hasAnyInflightRequest(time.milliseconds())); + + // Send a request and check state + cache.onRequestSent(otherNode, 11, time.milliseconds()); + assertTrue(cache.hasAnyInflightRequest(time.milliseconds())); + + // Wait until the request times out + time.sleep(requestTimeoutMs); + assertFalse(cache.hasAnyInflightRequest(time.milliseconds())); + + // Send another request and fail it + cache.onRequestSent(otherNode, 12, time.milliseconds()); + cache.onResponseResult(otherNode, 12, false, time.milliseconds()); + assertFalse(cache.hasAnyInflightRequest(time.milliseconds())); + + // Send another request and mark it successful + cache.onRequestSent(otherNode, 12, time.milliseconds()); + cache.onResponseResult(otherNode, 12, true, time.milliseconds()); + assertFalse(cache.hasAnyInflightRequest(time.milliseconds())); + } + + private List makeBootstrapList(int numberOfNodes) { + return IntStream.iterate(-2, id -> id - 1) + .limit(numberOfNodes) + .mapToObj(id -> new Node(id, String.format("mock-boot-host%d", id), 1234)) + .collect(Collectors.toList()); + } } diff --git a/raft/src/test/java/org/apache/kafka/raft/internals/KRaftControlRecordStateMachineTest.java b/raft/src/test/java/org/apache/kafka/raft/internals/KRaftControlRecordStateMachineTest.java index 80f7df026fa..82d51625eac 100644 --- a/raft/src/test/java/org/apache/kafka/raft/internals/KRaftControlRecordStateMachineTest.java +++ b/raft/src/test/java/org/apache/kafka/raft/internals/KRaftControlRecordStateMachineTest.java @@ -16,8 +16,8 @@ */ package org.apache.kafka.raft.internals; -import java.util.Arrays; import java.util.Optional; +import java.util.stream.IntStream; import org.apache.kafka.common.TopicPartition; import org.apache.kafka.common.Uuid; import org.apache.kafka.common.message.KRaftVersionRecord; @@ -52,7 +52,7 @@ final class KRaftControlRecordStateMachineTest { @Test void testEmptyPartition() { MockLog log = buildLog(); - VoterSet voterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true)); + VoterSet voterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(1, 2, 3), true)); KRaftControlRecordStateMachine partitionState = buildPartitionListener(log, Optional.of(voterSet)); @@ -65,7 +65,7 @@ final class KRaftControlRecordStateMachineTest { @Test void testUpdateWithoutSnapshot() { MockLog log = buildLog(); - VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true)); + VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(1, 2, 3), true)); BufferSupplier bufferSupplier = BufferSupplier.NO_CACHING; int epoch = 1; @@ -85,7 +85,7 @@ final class KRaftControlRecordStateMachineTest { ); // Append the voter set control record - VoterSet voterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(4, 5, 6), true)); + VoterSet voterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(4, 5, 6), true)); log.appendAsLeader( MemoryRecords.withVotersRecord( log.endOffset().offset, @@ -108,7 +108,7 @@ final class KRaftControlRecordStateMachineTest { @Test void testUpdateWithEmptySnapshot() { MockLog log = buildLog(); - VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true)); + VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(1, 2, 3), true)); BufferSupplier bufferSupplier = BufferSupplier.NO_CACHING; int epoch = 1; @@ -136,7 +136,7 @@ final class KRaftControlRecordStateMachineTest { ); // Append the voter set control record - VoterSet voterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(4, 5, 6), true)); + VoterSet voterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(4, 5, 6), true)); log.appendAsLeader( MemoryRecords.withVotersRecord( log.endOffset().offset, @@ -159,14 +159,14 @@ final class KRaftControlRecordStateMachineTest { @Test void testUpdateWithSnapshot() { MockLog log = buildLog(); - VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true)); + VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(1, 2, 3), true)); int epoch = 1; KRaftControlRecordStateMachine partitionState = buildPartitionListener(log, Optional.of(staticVoterSet)); // Create a snapshot that has kraft.version and voter set control records short kraftVersion = 1; - VoterSet voterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(4, 5, 6), true)); + VoterSet voterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(4, 5, 6), true)); RecordsSnapshotWriter.Builder builder = new RecordsSnapshotWriter.Builder() .setRawSnapshotWriter(log.createNewSnapshotUnchecked(new OffsetAndEpoch(10, epoch)).get()) @@ -188,7 +188,7 @@ final class KRaftControlRecordStateMachineTest { @Test void testUpdateWithSnapshotAndLogOverride() { MockLog log = buildLog(); - VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true)); + VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(1, 2, 3), true)); BufferSupplier bufferSupplier = BufferSupplier.NO_CACHING; int epoch = 1; @@ -196,7 +196,7 @@ final class KRaftControlRecordStateMachineTest { // Create a snapshot that has kraft.version and voter set control records short kraftVersion = 1; - VoterSet snapshotVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(4, 5, 6), true)); + VoterSet snapshotVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(4, 5, 6), true)); OffsetAndEpoch snapshotId = new OffsetAndEpoch(10, epoch); RecordsSnapshotWriter.Builder builder = new RecordsSnapshotWriter.Builder() @@ -235,7 +235,7 @@ final class KRaftControlRecordStateMachineTest { @Test void testTruncateTo() { MockLog log = buildLog(); - VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true)); + VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(1, 2, 3), true)); BufferSupplier bufferSupplier = BufferSupplier.NO_CACHING; int epoch = 1; @@ -256,7 +256,7 @@ final class KRaftControlRecordStateMachineTest { // Append the voter set control record long firstVoterSetOffset = log.endOffset().offset; - VoterSet firstVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(4, 5, 6), true)); + VoterSet firstVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(4, 5, 6), true)); log.appendAsLeader( MemoryRecords.withVotersRecord( firstVoterSetOffset, @@ -303,7 +303,7 @@ final class KRaftControlRecordStateMachineTest { @Test void testTrimPrefixTo() { MockLog log = buildLog(); - VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true)); + VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(1, 2, 3), true)); BufferSupplier bufferSupplier = BufferSupplier.NO_CACHING; int epoch = 1; @@ -325,7 +325,7 @@ final class KRaftControlRecordStateMachineTest { // Append the voter set control record long firstVoterSetOffset = log.endOffset().offset; - VoterSet firstVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(4, 5, 6), true)); + VoterSet firstVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(4, 5, 6), true)); log.appendAsLeader( MemoryRecords.withVotersRecord( firstVoterSetOffset, diff --git a/raft/src/test/java/org/apache/kafka/raft/internals/KafkaRaftMetricsTest.java b/raft/src/test/java/org/apache/kafka/raft/internals/KafkaRaftMetricsTest.java index 1b729e36d30..240a55d4403 100644 --- a/raft/src/test/java/org/apache/kafka/raft/internals/KafkaRaftMetricsTest.java +++ b/raft/src/test/java/org/apache/kafka/raft/internals/KafkaRaftMetricsTest.java @@ -22,7 +22,6 @@ import org.apache.kafka.common.metrics.Metrics; import org.apache.kafka.common.utils.LogContext; import org.apache.kafka.common.utils.MockTime; import org.apache.kafka.common.utils.Time; -import org.apache.kafka.common.utils.Utils; import org.apache.kafka.raft.LogOffsetMetadata; import org.apache.kafka.raft.MockQuorumStateStore; import org.apache.kafka.raft.OffsetAndEpoch; @@ -32,13 +31,13 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; import org.mockito.Mockito; -import java.util.Map; import java.util.Collections; +import java.util.Map; import java.util.Optional; import java.util.OptionalInt; import java.util.OptionalLong; import java.util.Random; -import java.util.Set; +import java.util.stream.IntStream; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -64,19 +63,11 @@ public class KafkaRaftMetricsTest { metrics.close(); } - private QuorumState buildQuorumState(Set voters, short kraftVersion) { - boolean withDirectoryId = kraftVersion > 0; - - return buildQuorumState( - VoterSetTest.voterSet(VoterSetTest.voterMap(voters, withDirectoryId)), - kraftVersion - ); - } - private QuorumState buildQuorumState(VoterSet voterSet, short kraftVersion) { return new QuorumState( OptionalInt.of(localId), localDirectoryId, + VoterSetTest.DEFAULT_LISTENER_NAME, () -> voterSet, () -> kraftVersion, electionTimeoutMs, @@ -88,11 +79,26 @@ public class KafkaRaftMetricsTest { ); } + private VoterSet localStandaloneVoterSet(short kraftVersion) { + boolean withDirectoryId = kraftVersion > 0; + return VoterSetTest.voterSet( + Collections.singletonMap( + localId, + VoterSetTest.voterNode( + ReplicaKey.of( + localId, + withDirectoryId ? Optional.of(localDirectoryId) : Optional.empty() + ) + ) + ) + ); + } + @ParameterizedTest @ValueSource(shorts = {0, 1}) public void shouldRecordVoterQuorumState(short kraftVersion) { boolean withDirectoryId = kraftVersion > 0; - Map voterMap = VoterSetTest.voterMap(Utils.mkSet(1, 2), withDirectoryId); + Map voterMap = VoterSetTest.voterMap(IntStream.of(1, 2), withDirectoryId); voterMap.put( localId, VoterSetTest.voterNode( @@ -102,7 +108,8 @@ public class KafkaRaftMetricsTest { ) ) ); - QuorumState state = buildQuorumState(VoterSetTest.voterSet(voterMap), kraftVersion); + VoterSet voters = VoterSetTest.voterSet(voterMap); + QuorumState state = buildQuorumState(voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, 0)); raftMetrics = new KafkaRaftMetrics(metrics, "raft", state); @@ -144,7 +151,7 @@ public class KafkaRaftMetricsTest { state.leaderStateOrThrow().updateReplicaState(1, 0, new LogOffsetMetadata(5L)); assertEquals((double) 5L, getMetric(metrics, "high-watermark").metricValue()); - state.transitionToFollower(2, 1); + state.transitionToFollower(2, voters.voterNode(1, VoterSetTest.DEFAULT_LISTENER_NAME).get()); assertEquals("follower", getMetric(metrics, "current-state").metricValue()); assertEquals((double) 1, getMetric(metrics, "current-leader").metricValue()); assertEquals((double) -1, getMetric(metrics, "current-vote").metricValue()); @@ -184,7 +191,11 @@ public class KafkaRaftMetricsTest { @ParameterizedTest @ValueSource(shorts = {0, 1}) public void shouldRecordNonVoterQuorumState(short kraftVersion) { - QuorumState state = buildQuorumState(Utils.mkSet(1, 2, 3), kraftVersion); + boolean withDirectoryId = kraftVersion > 0; + VoterSet voters = VoterSetTest.voterSet( + VoterSetTest.voterMap(IntStream.of(1, 2, 3), withDirectoryId) + ); + QuorumState state = buildQuorumState(voters, kraftVersion); state.initialize(new OffsetAndEpoch(0L, 0)); raftMetrics = new KafkaRaftMetrics(metrics, "raft", state); @@ -198,7 +209,7 @@ public class KafkaRaftMetricsTest { assertEquals((double) 0, getMetric(metrics, "current-epoch").metricValue()); assertEquals((double) -1L, getMetric(metrics, "high-watermark").metricValue()); - state.transitionToFollower(2, 1); + state.transitionToFollower(2, voters.voterNode(1, VoterSetTest.DEFAULT_LISTENER_NAME).get()); assertEquals("observer", getMetric(metrics, "current-state").metricValue()); assertEquals((double) 1, getMetric(metrics, "current-leader").metricValue()); assertEquals((double) -1, getMetric(metrics, "current-vote").metricValue()); @@ -227,7 +238,7 @@ public class KafkaRaftMetricsTest { @ParameterizedTest @ValueSource(shorts = {0, 1}) public void shouldRecordLogEnd(short kraftVersion) { - QuorumState state = buildQuorumState(Collections.singleton(localId), kraftVersion); + QuorumState state = buildQuorumState(localStandaloneVoterSet(kraftVersion), kraftVersion); state.initialize(new OffsetAndEpoch(0L, 0)); raftMetrics = new KafkaRaftMetrics(metrics, "raft", state); @@ -243,7 +254,7 @@ public class KafkaRaftMetricsTest { @ParameterizedTest @ValueSource(shorts = {0, 1}) public void shouldRecordNumUnknownVoterConnections(short kraftVersion) { - QuorumState state = buildQuorumState(Collections.singleton(localId), kraftVersion); + QuorumState state = buildQuorumState(localStandaloneVoterSet(kraftVersion), kraftVersion); state.initialize(new OffsetAndEpoch(0L, 0)); raftMetrics = new KafkaRaftMetrics(metrics, "raft", state); @@ -257,7 +268,7 @@ public class KafkaRaftMetricsTest { @ParameterizedTest @ValueSource(shorts = {0, 1}) public void shouldRecordPollIdleRatio(short kraftVersion) { - QuorumState state = buildQuorumState(Collections.singleton(localId), kraftVersion); + QuorumState state = buildQuorumState(localStandaloneVoterSet(kraftVersion), kraftVersion); state.initialize(new OffsetAndEpoch(0L, 0)); raftMetrics = new KafkaRaftMetrics(metrics, "raft", state); @@ -330,7 +341,7 @@ public class KafkaRaftMetricsTest { @ParameterizedTest @ValueSource(shorts = {0, 1}) public void shouldRecordLatency(short kraftVersion) { - QuorumState state = buildQuorumState(Collections.singleton(localId), kraftVersion); + QuorumState state = buildQuorumState(localStandaloneVoterSet(kraftVersion), kraftVersion); state.initialize(new OffsetAndEpoch(0L, 0)); raftMetrics = new KafkaRaftMetrics(metrics, "raft", state); @@ -362,7 +373,7 @@ public class KafkaRaftMetricsTest { @ParameterizedTest @ValueSource(shorts = {0, 1}) public void shouldRecordRate(short kraftVersion) { - QuorumState state = buildQuorumState(Collections.singleton(localId), kraftVersion); + QuorumState state = buildQuorumState(localStandaloneVoterSet(kraftVersion), kraftVersion); state.initialize(new OffsetAndEpoch(0L, 0)); raftMetrics = new KafkaRaftMetrics(metrics, "raft", state); diff --git a/raft/src/test/java/org/apache/kafka/raft/internals/RecordsIteratorTest.java b/raft/src/test/java/org/apache/kafka/raft/internals/RecordsIteratorTest.java index 3a64479fadf..e8896f3b576 100644 --- a/raft/src/test/java/org/apache/kafka/raft/internals/RecordsIteratorTest.java +++ b/raft/src/test/java/org/apache/kafka/raft/internals/RecordsIteratorTest.java @@ -21,7 +21,6 @@ import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; -import java.util.HashMap; import java.util.IdentityHashMap; import java.util.List; import java.util.NoSuchElementException; @@ -31,6 +30,7 @@ import java.util.Random; import java.util.Set; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; +import java.util.stream.IntStream; import java.util.stream.Stream; import net.jqwik.api.ForAll; import net.jqwik.api.Property; @@ -204,7 +204,7 @@ public final class RecordsIteratorTest { public void testControlRecordIterationWithKraftVersion1() { AtomicReference buffer = new AtomicReference<>(null); VoterSet voterSet = new VoterSet( - new HashMap<>(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true)) + VoterSetTest.voterMap(IntStream.of(1, 2, 3), true) ); RecordsSnapshotWriter.Builder builder = new RecordsSnapshotWriter.Builder() .setTime(new MockTime()) diff --git a/raft/src/test/java/org/apache/kafka/raft/internals/VoterSetHistoryTest.java b/raft/src/test/java/org/apache/kafka/raft/internals/VoterSetHistoryTest.java index 22dd52ec364..ac5c3b39c9c 100644 --- a/raft/src/test/java/org/apache/kafka/raft/internals/VoterSetHistoryTest.java +++ b/raft/src/test/java/org/apache/kafka/raft/internals/VoterSetHistoryTest.java @@ -16,10 +16,10 @@ */ package org.apache.kafka.raft.internals; -import java.util.Arrays; import java.util.HashMap; import java.util.Map; import java.util.Optional; +import java.util.stream.IntStream; import org.junit.jupiter.api.Test; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -27,7 +27,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows; final public class VoterSetHistoryTest { @Test void testStaticVoterSet() { - VoterSet staticVoterSet = new VoterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true)); + VoterSet staticVoterSet = new VoterSet(VoterSetTest.voterMap(IntStream.of(1, 2, 3), true)); VoterSetHistory votersHistory = new VoterSetHistory(Optional.of(staticVoterSet)); assertEquals(Optional.empty(), votersHistory.valueAtOrBefore(0)); @@ -58,13 +58,13 @@ final public class VoterSetHistoryTest { @Test void testAddAt() { - Map voterMap = VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true); + Map voterMap = VoterSetTest.voterMap(IntStream.of(1, 2, 3), true); VoterSet staticVoterSet = new VoterSet(new HashMap<>(voterMap)); VoterSetHistory votersHistory = new VoterSetHistory(Optional.of(staticVoterSet)); assertThrows( IllegalArgumentException.class, - () -> votersHistory.addAt(-1, new VoterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true))) + () -> votersHistory.addAt(-1, new VoterSet(VoterSetTest.voterMap(IntStream.of(1, 2, 3), true))) ); assertEquals(staticVoterSet, votersHistory.lastValue()); @@ -90,7 +90,7 @@ final public class VoterSetHistoryTest { void testAddAtNonOverlapping() { VoterSetHistory votersHistory = new VoterSetHistory(Optional.empty()); - Map voterMap = VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true); + Map voterMap = VoterSetTest.voterMap(IntStream.of(1, 2, 3), true); VoterSet voterSet = new VoterSet(new HashMap<>(voterMap)); // Add a starting voter to the history @@ -122,7 +122,7 @@ final public class VoterSetHistoryTest { @Test void testNonoverlappingFromStaticVoterSet() { - Map voterMap = VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true); + Map voterMap = VoterSetTest.voterMap(IntStream.of(1, 2, 3), true); VoterSet staticVoterSet = new VoterSet(new HashMap<>(voterMap)); VoterSetHistory votersHistory = new VoterSetHistory(Optional.empty()); @@ -137,7 +137,7 @@ final public class VoterSetHistoryTest { @Test void testTruncateTo() { - Map voterMap = VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true); + Map voterMap = VoterSetTest.voterMap(IntStream.of(1, 2, 3), true); VoterSet staticVoterSet = new VoterSet(new HashMap<>(voterMap)); VoterSetHistory votersHistory = new VoterSetHistory(Optional.of(staticVoterSet)); @@ -163,7 +163,7 @@ final public class VoterSetHistoryTest { @Test void testTrimPrefixTo() { - Map voterMap = VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true); + Map voterMap = VoterSetTest.voterMap(IntStream.of(1, 2, 3), true); VoterSet staticVoterSet = new VoterSet(new HashMap<>(voterMap)); VoterSetHistory votersHistory = new VoterSetHistory(Optional.of(staticVoterSet)); @@ -196,7 +196,7 @@ final public class VoterSetHistoryTest { @Test void testClear() { - Map voterMap = VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true); + Map voterMap = VoterSetTest.voterMap(IntStream.of(1, 2, 3), true); VoterSet staticVoterSet = new VoterSet(new HashMap<>(voterMap)); VoterSetHistory votersHistory = new VoterSetHistory(Optional.of(staticVoterSet)); diff --git a/raft/src/test/java/org/apache/kafka/raft/internals/VoterSetTest.java b/raft/src/test/java/org/apache/kafka/raft/internals/VoterSetTest.java index f0ed10a5428..9f879db10ee 100644 --- a/raft/src/test/java/org/apache/kafka/raft/internals/VoterSetTest.java +++ b/raft/src/test/java/org/apache/kafka/raft/internals/VoterSetTest.java @@ -18,7 +18,6 @@ package org.apache.kafka.raft.internals; import java.net.InetSocketAddress; import java.util.Arrays; -import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -26,8 +25,13 @@ import java.util.Map; import java.util.Optional; import java.util.function.Function; import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; +import org.apache.kafka.common.Node; import org.apache.kafka.common.Uuid; import org.apache.kafka.common.feature.SupportedVersionRange; +import org.apache.kafka.common.network.ListenerName; +import org.apache.kafka.common.utils.Utils; import org.junit.jupiter.api.Test; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; @@ -41,22 +45,45 @@ final public class VoterSetTest { } @Test - void testVoterAddress() { - VoterSet voterSet = new VoterSet(voterMap(Arrays.asList(1, 2, 3), true)); - assertEquals(Optional.of(new InetSocketAddress("replica-1", 1234)), voterSet.voterAddress(1, "LISTENER")); - assertEquals(Optional.empty(), voterSet.voterAddress(1, "MISSING")); - assertEquals(Optional.empty(), voterSet.voterAddress(4, "LISTENER")); + void testVoterNode() { + VoterSet voterSet = new VoterSet(voterMap(IntStream.of(1, 2, 3), true)); + assertEquals( + Optional.of(new Node(1, "replica-1", 1234)), + voterSet.voterNode(1, DEFAULT_LISTENER_NAME) + ); + assertEquals(Optional.empty(), voterSet.voterNode(1, ListenerName.normalised("MISSING"))); + assertEquals(Optional.empty(), voterSet.voterNode(4, DEFAULT_LISTENER_NAME)); + } + + @Test + void testVoterNodes() { + VoterSet voterSet = new VoterSet(voterMap(IntStream.of(1, 2, 3), true)); + + assertEquals( + Utils.mkSet(new Node(1, "replica-1", 1234), new Node(2, "replica-2", 1234)), + voterSet.voterNodes(IntStream.of(1, 2).boxed(), DEFAULT_LISTENER_NAME) + ); + + assertThrows( + IllegalArgumentException.class, + () -> voterSet.voterNodes(IntStream.of(1, 2).boxed(), ListenerName.normalised("MISSING")) + ); + + assertThrows( + IllegalArgumentException.class, + () -> voterSet.voterNodes(IntStream.of(1, 4).boxed(), DEFAULT_LISTENER_NAME) + ); } @Test void testVoterIds() { - VoterSet voterSet = new VoterSet(voterMap(Arrays.asList(1, 2, 3), true)); + VoterSet voterSet = new VoterSet(voterMap(IntStream.of(1, 2, 3), true)); assertEquals(new HashSet<>(Arrays.asList(1, 2, 3)), voterSet.voterIds()); } @Test void testAddVoter() { - Map aVoterMap = voterMap(Arrays.asList(1, 2, 3), true); + Map aVoterMap = voterMap(IntStream.of(1, 2, 3), true); VoterSet voterSet = new VoterSet(new HashMap<>(aVoterMap)); assertEquals(Optional.empty(), voterSet.addVoter(voterNode(1, true))); @@ -68,7 +95,7 @@ final public class VoterSetTest { @Test void testRemoveVoter() { - Map aVoterMap = voterMap(Arrays.asList(1, 2, 3), true); + Map aVoterMap = voterMap(IntStream.of(1, 2, 3), true); VoterSet voterSet = new VoterSet(new HashMap<>(aVoterMap)); assertEquals(Optional.empty(), voterSet.removeVoter(ReplicaKey.of(4, Optional.empty()))); @@ -83,7 +110,7 @@ final public class VoterSetTest { @Test void testIsVoterWithDirectoryId() { - Map aVoterMap = voterMap(Arrays.asList(1, 2, 3), true); + Map aVoterMap = voterMap(IntStream.of(1, 2, 3), true); VoterSet voterSet = new VoterSet(new HashMap<>(aVoterMap)); assertTrue(voterSet.isVoter(aVoterMap.get(1).voterKey())); @@ -100,7 +127,7 @@ final public class VoterSetTest { @Test void testIsVoterWithoutDirectoryId() { - Map aVoterMap = voterMap(Arrays.asList(1, 2, 3), false); + Map aVoterMap = voterMap(IntStream.of(1, 2, 3), false); VoterSet voterSet = new VoterSet(new HashMap<>(aVoterMap)); assertTrue(voterSet.isVoter(ReplicaKey.of(1, Optional.empty()))); @@ -111,7 +138,7 @@ final public class VoterSetTest { @Test void testIsOnlyVoterInStandalone() { - Map aVoterMap = voterMap(Arrays.asList(1), true); + Map aVoterMap = voterMap(IntStream.of(1), true); VoterSet voterSet = new VoterSet(new HashMap<>(aVoterMap)); assertTrue(voterSet.isOnlyVoter(aVoterMap.get(1).voterKey())); @@ -125,7 +152,7 @@ final public class VoterSetTest { @Test void testIsOnlyVoterInNotStandalone() { - Map aVoterMap = voterMap(Arrays.asList(1, 2), true); + Map aVoterMap = voterMap(IntStream.of(1, 2), true); VoterSet voterSet = new VoterSet(new HashMap<>(aVoterMap)); assertFalse(voterSet.isOnlyVoter(aVoterMap.get(1).voterKey())); @@ -142,14 +169,14 @@ final public class VoterSetTest { @Test void testRecordRoundTrip() { - VoterSet voterSet = new VoterSet(voterMap(Arrays.asList(1, 2, 3), true)); + VoterSet voterSet = new VoterSet(voterMap(IntStream.of(1, 2, 3), true)); assertEquals(voterSet, VoterSet.fromVotersRecord(voterSet.toVotersRecord((short) 0))); } @Test void testOverlappingMajority() { - Map startingVoterMap = voterMap(Arrays.asList(1, 2, 3), true); + Map startingVoterMap = voterMap(IntStream.of(1, 2, 3), true); VoterSet startingVoterSet = voterSet(startingVoterMap); VoterSet biggerVoterSet = startingVoterSet @@ -172,7 +199,7 @@ final public class VoterSetTest { @Test void testNonoverlappingMajority() { - Map startingVoterMap = voterMap(Arrays.asList(1, 2, 3, 4, 5), true); + Map startingVoterMap = voterMap(IntStream.of(1, 2, 3, 4, 5), true); VoterSet startingVoterSet = voterSet(startingVoterMap); // Two additions don't have an overlapping majority @@ -217,20 +244,27 @@ final public class VoterSetTest { ); } + public static final ListenerName DEFAULT_LISTENER_NAME = ListenerName.normalised("LISTENER"); + public static Map voterMap( - Collection replicas, + IntStream replicas, boolean withDirectoryId ) { return replicas - .stream() + .boxed() .collect( Collectors.toMap( Function.identity(), - id -> VoterSetTest.voterNode(id, withDirectoryId) + id -> voterNode(id, withDirectoryId) ) ); } + public static Map voterMap(Stream replicas) { + return replicas + .collect(Collectors.toMap(ReplicaKey::id, VoterSetTest::voterNode)); + } + public static VoterSet.VoterNode voterNode(int id, boolean withDirectoryId) { return voterNode( ReplicaKey.of( @@ -244,7 +278,7 @@ final public class VoterSetTest { return new VoterSet.VoterNode( replicaKey, Collections.singletonMap( - "LISTENER", + DEFAULT_LISTENER_NAME, InetSocketAddress.createUnresolved( String.format("replica-%d", replicaKey.id()), 1234 @@ -257,4 +291,8 @@ final public class VoterSetTest { public static VoterSet voterSet(Map voters) { return new VoterSet(voters); } + + public static VoterSet voterSet(Stream voterKeys) { + return voterSet(voterMap(voterKeys)); + } } diff --git a/raft/src/test/java/org/apache/kafka/snapshot/RecordsSnapshotWriterTest.java b/raft/src/test/java/org/apache/kafka/snapshot/RecordsSnapshotWriterTest.java index 17b7c5d9f39..32a980296b1 100644 --- a/raft/src/test/java/org/apache/kafka/snapshot/RecordsSnapshotWriterTest.java +++ b/raft/src/test/java/org/apache/kafka/snapshot/RecordsSnapshotWriterTest.java @@ -17,12 +17,11 @@ package org.apache.kafka.snapshot; - import java.nio.ByteBuffer; -import java.util.Arrays; import java.util.HashMap; import java.util.Optional; import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.IntStream; import org.apache.kafka.common.message.KRaftVersionRecord; import org.apache.kafka.common.message.SnapshotFooterRecord; import org.apache.kafka.common.message.SnapshotHeaderRecord; @@ -97,7 +96,7 @@ final class RecordsSnapshotWriterTest { OffsetAndEpoch snapshotId = new OffsetAndEpoch(100, 10); int maxBatchSize = 1024; VoterSet voterSet = VoterSetTest.voterSet( - new HashMap<>(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true)) + new HashMap<>(VoterSetTest.voterMap(IntStream.of(1, 2, 3), true)) ); AtomicReference buffer = new AtomicReference<>(null); RecordsSnapshotWriter.Builder builder = new RecordsSnapshotWriter.Builder() @@ -117,7 +116,7 @@ final class RecordsSnapshotWriterTest { OffsetAndEpoch snapshotId = new OffsetAndEpoch(100, 10); int maxBatchSize = 1024; VoterSet voterSet = VoterSetTest.voterSet( - new HashMap<>(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true)) + new HashMap<>(VoterSetTest.voterMap(IntStream.of(1, 2, 3), true)) ); AtomicReference buffer = new AtomicReference<>(null); RecordsSnapshotWriter.Builder builder = new RecordsSnapshotWriter.Builder()