KAFKA-16525; Dynamic KRaft network manager and channel (#15986)

Allow KRaft replicas to send requests to any node (Node) not just the nodes configured in the
controller.quorum.voters property. This flexibility is needed so KRaft can implement the
controller.quorum.voters configuration, send request to the dynamically changing set of voters and
send request to the leader endpoint (Node) discovered through the KRaft RPCs (specially
BeginQuorumEpoch request and Fetch response).

This was achieved by changing the RequestManager API to accept Node instead of just the replica ID.
Internally, the request manager tracks connection state using the Node.idString method to match the
connection management used by NetworkClient.

The API for RequestManager is also changed so that the ConnectState class is not exposed in the
API. This allows the request manager to reclaim heap memory for any connection that is ready.

The NetworkChannel was updated to receive the endpoint information (Node) through the outbound raft
request (RaftRequent.Outbound). This makes the network channel more flexible as it doesn't need to
be configured with the list of all possible endpoints. RaftRequest.Outbound and
RaftResponse.Inbound were updated to include the remote node instead of just the remote id.

The follower state tracked by KRaft replicas was updated to include both the leader id and the
leader's endpoint (Node). In this comment the node value is computed from the set of voters. In
future commit this will be updated so that it is sent through KRaft RPCs. For example
BeginQuorumEpoch request and Fetch response.

Support for configuring controller.quorum.bootstrap.servers was added. This includes changes to
KafkaConfig, QuorumConfig, etc. All of the tests using QuorumTestHarness were changed to use the
controller.quorum.bootstrap.servers instead of the controller.quorum.voters for the broker
configuration. Finally, the node id for the bootstrap server will be decreasing negative numbers
starting with -2.

Reviewers: Jason Gustafson <jason@confluent.io>, Luke Chen <showuon@gmail.com>, Colin P. McCabe <cmccabe@apache.org>
This commit is contained in:
José Armando García Sancio 2024-06-03 17:24:48 -04:00 committed by GitHub
parent 8a882a77a4
commit 459da4795a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
45 changed files with 2163 additions and 990 deletions

View File

@ -441,6 +441,7 @@
<allow pkg="org.apache.kafka.common.message" />
<allow pkg="org.apache.kafka.common.metadata" />
<allow pkg="org.apache.kafka.common.metrics" />
<allow pkg="org.apache.kafka.common.network" />
<allow pkg="org.apache.kafka.common.protocol" />
<allow pkg="org.apache.kafka.common.record" />
<allow pkg="org.apache.kafka.common.requests" />

View File

@ -23,6 +23,7 @@ import java.nio.file.Paths
import java.util.OptionalInt
import java.util.concurrent.CompletableFuture
import java.util.{Map => JMap}
import java.util.{Collection => JCollection}
import kafka.log.LogManager
import kafka.log.UnifiedLog
import kafka.server.KafkaConfig
@ -133,7 +134,7 @@ trait RaftManager[T] {
def replicatedLog: ReplicatedLog
def voterNode(id: Int, listener: String): Option[Node]
def voterNode(id: Int, listener: ListenerName): Option[Node]
}
class KafkaRaftManager[T](
@ -147,6 +148,7 @@ class KafkaRaftManager[T](
metrics: Metrics,
threadNamePrefixOpt: Option[String],
val controllerQuorumVotersFuture: CompletableFuture[JMap[Integer, InetSocketAddress]],
bootstrapServers: JCollection[InetSocketAddress],
fatalFaultHandler: FaultHandler
) extends RaftManager[T] with Logging {
@ -185,7 +187,6 @@ class KafkaRaftManager[T](
def startup(): Unit = {
client.initialize(
controllerQuorumVotersFuture.get(),
config.controllerListenerNames.head,
new FileQuorumStateStore(new File(dataDir, FileQuorumStateStore.DEFAULT_FILE_NAME)),
metrics
)
@ -228,14 +229,15 @@ class KafkaRaftManager[T](
expirationService,
logContext,
clusterId,
bootstrapServers,
raftConfig
)
client
}
private def buildNetworkChannel(): KafkaNetworkChannel = {
val netClient = buildNetworkClient()
new KafkaNetworkChannel(time, netClient, config.quorumRequestTimeoutMs, threadNamePrefix)
val (listenerName, netClient) = buildNetworkClient()
new KafkaNetworkChannel(time, listenerName, netClient, config.quorumRequestTimeoutMs, threadNamePrefix)
}
private def createDataDir(): File = {
@ -254,7 +256,7 @@ class KafkaRaftManager[T](
)
}
private def buildNetworkClient(): NetworkClient = {
private def buildNetworkClient(): (ListenerName, NetworkClient) = {
val controllerListenerName = new ListenerName(config.controllerListenerNames.head)
val controllerSecurityProtocol = config.effectiveListenerSecurityProtocolMap.getOrElse(
controllerListenerName,
@ -292,7 +294,7 @@ class KafkaRaftManager[T](
val reconnectBackoffMsMs = 500
val discoverBrokerVersions = true
new NetworkClient(
val networkClient = new NetworkClient(
selector,
new ManualMetadataUpdater(),
clientId,
@ -309,13 +311,15 @@ class KafkaRaftManager[T](
apiVersions,
logContext
)
(controllerListenerName, networkClient)
}
override def leaderAndEpoch: LeaderAndEpoch = {
client.leaderAndEpoch
}
override def voterNode(id: Int, listener: String): Option[Node] = {
override def voterNode(id: Int, listener: ListenerName): Option[Node] = {
client.voterNode(id, listener).toScala
}
}

View File

@ -439,6 +439,7 @@ object KafkaConfig {
/** ********* Raft Quorum Configuration *********/
.define(QuorumConfig.QUORUM_VOTERS_CONFIG, LIST, QuorumConfig.DEFAULT_QUORUM_VOTERS, new QuorumConfig.ControllerQuorumVotersValidator(), HIGH, QuorumConfig.QUORUM_VOTERS_DOC)
.define(QuorumConfig.QUORUM_BOOTSTRAP_SERVERS_CONFIG, LIST, QuorumConfig.DEFAULT_QUORUM_BOOTSTRAP_SERVERS, new QuorumConfig.ControllerQuorumBootstrapServersValidator(), HIGH, QuorumConfig.QUORUM_BOOTSTRAP_SERVERS_DOC)
.define(QuorumConfig.QUORUM_ELECTION_TIMEOUT_MS_CONFIG, INT, QuorumConfig.DEFAULT_QUORUM_ELECTION_TIMEOUT_MS, null, HIGH, QuorumConfig.QUORUM_ELECTION_TIMEOUT_MS_DOC)
.define(QuorumConfig.QUORUM_FETCH_TIMEOUT_MS_CONFIG, INT, QuorumConfig.DEFAULT_QUORUM_FETCH_TIMEOUT_MS, null, HIGH, QuorumConfig.QUORUM_FETCH_TIMEOUT_MS_DOC)
.define(QuorumConfig.QUORUM_ELECTION_BACKOFF_MAX_MS_CONFIG, INT, QuorumConfig.DEFAULT_QUORUM_ELECTION_BACKOFF_MAX_MS, null, HIGH, QuorumConfig.QUORUM_ELECTION_BACKOFF_MAX_MS_DOC)
@ -1055,6 +1056,7 @@ class KafkaConfig private(doLog: Boolean, val props: java.util.Map[_, _], dynami
/** ********* Raft Quorum Configuration *********/
val quorumVoters = getList(QuorumConfig.QUORUM_VOTERS_CONFIG)
val quorumBootstrapServers = getList(QuorumConfig.QUORUM_BOOTSTRAP_SERVERS_CONFIG)
val quorumElectionTimeoutMs = getInt(QuorumConfig.QUORUM_ELECTION_TIMEOUT_MS_CONFIG)
val quorumFetchTimeoutMs = getInt(QuorumConfig.QUORUM_FETCH_TIMEOUT_MS_CONFIG)
val quorumElectionBackoffMs = getInt(QuorumConfig.QUORUM_ELECTION_BACKOFF_MAX_MS_CONFIG)

View File

@ -71,6 +71,7 @@ class KafkaRaftServer(
time,
metrics,
CompletableFuture.completedFuture(QuorumConfig.parseVoterConnections(config.quorumVoters)),
QuorumConfig.parseBootstrapServers(config.quorumBootstrapServers),
new StandardFaultHandlerFactory(),
)

View File

@ -70,9 +70,9 @@ import java.net.{InetAddress, SocketTimeoutException}
import java.nio.file.{Files, Paths}
import java.time.Duration
import java.util
import java.util.{Optional, OptionalInt, OptionalLong}
import java.util.concurrent._
import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger}
import java.util.{Optional, OptionalInt, OptionalLong}
import scala.collection.{Map, Seq}
import scala.compat.java8.OptionConverters.RichOptionForJava8
import scala.jdk.CollectionConverters._
@ -439,6 +439,7 @@ class KafkaServer(
metrics,
threadNamePrefix,
CompletableFuture.completedFuture(quorumVoters),
QuorumConfig.parseBootstrapServers(config.quorumBootstrapServers),
fatalFaultHandler = new LoggingFaultHandler("raftManager", () => shutdown())
)
quorumControllerNodeProvider = RaftControllerNodeProvider(raftManager, config)

View File

@ -112,7 +112,7 @@ class RaftControllerNodeProvider(
val saslMechanism: String
) extends ControllerNodeProvider with Logging {
private def idToNode(id: Int): Option[Node] = raftManager.voterNode(id, listenerName.value())
private def idToNode(id: Int): Option[Node] = raftManager.voterNode(id, listenerName)
override def getControllerInfo(): ControllerInformation =
ControllerInformation(raftManager.leaderAndEpoch.leaderId.asScala.flatMap(idToNode),

View File

@ -41,6 +41,7 @@ import java.util.Arrays
import java.util.Optional
import java.util.concurrent.atomic.AtomicReference
import java.util.concurrent.{CompletableFuture, TimeUnit}
import java.util.{Collection => JCollection}
import java.util.{Map => JMap}
@ -94,6 +95,7 @@ class SharedServer(
val time: Time,
private val _metrics: Metrics,
val controllerQuorumVotersFuture: CompletableFuture[JMap[Integer, InetSocketAddress]],
val bootstrapServers: JCollection[InetSocketAddress],
val faultHandlerFactory: FaultHandlerFactory
) extends Logging {
private val logContext: LogContext = new LogContext(s"[SharedServer id=${sharedServerConfig.nodeId}] ")
@ -265,6 +267,7 @@ class SharedServer(
metrics,
Some(s"kafka-${sharedServerConfig.nodeId}-raft"), // No dash expected at the end
controllerQuorumVotersFuture,
bootstrapServers,
raftManagerFaultHandler
)
raftManager = _raftManager

View File

@ -502,7 +502,7 @@ object StorageTool extends Logging {
metaPropertiesEnsemble.verify(metaProperties.clusterId(), metaProperties.nodeId(),
util.EnumSet.noneOf(classOf[VerificationFlag]))
System.out.println(s"metaPropertiesEnsemble=$metaPropertiesEnsemble")
stream.println(s"metaPropertiesEnsemble=$metaPropertiesEnsemble")
val copier = new MetaPropertiesEnsemble.Copier(metaPropertiesEnsemble)
if (!(ignoreFormatted || copier.logDirProps().isEmpty)) {
val firstLogDir = copier.logDirProps().keySet().iterator().next()

View File

@ -95,6 +95,7 @@ class TestRaftServer(
metrics,
Some(threadNamePrefix),
CompletableFuture.completedFuture(QuorumConfig.parseVoterConnections(config.quorumVoters)),
QuorumConfig.parseBootstrapServers(config.quorumBootstrapServers),
new ProcessTerminatingFaultHandler.Builder().build()
)

View File

@ -239,12 +239,15 @@ public class KafkaClusterTestKit implements AutoCloseable {
ThreadUtils.createThreadFactory("kafka-cluster-test-kit-executor-%d", false));
for (ControllerNode node : nodes.controllerNodes().values()) {
setupNodeDirectories(baseDirectory, node.metadataDirectory(), Collections.emptyList());
SharedServer sharedServer = new SharedServer(createNodeConfig(node),
SharedServer sharedServer = new SharedServer(
createNodeConfig(node),
node.initialMetaPropertiesEnsemble(),
Time.SYSTEM,
new Metrics(),
connectFutureManager.future,
faultHandlerFactory);
Collections.emptyList(),
faultHandlerFactory
);
ControllerServer controller = null;
try {
controller = new ControllerServer(
@ -267,13 +270,18 @@ public class KafkaClusterTestKit implements AutoCloseable {
jointServers.put(node.id(), sharedServer);
}
for (BrokerNode node : nodes.brokerNodes().values()) {
SharedServer sharedServer = jointServers.computeIfAbsent(node.id(),
id -> new SharedServer(createNodeConfig(node),
SharedServer sharedServer = jointServers.computeIfAbsent(
node.id(),
id -> new SharedServer(
createNodeConfig(node),
node.initialMetaPropertiesEnsemble(),
Time.SYSTEM,
new Metrics(),
connectFutureManager.future,
faultHandlerFactory));
Collections.emptyList(),
faultHandlerFactory
)
);
BrokerServer broker = null;
try {
broker = new BrokerServer(sharedServer);

View File

@ -21,6 +21,5 @@ log4j.appender.stdout.layout.ConversionPattern=[%d] %p %m (%c:%L)%n
log4j.logger.kafka=WARN
log4j.logger.org.apache.kafka=WARN
# zkclient can be verbose, during debugging it is common to adjust it separately
log4j.logger.org.apache.zookeeper=WARN

View File

@ -21,12 +21,12 @@ import kafka.utils.{TestInfoUtils, TestUtils}
import org.apache.kafka.clients.admin.{NewPartitions, NewTopic}
import org.apache.kafka.clients.consumer._
import org.apache.kafka.clients.producer.{ProducerConfig, ProducerRecord}
import org.apache.kafka.common.{KafkaException, MetricName, TopicPartition}
import org.apache.kafka.common.config.TopicConfig
import org.apache.kafka.common.errors.{InvalidGroupIdException, InvalidTopicException, TimeoutException, WakeupException}
import org.apache.kafka.common.header.Headers
import org.apache.kafka.common.record.{CompressionType, TimestampType}
import org.apache.kafka.common.serialization._
import org.apache.kafka.common.{KafkaException, MetricName, TopicPartition}
import org.apache.kafka.test.{MockConsumerInterceptor, MockProducerInterceptor}
import org.junit.jupiter.api.Assertions._
import org.junit.jupiter.api.Timeout

View File

@ -124,12 +124,15 @@ class KRaftQuorumImplementation(
metaPropertiesEnsemble.verify(Optional.of(clusterId),
OptionalInt.of(config.nodeId),
util.EnumSet.of(REQUIRE_AT_LEAST_ONE_VALID, REQUIRE_METADATA_LOG_DIR))
val sharedServer = new SharedServer(config,
val sharedServer = new SharedServer(
config,
metaPropertiesEnsemble,
time,
new Metrics(),
controllerQuorumVotersFuture,
faultHandlerFactory)
controllerQuorumVotersFuture.get().values(),
faultHandlerFactory
)
var broker: BrokerServer = null
try {
broker = new BrokerServer(sharedServer)
@ -371,12 +374,15 @@ abstract class QuorumTestHarness extends Logging {
metaPropertiesEnsemble.verify(Optional.of(metaProperties.clusterId().get()),
OptionalInt.of(nodeId),
util.EnumSet.of(REQUIRE_AT_LEAST_ONE_VALID, REQUIRE_METADATA_LOG_DIR))
val sharedServer = new SharedServer(config,
val sharedServer = new SharedServer(
config,
metaPropertiesEnsemble,
Time.SYSTEM,
new Metrics(),
controllerQuorumVotersFuture,
faultHandlerFactory)
Collections.emptyList(),
faultHandlerFactory
)
var controllerServer: ControllerServer = null
try {
controllerServer = new ControllerServer(

View File

@ -118,6 +118,7 @@ class RaftManagerTest {
new Metrics(Time.SYSTEM),
Option.empty,
CompletableFuture.completedFuture(QuorumConfig.parseVoterConnections(config.quorumVoters)),
QuorumConfig.parseBootstrapServers(config.quorumBootstrapServers),
mock(classOf[FaultHandler])
)
}

View File

@ -19,7 +19,7 @@ package kafka.server
import java.net.InetSocketAddress
import java.util
import java.util.{Collections, Properties}
import java.util.{Arrays, Collections, Properties}
import kafka.cluster.EndPoint
import kafka.security.authorizer.AclAuthorizer
import kafka.utils.TestUtils.assertBadConfigContainingMessage
@ -1032,6 +1032,7 @@ class KafkaConfigTest {
// Raft Quorum Configs
case QuorumConfig.QUORUM_VOTERS_CONFIG => // ignore string
case QuorumConfig.QUORUM_BOOTSTRAP_SERVERS_CONFIG => // ignore string
case QuorumConfig.QUORUM_ELECTION_TIMEOUT_MS_CONFIG => assertPropertyInvalid(baseProperties, name, "not_a_number")
case QuorumConfig.QUORUM_FETCH_TIMEOUT_MS_CONFIG => assertPropertyInvalid(baseProperties, name, "not_a_number")
case QuorumConfig.QUORUM_ELECTION_BACKOFF_MAX_MS_CONFIG => assertPropertyInvalid(baseProperties, name, "not_a_number")
@ -1402,6 +1403,23 @@ class KafkaConfigTest {
assertEquals(expectedVoters, addresses)
}
@Test
def testParseQuorumBootstrapServers(): Unit = {
val expected = Arrays.asList(
InetSocketAddress.createUnresolved("kafka1", 9092),
InetSocketAddress.createUnresolved("kafka2", 9092)
)
val props = TestUtils.createBrokerConfig(0, null)
props.setProperty(QuorumConfig.QUORUM_BOOTSTRAP_SERVERS_CONFIG, "kafka1:9092,kafka2:9092")
val addresses = QuorumConfig.parseBootstrapServers(
KafkaConfig.fromProps(props).quorumBootstrapServers
)
assertEquals(expected, addresses)
}
@Test
def testAcceptsLargeNodeIdForRaftBasedCase(): Unit = {
// Generation of Broker IDs is not supported when using Raft-based controller quorums,

View File

@ -22,8 +22,8 @@ import java.nio.ByteBuffer
import java.util
import java.util.Collections
import java.util.Optional
import java.util.Arrays
import java.util.Properties
import java.util.stream.IntStream
import kafka.log.{LogTestUtils, UnifiedLog}
import kafka.raft.{KafkaMetadataLog, MetadataLogConfig}
import kafka.server.{BrokerTopicStats, KafkaRaftServer}
@ -338,7 +338,7 @@ class DumpLogSegmentsTest {
.setLastContainedLogTimestamp(lastContainedLogTimestamp)
.setRawSnapshotWriter(metadataLog.createNewSnapshot(new OffsetAndEpoch(0, 0)).get)
.setKraftVersion(1)
.setVoterSet(Optional.of(VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true))))
.setVoterSet(Optional.of(VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(1, 2, 3), true))))
.build(MetadataRecordSerde.INSTANCE)
) { snapshotWriter =>
snapshotWriter.append(metadataRecords.asJava)

View File

@ -30,9 +30,9 @@ import org.apache.kafka.raft.internals.ReplicaKey;
* Encapsulate election state stored on disk after every state change.
*/
final public class ElectionState {
private static int unknownLeaderId = -1;
private static int notVoted = -1;
private static Uuid noVotedDirectoryId = Uuid.ZERO_UUID;
private static final int UNKNOWN_LEADER_ID = -1;
private static final int NOT_VOTED = -1;
private static final Uuid NO_VOTED_DIRECTORY_ID = Uuid.ZERO_UUID;
private final int epoch;
private final OptionalInt leaderId;
@ -95,7 +95,7 @@ final public class ElectionState {
}
public int leaderIdOrSentinel() {
return leaderId.orElse(unknownLeaderId);
return leaderId.orElse(UNKNOWN_LEADER_ID);
}
public OptionalInt optionalLeaderId() {
@ -126,7 +126,7 @@ final public class ElectionState {
QuorumStateData data = new QuorumStateData()
.setLeaderEpoch(epoch)
.setLeaderId(leaderIdOrSentinel())
.setVotedId(votedKey.map(ReplicaKey::id).orElse(notVoted));
.setVotedId(votedKey.map(ReplicaKey::id).orElse(NOT_VOTED));
if (version == 0) {
List<QuorumStateData.Voter> dataVoters = voters
@ -135,7 +135,7 @@ final public class ElectionState {
.collect(Collectors.toList());
data.setCurrentVoters(dataVoters);
} else if (version == 1) {
data.setVotedDirectoryId(votedKey.flatMap(ReplicaKey::directoryId).orElse(noVotedDirectoryId));
data.setVotedDirectoryId(votedKey.flatMap(ReplicaKey::directoryId).orElse(NO_VOTED_DIRECTORY_ID));
} else {
throw new IllegalStateException(
String.format(
@ -198,17 +198,17 @@ final public class ElectionState {
}
public static ElectionState fromQuorumStateData(QuorumStateData data) {
Optional<Uuid> votedDirectoryId = data.votedDirectoryId().equals(noVotedDirectoryId) ?
Optional<Uuid> votedDirectoryId = data.votedDirectoryId().equals(NO_VOTED_DIRECTORY_ID) ?
Optional.empty() :
Optional.of(data.votedDirectoryId());
Optional<ReplicaKey> votedKey = data.votedId() == notVoted ?
Optional<ReplicaKey> votedKey = data.votedId() == NOT_VOTED ?
Optional.empty() :
Optional.of(ReplicaKey.of(data.votedId(), votedDirectoryId));
return new ElectionState(
data.leaderEpoch(),
data.leaderId() == unknownLeaderId ? OptionalInt.empty() : OptionalInt.of(data.leaderId()),
data.leaderId() == UNKNOWN_LEADER_ID ? OptionalInt.empty() : OptionalInt.of(data.leaderId()),
votedKey,
data.currentVoters().stream().map(QuorumStateData.Voter::voterId).collect(Collectors.toSet())
);

View File

@ -19,6 +19,7 @@ package org.apache.kafka.raft;
import java.util.Optional;
import java.util.OptionalLong;
import java.util.Set;
import org.apache.kafka.common.Node;
import org.apache.kafka.common.utils.LogContext;
import org.apache.kafka.common.utils.Time;
import org.apache.kafka.common.utils.Timer;
@ -29,7 +30,7 @@ import org.slf4j.Logger;
public class FollowerState implements EpochState {
private final int fetchTimeoutMs;
private final int epoch;
private final int leaderId;
private final Node leader;
private final Set<Integer> voters;
// Used for tracking the expiration of both the Fetch and FetchSnapshot requests
private final Timer fetchTimer;
@ -37,14 +38,14 @@ public class FollowerState implements EpochState {
/* Used to track the currently fetching snapshot. When fetching snapshot regular
* Fetch request are paused
*/
private Optional<RawSnapshotWriter> fetchingSnapshot;
private Optional<RawSnapshotWriter> fetchingSnapshot = Optional.empty();
private final Logger log;
public FollowerState(
Time time,
int epoch,
int leaderId,
Node leader,
Set<Integer> voters,
Optional<LogOffsetMetadata> highWatermark,
int fetchTimeoutMs,
@ -52,17 +53,16 @@ public class FollowerState implements EpochState {
) {
this.fetchTimeoutMs = fetchTimeoutMs;
this.epoch = epoch;
this.leaderId = leaderId;
this.leader = leader;
this.voters = voters;
this.fetchTimer = time.timer(fetchTimeoutMs);
this.highWatermark = highWatermark;
this.fetchingSnapshot = Optional.empty();
this.log = logContext.logger(FollowerState.class);
}
@Override
public ElectionState election() {
return ElectionState.withElectedLeader(epoch, leaderId, voters);
return ElectionState.withElectedLeader(epoch, leader.id(), voters);
}
@Override
@ -80,8 +80,8 @@ public class FollowerState implements EpochState {
return fetchTimer.remainingMs();
}
public int leaderId() {
return leaderId;
public Node leader() {
return leader;
}
public boolean hasFetchTimeoutExpired(long currentTimeMs) {
@ -156,7 +156,7 @@ public class FollowerState implements EpochState {
log.debug(
"Rejecting vote request from candidate ({}) since we already have a leader {} in epoch {}",
candidateKey,
leaderId(),
leader,
epoch
);
return false;
@ -164,14 +164,16 @@ public class FollowerState implements EpochState {
@Override
public String toString() {
return "FollowerState(" +
"fetchTimeoutMs=" + fetchTimeoutMs +
", epoch=" + epoch +
", leaderId=" + leaderId +
", voters=" + voters +
", highWatermark=" + highWatermark +
", fetchingSnapshot=" + fetchingSnapshot +
')';
return String.format(
"FollowerState(fetchTimeoutMs=%d, epoch=%d, leader=%s voters=%s, highWatermark=%s, " +
"fetchingSnapshot=%s)",
fetchTimeoutMs,
epoch,
leader,
voters,
highWatermark,
fetchingSnapshot
);
}
@Override

View File

@ -24,6 +24,7 @@ import org.apache.kafka.common.message.EndQuorumEpochRequestData;
import org.apache.kafka.common.message.FetchRequestData;
import org.apache.kafka.common.message.FetchSnapshotRequestData;
import org.apache.kafka.common.message.VoteRequestData;
import org.apache.kafka.common.network.ListenerName;
import org.apache.kafka.common.protocol.ApiKeys;
import org.apache.kafka.common.protocol.ApiMessage;
import org.apache.kafka.common.protocol.Errors;
@ -39,12 +40,9 @@ import org.apache.kafka.server.util.RequestAndCompletionHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.net.InetSocketAddress;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicInteger;
@ -83,9 +81,17 @@ public class KafkaNetworkChannel implements NetworkChannel {
private final SendThread requestThread;
private final AtomicInteger correlationIdCounter = new AtomicInteger(0);
private final Map<Integer, Node> endpoints = new HashMap<>();
public KafkaNetworkChannel(Time time, KafkaClient client, int requestTimeoutMs, String threadNamePrefix) {
private final ListenerName listenerName;
public KafkaNetworkChannel(
Time time,
ListenerName listenerName,
KafkaClient client,
int requestTimeoutMs,
String threadNamePrefix
) {
this.listenerName = listenerName;
this.requestThread = new SendThread(
threadNamePrefix + "-outbound-request-thread",
client,
@ -102,23 +108,23 @@ public class KafkaNetworkChannel implements NetworkChannel {
@Override
public void send(RaftRequest.Outbound request) {
Node node = endpoints.get(request.destinationId());
Node node = request.destination();
if (node != null) {
requestThread.sendRequest(new RequestAndCompletionHandler(
request.createdTimeMs,
request.createdTimeMs(),
node,
buildRequest(request.data),
buildRequest(request.data()),
response -> sendOnComplete(request, response)
));
} else
sendCompleteFuture(request, errorResponse(request.data, Errors.BROKER_NOT_AVAILABLE));
sendCompleteFuture(request, errorResponse(request.data(), Errors.BROKER_NOT_AVAILABLE));
}
private void sendCompleteFuture(RaftRequest.Outbound request, ApiMessage message) {
RaftResponse.Inbound response = new RaftResponse.Inbound(
request.correlationId,
request.correlationId(),
message,
request.destinationId()
request.destination()
);
request.completion.complete(response);
}
@ -127,16 +133,16 @@ public class KafkaNetworkChannel implements NetworkChannel {
ApiMessage response;
if (clientResponse.versionMismatch() != null) {
log.error("Request {} failed due to unsupported version error", request, clientResponse.versionMismatch());
response = errorResponse(request.data, Errors.UNSUPPORTED_VERSION);
response = errorResponse(request.data(), Errors.UNSUPPORTED_VERSION);
} else if (clientResponse.authenticationException() != null) {
// For now we treat authentication errors as retriable. We use the
// `NETWORK_EXCEPTION` error code for lack of a good alternative.
// Note that `NodeToControllerChannelManager` will still log the
// authentication errors so that users have a chance to fix the problem.
log.error("Request {} failed due to authentication error", request, clientResponse.authenticationException());
response = errorResponse(request.data, Errors.NETWORK_EXCEPTION);
response = errorResponse(request.data(), Errors.NETWORK_EXCEPTION);
} else if (clientResponse.wasDisconnected()) {
response = errorResponse(request.data, Errors.BROKER_NOT_AVAILABLE);
response = errorResponse(request.data(), Errors.BROKER_NOT_AVAILABLE);
} else {
response = clientResponse.responseBody().data();
}
@ -149,9 +155,8 @@ public class KafkaNetworkChannel implements NetworkChannel {
}
@Override
public void updateEndpoint(int id, InetSocketAddress address) {
Node node = new Node(id, address.getHostString(), address.getPort());
endpoints.put(id, node);
public ListenerName listenerName() {
return listenerName;
}
public void start() {

View File

@ -37,6 +37,7 @@ import org.apache.kafka.common.message.FetchSnapshotResponseData;
import org.apache.kafka.common.message.VoteRequestData;
import org.apache.kafka.common.message.VoteResponseData;
import org.apache.kafka.common.metrics.Metrics;
import org.apache.kafka.common.network.ListenerName;
import org.apache.kafka.common.protocol.ApiKeys;
import org.apache.kafka.common.protocol.ApiMessage;
import org.apache.kafka.common.protocol.Errors;
@ -60,7 +61,6 @@ import org.apache.kafka.common.utils.BufferSupplier;
import org.apache.kafka.common.utils.LogContext;
import org.apache.kafka.common.utils.Time;
import org.apache.kafka.common.utils.Timer;
import org.apache.kafka.raft.RequestManager.ConnectionState;
import org.apache.kafka.raft.errors.NotLeaderException;
import org.apache.kafka.raft.internals.BatchAccumulator;
import org.apache.kafka.raft.internals.BatchMemoryPool;
@ -85,6 +85,7 @@ import org.apache.kafka.snapshot.SnapshotWriter;
import org.slf4j.Logger;
import java.net.InetSocketAddress;
import java.util.Collection;
import java.util.Collections;
import java.util.IdentityHashMap;
import java.util.Iterator;
@ -100,8 +101,10 @@ import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import static java.util.concurrent.CompletableFuture.completedFuture;
import static org.apache.kafka.raft.RaftUtil.hasValidTopicPartition;
@ -209,6 +212,7 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
ExpirationService expirationService,
LogContext logContext,
String clusterId,
Collection<InetSocketAddress> bootstrapServers,
QuorumConfig quorumConfig
) {
this(
@ -223,6 +227,7 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
expirationService,
MAX_FETCH_WAIT_MS,
clusterId,
bootstrapServers,
logContext,
new Random(),
quorumConfig
@ -241,6 +246,7 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
ExpirationService expirationService,
int fetchMaxWaitMs,
String clusterId,
Collection<InetSocketAddress> bootstrapServers,
LogContext logContext,
Random random,
QuorumConfig quorumConfig
@ -262,6 +268,30 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
this.random = random;
this.quorumConfig = quorumConfig;
this.snapshotCleaner = new RaftMetadataLogCleanerManager(logger, time, 60000, log::maybeClean);
if (!bootstrapServers.isEmpty()) {
// generate Node objects from network addresses by using decreasing negative ids
AtomicInteger id = new AtomicInteger(-2);
List<Node> bootstrapNodes = bootstrapServers
.stream()
.map(address ->
new Node(
id.getAndDecrement(),
address.getHostString(),
address.getPort()
)
)
.collect(Collectors.toList());
logger.info("Starting request manager with bootstrap servers: {}", bootstrapNodes);
requestManager = new RequestManager(
bootstrapNodes,
quorumConfig.retryBackoffMs(),
quorumConfig.requestTimeoutMs(),
random
);
}
}
private void updateFollowerHighWatermark(
@ -378,12 +408,11 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
public void initialize(
Map<Integer, InetSocketAddress> voterAddresses,
String listenerName,
QuorumStateStore quorumStateStore,
Metrics metrics
) {
partitionState = new KRaftControlRecordStateMachine(
Optional.of(VoterSet.fromInetSocketAddresses(listenerName, voterAddresses)),
Optional.of(VoterSet.fromInetSocketAddresses(channel.listenerName(), voterAddresses)),
log,
serde,
BufferSupplier.create(),
@ -394,17 +423,35 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
logger.info("Reading KRaft snapshot and log as part of the initialization");
partitionState.updateState();
VoterSet lastVoterSet = partitionState.lastVoterSet();
if (requestManager == null) {
// The request manager wasn't created using the bootstrap servers
// create it using the voters static configuration
List<Node> bootstrapNodes = voterAddresses
.entrySet()
.stream()
.map(entry ->
new Node(
entry.getKey(),
entry.getValue().getHostString(),
entry.getValue().getPort()
)
)
.collect(Collectors.toList());
logger.info("Starting request manager with static voters: {}", bootstrapNodes);
requestManager = new RequestManager(
lastVoterSet.voterIds(),
bootstrapNodes,
quorumConfig.retryBackoffMs(),
quorumConfig.requestTimeoutMs(),
random
);
}
quorum = new QuorumState(
nodeId,
nodeDirectoryId,
channel.listenerName(),
partitionState::lastVoterSet,
partitionState::lastKraftVersion,
quorumConfig.electionTimeoutMs(),
@ -420,10 +467,6 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
// so there are no unknown voter connections. Report this metric as 0.
kafkaRaftMetrics.updateNumUnknownVoterConnections(0);
for (Integer voterId : lastVoterSet.voterIds()) {
channel.updateEndpoint(voterId, lastVoterSet.voterAddress(voterId, listenerName).get());
}
quorum.initialize(new OffsetAndEpoch(log.endOffset().offset, log.lastFetchedEpoch()));
long currentTimeMs = time.milliseconds();
@ -569,10 +612,10 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
private void transitionToFollower(
int epoch,
int leaderId,
Node leader,
long currentTimeMs
) {
quorum.transitionToFollower(epoch, leaderId);
quorum.transitionToFollower(epoch, leader);
maybeFireLeaderChange();
onBecomeFollower(currentTimeMs);
}
@ -601,7 +644,7 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
private VoteResponseData handleVoteRequest(
RaftRequest.Inbound requestMetadata
) {
VoteRequestData request = (VoteRequestData) requestMetadata.data;
VoteRequestData request = (VoteRequestData) requestMetadata.data();
if (!hasValidClusterId(request.clusterId())) {
return new VoteResponseData().setErrorCode(Errors.INCONSISTENT_CLUSTER_ID.code());
@ -652,8 +695,8 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
RaftResponse.Inbound responseMetadata,
long currentTimeMs
) {
int remoteNodeId = responseMetadata.sourceId();
VoteResponseData response = (VoteResponseData) responseMetadata.data;
int remoteNodeId = responseMetadata.source().id();
VoteResponseData response = (VoteResponseData) responseMetadata.data();
Errors topLevelError = Errors.forCode(response.errorCode());
if (topLevelError != Errors.NONE) {
return handleTopLevelError(topLevelError, responseMetadata);
@ -751,7 +794,7 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
RaftRequest.Inbound requestMetadata,
long currentTimeMs
) {
BeginQuorumEpochRequestData request = (BeginQuorumEpochRequestData) requestMetadata.data;
BeginQuorumEpochRequestData request = (BeginQuorumEpochRequestData) requestMetadata.data();
if (!hasValidClusterId(request.clusterId())) {
return new BeginQuorumEpochResponseData().setErrorCode(Errors.INCONSISTENT_CLUSTER_ID.code());
@ -773,7 +816,11 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
return buildBeginQuorumEpochResponse(errorOpt.get());
}
maybeTransition(OptionalInt.of(requestLeaderId), requestEpoch, currentTimeMs);
maybeTransition(
partitionState.lastVoterSet().voterNode(requestLeaderId, channel.listenerName()),
requestEpoch,
currentTimeMs
);
return buildBeginQuorumEpochResponse(Errors.NONE);
}
@ -781,8 +828,8 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
RaftResponse.Inbound responseMetadata,
long currentTimeMs
) {
int remoteNodeId = responseMetadata.sourceId();
BeginQuorumEpochResponseData response = (BeginQuorumEpochResponseData) responseMetadata.data;
int remoteNodeId = responseMetadata.source().id();
BeginQuorumEpochResponseData response = (BeginQuorumEpochResponseData) responseMetadata.data();
Errors topLevelError = Errors.forCode(response.errorCode());
if (topLevelError != Errors.NONE) {
return handleTopLevelError(topLevelError, responseMetadata);
@ -840,7 +887,7 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
RaftRequest.Inbound requestMetadata,
long currentTimeMs
) {
EndQuorumEpochRequestData request = (EndQuorumEpochRequestData) requestMetadata.data;
EndQuorumEpochRequestData request = (EndQuorumEpochRequestData) requestMetadata.data();
if (!hasValidClusterId(request.clusterId())) {
return new EndQuorumEpochResponseData().setErrorCode(Errors.INCONSISTENT_CLUSTER_ID.code());
@ -861,11 +908,15 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
if (errorOpt.isPresent()) {
return buildEndQuorumEpochResponse(errorOpt.get());
}
maybeTransition(OptionalInt.of(requestLeaderId), requestEpoch, currentTimeMs);
maybeTransition(
partitionState.lastVoterSet().voterNode(requestLeaderId, channel.listenerName()),
requestEpoch,
currentTimeMs
);
if (quorum.isFollower()) {
FollowerState state = quorum.followerStateOrThrow();
if (state.leaderId() == requestLeaderId) {
if (state.leader().id() == requestLeaderId) {
List<Integer> preferredSuccessors = partitionRequest.preferredSuccessors();
long electionBackoffMs = endEpochElectionBackoff(preferredSuccessors);
logger.debug("Overriding follower fetch timeout to {} after receiving " +
@ -894,7 +945,7 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
RaftResponse.Inbound responseMetadata,
long currentTimeMs
) {
EndQuorumEpochResponseData response = (EndQuorumEpochResponseData) responseMetadata.data;
EndQuorumEpochResponseData response = (EndQuorumEpochResponseData) responseMetadata.data();
Errors topLevelError = Errors.forCode(response.errorCode());
if (topLevelError != Errors.NONE) {
return handleTopLevelError(topLevelError, responseMetadata);
@ -917,7 +968,7 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
return handled.get();
} else if (partitionError == Errors.NONE) {
ResignedState resignedState = quorum.resignedStateOrThrow();
resignedState.acknowledgeResignation(responseMetadata.sourceId());
resignedState.acknowledgeResignation(responseMetadata.source().id());
return true;
} else {
return handleUnexpectedError(partitionError, responseMetadata);
@ -998,7 +1049,7 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
RaftRequest.Inbound requestMetadata,
long currentTimeMs
) {
FetchRequestData request = (FetchRequestData) requestMetadata.data;
FetchRequestData request = (FetchRequestData) requestMetadata.data();
if (!hasValidClusterId(request.clusterId())) {
return completedFuture(new FetchResponseData().setErrorCode(Errors.INCONSISTENT_CLUSTER_ID.code()));
@ -1147,13 +1198,13 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
RaftResponse.Inbound responseMetadata,
long currentTimeMs
) {
FetchResponseData response = (FetchResponseData) responseMetadata.data;
FetchResponseData response = (FetchResponseData) responseMetadata.data();
Errors topLevelError = Errors.forCode(response.errorCode());
if (topLevelError != Errors.NONE) {
return handleTopLevelError(topLevelError, responseMetadata);
}
if (!RaftUtil.hasValidTopicPartition(response, log.topicPartition(), log.topicId())) {
if (!hasValidTopicPartition(response, log.topicPartition(), log.topicId())) {
return false;
}
// If the ID is valid, we can set the topic name.
@ -1286,7 +1337,7 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
RaftRequest.Inbound requestMetadata,
long currentTimeMs
) {
DescribeQuorumRequestData describeQuorumRequestData = (DescribeQuorumRequestData) requestMetadata.data;
DescribeQuorumRequestData describeQuorumRequestData = (DescribeQuorumRequestData) requestMetadata.data();
if (!hasValidTopicPartition(describeQuorumRequestData, log.topicPartition())) {
return DescribeQuorumRequest.getPartitionLevelErrorResponse(
describeQuorumRequestData, Errors.UNKNOWN_TOPIC_OR_PARTITION);
@ -1325,7 +1376,7 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
RaftRequest.Inbound requestMetadata,
long currentTimeMs
) {
FetchSnapshotRequestData data = (FetchSnapshotRequestData) requestMetadata.data;
FetchSnapshotRequestData data = (FetchSnapshotRequestData) requestMetadata.data();
if (!hasValidClusterId(data.clusterId())) {
return new FetchSnapshotResponseData().setErrorCode(Errors.INCONSISTENT_CLUSTER_ID.code());
@ -1429,7 +1480,7 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
RaftResponse.Inbound responseMetadata,
long currentTimeMs
) {
FetchSnapshotResponseData data = (FetchSnapshotResponseData) responseMetadata.data;
FetchSnapshotResponseData data = (FetchSnapshotResponseData) responseMetadata.data();
Errors topLevelError = Errors.forCode(data.errorCode());
if (topLevelError != Errors.NONE) {
return handleTopLevelError(topLevelError, responseMetadata);
@ -1593,6 +1644,12 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
int epoch,
long currentTimeMs
) {
Optional<Node> leader = leaderId.isPresent() ?
partitionState
.lastVoterSet()
.voterNode(leaderId.getAsInt(), channel.listenerName()) :
Optional.empty();
if (epoch < quorum.epoch() || error == Errors.UNKNOWN_LEADER_EPOCH) {
// We have a larger epoch, so the response is no longer relevant
return Optional.of(true);
@ -1602,10 +1659,10 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
// The response indicates that the request had a stale epoch, but we need
// to validate the epoch from the response against our current state.
maybeTransition(leaderId, epoch, currentTimeMs);
maybeTransition(leader, epoch, currentTimeMs);
return Optional.of(true);
} else if (epoch == quorum.epoch()
&& leaderId.isPresent()
&& leader.isPresent()
&& !quorum.hasLeader()) {
// Since we are transitioning to Follower, we will only forward the
@ -1613,7 +1670,7 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
// the request be retried immediately (if needed) after the transition.
// This handling allows an observer to discover the leader and append
// to the log in the same Fetch request.
transitionToFollower(epoch, leaderId.getAsInt(), currentTimeMs);
transitionToFollower(epoch, leader.get(), currentTimeMs);
if (error == Errors.NONE) {
return Optional.empty();
} else {
@ -1635,24 +1692,28 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
}
private void maybeTransition(
OptionalInt leaderId,
Optional<Node> leader,
int epoch,
long currentTimeMs
) {
OptionalInt leaderId = leader.isPresent() ?
OptionalInt.of(leader.get().id()) :
OptionalInt.empty();
if (!hasConsistentLeader(epoch, leaderId)) {
throw new IllegalStateException("Received request or response with leader " + leaderId +
throw new IllegalStateException("Received request or response with leader " + leader +
" and epoch " + epoch + " which is inconsistent with current leader " +
quorum.leaderId() + " and epoch " + quorum.epoch());
} else if (epoch > quorum.epoch()) {
if (leaderId.isPresent()) {
transitionToFollower(epoch, leaderId.getAsInt(), currentTimeMs);
if (leader.isPresent()) {
transitionToFollower(epoch, leader.get(), currentTimeMs);
} else {
transitionToUnattached(epoch);
}
} else if (leaderId.isPresent() && !quorum.hasLeader()) {
} else if (leader.isPresent() && !quorum.hasLeader()) {
// The request or response indicates the leader of the current epoch,
// which is currently unknown
transitionToFollower(epoch, leaderId.getAsInt(), currentTimeMs);
transitionToFollower(epoch, leader.get(), currentTimeMs);
}
}
@ -1668,13 +1729,13 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
private boolean handleUnexpectedError(Errors error, RaftResponse.Inbound response) {
logger.error("Unexpected error {} in {} response: {}",
error, ApiKeys.forId(response.data.apiKey()), response);
error, ApiKeys.forId(response.data().apiKey()), response);
return false;
}
private void handleResponse(RaftResponse.Inbound response, long currentTimeMs) {
// The response epoch matches the local epoch, so we can handle the response
ApiKeys apiKey = ApiKeys.forId(response.data.apiKey());
ApiKeys apiKey = ApiKeys.forId(response.data().apiKey());
final boolean handledSuccessfully;
switch (apiKey) {
@ -1702,12 +1763,12 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
throw new IllegalArgumentException("Received unexpected response type: " + apiKey);
}
ConnectionState connection = requestManager.getOrCreate(response.sourceId());
if (handledSuccessfully) {
connection.onResponseReceived(response.correlationId);
} else {
connection.onResponseError(response.correlationId, currentTimeMs);
}
requestManager.onResponseResult(
response.source(),
response.correlationId(),
handledSuccessfully,
currentTimeMs
);
}
/**
@ -1749,7 +1810,7 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
}
private void handleRequest(RaftRequest.Inbound request, long currentTimeMs) {
ApiKeys apiKey = ApiKeys.forId(request.data.apiKey());
ApiKeys apiKey = ApiKeys.forId(request.data().apiKey());
final CompletableFuture<? extends ApiMessage> responseFuture;
switch (apiKey) {
@ -1803,8 +1864,7 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
handleRequest(request, currentTimeMs);
} else if (message instanceof RaftResponse.Inbound) {
RaftResponse.Inbound response = (RaftResponse.Inbound) message;
ConnectionState connection = requestManager.getOrCreate(response.sourceId());
if (connection.isResponseExpected(response.correlationId)) {
if (requestManager.isResponseExpected(response.source(), response.correlationId())) {
handleResponse(response, currentTimeMs);
} else {
logger.debug("Ignoring response {} since it is no longer needed", response);
@ -1819,25 +1879,23 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
*/
private long maybeSendRequest(
long currentTimeMs,
int destinationId,
Node destination,
Supplier<ApiMessage> requestSupplier
) {
ConnectionState connection = requestManager.getOrCreate(destinationId);
if (connection.isBackingOff(currentTimeMs)) {
long remainingBackoffMs = connection.remainingBackoffMs(currentTimeMs);
logger.debug("Connection for {} is backing off for {} ms", destinationId, remainingBackoffMs);
if (requestManager.isBackingOff(destination, currentTimeMs)) {
long remainingBackoffMs = requestManager.remainingBackoffMs(destination, currentTimeMs);
logger.debug("Connection for {} is backing off for {} ms", destination, remainingBackoffMs);
return remainingBackoffMs;
}
if (connection.isReady(currentTimeMs)) {
if (requestManager.isReady(destination, currentTimeMs)) {
int correlationId = channel.newCorrelationId();
ApiMessage request = requestSupplier.get();
RaftRequest.Outbound requestMessage = new RaftRequest.Outbound(
correlationId,
request,
destinationId,
destination,
currentTimeMs
);
@ -1850,20 +1908,19 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
response = new RaftResponse.Inbound(
correlationId,
errorResponse,
destinationId
destination
);
}
messageQueue.add(response);
});
requestManager.onRequestSent(destination, correlationId, currentTimeMs);
channel.send(requestMessage);
logger.trace("Sent outbound request: {}", requestMessage);
connection.onRequestSent(correlationId, currentTimeMs);
return Long.MAX_VALUE;
}
return connection.remainingRequestTimeMs(currentTimeMs);
return requestManager.remainingRequestTimeMs(destination, currentTimeMs);
}
private EndQuorumEpochRequestData buildEndQuorumEpochRequest(
@ -1880,12 +1937,12 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
private long maybeSendRequests(
long currentTimeMs,
Set<Integer> destinationIds,
Set<Node> destinations,
Supplier<ApiMessage> requestSupplier
) {
long minBackoffMs = Long.MAX_VALUE;
for (Integer destinationId : destinationIds) {
long backoffMs = maybeSendRequest(currentTimeMs, destinationId, requestSupplier);
for (Node destination : destinations) {
long backoffMs = maybeSendRequest(currentTimeMs, destination, requestSupplier);
if (backoffMs < minBackoffMs) {
minBackoffMs = backoffMs;
}
@ -1929,15 +1986,15 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
}
private long maybeSendAnyVoterFetch(long currentTimeMs) {
OptionalInt readyVoterIdOpt = requestManager.findReadyVoter(currentTimeMs);
if (readyVoterIdOpt.isPresent()) {
Optional<Node> readyNode = requestManager.findReadyBootstrapServer(currentTimeMs);
if (readyNode.isPresent()) {
return maybeSendRequest(
currentTimeMs,
readyVoterIdOpt.getAsInt(),
readyNode.get(),
this::buildFetchRequest
);
} else {
return requestManager.backoffBeforeAvailableVoter(currentTimeMs);
return requestManager.backoffBeforeAvailableBootstrapServer(currentTimeMs);
}
}
@ -2038,7 +2095,9 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
ResignedState state = quorum.resignedStateOrThrow();
long endQuorumBackoffMs = maybeSendRequests(
currentTimeMs,
state.unackedVoters(),
partitionState
.lastVoterSet()
.voterNodes(state.unackedVoters().stream(), channel.listenerName()),
() -> buildEndQuorumEpochRequest(state)
);
@ -2075,7 +2134,9 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
long timeUntilSend = maybeSendRequests(
currentTimeMs,
state.nonAcknowledgingVoters(),
partitionState
.lastVoterSet()
.voterNodes(state.nonAcknowledgingVoters().stream(), channel.listenerName()),
this::buildBeginQuorumEpochRequest
);
@ -2090,7 +2151,9 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
if (!state.isVoteRejected()) {
return maybeSendRequests(
currentTimeMs,
state.unrecordedVoters(),
partitionState
.lastVoterSet()
.voterNodes(state.unrecordedVoters().stream(), channel.listenerName()),
this::buildVoteRequest
);
}
@ -2163,14 +2226,16 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
// If the current leader is backing off due to some failure or if the
// request has timed out, then we attempt to send the Fetch to another
// voter in order to discover if there has been a leader change.
ConnectionState connection = requestManager.getOrCreate(state.leaderId());
if (connection.hasRequestTimedOut(currentTimeMs)) {
if (requestManager.hasRequestTimedOut(state.leader(), currentTimeMs)) {
// Once the request has timed out backoff the connection
requestManager.reset(state.leader());
backoffMs = maybeSendAnyVoterFetch(currentTimeMs);
connection.reset();
} else if (connection.isBackingOff(currentTimeMs)) {
} else if (requestManager.isBackingOff(state.leader(), currentTimeMs)) {
backoffMs = maybeSendAnyVoterFetch(currentTimeMs);
} else {
} else if (!requestManager.hasAnyInflightRequest(currentTimeMs)) {
backoffMs = maybeSendFetchOrFetchSnapshot(state, currentTimeMs);
} else {
backoffMs = requestManager.backoffBeforeAvailableBootstrapServer(currentTimeMs);
}
return Math.min(backoffMs, state.remainingFetchTimeMs(currentTimeMs));
@ -2189,7 +2254,7 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
requestSupplier = this::buildFetchRequest;
}
return maybeSendRequest(currentTimeMs, state.leaderId(), requestSupplier);
return maybeSendRequest(currentTimeMs, state.leader(), requestSupplier);
}
private long pollVoted(long currentTimeMs) {
@ -2549,8 +2614,8 @@ final public class KafkaRaftClient<T> implements RaftClient<T> {
}
}
public Optional<Node> voterNode(int id, String listener) {
return partitionState.lastVoterSet().voterNode(id, listener);
public Optional<Node> voterNode(int id, ListenerName listenerName) {
return partitionState.lastVoterSet().voterNode(id, listenerName);
}
// Visible only for test

View File

@ -16,7 +16,7 @@
*/
package org.apache.kafka.raft;
import java.net.InetSocketAddress;
import org.apache.kafka.common.network.ListenerName;
/**
* A simple network interface with few assumptions. We do not assume ordering
@ -37,10 +37,11 @@ public interface NetworkChannel extends AutoCloseable {
void send(RaftRequest.Outbound request);
/**
* Update connection information for the given id.
* The name of listener used when sending requests.
*
* @return the name of the listener
*/
void updateEndpoint(int id, InetSocketAddress address);
ListenerName listenerName();
default void close() throws InterruptedException {}
}

View File

@ -54,6 +54,13 @@ public class QuorumConfig {
"For example: <code>1@localhost:9092,2@localhost:9093,3@localhost:9094</code>";
public static final List<String> DEFAULT_QUORUM_VOTERS = Collections.emptyList();
public static final String QUORUM_BOOTSTRAP_SERVERS_CONFIG = QUORUM_PREFIX + "bootstrap.servers";
public static final String QUORUM_BOOTSTRAP_SERVERS_DOC = "List of endpoints to use for " +
"bootstrapping the cluster metadata. The endpoints are specified in comma-separated list " +
"of <code>{host}:{port}</code> entries. For example: " +
"<code>localhost:9092,localhost:9093,localhost:9094</code>.";
public static final List<String> DEFAULT_QUORUM_BOOTSTRAP_SERVERS = Collections.emptyList();
public static final String QUORUM_ELECTION_TIMEOUT_MS_CONFIG = QUORUM_PREFIX + "election.timeout.ms";
public static final String QUORUM_ELECTION_TIMEOUT_MS_DOC = "Maximum time in milliseconds to wait " +
"without being able to fetch from the leader before triggering a new election";
@ -163,7 +170,7 @@ public class QuorumConfig {
List<String> voterEntries,
boolean requireRoutableAddresses
) {
Map<Integer, InetSocketAddress> voterMap = new HashMap<>();
Map<Integer, InetSocketAddress> voterMap = new HashMap<>(voterEntries.size());
for (String voterMapEntry : voterEntries) {
String[] idAndAddress = voterMapEntry.split("@");
if (idAndAddress.length != 2) {
@ -173,7 +180,7 @@ public class QuorumConfig {
Integer voterId = parseVoterId(idAndAddress[0]);
String host = Utils.getHost(idAndAddress[1]);
if (host == null) {
if (host == null || !Utils.validHostPattern(host)) {
throw new ConfigException("Failed to parse host name from entry " + voterMapEntry
+ " for the configuration " + QUORUM_VOTERS_CONFIG
+ ". Each entry should be in the form `{id}@{host}:{port}`.");
@ -199,6 +206,41 @@ public class QuorumConfig {
return voterMap;
}
public static List<InetSocketAddress> parseBootstrapServers(List<String> bootstrapServers) {
return bootstrapServers
.stream()
.map(QuorumConfig::parseBootstrapServer)
.collect(Collectors.toList());
}
private static InetSocketAddress parseBootstrapServer(String bootstrapServer) {
String host = Utils.getHost(bootstrapServer);
if (host == null || !Utils.validHostPattern(host)) {
throw new ConfigException(
String.format(
"Failed to parse host name from {} for the configuration {}. Each " +
"entry should be in the form \"{host}:{port}\"",
bootstrapServer,
QUORUM_BOOTSTRAP_SERVERS_CONFIG
)
);
}
Integer port = Utils.getPort(bootstrapServer);
if (port == null) {
throw new ConfigException(
String.format(
"Failed to parse host port from {} for the configuration {}. Each " +
"entry should be in the form \"{host}:{port}\"",
bootstrapServer,
QUORUM_BOOTSTRAP_SERVERS_CONFIG
)
);
}
return InetSocketAddress.createUnresolved(host, port);
}
public static List<Node> quorumVoterStringsToNodes(List<String> voters) {
return voterConnectionsToNodes(parseVoterConnections(voters));
}
@ -231,4 +273,26 @@ public class QuorumConfig {
return "non-empty list";
}
}
public static class ControllerQuorumBootstrapServersValidator implements ConfigDef.Validator {
@Override
public void ensureValid(String name, Object value) {
if (value == null) {
throw new ConfigException(name, null);
}
@SuppressWarnings("unchecked")
List<String> entries = (List<String>) value;
// Attempt to parse the connect strings
for (String entry : entries) {
parseBootstrapServer(entry);
}
}
@Override
public String toString() {
return "non-empty list";
}
}
}

View File

@ -25,7 +25,9 @@ import java.util.OptionalInt;
import java.util.Random;
import java.util.function.Supplier;
import org.apache.kafka.common.Node;
import org.apache.kafka.common.Uuid;
import org.apache.kafka.common.network.ListenerName;
import org.apache.kafka.common.utils.LogContext;
import org.apache.kafka.common.utils.Time;
import org.apache.kafka.raft.internals.BatchAccumulator;
@ -81,6 +83,7 @@ public class QuorumState {
private final Time time;
private final Logger log;
private final QuorumStateStore store;
private final ListenerName listenerName;
private final Supplier<VoterSet> latestVoterSet;
private final Supplier<Short> latestKraftVersion;
private final Random random;
@ -93,6 +96,7 @@ public class QuorumState {
public QuorumState(
OptionalInt localId,
Uuid localDirectoryId,
ListenerName listenerName,
Supplier<VoterSet> latestVoterSet,
Supplier<Short> latestKraftVersion,
int electionTimeoutMs,
@ -104,6 +108,7 @@ public class QuorumState {
) {
this.localId = localId;
this.localDirectoryId = localDirectoryId;
this.listenerName = listenerName;
this.latestVoterSet = latestVoterSet;
this.latestKraftVersion = latestKraftVersion;
this.electionTimeoutMs = electionTimeoutMs;
@ -115,16 +120,21 @@ public class QuorumState {
this.logContext = logContext;
}
public void initialize(OffsetAndEpoch logEndOffsetAndEpoch) throws IllegalStateException {
// We initialize in whatever state we were in on shutdown. If we were a leader
// or candidate, probably an election was held, but we will find out about it
// when we send Vote or BeginEpoch requests.
private ElectionState readElectionState() {
ElectionState election;
election = store
.readElectionState()
.orElseGet(() -> ElectionState.withUnknownLeader(0, latestVoterSet.get().voterIds()));
return election;
}
public void initialize(OffsetAndEpoch logEndOffsetAndEpoch) throws IllegalStateException {
// We initialize in whatever state we were in on shutdown. If we were a leader
// or candidate, probably an election was held, but we will find out about it
// when we send Vote or BeginEpoch requests.
ElectionState election = readElectionState();
final EpochState initialState;
if (election.hasVoted() && !localId.isPresent()) {
throw new IllegalStateException(
@ -191,10 +201,26 @@ public class QuorumState {
logContext
);
} else if (election.hasLeader()) {
/* KAFKA-16529 is going to change this so that the leader is not required to be in the set
* of voters. In other words, don't throw an IllegalStateException if the leader is not in
* the set of voters.
*/
Node leader = latestVoterSet
.get()
.voterNode(election.leaderId(), listenerName)
.orElseThrow(() ->
new IllegalStateException(
String.format(
"Leader %s must be in the voter set %s",
election.leaderId(),
latestVoterSet.get()
)
)
);
initialState = new FollowerState(
time,
election.epoch(),
election.leaderId(),
leader,
latestVoterSet.get().voterIds(),
Optional.empty(),
fetchTimeoutMs,
@ -400,28 +426,24 @@ public class QuorumState {
/**
* Become a follower of an elected leader so that we can begin fetching.
*/
public void transitionToFollower(
int epoch,
int leaderId
) {
public void transitionToFollower(int epoch, Node leader) {
int currentEpoch = state.epoch();
if (localId.isPresent() && leaderId == localId.getAsInt()) {
throw new IllegalStateException("Cannot transition to Follower with leaderId=" + leaderId +
" and epoch=" + epoch + " since it matches the local broker.id=" + localId);
if (localId.isPresent() && leader.id() == localId.getAsInt()) {
throw new IllegalStateException("Cannot transition to Follower with leader " + leader +
" and epoch " + epoch + " since it matches the local broker.id " + localId);
} else if (epoch < currentEpoch) {
throw new IllegalStateException("Cannot transition to Follower with leaderId=" + leaderId +
" and epoch=" + epoch + " since the current epoch " + currentEpoch + " is larger");
} else if (epoch == currentEpoch
&& (isFollower() || isLeader())) {
throw new IllegalStateException("Cannot transition to Follower with leaderId=" + leaderId +
" and epoch=" + epoch + " from state " + state);
throw new IllegalStateException("Cannot transition to Follower with leader " + leader +
" and epoch " + epoch + " since the current epoch " + currentEpoch + " is larger");
} else if (epoch == currentEpoch && (isFollower() || isLeader())) {
throw new IllegalStateException("Cannot transition to Follower with leader " + leader +
" and epoch " + epoch + " from state " + state);
}
durableTransitionTo(
new FollowerState(
time,
epoch,
leaderId,
leader,
latestVoterSet.get().voterIds(),
state.highWatermark(),
fetchTimeoutMs,

View File

@ -17,13 +17,14 @@
package org.apache.kafka.raft;
import org.apache.kafka.common.protocol.ApiMessage;
import org.apache.kafka.common.Node;
import java.util.concurrent.CompletableFuture;
public abstract class RaftRequest implements RaftMessage {
protected final int correlationId;
protected final ApiMessage data;
protected final long createdTimeMs;
private final int correlationId;
private final ApiMessage data;
private final long createdTimeMs;
public RaftRequest(int correlationId, ApiMessage data, long createdTimeMs) {
this.correlationId = correlationId;
@ -45,7 +46,7 @@ public abstract class RaftRequest implements RaftMessage {
return createdTimeMs;
}
public static class Inbound extends RaftRequest {
public final static class Inbound extends RaftRequest {
public final CompletableFuture<RaftResponse.Outbound> completion = new CompletableFuture<>();
public Inbound(int correlationId, ApiMessage data, long createdTimeMs) {
@ -54,35 +55,37 @@ public abstract class RaftRequest implements RaftMessage {
@Override
public String toString() {
return "InboundRequest(" +
"correlationId=" + correlationId +
", data=" + data +
", createdTimeMs=" + createdTimeMs +
')';
return String.format(
"InboundRequest(correlationId=%d, data=%s, createdTimeMs=%d)",
correlationId(),
data(),
createdTimeMs()
);
}
}
public static class Outbound extends RaftRequest {
private final int destinationId;
public final static class Outbound extends RaftRequest {
private final Node destination;
public final CompletableFuture<RaftResponse.Inbound> completion = new CompletableFuture<>();
public Outbound(int correlationId, ApiMessage data, int destinationId, long createdTimeMs) {
public Outbound(int correlationId, ApiMessage data, Node destination, long createdTimeMs) {
super(correlationId, data, createdTimeMs);
this.destinationId = destinationId;
this.destination = destination;
}
public int destinationId() {
return destinationId;
public Node destination() {
return destination;
}
@Override
public String toString() {
return "OutboundRequest(" +
"correlationId=" + correlationId +
", data=" + data +
", createdTimeMs=" + createdTimeMs +
", destinationId=" + destinationId +
')';
return String.format(
"OutboundRequest(correlationId=%d, data=%s, createdTimeMs=%d, destination=%s)",
correlationId(),
data(),
createdTimeMs(),
destination
);
}
}
}

View File

@ -16,11 +16,12 @@
*/
package org.apache.kafka.raft;
import org.apache.kafka.common.Node;
import org.apache.kafka.common.protocol.ApiMessage;
public abstract class RaftResponse implements RaftMessage {
protected final int correlationId;
protected final ApiMessage data;
private final int correlationId;
private final ApiMessage data;
protected RaftResponse(int correlationId, ApiMessage data) {
this.correlationId = correlationId;
@ -37,39 +38,41 @@ public abstract class RaftResponse implements RaftMessage {
return data;
}
public static class Inbound extends RaftResponse {
private final int sourceId;
public final static class Inbound extends RaftResponse {
private final Node source;
public Inbound(int correlationId, ApiMessage data, int sourceId) {
public Inbound(int correlationId, ApiMessage data, Node source) {
super(correlationId, data);
this.sourceId = sourceId;
this.source = source;
}
public int sourceId() {
return sourceId;
public Node source() {
return source;
}
@Override
public String toString() {
return "InboundResponse(" +
"correlationId=" + correlationId +
", data=" + data +
", sourceId=" + sourceId +
')';
return String.format(
"InboundResponse(correlationId=%d, data=%s, source=%s)",
correlationId(),
data(),
source
);
}
}
public static class Outbound extends RaftResponse {
public final static class Outbound extends RaftResponse {
public Outbound(int requestId, ApiMessage data) {
super(requestId, data);
}
@Override
public String toString() {
return "OutboundResponse(" +
"correlationId=" + correlationId +
", data=" + data +
')';
return String.format(
"OutboundResponse(correlationId=%d, data=%s)",
correlationId(),
data()
);
}
}
}

View File

@ -25,6 +25,7 @@ import org.apache.kafka.common.message.EndQuorumEpochRequestData;
import org.apache.kafka.common.message.EndQuorumEpochResponseData;
import org.apache.kafka.common.message.FetchRequestData;
import org.apache.kafka.common.message.FetchResponseData;
import org.apache.kafka.common.message.FetchSnapshotResponseData;
import org.apache.kafka.common.message.VoteRequestData;
import org.apache.kafka.common.message.VoteResponseData;
import org.apache.kafka.common.protocol.ApiKeys;
@ -48,6 +49,8 @@ public class RaftUtil {
return new EndQuorumEpochResponseData().setErrorCode(error.code());
case FETCH:
return new FetchResponseData().setErrorCode(error.code());
case FETCH_SNAPSHOT:
return new FetchSnapshotResponseData().setErrorCode(error.code());
default:
throw new IllegalArgumentException("Received response for unexpected request type: " + apiKey);
}

View File

@ -17,96 +17,288 @@
package org.apache.kafka.raft;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Iterator;
import java.util.Map;
import java.util.OptionalInt;
import java.util.Optional;
import java.util.OptionalLong;
import java.util.Random;
import java.util.Set;
import org.apache.kafka.common.Node;
/**
* The request manager keeps tracks of the connection with remote replicas.
*
* When sending a request update this type by calling {@code onRequestSent(Node, long, long)}. When
* the RPC returns a response, update this manager with {@code onResponseResult(Node, long, boolean, long)}.
*
* Connections start in the ready state ({@code isReady(Node, long)} returns true).
*
* When a request times out or completes successfully the collection will transition back to the
* ready state.
*
* When a request completes with an error it still transition to the backoff state until
* {@code retryBackoffMs}.
*/
public class RequestManager {
private final Map<Integer, ConnectionState> connections = new HashMap<>();
private final List<Integer> voters = new ArrayList<>();
private final Map<String, ConnectionState> connections = new HashMap<>();
private final ArrayList<Node> bootstrapServers;
private final int retryBackoffMs;
private final int requestTimeoutMs;
private final Random random;
public RequestManager(Set<Integer> voterIds,
public RequestManager(
Collection<Node> bootstrapServers,
int retryBackoffMs,
int requestTimeoutMs,
Random random) {
Random random
) {
this.bootstrapServers = new ArrayList<>(bootstrapServers);
this.retryBackoffMs = retryBackoffMs;
this.requestTimeoutMs = requestTimeoutMs;
this.voters.addAll(voterIds);
this.random = random;
for (Integer voterId: voterIds) {
ConnectionState connection = new ConnectionState(voterId);
connections.put(voterId, connection);
}
}
public ConnectionState getOrCreate(int id) {
return connections.computeIfAbsent(id, key -> new ConnectionState(id));
}
/**
* Returns true if there any connection with pending requests.
*
* This is useful for satisfying the invariant that there is only one pending Fetch request.
* If there are more than one pending fetch request, it is possible for the follower to write
* the same offset twice.
*
* @param currentTimeMs the current time
* @return true if the request manager is tracking at least one request
*/
public boolean hasAnyInflightRequest(long currentTimeMs) {
boolean result = false;
public OptionalInt findReadyVoter(long currentTimeMs) {
int startIndex = random.nextInt(voters.size());
OptionalInt res = OptionalInt.empty();
for (int i = 0; i < voters.size(); i++) {
int index = (startIndex + i) % voters.size();
Integer voterId = voters.get(index);
ConnectionState connection = connections.get(voterId);
boolean isReady = connection.isReady(currentTimeMs);
if (isReady) {
res = OptionalInt.of(voterId);
} else if (connection.inFlightCorrelationId.isPresent()) {
res = OptionalInt.empty();
Iterator<ConnectionState> iterator = connections.values().iterator();
while (iterator.hasNext()) {
ConnectionState connection = iterator.next();
if (connection.hasRequestTimedOut(currentTimeMs)) {
// Mark the node as ready after request timeout
iterator.remove();
} else if (connection.isBackoffComplete(currentTimeMs)) {
// Mark the node as ready after completed backoff
iterator.remove();
} else if (connection.hasInflightRequest(currentTimeMs)) {
// If there is at least one inflight request, it is enough
// to stop checking the rest of the connections
result = true;
break;
}
}
return res;
return result;
}
public long backoffBeforeAvailableVoter(long currentTimeMs) {
long minBackoffMs = Long.MAX_VALUE;
for (Integer voterId : voters) {
ConnectionState connection = connections.get(voterId);
if (connection.isReady(currentTimeMs)) {
return 0L;
/**
* Returns a random bootstrap node that is ready to receive a request.
*
* This method doesn't return a node if there is at least one request pending. In general this
* method is used to send Fetch requests. Fetch requests have the invariant that there can
* only be one pending Fetch request for the LEO.
*
* @param currentTimeMs the current time
* @return a random ready bootstrap node
*/
public Optional<Node> findReadyBootstrapServer(long currentTimeMs) {
// Check that there are no infilght requests accross any of the known nodes not just
// the bootstrap servers
if (hasAnyInflightRequest(currentTimeMs)) {
return Optional.empty();
}
int startIndex = random.nextInt(bootstrapServers.size());
Optional<Node> result = Optional.empty();
for (int i = 0; i < bootstrapServers.size(); i++) {
int index = (startIndex + i) % bootstrapServers.size();
Node node = bootstrapServers.get(index);
if (isReady(node, currentTimeMs)) {
result = Optional.of(node);
break;
}
}
return result;
}
/**
* Computes the amount of time needed to wait before a bootstrap server is ready for a Fetch
* request.
*
* If there is a connection with a pending request it returns the amount of time to wait until
* the request times out.
*
* Returns zero, if there are no pending request and at least one of the boorstrap servers is
* ready.
*
* If all of the bootstrap servers are backing off and there are no pending requests, return
* the minimum amount of time until a bootstrap server becomes ready.
*
* @param currentTimeMs the current time
* @return the amount of time to wait until bootstrap server can accept a Fetch request
*/
public long backoffBeforeAvailableBootstrapServer(long currentTimeMs) {
long minBackoffMs = retryBackoffMs;
Iterator<ConnectionState> iterator = connections.values().iterator();
while (iterator.hasNext()) {
ConnectionState connection = iterator.next();
if (connection.hasRequestTimedOut(currentTimeMs)) {
// Mark the node as ready after request timeout
iterator.remove();
} else if (connection.isBackoffComplete(currentTimeMs)) {
// Mark the node as ready after completed backoff
iterator.remove();
} else if (connection.hasInflightRequest(currentTimeMs)) {
// There can be at most one inflight fetch request
return connection.remainingRequestTimeMs(currentTimeMs);
} else if (connection.isBackingOff(currentTimeMs)) {
minBackoffMs = Math.min(minBackoffMs, connection.remainingBackoffMs(currentTimeMs));
} else {
minBackoffMs = Math.min(minBackoffMs, connection.remainingRequestTimeMs(currentTimeMs));
}
}
// There are no inflight fetch requests so check if there is a ready bootstrap server
for (Node node : bootstrapServers) {
if (isReady(node, currentTimeMs)) {
return 0L;
}
}
// There are no ready bootstrap servers and inflight fetch requests, return the backoff
return minBackoffMs;
}
public boolean hasRequestTimedOut(Node node, long timeMs) {
ConnectionState state = connections.get(node.idString());
if (state == null) {
return false;
}
return state.hasRequestTimedOut(timeMs);
}
public boolean isReady(Node node, long timeMs) {
ConnectionState state = connections.get(node.idString());
if (state == null) {
return true;
}
boolean ready = state.isReady(timeMs);
if (ready) {
reset(node);
}
return ready;
}
public boolean isBackingOff(Node node, long timeMs) {
ConnectionState state = connections.get(node.idString());
if (state == null) {
return false;
}
return state.isBackingOff(timeMs);
}
public long remainingRequestTimeMs(Node node, long timeMs) {
ConnectionState state = connections.get(node.idString());
if (state == null) {
return 0;
}
return state.remainingRequestTimeMs(timeMs);
}
public long remainingBackoffMs(Node node, long timeMs) {
ConnectionState state = connections.get(node.idString());
if (state == null) {
return 0;
}
return state.remainingBackoffMs(timeMs);
}
public boolean isResponseExpected(Node node, long correlationId) {
ConnectionState state = connections.get(node.idString());
if (state == null) {
return false;
}
return state.isResponseExpected(correlationId);
}
/**
* Updates the manager when a response is received.
*
* @param node the source of the response
* @param correlationId the correlation id of the response
* @param success true if the request was successful, false otherwise
* @param timeMs the current time
*/
public void onResponseResult(Node node, long correlationId, boolean success, long timeMs) {
if (isResponseExpected(node, correlationId)) {
if (success) {
// Mark the connection as ready by reseting it
reset(node);
} else {
// Backoff the connection
connections.get(node.idString()).onResponseError(correlationId, timeMs);
}
}
}
/**
* Updates the manager when a request is sent.
*
* @param node the destination of the request
* @param correlationId the correlation id of the request
* @param timeMs the current time
*/
public void onRequestSent(Node node, long correlationId, long timeMs) {
ConnectionState state = connections.computeIfAbsent(
node.idString(),
key -> new ConnectionState(node, retryBackoffMs, requestTimeoutMs)
);
state.onRequestSent(correlationId, timeMs);
}
public void reset(Node node) {
connections.remove(node.idString());
}
public void resetAll() {
for (ConnectionState connectionState : connections.values())
connectionState.reset();
connections.clear();
}
private enum State {
AWAITING_REQUEST,
AWAITING_RESPONSE,
BACKING_OFF,
READY
}
public class ConnectionState {
private final long id;
private final static class ConnectionState {
private final Node node;
private final int retryBackoffMs;
private final int requestTimeoutMs;
private State state = State.READY;
private long lastSendTimeMs = 0L;
private long lastFailTimeMs = 0L;
private OptionalLong inFlightCorrelationId = OptionalLong.empty();
public ConnectionState(long id) {
this.id = id;
private ConnectionState(
Node node,
int retryBackoffMs,
int requestTimeoutMs
) {
this.node = node;
this.retryBackoffMs = retryBackoffMs;
this.requestTimeoutMs = requestTimeoutMs;
}
private boolean isBackoffComplete(long timeMs) {
@ -114,11 +306,7 @@ public class RequestManager {
}
boolean hasRequestTimedOut(long timeMs) {
return state == State.AWAITING_REQUEST && timeMs >= lastSendTimeMs + requestTimeoutMs;
}
public long id() {
return id;
return state == State.AWAITING_RESPONSE && timeMs >= lastSendTimeMs + requestTimeoutMs;
}
boolean isReady(long timeMs) {
@ -136,8 +324,8 @@ public class RequestManager {
}
}
boolean hasInflightRequest(long timeMs) {
if (state != State.AWAITING_REQUEST) {
private boolean hasInflightRequest(long timeMs) {
if (state != State.AWAITING_RESPONSE) {
return false;
} else {
return !hasRequestTimedOut(timeMs);
@ -174,41 +362,22 @@ public class RequestManager {
});
}
void onResponseReceived(long correlationId) {
inFlightCorrelationId.ifPresent(inflightRequestId -> {
if (inflightRequestId == correlationId) {
state = State.READY;
inFlightCorrelationId = OptionalLong.empty();
}
});
}
void onRequestSent(long correlationId, long timeMs) {
lastSendTimeMs = timeMs;
inFlightCorrelationId = OptionalLong.of(correlationId);
state = State.AWAITING_REQUEST;
}
/**
* Ignore in-flight requests or backoff and become available immediately. This is used
* when there is a state change which usually means in-flight requests are obsolete
* and we need to send new requests.
*/
void reset() {
state = State.READY;
inFlightCorrelationId = OptionalLong.empty();
state = State.AWAITING_RESPONSE;
}
@Override
public String toString() {
return "ConnectionState(" +
"id=" + id +
", state=" + state +
", lastSendTimeMs=" + lastSendTimeMs +
", lastFailTimeMs=" + lastFailTimeMs +
", inFlightCorrelationId=" + inFlightCorrelationId +
')';
return String.format(
"ConnectionState(node=%s, state=%s, lastSendTimeMs=%d, lastFailTimeMs=%d, inFlightCorrelationId=%d)",
node,
state,
lastSendTimeMs,
lastFailTimeMs,
inFlightCorrelationId
);
}
}
}

View File

@ -28,11 +28,13 @@ import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.kafka.common.Node;
import org.apache.kafka.common.Uuid;
import org.apache.kafka.common.feature.SupportedVersionRange;
import org.apache.kafka.common.message.VotersRecord;
import org.apache.kafka.common.network.ListenerName;
import org.apache.kafka.common.utils.Utils;
/**
@ -55,15 +57,41 @@ final public class VoterSet {
}
/**
* Returns the socket address for a given voter at a given listener.
* Returns the node information for all the given voter ids and listener.
*
* @param voter the id of the voter
* @param listener the name of the listener
* @return the socket address if it exists, otherwise {@code Optional.empty()}
* @param voterIds the ids of the voters
* @param listenerName the name of the listener
* @return the node information for all of the voter ids
* @throws IllegalArgumentException if there are missing endpoints
*/
public Optional<InetSocketAddress> voterAddress(int voter, String listener) {
return Optional.ofNullable(voters.get(voter))
.flatMap(voterNode -> voterNode.address(listener));
public Set<Node> voterNodes(Stream<Integer> voterIds, ListenerName listenerName) {
return voterIds
.map(voterId ->
voterNode(voterId, listenerName).orElseThrow(() ->
new IllegalArgumentException(
String.format(
"Unable to find endpoint for voter %d and listener %s in %s",
voterId,
listenerName,
voters
)
)
)
)
.collect(Collectors.toSet());
}
/**
* Returns the node information for a given voter id and listener.
*
* @param voterId the id of the voter
* @param listenerName the name of the listener
* @return the node information if it exists, otherwise {@code Optional.empty()}
*/
public Optional<Node> voterNode(int voterId, ListenerName listenerName) {
return Optional.ofNullable(voters.get(voterId))
.flatMap(voterNode -> voterNode.address(listenerName))
.map(address -> new Node(voterId, address.getHostString(), address.getPort()));
}
/**
@ -166,7 +194,7 @@ final public class VoterSet {
.stream()
.map(entry ->
new VotersRecord.Endpoint()
.setName(entry.getKey())
.setName(entry.getKey().value())
.setHost(entry.getValue().getHostString())
.setPort(entry.getValue().getPort())
)
@ -247,12 +275,12 @@ final public class VoterSet {
public final static class VoterNode {
private final ReplicaKey voterKey;
private final Map<String, InetSocketAddress> listeners;
private final Map<ListenerName, InetSocketAddress> listeners;
private final SupportedVersionRange supportedKRaftVersion;
VoterNode(
ReplicaKey voterKey,
Map<String, InetSocketAddress> listeners,
Map<ListenerName, InetSocketAddress> listeners,
SupportedVersionRange supportedKRaftVersion
) {
this.voterKey = voterKey;
@ -264,7 +292,7 @@ final public class VoterSet {
return voterKey;
}
Map<String, InetSocketAddress> listeners() {
Map<ListenerName, InetSocketAddress> listeners() {
return listeners;
}
@ -273,7 +301,7 @@ final public class VoterSet {
}
Optional<InetSocketAddress> address(String listener) {
Optional<InetSocketAddress> address(ListenerName listener) {
return Optional.ofNullable(listeners.get(listener));
}
@ -323,9 +351,12 @@ final public class VoterSet {
directoryId = Optional.empty();
}
Map<String, InetSocketAddress> listeners = new HashMap<>(voter.endpoints().size());
Map<ListenerName, InetSocketAddress> listeners = new HashMap<>(voter.endpoints().size());
for (VotersRecord.Endpoint endpoint : voter.endpoints()) {
listeners.put(endpoint.name(), InetSocketAddress.createUnresolved(endpoint.host(), endpoint.port()));
listeners.put(
ListenerName.normalised(endpoint.name()),
InetSocketAddress.createUnresolved(endpoint.host(), endpoint.port())
);
}
voterNodes.put(
@ -351,7 +382,7 @@ final public class VoterSet {
* @param voters the socket addresses by voter id
* @return the voter set
*/
public static VoterSet fromInetSocketAddresses(String listener, Map<Integer, InetSocketAddress> voters) {
public static VoterSet fromInetSocketAddresses(ListenerName listener, Map<Integer, InetSocketAddress> voters) {
Map<Integer, VoterNode> voterNodes = voters
.entrySet()
.stream()
@ -368,16 +399,4 @@ final public class VoterSet {
return new VoterSet(voterNodes);
}
public Optional<Node> voterNode(int id, String listener) {
VoterNode voterNode = voters.get(id);
if (voterNode == null) {
return Optional.empty();
}
InetSocketAddress address = voterNode.listeners.get(listener);
if (address == null) {
return Optional.empty();
}
return Optional.of(new Node(id, address.getHostString(), address.getPort()));
}
}

View File

@ -26,11 +26,10 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Map;
import java.util.Optional;
import java.util.stream.IntStream;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
@ -60,7 +59,7 @@ public class CandidateStateTest {
@Test
public void testSingleNodeQuorum() {
CandidateState state = newCandidateState(voterSetWithLocal(Collections.emptyList()));
CandidateState state = newCandidateState(voterSetWithLocal(IntStream.empty()));
assertTrue(state.isVoteGranted());
assertFalse(state.isVoteRejected());
assertEquals(Collections.emptySet(), state.unrecordedVoters());
@ -70,7 +69,7 @@ public class CandidateStateTest {
public void testTwoNodeQuorumVoteRejected() {
int otherNodeId = 1;
CandidateState state = newCandidateState(
voterSetWithLocal(Collections.singletonList(otherNodeId))
voterSetWithLocal(IntStream.of(otherNodeId))
);
assertFalse(state.isVoteGranted());
assertFalse(state.isVoteRejected());
@ -84,7 +83,7 @@ public class CandidateStateTest {
public void testTwoNodeQuorumVoteGranted() {
int otherNodeId = 1;
CandidateState state = newCandidateState(
voterSetWithLocal(Collections.singletonList(otherNodeId))
voterSetWithLocal(IntStream.of(otherNodeId))
);
assertFalse(state.isVoteGranted());
assertFalse(state.isVoteRejected());
@ -100,7 +99,7 @@ public class CandidateStateTest {
int node1 = 1;
int node2 = 2;
CandidateState state = newCandidateState(
voterSetWithLocal(Arrays.asList(node1, node2))
voterSetWithLocal(IntStream.of(node1, node2))
);
assertFalse(state.isVoteGranted());
assertFalse(state.isVoteRejected());
@ -120,7 +119,7 @@ public class CandidateStateTest {
int node1 = 1;
int node2 = 2;
CandidateState state = newCandidateState(
voterSetWithLocal(Arrays.asList(node1, node2))
voterSetWithLocal(IntStream.of(node1, node2))
);
assertFalse(state.isVoteGranted());
assertFalse(state.isVoteRejected());
@ -139,7 +138,7 @@ public class CandidateStateTest {
public void testCannotRejectVoteFromLocalId() {
int otherNodeId = 1;
CandidateState state = newCandidateState(
voterSetWithLocal(Collections.singletonList(otherNodeId))
voterSetWithLocal(IntStream.of(otherNodeId))
);
assertThrows(
IllegalArgumentException.class,
@ -151,7 +150,7 @@ public class CandidateStateTest {
public void testCannotChangeVoteGrantedToRejected() {
int otherNodeId = 1;
CandidateState state = newCandidateState(
voterSetWithLocal(Collections.singletonList(otherNodeId))
voterSetWithLocal(IntStream.of(otherNodeId))
);
assertTrue(state.recordGrantedVote(otherNodeId));
assertThrows(IllegalArgumentException.class, () -> state.recordRejectedVote(otherNodeId));
@ -162,7 +161,7 @@ public class CandidateStateTest {
public void testCannotChangeVoteRejectedToGranted() {
int otherNodeId = 1;
CandidateState state = newCandidateState(
voterSetWithLocal(Collections.singletonList(otherNodeId))
voterSetWithLocal(IntStream.of(otherNodeId))
);
assertTrue(state.recordRejectedVote(otherNodeId));
assertThrows(IllegalArgumentException.class, () -> state.recordGrantedVote(otherNodeId));
@ -172,7 +171,7 @@ public class CandidateStateTest {
@Test
public void testCannotGrantOrRejectNonVoters() {
int nonVoterId = 1;
CandidateState state = newCandidateState(voterSetWithLocal(Collections.emptyList()));
CandidateState state = newCandidateState(voterSetWithLocal(IntStream.empty()));
assertThrows(IllegalArgumentException.class, () -> state.recordGrantedVote(nonVoterId));
assertThrows(IllegalArgumentException.class, () -> state.recordRejectedVote(nonVoterId));
}
@ -181,7 +180,7 @@ public class CandidateStateTest {
public void testIdempotentGrant() {
int otherNodeId = 1;
CandidateState state = newCandidateState(
voterSetWithLocal(Collections.singletonList(otherNodeId))
voterSetWithLocal(IntStream.of(otherNodeId))
);
assertTrue(state.recordGrantedVote(otherNodeId));
assertFalse(state.recordGrantedVote(otherNodeId));
@ -191,7 +190,7 @@ public class CandidateStateTest {
public void testIdempotentReject() {
int otherNodeId = 1;
CandidateState state = newCandidateState(
voterSetWithLocal(Collections.singletonList(otherNodeId))
voterSetWithLocal(IntStream.of(otherNodeId))
);
assertTrue(state.recordRejectedVote(otherNodeId));
assertFalse(state.recordRejectedVote(otherNodeId));
@ -201,7 +200,7 @@ public class CandidateStateTest {
@ValueSource(booleans = {true, false})
public void testGrantVote(boolean isLogUpToDate) {
CandidateState state = newCandidateState(
voterSetWithLocal(Arrays.asList(1, 2, 3))
voterSetWithLocal(IntStream.of(1, 2, 3))
);
assertFalse(state.canGrantVote(ReplicaKey.of(0, Optional.empty()), isLogUpToDate));
@ -212,7 +211,7 @@ public class CandidateStateTest {
@Test
public void testElectionState() {
VoterSet voters = voterSetWithLocal(Arrays.asList(1, 2, 3));
VoterSet voters = voterSetWithLocal(IntStream.of(1, 2, 3));
CandidateState state = newCandidateState(voters);
assertEquals(
ElectionState.withVotedCandidate(
@ -228,11 +227,11 @@ public class CandidateStateTest {
public void testInvalidVoterSet() {
assertThrows(
IllegalArgumentException.class,
() -> newCandidateState(VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true)))
() -> newCandidateState(VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(1, 2, 3), true)))
);
}
private VoterSet voterSetWithLocal(Collection<Integer> remoteVoters) {
private VoterSet voterSetWithLocal(IntStream remoteVoters) {
Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(remoteVoters, true);
voterMap.put(localNode.voterKey().id(), localNode);

View File

@ -16,6 +16,7 @@
*/
package org.apache.kafka.raft;
import org.apache.kafka.common.Node;
import org.apache.kafka.common.utils.LogContext;
import org.apache.kafka.common.utils.MockTime;
import org.apache.kafka.common.utils.Utils;
@ -38,7 +39,7 @@ public class FollowerStateTest {
private final LogContext logContext = new LogContext();
private final int epoch = 5;
private final int fetchTimeoutMs = 15000;
int leaderId = 3;
private final Node leader = new Node(3, "mock-host-3", 1234);
private FollowerState newFollowerState(
Set<Integer> voters,
@ -47,7 +48,7 @@ public class FollowerStateTest {
return new FollowerState(
time,
epoch,
leaderId,
leader,
voters,
highWatermark,
fetchTimeoutMs,
@ -96,4 +97,10 @@ public class FollowerStateTest {
assertFalse(state.canGrantVote(ReplicaKey.of(3, Optional.empty()), isLogUpToDate));
}
@Test
public void testLeaderNode() {
FollowerState state = newFollowerState(Utils.mkSet(0, 1, 2), Optional.empty());
assertEquals(leader, state.leader());
}
}

View File

@ -26,7 +26,10 @@ import org.apache.kafka.common.message.BeginQuorumEpochResponseData;
import org.apache.kafka.common.message.EndQuorumEpochResponseData;
import org.apache.kafka.common.message.FetchRequestData;
import org.apache.kafka.common.message.FetchResponseData;
import org.apache.kafka.common.message.FetchSnapshotRequestData;
import org.apache.kafka.common.message.FetchSnapshotResponseData;
import org.apache.kafka.common.message.VoteResponseData;
import org.apache.kafka.common.network.ListenerName;
import org.apache.kafka.common.protocol.ApiKeys;
import org.apache.kafka.common.protocol.ApiMessage;
import org.apache.kafka.common.protocol.Errors;
@ -39,6 +42,8 @@ import org.apache.kafka.common.requests.EndQuorumEpochRequest;
import org.apache.kafka.common.requests.EndQuorumEpochResponse;
import org.apache.kafka.common.requests.FetchRequest;
import org.apache.kafka.common.requests.FetchResponse;
import org.apache.kafka.common.requests.FetchSnapshotRequest;
import org.apache.kafka.common.requests.FetchSnapshotResponse;
import org.apache.kafka.common.requests.VoteRequest;
import org.apache.kafka.common.requests.VoteResponse;
import org.apache.kafka.common.utils.MockTime;
@ -47,8 +52,8 @@ import org.apache.kafka.common.utils.annotation.ApiKeyVersionsSource;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import java.net.InetSocketAddress;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ExecutionException;
@ -80,7 +85,8 @@ public class KafkaNetworkChannelTest {
ApiKeys.VOTE,
ApiKeys.BEGIN_QUORUM_EPOCH,
ApiKeys.END_QUORUM_EPOCH,
ApiKeys.FETCH
ApiKeys.FETCH,
ApiKeys.FETCH_SNAPSHOT
);
private final int requestTimeoutMs = 30000;
@ -88,35 +94,40 @@ public class KafkaNetworkChannelTest {
private final MockClient client = new MockClient(time, new StubMetadataUpdater());
private final TopicPartition topicPartition = new TopicPartition("topic", 0);
private final Uuid topicId = Uuid.randomUuid();
private final KafkaNetworkChannel channel = new KafkaNetworkChannel(time, client, requestTimeoutMs, "test-raft");
private final KafkaNetworkChannel channel = new KafkaNetworkChannel(
time,
ListenerName.normalised("NAME"),
client,
requestTimeoutMs,
"test-raft"
);
private Node nodeWithId(boolean withId) {
int id = withId ? 2 : -2;
return new Node(id, "127.0.0.1", 9092);
}
@BeforeEach
public void setupSupportedApis() {
List<ApiVersionsResponseData.ApiVersion> supportedApis = RAFT_APIS.stream().map(
ApiVersionsResponse::toApiVersion).collect(Collectors.toList());
List<ApiVersionsResponseData.ApiVersion> supportedApis = RAFT_APIS
.stream()
.map(ApiVersionsResponse::toApiVersion)
.collect(Collectors.toList());
client.setNodeApiVersions(NodeApiVersions.create(supportedApis));
}
@Test
public void testSendToUnknownDestination() throws ExecutionException, InterruptedException {
int destinationId = 2;
assertBrokerNotAvailable(destinationId);
}
@Test
public void testSendToBlackedOutDestination() throws ExecutionException, InterruptedException {
int destinationId = 2;
Node destinationNode = new Node(destinationId, "127.0.0.1", 9092);
channel.updateEndpoint(destinationId, new InetSocketAddress(destinationNode.host(), destinationNode.port()));
client.backoff(destinationNode, 500);
assertBrokerNotAvailable(destinationId);
@ParameterizedTest
@ValueSource(booleans = {true, false})
public void testSendToBlackedOutDestination(boolean withDestinationId) throws ExecutionException, InterruptedException {
Node destination = nodeWithId(withDestinationId);
client.backoff(destination, 500);
assertBrokerNotAvailable(destination);
}
@Test
public void testWakeupClientOnSend() throws InterruptedException, ExecutionException {
int destinationId = 2;
Node destinationNode = new Node(destinationId, "127.0.0.1", 9092);
channel.updateEndpoint(destinationId, new InetSocketAddress(destinationNode.host(), destinationNode.port()));
client.enableBlockingUntilWakeup(1);
@ -132,7 +143,7 @@ public class KafkaNetworkChannelTest {
client.prepareResponseFrom(response, destinationNode, false);
ioThread.start();
RaftRequest.Outbound request = sendTestRequest(ApiKeys.FETCH, destinationId);
RaftRequest.Outbound request = sendTestRequest(ApiKeys.FETCH, destinationNode);
ioThread.join();
assertResponseCompleted(request, Errors.INVALID_REQUEST);
@ -142,12 +153,11 @@ public class KafkaNetworkChannelTest {
public void testSendAndDisconnect() throws ExecutionException, InterruptedException {
int destinationId = 2;
Node destinationNode = new Node(destinationId, "127.0.0.1", 9092);
channel.updateEndpoint(destinationId, new InetSocketAddress(destinationNode.host(), destinationNode.port()));
for (ApiKeys apiKey : RAFT_APIS) {
AbstractResponse response = buildResponse(buildTestErrorResponse(apiKey, Errors.INVALID_REQUEST));
client.prepareResponseFrom(response, destinationNode, true);
sendAndAssertErrorResponse(apiKey, destinationId, Errors.BROKER_NOT_AVAILABLE);
sendAndAssertErrorResponse(apiKey, destinationNode, Errors.BROKER_NOT_AVAILABLE);
}
}
@ -155,35 +165,33 @@ public class KafkaNetworkChannelTest {
public void testSendAndFailAuthentication() throws ExecutionException, InterruptedException {
int destinationId = 2;
Node destinationNode = new Node(destinationId, "127.0.0.1", 9092);
channel.updateEndpoint(destinationId, new InetSocketAddress(destinationNode.host(), destinationNode.port()));
for (ApiKeys apiKey : RAFT_APIS) {
client.createPendingAuthenticationError(destinationNode, 100);
sendAndAssertErrorResponse(apiKey, destinationId, Errors.NETWORK_EXCEPTION);
sendAndAssertErrorResponse(apiKey, destinationNode, Errors.NETWORK_EXCEPTION);
// reset to clear backoff time
client.reset();
}
}
private void assertBrokerNotAvailable(int destinationId) throws ExecutionException, InterruptedException {
private void assertBrokerNotAvailable(Node destination) throws ExecutionException, InterruptedException {
for (ApiKeys apiKey : RAFT_APIS) {
sendAndAssertErrorResponse(apiKey, destinationId, Errors.BROKER_NOT_AVAILABLE);
sendAndAssertErrorResponse(apiKey, destination, Errors.BROKER_NOT_AVAILABLE);
}
}
@Test
public void testSendAndReceiveOutboundRequest() throws ExecutionException, InterruptedException {
int destinationId = 2;
Node destinationNode = new Node(destinationId, "127.0.0.1", 9092);
channel.updateEndpoint(destinationId, new InetSocketAddress(destinationNode.host(), destinationNode.port()));
@ParameterizedTest
@ValueSource(booleans = {true, false})
public void testSendAndReceiveOutboundRequest(boolean withDestinationId) throws ExecutionException, InterruptedException {
Node destination = nodeWithId(withDestinationId);
for (ApiKeys apiKey : RAFT_APIS) {
Errors expectedError = Errors.INVALID_REQUEST;
AbstractResponse response = buildResponse(buildTestErrorResponse(apiKey, expectedError));
client.prepareResponseFrom(response, destinationNode);
client.prepareResponseFrom(response, destination);
System.out.println("api key " + apiKey + ", response " + response);
sendAndAssertErrorResponse(apiKey, destinationId, expectedError);
sendAndAssertErrorResponse(apiKey, destination, expectedError);
}
}
@ -191,11 +199,10 @@ public class KafkaNetworkChannelTest {
public void testUnsupportedVersionError() throws ExecutionException, InterruptedException {
int destinationId = 2;
Node destinationNode = new Node(destinationId, "127.0.0.1", 9092);
channel.updateEndpoint(destinationId, new InetSocketAddress(destinationNode.host(), destinationNode.port()));
for (ApiKeys apiKey : RAFT_APIS) {
client.prepareUnsupportedVersionResponse(request -> request.apiKey() == apiKey);
sendAndAssertErrorResponse(apiKey, destinationId, Errors.UNSUPPORTED_VERSION);
sendAndAssertErrorResponse(apiKey, destinationNode, Errors.UNSUPPORTED_VERSION);
}
}
@ -204,8 +211,7 @@ public class KafkaNetworkChannelTest {
public void testFetchRequestDowngrade(short version) {
int destinationId = 2;
Node destinationNode = new Node(destinationId, "127.0.0.1", 9092);
channel.updateEndpoint(destinationId, new InetSocketAddress(destinationNode.host(), destinationNode.port()));
sendTestRequest(ApiKeys.FETCH, destinationId);
sendTestRequest(ApiKeys.FETCH, destinationNode);
channel.pollOnce();
assertEquals(1, client.requests().size());
@ -220,27 +226,39 @@ public class KafkaNetworkChannelTest {
}
}
private RaftRequest.Outbound sendTestRequest(ApiKeys apiKey, int destinationId) {
private RaftRequest.Outbound sendTestRequest(ApiKeys apiKey, Node destination) {
int correlationId = channel.newCorrelationId();
long createdTimeMs = time.milliseconds();
ApiMessage apiRequest = buildTestRequest(apiKey);
RaftRequest.Outbound request = new RaftRequest.Outbound(correlationId, apiRequest, destinationId, createdTimeMs);
RaftRequest.Outbound request = new RaftRequest.Outbound(
correlationId,
apiRequest,
destination,
createdTimeMs
);
channel.send(request);
return request;
}
private void assertResponseCompleted(RaftRequest.Outbound request, Errors expectedError) throws ExecutionException, InterruptedException {
private void assertResponseCompleted(
RaftRequest.Outbound request,
Errors expectedError
) throws ExecutionException, InterruptedException {
assertTrue(request.completion.isDone());
RaftResponse.Inbound response = request.completion.get();
assertEquals(request.destinationId(), response.sourceId());
assertEquals(request.correlationId, response.correlationId);
assertEquals(request.data.apiKey(), response.data.apiKey());
assertEquals(expectedError, extractError(response.data));
assertEquals(request.destination(), response.source());
assertEquals(request.correlationId(), response.correlationId());
assertEquals(request.data().apiKey(), response.data().apiKey());
assertEquals(expectedError, extractError(response.data()));
}
private void sendAndAssertErrorResponse(ApiKeys apiKey, int destinationId, Errors error) throws ExecutionException, InterruptedException {
RaftRequest.Outbound request = sendTestRequest(apiKey, destinationId);
private void sendAndAssertErrorResponse(
ApiKeys apiKey,
Node destination,
Errors error
) throws ExecutionException, InterruptedException {
RaftRequest.Outbound request = sendTestRequest(apiKey, destination);
channel.pollOnce();
assertResponseCompleted(request, error);
}
@ -252,12 +270,20 @@ public class KafkaNetworkChannelTest {
switch (key) {
case BEGIN_QUORUM_EPOCH:
return BeginQuorumEpochRequest.singletonRequest(topicPartition, clusterId, leaderEpoch, leaderId);
case END_QUORUM_EPOCH:
return EndQuorumEpochRequest.singletonRequest(topicPartition, clusterId, leaderId, leaderEpoch,
Collections.singletonList(2));
return EndQuorumEpochRequest.singletonRequest(
topicPartition,
clusterId,
leaderId,
leaderEpoch,
Collections.singletonList(2)
);
case VOTE:
int lastEpoch = 4;
return VoteRequest.singletonRequest(topicPartition, clusterId, leaderEpoch, leaderId, lastEpoch, 329);
case FETCH:
FetchRequestData request = RaftUtil.singletonFetchRequest(topicPartition, topicId, fetchPartition -> {
fetchPartition
@ -267,6 +293,21 @@ public class KafkaNetworkChannelTest {
});
request.setReplicaState(new FetchRequestData.ReplicaState().setReplicaId(1));
return request;
case FETCH_SNAPSHOT:
return FetchSnapshotRequest.singleton(
clusterId,
1,
topicPartition,
snapshotPartition -> snapshotPartition
.setCurrentLeaderEpoch(5)
.setSnapshotId(new FetchSnapshotRequestData.SnapshotId()
.setEpoch(4)
.setEndOffset(323)
)
.setPosition(10)
);
default:
throw new AssertionError("Unexpected api " + key);
}
@ -282,6 +323,8 @@ public class KafkaNetworkChannelTest {
return VoteResponse.singletonResponse(error, topicPartition, Errors.NONE, 1, 5, false);
case FETCH:
return new FetchResponseData().setErrorCode(error.code());
case FETCH_SNAPSHOT:
return new FetchSnapshotResponseData().setErrorCode(error.code());
default:
throw new AssertionError("Unexpected api " + key);
}
@ -289,28 +332,36 @@ public class KafkaNetworkChannelTest {
private Errors extractError(ApiMessage response) {
short code;
if (response instanceof BeginQuorumEpochResponseData)
if (response instanceof BeginQuorumEpochResponseData) {
code = ((BeginQuorumEpochResponseData) response).errorCode();
else if (response instanceof EndQuorumEpochResponseData)
} else if (response instanceof EndQuorumEpochResponseData) {
code = ((EndQuorumEpochResponseData) response).errorCode();
else if (response instanceof FetchResponseData)
} else if (response instanceof FetchResponseData) {
code = ((FetchResponseData) response).errorCode();
else if (response instanceof VoteResponseData)
} else if (response instanceof VoteResponseData) {
code = ((VoteResponseData) response).errorCode();
else
} else if (response instanceof FetchSnapshotResponseData) {
code = ((FetchSnapshotResponseData) response).errorCode();
} else {
throw new IllegalArgumentException("Unexpected type for responseData: " + response);
}
return Errors.forCode(code);
}
private AbstractResponse buildResponse(ApiMessage responseData) {
if (responseData instanceof VoteResponseData)
if (responseData instanceof VoteResponseData) {
return new VoteResponse((VoteResponseData) responseData);
if (responseData instanceof BeginQuorumEpochResponseData)
} else if (responseData instanceof BeginQuorumEpochResponseData) {
return new BeginQuorumEpochResponse((BeginQuorumEpochResponseData) responseData);
if (responseData instanceof EndQuorumEpochResponseData)
} else if (responseData instanceof EndQuorumEpochResponseData) {
return new EndQuorumEpochResponse((EndQuorumEpochResponseData) responseData);
if (responseData instanceof FetchResponseData)
} else if (responseData instanceof FetchResponseData) {
return new FetchResponse((FetchResponseData) responseData);
} else if (responseData instanceof FetchSnapshotResponseData) {
return new FetchSnapshotResponse((FetchSnapshotResponseData) responseData);
} else {
throw new IllegalArgumentException("Unexpected type for responseData: " + responseData);
}
}
}

View File

@ -153,8 +153,8 @@ final public class KafkaRaftClientSnapshotTest {
RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest();
context.assertFetchRequestData(fetchRequest, epoch, localLogEndOffset, snapshotId.epoch());
context.deliverResponse(
fetchRequest.correlationId,
fetchRequest.destinationId(),
fetchRequest.correlationId(),
fetchRequest.destination(),
context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, localLogEndOffset, Errors.NONE)
);
@ -195,8 +195,8 @@ final public class KafkaRaftClientSnapshotTest {
RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest();
context.assertFetchRequestData(fetchRequest, epoch, localLogEndOffset, snapshotId.epoch());
context.deliverResponse(
fetchRequest.correlationId,
fetchRequest.destinationId(),
fetchRequest.correlationId(),
fetchRequest.destination(),
context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, localLogEndOffset, Errors.NONE)
);
@ -1032,8 +1032,8 @@ final public class KafkaRaftClientSnapshotTest {
context.assertFetchRequestData(fetchRequest, epoch, 0L, 0);
context.deliverResponse(
fetchRequest.correlationId,
fetchRequest.destinationId(),
fetchRequest.correlationId(),
fetchRequest.destination(),
snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, invalidEpoch, 200L)
);
@ -1049,8 +1049,8 @@ final public class KafkaRaftClientSnapshotTest {
context.assertFetchRequestData(fetchRequest, epoch, 0L, 0);
context.deliverResponse(
fetchRequest.correlationId,
fetchRequest.destinationId(),
fetchRequest.correlationId(),
fetchRequest.destination(),
snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, invalidEndOffset, 200L)
);
@ -1091,8 +1091,8 @@ final public class KafkaRaftClientSnapshotTest {
context.assertFetchRequestData(fetchRequest, epoch, 0L, 0);
context.deliverResponse(
fetchRequest.correlationId,
fetchRequest.destinationId(),
fetchRequest.correlationId(),
fetchRequest.destination(),
snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, snapshotId, 200L)
);
@ -1116,8 +1116,8 @@ final public class KafkaRaftClientSnapshotTest {
}
context.deliverResponse(
snapshotRequest.correlationId,
snapshotRequest.destinationId(),
snapshotRequest.correlationId(),
snapshotRequest.destination(),
fetchSnapshotResponse(
context.metadataPartition,
epoch,
@ -1162,8 +1162,8 @@ final public class KafkaRaftClientSnapshotTest {
context.assertFetchRequestData(fetchRequest, epoch, 0L, 0);
context.deliverResponse(
fetchRequest.correlationId,
fetchRequest.destinationId(),
fetchRequest.correlationId(),
fetchRequest.destination(),
snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, snapshotId, 200L)
);
@ -1190,8 +1190,8 @@ final public class KafkaRaftClientSnapshotTest {
sendingBuffer.limit(sendingBuffer.limit() / 2);
context.deliverResponse(
snapshotRequest.correlationId,
snapshotRequest.destinationId(),
snapshotRequest.correlationId(),
snapshotRequest.destination(),
fetchSnapshotResponse(
context.metadataPartition,
epoch,
@ -1219,8 +1219,8 @@ final public class KafkaRaftClientSnapshotTest {
sendingBuffer.position(Math.toIntExact(request.position()));
context.deliverResponse(
snapshotRequest.correlationId,
snapshotRequest.destinationId(),
snapshotRequest.correlationId(),
snapshotRequest.destination(),
fetchSnapshotResponse(
context.metadataPartition,
epoch,
@ -1265,8 +1265,8 @@ final public class KafkaRaftClientSnapshotTest {
context.assertFetchRequestData(fetchRequest, epoch, 0L, 0);
context.deliverResponse(
fetchRequest.correlationId,
fetchRequest.destinationId(),
fetchRequest.correlationId(),
fetchRequest.destination(),
snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, snapshotId, 200L)
);
@ -1284,8 +1284,8 @@ final public class KafkaRaftClientSnapshotTest {
// Reply with a snapshot not found error
context.deliverResponse(
snapshotRequest.correlationId,
snapshotRequest.destinationId(),
snapshotRequest.correlationId(),
snapshotRequest.destination(),
FetchSnapshotResponse.singleton(
context.metadataPartition,
responsePartitionSnapshot -> {
@ -1323,8 +1323,8 @@ final public class KafkaRaftClientSnapshotTest {
context.assertFetchRequestData(fetchRequest, epoch, 0L, 0);
context.deliverResponse(
fetchRequest.correlationId,
fetchRequest.destinationId(),
fetchRequest.correlationId(),
fetchRequest.destination(),
snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, firstLeaderId, snapshotId, 200L)
);
@ -1342,8 +1342,8 @@ final public class KafkaRaftClientSnapshotTest {
// Reply with new leader response
context.deliverResponse(
snapshotRequest.correlationId,
snapshotRequest.destinationId(),
snapshotRequest.correlationId(),
snapshotRequest.destination(),
FetchSnapshotResponse.singleton(
context.metadataPartition,
responsePartitionSnapshot -> {
@ -1380,8 +1380,8 @@ final public class KafkaRaftClientSnapshotTest {
context.assertFetchRequestData(fetchRequest, epoch, 0L, 0);
context.deliverResponse(
fetchRequest.correlationId,
fetchRequest.destinationId(),
fetchRequest.correlationId(),
fetchRequest.destination(),
snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, snapshotId, 200L)
);
@ -1399,8 +1399,8 @@ final public class KafkaRaftClientSnapshotTest {
// Reply with new leader epoch
context.deliverResponse(
snapshotRequest.correlationId,
snapshotRequest.destinationId(),
snapshotRequest.correlationId(),
snapshotRequest.destination(),
FetchSnapshotResponse.singleton(
context.metadataPartition,
responsePartitionSnapshot -> {
@ -1437,8 +1437,8 @@ final public class KafkaRaftClientSnapshotTest {
context.assertFetchRequestData(fetchRequest, epoch, 0L, 0);
context.deliverResponse(
fetchRequest.correlationId,
fetchRequest.destinationId(),
fetchRequest.correlationId(),
fetchRequest.destination(),
snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, snapshotId, 200L)
);
@ -1456,8 +1456,8 @@ final public class KafkaRaftClientSnapshotTest {
// Reply with unknown leader epoch
context.deliverResponse(
snapshotRequest.correlationId,
snapshotRequest.destinationId(),
snapshotRequest.correlationId(),
snapshotRequest.destination(),
FetchSnapshotResponse.singleton(
context.metadataPartition,
responsePartitionSnapshot -> {
@ -1504,8 +1504,8 @@ final public class KafkaRaftClientSnapshotTest {
context.assertFetchRequestData(fetchRequest, epoch, 0L, 0);
context.deliverResponse(
fetchRequest.correlationId,
fetchRequest.destinationId(),
fetchRequest.correlationId(),
fetchRequest.destination(),
snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, snapshotId, 200L)
);
@ -1523,8 +1523,8 @@ final public class KafkaRaftClientSnapshotTest {
// Reply with an invalid snapshot id endOffset
context.deliverResponse(
snapshotRequest.correlationId,
snapshotRequest.destinationId(),
snapshotRequest.correlationId(),
snapshotRequest.destination(),
FetchSnapshotResponse.singleton(
context.metadataPartition,
responsePartitionSnapshot -> {
@ -1550,8 +1550,8 @@ final public class KafkaRaftClientSnapshotTest {
context.assertFetchRequestData(fetchRequest, epoch, 0L, 0);
context.deliverResponse(
fetchRequest.correlationId,
fetchRequest.destinationId(),
fetchRequest.correlationId(),
fetchRequest.destination(),
snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, snapshotId, 200L)
);
@ -1570,8 +1570,8 @@ final public class KafkaRaftClientSnapshotTest {
// Reply with an invalid snapshot id epoch
context.deliverResponse(
snapshotRequest.correlationId,
snapshotRequest.destinationId(),
snapshotRequest.correlationId(),
snapshotRequest.destination(),
FetchSnapshotResponse.singleton(
context.metadataPartition,
responsePartitionSnapshot -> {
@ -1614,8 +1614,8 @@ final public class KafkaRaftClientSnapshotTest {
context.assertFetchRequestData(fetchRequest, epoch, 0L, 0);
context.deliverResponse(
fetchRequest.correlationId,
fetchRequest.destinationId(),
fetchRequest.correlationId(),
fetchRequest.destination(),
snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, snapshotId, 200L)
);
@ -1642,8 +1642,8 @@ final public class KafkaRaftClientSnapshotTest {
// Send the response late
context.deliverResponse(
snapshotRequest.correlationId,
snapshotRequest.destinationId(),
snapshotRequest.correlationId(),
snapshotRequest.destination(),
FetchSnapshotResponse.singleton(
context.metadataPartition,
responsePartitionSnapshot -> {
@ -1805,14 +1805,17 @@ final public class KafkaRaftClientSnapshotTest {
// Poll for our first fetch request
context.pollUntilRequest();
RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest();
assertTrue(voters.contains(fetchRequest.destinationId()));
assertTrue(voters.contains(fetchRequest.destination().id()));
context.assertFetchRequestData(fetchRequest, epoch, 0L, 0);
// The response does not advance the high watermark
List<String> records1 = Arrays.asList("a", "b", "c");
MemoryRecords batch1 = context.buildBatch(0L, 3, records1);
context.deliverResponse(fetchRequest.correlationId, fetchRequest.destinationId(),
context.fetchResponse(epoch, leaderId, batch1, 0L, Errors.NONE));
context.deliverResponse(
fetchRequest.correlationId(),
fetchRequest.destination(),
context.fetchResponse(epoch, leaderId, batch1, 0L, Errors.NONE)
);
context.client.poll();
// 2) The high watermark must be larger than or equal to the snapshotId's endOffset
@ -1827,13 +1830,16 @@ final public class KafkaRaftClientSnapshotTest {
// The high watermark advances to be larger than log.endOffsetForEpoch(3), to test the case 3
context.pollUntilRequest();
fetchRequest = context.assertSentFetchRequest();
assertTrue(voters.contains(fetchRequest.destinationId()));
assertTrue(voters.contains(fetchRequest.destination().id()));
context.assertFetchRequestData(fetchRequest, epoch, 3L, 3);
List<String> records2 = Arrays.asList("d", "e", "f");
MemoryRecords batch2 = context.buildBatch(3L, 4, records2);
context.deliverResponse(fetchRequest.correlationId, fetchRequest.destinationId(),
context.fetchResponse(epoch, leaderId, batch2, 6L, Errors.NONE));
context.deliverResponse(
fetchRequest.correlationId(),
fetchRequest.destination(),
context.fetchResponse(epoch, leaderId, batch2, 6L, Errors.NONE)
);
context.client.poll();
assertEquals(6L, context.client.highWatermark().getAsLong());

View File

@ -51,6 +51,7 @@ import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.Mockito;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
@ -62,6 +63,7 @@ import java.util.OptionalLong;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;
import static java.util.Collections.singletonList;
import static org.apache.kafka.raft.RaftClientTestContext.Builder.DEFAULT_ELECTION_TIMEOUT_MS;
@ -274,8 +276,12 @@ public class KafkaRaftClientTest {
assertThrows(NotLeaderException.class, () -> context.client.scheduleAppend(epoch, Arrays.asList("a", "b")));
context.pollUntilRequest();
int correlationId = context.assertSentEndQuorumEpochRequest(epoch, 1);
context.deliverResponse(correlationId, 1, context.endEpochResponse(epoch, OptionalInt.of(localId)));
RaftRequest.Outbound request = context.assertSentEndQuorumEpochRequest(epoch, 1);
context.deliverResponse(
request.correlationId(),
request.destination(),
context.endEpochResponse(epoch, OptionalInt.of(localId))
);
context.client.poll();
context.time.sleep(context.electionTimeoutMs());
@ -389,14 +395,17 @@ public class KafkaRaftClientTest {
// Respond to one of the requests so that we can verify that no additional
// request to this node is sent.
RaftRequest.Outbound endEpochOutbound = requests.get(0);
context.deliverResponse(endEpochOutbound.correlationId, endEpochOutbound.destinationId(),
context.endEpochResponse(epoch, OptionalInt.of(localId)));
context.deliverResponse(
endEpochOutbound.correlationId(),
endEpochOutbound.destination(),
context.endEpochResponse(epoch, OptionalInt.of(localId))
);
context.client.poll();
assertEquals(Collections.emptyList(), context.channel.drainSendQueue());
// Now sleep for the request timeout and verify that we get only one
// retried request from the voter that hasn't responded yet.
int nonRespondedId = requests.get(1).destinationId();
int nonRespondedId = requests.get(1).destination().id();
context.time.sleep(6000);
context.pollUntilRequest();
List<RaftRequest.Outbound> retries = context.collectEndQuorumRequests(
@ -573,7 +582,7 @@ public class KafkaRaftClientTest {
context.pollUntil(context.client.quorum()::isResigned);
context.pollUntilRequest();
int correlationId = context.assertSentEndQuorumEpochRequest(resignedEpoch, otherNodeId);
RaftRequest.Outbound request = context.assertSentEndQuorumEpochRequest(resignedEpoch, otherNodeId);
EndQuorumEpochResponseData response = EndQuorumEpochResponse.singletonResponse(
Errors.NONE,
@ -583,7 +592,7 @@ public class KafkaRaftClientTest {
localId
);
context.deliverResponse(correlationId, otherNodeId, response);
context.deliverResponse(request.correlationId(), request.destination(), response);
context.client.poll();
// We do not resend `EndQuorumRequest` once the other voter has acknowledged it.
@ -644,11 +653,14 @@ public class KafkaRaftClientTest {
context.pollUntilRequest();
RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest();
assertTrue(voters.contains(fetchRequest.destinationId()));
assertTrue(voters.contains(fetchRequest.destination().id()));
context.assertFetchRequestData(fetchRequest, 0, 0L, 0);
context.deliverResponse(fetchRequest.correlationId, fetchRequest.destinationId(),
context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH));
context.deliverResponse(
fetchRequest.correlationId(),
fetchRequest.destination(),
context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH)
);
context.client.poll();
context.assertElectedLeader(epoch, leaderId);
@ -686,8 +698,12 @@ public class KafkaRaftClientTest {
context.pollUntilRequest();
context.assertVotedCandidate(1, localId);
int correlationId = context.assertSentVoteRequest(1, 0, 0L, 1);
context.deliverResponse(correlationId, otherNodeId, context.voteResponse(true, Optional.empty(), 1));
RaftRequest.Outbound request = context.assertSentVoteRequest(1, 0, 0L, 1);
context.deliverResponse(
request.correlationId(),
request.destination(),
context.voteResponse(true, Optional.empty(), 1)
);
// Become leader after receiving the vote
context.pollUntil(() -> context.log.endOffset().offset == 1L);
@ -726,8 +742,12 @@ public class KafkaRaftClientTest {
context.pollUntilRequest();
context.assertVotedCandidate(1, localId);
int correlationId = context.assertSentVoteRequest(1, 0, 0L, 2);
context.deliverResponse(correlationId, firstNodeId, context.voteResponse(true, Optional.empty(), 1));
RaftRequest.Outbound request = context.assertSentVoteRequest(1, 0, 0L, 2);
context.deliverResponse(
request.correlationId(),
request.destination(),
context.voteResponse(true, Optional.empty(), 1)
);
// Become leader after receiving the vote
context.pollUntil(() -> context.log.endOffset().offset == 1L);
@ -1102,19 +1122,27 @@ public class KafkaRaftClientTest {
context.pollUntilRequest();
context.assertVotedCandidate(epoch, localId);
int correlationId = context.assertSentVoteRequest(epoch, 0, 0L, 1);
RaftRequest.Outbound request = context.assertSentVoteRequest(epoch, 0, 0L, 1);
context.time.sleep(context.requestTimeoutMs());
context.client.poll();
int retryCorrelationId = context.assertSentVoteRequest(epoch, 0, 0L, 1);
RaftRequest.Outbound retryRequest = context.assertSentVoteRequest(epoch, 0, 0L, 1);
// We will ignore the timed out response if it arrives late
context.deliverResponse(correlationId, otherNodeId, context.voteResponse(true, Optional.empty(), 1));
context.deliverResponse(
request.correlationId(),
request.destination(),
context.voteResponse(true, Optional.empty(), 1)
);
context.client.poll();
context.assertVotedCandidate(epoch, localId);
// Become leader after receiving the retry response
context.deliverResponse(retryCorrelationId, otherNodeId, context.voteResponse(true, Optional.empty(), 1));
context.deliverResponse(
retryRequest.correlationId(),
retryRequest.destination(),
context.voteResponse(true, Optional.empty(), 1)
);
context.client.poll();
context.assertElectedLeader(epoch, localId);
}
@ -1338,8 +1366,12 @@ public class KafkaRaftClientTest {
context.assertVotedCandidate(epoch, localId);
// Quorum size is two. If the other member rejects, then we need to schedule a revote.
int correlationId = context.assertSentVoteRequest(epoch, 0, 0L, 1);
context.deliverResponse(correlationId, otherNodeId, context.voteResponse(false, Optional.empty(), 1));
RaftRequest.Outbound request = context.assertSentVoteRequest(epoch, 0, 0L, 1);
context.deliverResponse(
request.correlationId(),
request.destination(),
context.voteResponse(false, Optional.empty(), 1)
);
context.client.poll();
@ -1434,11 +1466,14 @@ public class KafkaRaftClientTest {
context.pollUntilRequest();
RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest();
assertTrue(voters.contains(fetchRequest.destinationId()));
assertTrue(voters.contains(fetchRequest.destination().id()));
context.assertFetchRequestData(fetchRequest, 0, 0L, 0);
context.deliverResponse(fetchRequest.correlationId, fetchRequest.destinationId(),
context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH));
context.deliverResponse(
fetchRequest.correlationId(),
fetchRequest.destination(),
context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH)
);
context.client.poll();
context.assertElectedLeader(epoch, leaderId);
@ -1450,27 +1485,39 @@ public class KafkaRaftClientTest {
int leaderId = 1;
int epoch = 5;
Set<Integer> voters = Utils.mkSet(leaderId);
List<InetSocketAddress> bootstrapServers = voters
.stream()
.map(RaftClientTestContext::mockAddress)
.collect(Collectors.toList());
RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters).build();
RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters)
.withBootstrapServers(bootstrapServers)
.build();
context.pollUntilRequest();
RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest();
assertTrue(voters.contains(fetchRequest.destinationId()));
assertTrue(context.bootstrapIds.contains(fetchRequest.destination().id()));
context.assertFetchRequestData(fetchRequest, 0, 0L, 0);
context.deliverResponse(fetchRequest.correlationId, fetchRequest.destinationId(),
context.fetchResponse(-1, -1, MemoryRecords.EMPTY, -1, Errors.UNKNOWN_SERVER_ERROR));
context.deliverResponse(
fetchRequest.correlationId(),
fetchRequest.destination(),
context.fetchResponse(-1, -1, MemoryRecords.EMPTY, -1, Errors.UNKNOWN_SERVER_ERROR)
);
context.client.poll();
context.time.sleep(context.retryBackoffMs);
context.pollUntilRequest();
fetchRequest = context.assertSentFetchRequest();
assertTrue(voters.contains(fetchRequest.destinationId()));
assertTrue(context.bootstrapIds.contains(fetchRequest.destination().id()));
context.assertFetchRequestData(fetchRequest, 0, 0L, 0);
context.deliverResponse(fetchRequest.correlationId, fetchRequest.destinationId(),
context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH));
context.deliverResponse(
fetchRequest.correlationId(),
fetchRequest.destination(),
context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH)
);
context.client.poll();
context.assertElectedLeader(epoch, leaderId);
@ -1483,27 +1530,169 @@ public class KafkaRaftClientTest {
int otherNodeId = 2;
int epoch = 5;
Set<Integer> voters = Utils.mkSet(leaderId, otherNodeId);
List<InetSocketAddress> bootstrapServers = voters
.stream()
.map(RaftClientTestContext::mockAddress)
.collect(Collectors.toList());
RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters).build();
RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters)
.withBootstrapServers(bootstrapServers)
.build();
context.pollUntilRequest();
RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest();
assertTrue(voters.contains(fetchRequest.destinationId()));
assertTrue(context.bootstrapIds.contains(fetchRequest.destination().id()));
context.assertFetchRequestData(fetchRequest, 0, 0L, 0);
context.deliverResponse(fetchRequest.correlationId, fetchRequest.destinationId(),
context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH));
context.client.poll();
context.deliverResponse(
fetchRequest.correlationId(),
fetchRequest.destination(),
context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH)
);
context.client.poll();
context.assertElectedLeader(epoch, leaderId);
context.time.sleep(context.fetchTimeoutMs);
context.pollUntilRequest();
fetchRequest = context.assertSentFetchRequest();
assertTrue(voters.contains(fetchRequest.destinationId()));
assertNotEquals(leaderId, fetchRequest.destination().id());
assertTrue(context.bootstrapIds.contains(fetchRequest.destination().id()));
context.assertFetchRequestData(fetchRequest, epoch, 0L, 0);
}
@Test
public void testObserverHandleRetryFetchtToBootstrapServer() throws Exception {
// This test tries to check that KRaft is able to handle a retrying Fetch request to
// a boostrap server after a Fetch request to the leader.
int localId = 0;
int leaderId = 1;
int otherNodeId = 2;
int epoch = 5;
Set<Integer> voters = Utils.mkSet(leaderId, otherNodeId);
List<InetSocketAddress> bootstrapServers = voters
.stream()
.map(RaftClientTestContext::mockAddress)
.collect(Collectors.toList());
RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters)
.withBootstrapServers(bootstrapServers)
.build();
// Expect a fetch request to one of the bootstrap servers
context.pollUntilRequest();
RaftRequest.Outbound discoveryFetchRequest = context.assertSentFetchRequest();
assertFalse(voters.contains(discoveryFetchRequest.destination().id()));
assertTrue(context.bootstrapIds.contains(discoveryFetchRequest.destination().id()));
context.assertFetchRequestData(discoveryFetchRequest, 0, 0L, 0);
// Send a response with the leader and epoch
context.deliverResponse(
discoveryFetchRequest.correlationId(),
discoveryFetchRequest.destination(),
context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH)
);
context.client.poll();
context.assertElectedLeader(epoch, leaderId);
// Expect a fetch request to the leader
context.pollUntilRequest();
RaftRequest.Outbound toLeaderFetchRequest = context.assertSentFetchRequest();
assertEquals(leaderId, toLeaderFetchRequest.destination().id());
context.assertFetchRequestData(toLeaderFetchRequest, epoch, 0L, 0);
context.time.sleep(context.requestTimeoutMs());
// After the fetch timeout expect a request to a bootstrap server
context.pollUntilRequest();
RaftRequest.Outbound retryToBootstrapServerFetchRequest = context.assertSentFetchRequest();
assertFalse(voters.contains(retryToBootstrapServerFetchRequest.destination().id()));
assertTrue(context.bootstrapIds.contains(retryToBootstrapServerFetchRequest.destination().id()));
context.assertFetchRequestData(retryToBootstrapServerFetchRequest, epoch, 0L, 0);
// Deliver the delayed responses from the leader
Records records = context.buildBatch(0L, 3, Arrays.asList("a", "b"));
context.deliverResponse(
toLeaderFetchRequest.correlationId(),
toLeaderFetchRequest.destination(),
context.fetchResponse(epoch, leaderId, records, 0L, Errors.NONE)
);
context.client.poll();
// Deliver the same delayed responses from the bootstrap server and assume that it is the leader
records = context.buildBatch(0L, 3, Arrays.asList("a", "b"));
context.deliverResponse(
retryToBootstrapServerFetchRequest.correlationId(),
retryToBootstrapServerFetchRequest.destination(),
context.fetchResponse(epoch, leaderId, records, 0L, Errors.NONE)
);
// This poll should not fail when handling the duplicate response from the bootstrap server
context.client.poll();
}
@Test
public void testObserverHandleRetryFetchToLeader() throws Exception {
// This test tries to check that KRaft is able to handle a retrying Fetch request to
// the leader after a Fetch request to the bootstrap server.
int localId = 0;
int leaderId = 1;
int otherNodeId = 2;
int epoch = 5;
Set<Integer> voters = Utils.mkSet(leaderId, otherNodeId);
List<InetSocketAddress> bootstrapServers = voters
.stream()
.map(RaftClientTestContext::mockAddress)
.collect(Collectors.toList());
RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters)
.withBootstrapServers(bootstrapServers)
.build();
// Expect a fetch request to one of the bootstrap servers
context.pollUntilRequest();
RaftRequest.Outbound discoveryFetchRequest = context.assertSentFetchRequest();
assertFalse(voters.contains(discoveryFetchRequest.destination().id()));
assertTrue(context.bootstrapIds.contains(discoveryFetchRequest.destination().id()));
context.assertFetchRequestData(discoveryFetchRequest, 0, 0L, 0);
// Send a response with the leader and epoch
context.deliverResponse(
discoveryFetchRequest.correlationId(),
discoveryFetchRequest.destination(),
context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH)
);
context.client.poll();
context.assertElectedLeader(epoch, leaderId);
// Expect a fetch request to the leader
context.pollUntilRequest();
RaftRequest.Outbound toLeaderFetchRequest = context.assertSentFetchRequest();
assertEquals(leaderId, toLeaderFetchRequest.destination().id());
context.assertFetchRequestData(toLeaderFetchRequest, epoch, 0L, 0);
context.time.sleep(context.requestTimeoutMs());
// After the fetch timeout expect a request to a bootstrap server
context.pollUntilRequest();
RaftRequest.Outbound retryToBootstrapServerFetchRequest = context.assertSentFetchRequest();
assertFalse(voters.contains(retryToBootstrapServerFetchRequest.destination().id()));
assertTrue(context.bootstrapIds.contains(retryToBootstrapServerFetchRequest.destination().id()));
context.assertFetchRequestData(retryToBootstrapServerFetchRequest, epoch, 0L, 0);
// At this point toLeaderFetchRequest has timed out but retryToBootstrapServerFetchRequest
// is still waiting for a response.
// Confirm that no new fetch request has been sent
context.client.poll();
assertFalse(context.channel.hasSentRequests());
}
@Test
public void testInvalidFetchRequest() throws Exception {
int localId = 0;
@ -1828,7 +2017,7 @@ public class KafkaRaftClientTest {
// Wait until we have a Fetch inflight to the leader
context.pollUntilRequest();
int fetchCorrelationId = context.assertSentFetchRequest(epoch, 0L, 0);
RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(epoch, 0L, 0);
// Now await the fetch timeout and become a candidate
context.time.sleep(context.fetchTimeoutMs);
@ -1837,8 +2026,11 @@ public class KafkaRaftClientTest {
// The fetch response from the old leader returns, but it should be ignored
Records records = context.buildBatch(0L, 3, Arrays.asList("a", "b"));
context.deliverResponse(fetchCorrelationId, otherNodeId,
context.fetchResponse(epoch, otherNodeId, records, 0L, Errors.NONE));
context.deliverResponse(
fetchRequest.correlationId(),
fetchRequest.destination(),
context.fetchResponse(epoch, otherNodeId, records, 0L, Errors.NONE)
);
context.client.poll();
assertEquals(0, context.log.endOffset().offset);
@ -1862,7 +2054,7 @@ public class KafkaRaftClientTest {
// Wait until we have a Fetch inflight to the leader
context.pollUntilRequest();
int fetchCorrelationId = context.assertSentFetchRequest(epoch, 0L, 0);
RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(epoch, 0L, 0);
// Now receive a BeginEpoch from `voter3`
context.deliverRequest(context.beginEpochRequest(epoch + 1, voter3));
@ -1872,7 +2064,11 @@ public class KafkaRaftClientTest {
// The fetch response from the old leader returns, but it should be ignored
Records records = context.buildBatch(0L, 3, Arrays.asList("a", "b"));
FetchResponseData response = context.fetchResponse(epoch, voter2, records, 0L, Errors.NONE);
context.deliverResponse(fetchCorrelationId, voter2, response);
context.deliverResponse(
fetchRequest.correlationId(),
fetchRequest.destination(),
response
);
context.client.poll();
assertEquals(0, context.log.endOffset().offset);
@ -1909,10 +2105,18 @@ public class KafkaRaftClientTest {
// The vote requests now return and should be ignored
VoteResponseData voteResponse1 = context.voteResponse(false, Optional.empty(), epoch);
context.deliverResponse(voteRequests.get(0).correlationId, voter2, voteResponse1);
context.deliverResponse(
voteRequests.get(0).correlationId(),
voteRequests.get(0).destination(),
voteResponse1
);
VoteResponseData voteResponse2 = context.voteResponse(false, Optional.of(voter3), epoch);
context.deliverResponse(voteRequests.get(1).correlationId, voter3, voteResponse2);
context.deliverResponse(
voteRequests.get(1).correlationId(),
voteRequests.get(1).destination(),
voteResponse2
);
context.client.poll();
context.assertElectedLeader(epoch, voter3);
@ -1925,31 +2129,43 @@ public class KafkaRaftClientTest {
int otherNodeId = 2;
int epoch = 5;
Set<Integer> voters = Utils.mkSet(leaderId, otherNodeId);
List<InetSocketAddress> bootstrapServers = voters
.stream()
.map(RaftClientTestContext::mockAddress)
.collect(Collectors.toList());
RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters).build();
RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters)
.withBootstrapServers(bootstrapServers)
.build();
context.discoverLeaderAsObserver(leaderId, epoch);
context.pollUntilRequest();
RaftRequest.Outbound fetchRequest1 = context.assertSentFetchRequest();
assertEquals(leaderId, fetchRequest1.destinationId());
assertEquals(leaderId, fetchRequest1.destination().id());
context.assertFetchRequestData(fetchRequest1, epoch, 0L, 0);
context.deliverResponse(fetchRequest1.correlationId, fetchRequest1.destinationId(),
context.fetchResponse(epoch, -1, MemoryRecords.EMPTY, -1, Errors.BROKER_NOT_AVAILABLE));
context.deliverResponse(
fetchRequest1.correlationId(),
fetchRequest1.destination(),
context.fetchResponse(epoch, -1, MemoryRecords.EMPTY, -1, Errors.BROKER_NOT_AVAILABLE)
);
context.pollUntilRequest();
// We should retry the Fetch against the other voter since the original
// voter connection will be backing off.
RaftRequest.Outbound fetchRequest2 = context.assertSentFetchRequest();
assertNotEquals(leaderId, fetchRequest2.destinationId());
assertTrue(voters.contains(fetchRequest2.destinationId()));
assertNotEquals(leaderId, fetchRequest2.destination().id());
assertTrue(context.bootstrapIds.contains(fetchRequest2.destination().id()));
context.assertFetchRequestData(fetchRequest2, epoch, 0L, 0);
Errors error = fetchRequest2.destinationId() == leaderId ?
Errors error = fetchRequest2.destination().id() == leaderId ?
Errors.NONE : Errors.NOT_LEADER_OR_FOLLOWER;
context.deliverResponse(fetchRequest2.correlationId, fetchRequest2.destinationId(),
context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, error));
context.deliverResponse(
fetchRequest2.correlationId(),
fetchRequest2.destination(),
context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, error)
);
context.client.poll();
context.assertElectedLeader(epoch, leaderId);
@ -1962,14 +2178,20 @@ public class KafkaRaftClientTest {
int otherNodeId = 2;
int epoch = 5;
Set<Integer> voters = Utils.mkSet(leaderId, otherNodeId);
List<InetSocketAddress> bootstrapServers = voters
.stream()
.map(RaftClientTestContext::mockAddress)
.collect(Collectors.toList());
RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters).build();
RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters)
.withBootstrapServers(bootstrapServers)
.build();
context.discoverLeaderAsObserver(leaderId, epoch);
context.pollUntilRequest();
RaftRequest.Outbound fetchRequest1 = context.assertSentFetchRequest();
assertEquals(leaderId, fetchRequest1.destinationId());
assertEquals(leaderId, fetchRequest1.destination().id());
context.assertFetchRequestData(fetchRequest1, epoch, 0L, 0);
context.time.sleep(context.requestTimeoutMs());
@ -1978,12 +2200,15 @@ public class KafkaRaftClientTest {
// We should retry the Fetch against the other voter since the original
// voter connection will be backing off.
RaftRequest.Outbound fetchRequest2 = context.assertSentFetchRequest();
assertNotEquals(leaderId, fetchRequest2.destinationId());
assertTrue(voters.contains(fetchRequest2.destinationId()));
assertNotEquals(leaderId, fetchRequest2.destination().id());
assertTrue(context.bootstrapIds.contains(fetchRequest2.destination().id()));
context.assertFetchRequestData(fetchRequest2, epoch, 0L, 0);
context.deliverResponse(fetchRequest2.correlationId, fetchRequest2.destinationId(),
context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH));
context.deliverResponse(
fetchRequest2.correlationId(),
fetchRequest2.destination(),
context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH)
);
context.client.poll();
context.assertElectedLeader(epoch, leaderId);
@ -2273,10 +2498,14 @@ public class KafkaRaftClientTest {
context.pollUntilRequest();
int fetchQuorumCorrelationId = context.assertSentFetchRequest(epoch, 0L, 0);
RaftRequest.Outbound fetchQuorumRequest = context.assertSentFetchRequest(epoch, 0L, 0);
Records records = context.buildBatch(0L, 3, Arrays.asList("a", "b"));
FetchResponseData response = context.fetchResponse(epoch, otherNodeId, records, 0L, Errors.NONE);
context.deliverResponse(fetchQuorumCorrelationId, otherNodeId, response);
context.deliverResponse(
fetchQuorumRequest.correlationId(),
fetchQuorumRequest.destination(),
response
);
context.client.poll();
assertEquals(2L, context.log.endOffset().offset);
@ -2297,10 +2526,19 @@ public class KafkaRaftClientTest {
// Receive an empty fetch response
context.pollUntilRequest();
int fetchQuorumCorrelationId = context.assertSentFetchRequest(epoch, 0L, 0);
FetchResponseData fetchResponse = context.fetchResponse(epoch, otherNodeId,
MemoryRecords.EMPTY, 0L, Errors.NONE);
context.deliverResponse(fetchQuorumCorrelationId, otherNodeId, fetchResponse);
RaftRequest.Outbound fetchQuorumRequest = context.assertSentFetchRequest(epoch, 0L, 0);
FetchResponseData fetchResponse = context.fetchResponse(
epoch,
otherNodeId,
MemoryRecords.EMPTY,
0L,
Errors.NONE
);
context.deliverResponse(
fetchQuorumRequest.correlationId(),
fetchQuorumRequest.destination(),
fetchResponse
);
context.client.poll();
assertEquals(0L, context.log.endOffset().offset);
assertEquals(OptionalLong.of(0L), context.client.highWatermark());
@ -2308,20 +2546,32 @@ public class KafkaRaftClientTest {
// Receive some records in the next poll, but do not advance high watermark
context.pollUntilRequest();
Records records = context.buildBatch(0L, epoch, Arrays.asList("a", "b"));
fetchQuorumCorrelationId = context.assertSentFetchRequest(epoch, 0L, 0);
fetchResponse = context.fetchResponse(epoch, otherNodeId,
records, 0L, Errors.NONE);
context.deliverResponse(fetchQuorumCorrelationId, otherNodeId, fetchResponse);
fetchQuorumRequest = context.assertSentFetchRequest(epoch, 0L, 0);
fetchResponse = context.fetchResponse(epoch, otherNodeId, records, 0L, Errors.NONE);
context.deliverResponse(
fetchQuorumRequest.correlationId(),
fetchQuorumRequest.destination(),
fetchResponse
);
context.client.poll();
assertEquals(2L, context.log.endOffset().offset);
assertEquals(OptionalLong.of(0L), context.client.highWatermark());
// The next fetch response is empty, but should still advance the high watermark
context.pollUntilRequest();
fetchQuorumCorrelationId = context.assertSentFetchRequest(epoch, 2L, epoch);
fetchResponse = context.fetchResponse(epoch, otherNodeId,
MemoryRecords.EMPTY, 2L, Errors.NONE);
context.deliverResponse(fetchQuorumCorrelationId, otherNodeId, fetchResponse);
fetchQuorumRequest = context.assertSentFetchRequest(epoch, 2L, epoch);
fetchResponse = context.fetchResponse(
epoch,
otherNodeId,
MemoryRecords.EMPTY,
2L,
Errors.NONE
);
context.deliverResponse(
fetchQuorumRequest.correlationId(),
fetchQuorumRequest.destination(),
fetchResponse
);
context.client.poll();
assertEquals(2L, context.log.endOffset().offset);
assertEquals(OptionalLong.of(2L), context.client.highWatermark());
@ -2454,11 +2704,11 @@ public class KafkaRaftClientTest {
context.pollUntilRequest();
int correlationId = context.assertSentFetchRequest(epoch, 3L, lastEpoch);
RaftRequest.Outbound request = context.assertSentFetchRequest(epoch, 3L, lastEpoch);
FetchResponseData response = context.divergingFetchResponse(epoch, otherNodeId, 2L,
lastEpoch, 1L);
context.deliverResponse(correlationId, otherNodeId, response);
context.deliverResponse(request.correlationId(), request.destination(), response);
// Poll again to complete truncation
context.client.poll();
@ -2530,10 +2780,14 @@ public class KafkaRaftClientTest {
context.pollUntilRequest();
int correlationId = context.assertSentFetchRequest(epoch, 0, 0);
RaftRequest.Outbound request = context.assertSentFetchRequest(epoch, 0, 0);
FetchResponseData response = new FetchResponseData()
.setErrorCode(Errors.CLUSTER_AUTHORIZATION_FAILED.code());
context.deliverResponse(correlationId, otherNodeId, response);
context.deliverResponse(
request.correlationId(),
request.destination(),
response
);
assertThrows(ClusterAuthorizationException.class, context.client::poll);
}
@ -2553,11 +2807,11 @@ public class KafkaRaftClientTest {
context.expectAndGrantVotes(epoch);
context.pollUntilRequest();
int correlationId = context.assertSentBeginQuorumEpochRequest(epoch, 1);
RaftRequest.Outbound request = context.assertSentBeginQuorumEpochRequest(epoch, 1);
BeginQuorumEpochResponseData response = new BeginQuorumEpochResponseData()
.setErrorCode(Errors.CLUSTER_AUTHORIZATION_FAILED.code());
context.deliverResponse(correlationId, otherNodeId, response);
context.deliverResponse(request.correlationId(), request.destination(), response);
assertThrows(ClusterAuthorizationException.class, context.client::poll);
}
@ -2577,11 +2831,11 @@ public class KafkaRaftClientTest {
context.pollUntilRequest();
context.assertVotedCandidate(epoch, localId);
int correlationId = context.assertSentVoteRequest(epoch, 0, 0L, 1);
RaftRequest.Outbound request = context.assertSentVoteRequest(epoch, 0, 0L, 1);
VoteResponseData response = new VoteResponseData()
.setErrorCode(Errors.CLUSTER_AUTHORIZATION_FAILED.code());
context.deliverResponse(correlationId, otherNodeId, response);
context.deliverResponse(request.correlationId(), request.destination(), response);
assertThrows(ClusterAuthorizationException.class, context.client::poll);
}
@ -2597,11 +2851,11 @@ public class KafkaRaftClientTest {
context.client.shutdown(5000);
context.pollUntilRequest();
int correlationId = context.assertSentEndQuorumEpochRequest(epoch, otherNodeId);
RaftRequest.Outbound request = context.assertSentEndQuorumEpochRequest(epoch, otherNodeId);
EndQuorumEpochResponseData response = new EndQuorumEpochResponseData()
.setErrorCode(Errors.CLUSTER_AUTHORIZATION_FAILED.code());
context.deliverResponse(correlationId, otherNodeId, response);
context.deliverResponse(request.correlationId(), request.destination(), response);
assertThrows(ClusterAuthorizationException.class, context.client::poll);
}
@ -2810,14 +3064,17 @@ public class KafkaRaftClientTest {
// Poll for our first fetch request
context.pollUntilRequest();
RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest();
assertTrue(voters.contains(fetchRequest.destinationId()));
assertTrue(voters.contains(fetchRequest.destination().id()));
context.assertFetchRequestData(fetchRequest, epoch, 0L, 0);
// The response does not advance the high watermark
List<String> records1 = Arrays.asList("a", "b", "c");
MemoryRecords batch1 = context.buildBatch(0L, 3, records1);
context.deliverResponse(fetchRequest.correlationId, fetchRequest.destinationId(),
context.fetchResponse(epoch, otherNodeId, batch1, 0L, Errors.NONE));
context.deliverResponse(
fetchRequest.correlationId(),
fetchRequest.destination(),
context.fetchResponse(epoch, otherNodeId, batch1, 0L, Errors.NONE)
);
context.client.poll();
// The listener should not have seen any data
@ -2828,14 +3085,17 @@ public class KafkaRaftClientTest {
// Now look for the next fetch request
context.pollUntilRequest();
fetchRequest = context.assertSentFetchRequest();
assertTrue(voters.contains(fetchRequest.destinationId()));
assertTrue(voters.contains(fetchRequest.destination().id()));
context.assertFetchRequestData(fetchRequest, epoch, 3L, 3);
// The high watermark advances to include the first batch we fetched
List<String> records2 = Arrays.asList("d", "e", "f");
MemoryRecords batch2 = context.buildBatch(3L, 3, records2);
context.deliverResponse(fetchRequest.correlationId, fetchRequest.destinationId(),
context.fetchResponse(epoch, otherNodeId, batch2, 3L, Errors.NONE));
context.deliverResponse(
fetchRequest.correlationId(),
fetchRequest.destination(),
context.fetchResponse(epoch, otherNodeId, batch2, 3L, Errors.NONE)
);
context.client.poll();
// The listener should have seen only the data from the first batch
@ -3012,21 +3272,30 @@ public class KafkaRaftClientTest {
// This is designed for tooling/debugging use cases.
Set<Integer> voters = Utils.mkSet(1, 2);
List<InetSocketAddress> bootstrapServers = voters
.stream()
.map(RaftClientTestContext::mockAddress)
.collect(Collectors.toList());
RaftClientTestContext context = new RaftClientTestContext.Builder(OptionalInt.empty(), voters)
.withBootstrapServers(bootstrapServers)
.build();
// First fetch discovers the current leader and epoch
context.pollUntilRequest();
RaftRequest.Outbound fetchRequest1 = context.assertSentFetchRequest();
assertTrue(voters.contains(fetchRequest1.destinationId()));
assertTrue(context.bootstrapIds.contains(fetchRequest1.destination().id()));
context.assertFetchRequestData(fetchRequest1, 0, 0L, 0);
int leaderEpoch = 5;
int leaderId = 1;
context.deliverResponse(fetchRequest1.correlationId, fetchRequest1.destinationId(),
context.fetchResponse(5, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH));
context.deliverResponse(
fetchRequest1.correlationId(),
fetchRequest1.destination(),
context.fetchResponse(5, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH)
);
context.client.poll();
context.assertElectedLeader(leaderEpoch, leaderId);
@ -3034,13 +3303,16 @@ public class KafkaRaftClientTest {
context.pollUntilRequest();
RaftRequest.Outbound fetchRequest2 = context.assertSentFetchRequest();
assertEquals(leaderId, fetchRequest2.destinationId());
assertEquals(leaderId, fetchRequest2.destination().id());
context.assertFetchRequestData(fetchRequest2, leaderEpoch, 0L, 0);
List<String> records = Arrays.asList("a", "b", "c");
MemoryRecords batch1 = context.buildBatch(0L, 3, records);
context.deliverResponse(fetchRequest2.correlationId, fetchRequest2.destinationId(),
context.fetchResponse(leaderEpoch, leaderId, batch1, 0L, Errors.NONE));
context.deliverResponse(
fetchRequest2.correlationId(),
fetchRequest2.destination(),
context.fetchResponse(leaderEpoch, leaderId, batch1, 0L, Errors.NONE)
);
context.client.poll();
assertEquals(3L, context.log.endOffset().offset);
assertEquals(3, context.log.lastFetchedEpoch());

View File

@ -16,31 +16,29 @@
*/
package org.apache.kafka.raft;
import org.apache.kafka.common.network.ListenerName;
import org.apache.kafka.common.protocol.ApiKeys;
import java.net.InetSocketAddress;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
public class MockNetworkChannel implements NetworkChannel {
private final AtomicInteger correlationIdCounter;
private final Set<Integer> nodeCache;
private final List<RaftRequest.Outbound> sendQueue = new ArrayList<>();
private final Map<Integer, RaftRequest.Outbound> awaitingResponse = new HashMap<>();
private final ListenerName listenerName = ListenerName.normalised("CONTROLLER");
public MockNetworkChannel(AtomicInteger correlationIdCounter, Set<Integer> destinationIds) {
public MockNetworkChannel(AtomicInteger correlationIdCounter) {
this.correlationIdCounter = correlationIdCounter;
this.nodeCache = destinationIds;
}
public MockNetworkChannel(Set<Integer> destinationIds) {
this(new AtomicInteger(0), destinationIds);
public MockNetworkChannel() {
this(new AtomicInteger(0));
}
@Override
@ -50,16 +48,12 @@ public class MockNetworkChannel implements NetworkChannel {
@Override
public void send(RaftRequest.Outbound request) {
if (!nodeCache.contains(request.destinationId())) {
throw new IllegalArgumentException("Attempted to send to destination " +
request.destinationId() + ", but its address is not yet known");
}
sendQueue.add(request);
}
@Override
public void updateEndpoint(int id, InetSocketAddress address) {
// empty
public ListenerName listenerName() {
return listenerName;
}
public List<RaftRequest.Outbound> drainSendQueue() {
@ -72,7 +66,7 @@ public class MockNetworkChannel implements NetworkChannel {
while (iterator.hasNext()) {
RaftRequest.Outbound request = iterator.next();
if (!apiKeyFilter.isPresent() || request.data().apiKey() == apiKeyFilter.get().id) {
awaitingResponse.put(request.correlationId, request);
awaitingResponse.put(request.correlationId(), request);
requests.add(request);
iterator.remove();
}
@ -80,17 +74,15 @@ public class MockNetworkChannel implements NetworkChannel {
return requests;
}
public boolean hasSentRequests() {
return !sendQueue.isEmpty();
}
public void mockReceive(RaftResponse.Inbound response) {
RaftRequest.Outbound request = awaitingResponse.get(response.correlationId);
RaftRequest.Outbound request = awaitingResponse.get(response.correlationId());
if (request == null) {
throw new IllegalStateException("Received response for a request which is not being awaited");
}
request.completion.complete(response);
}
}

File diff suppressed because it is too large Load Diff

View File

@ -16,6 +16,7 @@
*/
package org.apache.kafka.raft;
import org.apache.kafka.common.Node;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.Uuid;
import org.apache.kafka.common.compress.Compression;
@ -79,6 +80,7 @@ import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static org.apache.kafka.raft.LeaderState.CHECK_QUORUM_TIMEOUT_FACTOR;
import static org.apache.kafka.raft.RaftUtil.hasValidTopicPartition;
@ -114,6 +116,7 @@ public final class RaftClientTestContext {
final MockTime time;
final MockListener listener;
final Set<Integer> voters;
final Set<Integer> bootstrapIds;
private final List<RaftResponse.Outbound> sentResponses = new ArrayList<>();
@ -146,6 +149,7 @@ public final class RaftClientTestContext {
private int electionTimeoutMs = DEFAULT_ELECTION_TIMEOUT_MS;
private int appendLingerMs = DEFAULT_APPEND_LINGER_MS;
private MemoryPool memoryPool = MemoryPool.NONE;
private List<InetSocketAddress> bootstrapServers = Collections.emptyList();
public Builder(int localId, Set<Integer> voters) {
this(OptionalInt.of(localId), voters);
@ -240,9 +244,14 @@ public final class RaftClientTestContext {
return this;
}
Builder withBootstrapServers(List<InetSocketAddress> bootstrapServers) {
this.bootstrapServers = bootstrapServers;
return this;
}
public RaftClientTestContext build() throws IOException {
Metrics metrics = new Metrics(time);
MockNetworkChannel channel = new MockNetworkChannel(voters);
MockNetworkChannel channel = new MockNetworkChannel();
MockListener listener = new MockListener(localId);
Map<Integer, InetSocketAddress> voterAddressMap = voters
.stream()
@ -269,6 +278,7 @@ public final class RaftClientTestContext {
new MockExpirationService(time),
FETCH_MAX_WAIT_MS,
clusterId.toString(),
bootstrapServers,
logContext,
random,
quorumConfig
@ -277,7 +287,6 @@ public final class RaftClientTestContext {
client.register(listener);
client.initialize(
voterAddressMap,
"CONTROLLER",
quorumStateStore,
metrics
);
@ -292,6 +301,11 @@ public final class RaftClientTestContext {
time,
quorumStateStore,
voters,
IntStream
.iterate(-2, id -> id - 1)
.limit(bootstrapServers.size())
.boxed()
.collect(Collectors.toSet()),
metrics,
listener
);
@ -314,6 +328,7 @@ public final class RaftClientTestContext {
MockTime time,
QuorumStateStore quorumStateStore,
Set<Integer> voters,
Set<Integer> bootstrapIds,
Metrics metrics,
MockListener listener
) {
@ -326,6 +341,7 @@ public final class RaftClientTestContext {
this.time = time;
this.quorumStateStore = quorumStateStore;
this.voters = voters;
this.bootstrapIds = bootstrapIds;
this.metrics = metrics;
this.listener = listener;
}
@ -417,7 +433,7 @@ public final class RaftClientTestContext {
for (RaftRequest.Outbound request : voteRequests) {
VoteResponseData voteResponse = voteResponse(true, Optional.empty(), epoch);
deliverResponse(request.correlationId, request.destinationId(), voteResponse);
deliverResponse(request.correlationId(), request.destination(), voteResponse);
}
client.poll();
@ -432,7 +448,7 @@ public final class RaftClientTestContext {
pollUntilRequest();
for (RaftRequest.Outbound request : collectBeginEpochRequests(epoch)) {
BeginQuorumEpochResponseData beginEpochResponse = beginEpochResponse(epoch, localIdOrThrow());
deliverResponse(request.correlationId, request.destinationId(), beginEpochResponse);
deliverResponse(request.correlationId(), request.destination(), beginEpochResponse);
}
client.poll();
}
@ -519,10 +535,10 @@ public final class RaftClientTestContext {
assertEquals(expectedResponse, response);
}
int assertSentVoteRequest(int epoch, int lastEpoch, long lastEpochOffset, int numVoteReceivers) {
RaftRequest.Outbound assertSentVoteRequest(int epoch, int lastEpoch, long lastEpochOffset, int numVoteReceivers) {
List<RaftRequest.Outbound> voteRequests = collectVoteRequests(epoch, lastEpoch, lastEpochOffset);
assertEquals(numVoteReceivers, voteRequests.size());
return voteRequests.iterator().next().correlationId();
return voteRequests.iterator().next();
}
void assertSentVoteResponse(Errors error) {
@ -590,14 +606,14 @@ public final class RaftClientTestContext {
client.handle(inboundRequest);
}
void deliverResponse(int correlationId, int sourceId, ApiMessage response) {
channel.mockReceive(new RaftResponse.Inbound(correlationId, response, sourceId));
void deliverResponse(int correlationId, Node source, ApiMessage response) {
channel.mockReceive(new RaftResponse.Inbound(correlationId, response, source));
}
int assertSentBeginQuorumEpochRequest(int epoch, int numBeginEpochRequests) {
RaftRequest.Outbound assertSentBeginQuorumEpochRequest(int epoch, int numBeginEpochRequests) {
List<RaftRequest.Outbound> requests = collectBeginEpochRequests(epoch);
assertEquals(numBeginEpochRequests, requests.size());
return requests.get(0).correlationId;
return requests.get(0);
}
private List<RaftResponse.Outbound> drainSentResponses(
@ -607,7 +623,7 @@ public final class RaftClientTestContext {
Iterator<RaftResponse.Outbound> iterator = sentResponses.iterator();
while (iterator.hasNext()) {
RaftResponse.Outbound response = iterator.next();
if (response.data.apiKey() == apiKey.id) {
if (response.data().apiKey() == apiKey.id) {
res.add(response);
iterator.remove();
}
@ -646,11 +662,14 @@ public final class RaftClientTestContext {
assertEquals(partitionError, Errors.forCode(partitionResponse.errorCode()));
}
int assertSentEndQuorumEpochRequest(int epoch, int destinationId) {
RaftRequest.Outbound assertSentEndQuorumEpochRequest(int epoch, int destinationId) {
List<RaftRequest.Outbound> endQuorumRequests = collectEndQuorumRequests(
epoch, Collections.singleton(destinationId), Optional.empty());
epoch,
Collections.singleton(destinationId),
Optional.empty()
);
assertEquals(1, endQuorumRequests.size());
return endQuorumRequests.get(0).correlationId();
return endQuorumRequests.get(0);
}
void assertSentEndQuorumEpochResponse(
@ -690,7 +709,7 @@ public final class RaftClientTestContext {
return sentRequests.get(0);
}
int assertSentFetchRequest(
RaftRequest.Outbound assertSentFetchRequest(
int epoch,
long fetchOffset,
int lastFetchedEpoch
@ -700,7 +719,7 @@ public final class RaftClientTestContext {
RaftRequest.Outbound raftRequest = sentMessages.get(0);
assertFetchRequestData(raftRequest, epoch, fetchOffset, lastFetchedEpoch);
return raftRequest.correlationId();
return raftRequest;
}
FetchResponseData.PartitionData assertSentFetchPartitionResponse() {
@ -708,7 +727,7 @@ public final class RaftClientTestContext {
assertEquals(
1, sentMessages.size(), "Found unexpected sent messages " + sentMessages);
RaftResponse.Outbound raftMessage = sentMessages.get(0);
assertEquals(ApiKeys.FETCH.id, raftMessage.data.apiKey());
assertEquals(ApiKeys.FETCH.id, raftMessage.data().apiKey());
FetchResponseData response = (FetchResponseData) raftMessage.data();
assertEquals(Errors.NONE, Errors.forCode(response.errorCode()));
@ -723,7 +742,7 @@ public final class RaftClientTestContext {
assertEquals(
1, sentMessages.size(), "Found unexpected sent messages " + sentMessages);
RaftResponse.Outbound raftMessage = sentMessages.get(0);
assertEquals(ApiKeys.FETCH.id, raftMessage.data.apiKey());
assertEquals(ApiKeys.FETCH.id, raftMessage.data().apiKey());
FetchResponseData response = (FetchResponseData) raftMessage.data();
assertEquals(topLevelError, Errors.forCode(response.errorCode()));
}
@ -811,7 +830,7 @@ public final class RaftClientTestContext {
assertEquals(preferredSuccessors, partitionRequest.preferredSuccessors());
});
collectedDestinationIdSet.add(raftMessage.destinationId());
collectedDestinationIdSet.add(raftMessage.destination().id());
endQuorumRequests.add(raftMessage);
}
}
@ -825,11 +844,18 @@ public final class RaftClientTestContext {
) throws Exception {
pollUntilRequest();
RaftRequest.Outbound fetchRequest = assertSentFetchRequest();
assertTrue(voters.contains(fetchRequest.destinationId()));
int destinationId = fetchRequest.destination().id();
assertTrue(
voters.contains(destinationId) || bootstrapIds.contains(destinationId),
String.format("id %d is not in sets %s or %s", destinationId, voters, bootstrapIds)
);
assertFetchRequestData(fetchRequest, 0, 0L, 0);
deliverResponse(fetchRequest.correlationId, fetchRequest.destinationId(),
fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.NONE));
deliverResponse(
fetchRequest.correlationId(),
fetchRequest.destination(),
fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.NONE)
);
client.poll();
assertElectedLeader(epoch, leaderId);
}
@ -850,7 +876,7 @@ public final class RaftClientTestContext {
return requests;
}
private static InetSocketAddress mockAddress(int id) {
public static InetSocketAddress mockAddress(int id) {
return new InetSocketAddress("localhost", 9990 + id);
}

View File

@ -21,6 +21,7 @@ import net.jqwik.api.ForAll;
import net.jqwik.api.Property;
import net.jqwik.api.Tag;
import net.jqwik.api.constraints.IntRange;
import org.apache.kafka.common.Node;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.Uuid;
import org.apache.kafka.common.memory.MemoryPool;
@ -45,6 +46,7 @@ import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
@ -59,7 +61,6 @@ import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
@ -189,7 +190,7 @@ public class RaftEventSimulationTest {
// they are able to elect a leader and continue making progress
cluster.killAll();
Iterator<Integer> nodeIdsIterator = cluster.nodes().iterator();
Iterator<Integer> nodeIdsIterator = cluster.nodeIds().iterator();
for (int i = 0; i < cluster.majoritySize(); i++) {
Integer nodeId = nodeIdsIterator.next();
cluster.start(nodeId);
@ -224,7 +225,7 @@ public class RaftEventSimulationTest {
);
router.filter(leaderId, new DropAllTraffic());
Set<Integer> nonPartitionedNodes = new HashSet<>(cluster.nodes());
Set<Integer> nonPartitionedNodes = new HashSet<>(cluster.nodeIds());
nonPartitionedNodes.remove(leaderId);
scheduler.runUntil(() -> cluster.allReachedHighWatermark(20, nonPartitionedNodes));
@ -252,11 +253,17 @@ public class RaftEventSimulationTest {
// Partition the nodes into two sets. Nodes are reachable within each set,
// but the two sets cannot communicate with each other. We should be able
// to make progress even if an election is needed in the larger set.
router.filter(0, new DropOutboundRequestsFrom(Utils.mkSet(2, 3, 4)));
router.filter(1, new DropOutboundRequestsFrom(Utils.mkSet(2, 3, 4)));
router.filter(2, new DropOutboundRequestsFrom(Utils.mkSet(0, 1)));
router.filter(3, new DropOutboundRequestsFrom(Utils.mkSet(0, 1)));
router.filter(4, new DropOutboundRequestsFrom(Utils.mkSet(0, 1)));
router.filter(
0,
new DropOutboundRequestsTo(cluster.endpointsFromIds(Utils.mkSet(2, 3, 4)))
);
router.filter(
1,
new DropOutboundRequestsTo(cluster.endpointsFromIds(Utils.mkSet(2, 3, 4)))
);
router.filter(2, new DropOutboundRequestsTo(cluster.endpointsFromIds(Utils.mkSet(0, 1))));
router.filter(3, new DropOutboundRequestsTo(cluster.endpointsFromIds(Utils.mkSet(0, 1))));
router.filter(4, new DropOutboundRequestsTo(cluster.endpointsFromIds(Utils.mkSet(0, 1))));
long partitionLogEndOffset = cluster.maxLogEndOffset();
scheduler.runUntil(() -> cluster.anyReachedHighWatermark(2 * partitionLogEndOffset));
@ -374,7 +381,7 @@ public class RaftEventSimulationTest {
int pollIntervalMs,
int pollJitterMs) {
int delayMs = 0;
for (int nodeId : cluster.nodes()) {
for (int nodeId : cluster.nodeIds()) {
scheduler.schedule(() -> cluster.pollIfRunning(nodeId), delayMs, pollIntervalMs, pollJitterMs);
delayMs++;
}
@ -527,25 +534,37 @@ public class RaftEventSimulationTest {
final AtomicInteger correlationIdCounter = new AtomicInteger();
final MockTime time = new MockTime();
final Uuid clusterId = Uuid.randomUuid();
final Set<Integer> voters = new HashSet<>();
final Map<Integer, Node> voters = new HashMap<>();
final Map<Integer, PersistentState> nodes = new HashMap<>();
final Map<Integer, RaftNode> running = new HashMap<>();
private Cluster(int numVoters, int numObservers, Random random) {
this.random = random;
int nodeId = 0;
for (; nodeId < numVoters; nodeId++) {
voters.add(nodeId);
for (int nodeId = 0; nodeId < numVoters; nodeId++) {
voters.put(
nodeId,
new Node(nodeId, String.format("host-node-%d", nodeId), 1234)
);
nodes.put(nodeId, new PersistentState(nodeId));
}
for (; nodeId < numVoters + numObservers; nodeId++) {
for (int nodeIdDelta = 0; nodeIdDelta < numObservers; nodeIdDelta++) {
int nodeId = numVoters + nodeIdDelta;
nodes.put(nodeId, new PersistentState(nodeId));
}
}
Set<Integer> nodes() {
Set<InetSocketAddress> endpointsFromIds(Set<Integer> nodeIds) {
return voters
.values()
.stream()
.filter(node -> nodeIds.contains(node.id()))
.map(Cluster::nodeAddress)
.collect(Collectors.toSet());
}
Set<Integer> nodeIds() {
return nodes.keySet();
}
@ -710,18 +729,19 @@ public class RaftEventSimulationTest {
nodes.put(nodeId, new PersistentState(nodeId));
}
private static InetSocketAddress nodeAddress(int id) {
return new InetSocketAddress("localhost", 9990 + id);
private static InetSocketAddress nodeAddress(Node node) {
return InetSocketAddress.createUnresolved(node.host(), node.port());
}
void start(int nodeId) {
LogContext logContext = new LogContext("[Node " + nodeId + "] ");
PersistentState persistentState = nodes.get(nodeId);
MockNetworkChannel channel = new MockNetworkChannel(correlationIdCounter, voters);
MockNetworkChannel channel = new MockNetworkChannel(correlationIdCounter);
MockMessageQueue messageQueue = new MockMessageQueue();
Map<Integer, InetSocketAddress> voterAddressMap = voters
.values()
.stream()
.collect(Collectors.toMap(Function.identity(), Cluster::nodeAddress));
.collect(Collectors.toMap(Node::id, Cluster::nodeAddress));
QuorumConfig quorumConfig = new QuorumConfig(
REQUEST_TIMEOUT_MS,
@ -750,6 +770,7 @@ public class RaftEventSimulationTest {
new MockExpirationService(time),
FETCH_MAX_WAIT_MS,
clusterId.toString(),
Collections.emptyList(),
logContext,
random,
quorumConfig
@ -808,7 +829,6 @@ public class RaftEventSimulationTest {
client.register(counter);
client.initialize(
voterAddresses,
"CONTROLLER",
store,
metrics
);
@ -847,9 +867,11 @@ public class RaftEventSimulationTest {
private static class InflightRequest {
final int sourceId;
final Node destination;
private InflightRequest(int sourceId) {
private InflightRequest(int sourceId, Node destination) {
this.sourceId = sourceId;
this.destination = destination;
}
}
@ -884,11 +906,15 @@ public class RaftEventSimulationTest {
}
}
private static class DropOutboundRequestsFrom implements NetworkFilter {
private static class DropOutboundRequestsTo implements NetworkFilter {
private final Set<InetSocketAddress> unreachable;
private final Set<Integer> unreachable;
private DropOutboundRequestsFrom(Set<Integer> unreachable) {
/**
* This network filter drops any outbound message sent to the {@code unreachable} nodes.
*
* @param unreachable the set of destination address which are not reachable
*/
private DropOutboundRequestsTo(Set<InetSocketAddress> unreachable) {
this.unreachable = unreachable;
}
@ -897,11 +923,25 @@ public class RaftEventSimulationTest {
return true;
}
/**
* Returns if the message should be sent to the destination.
*
* Returns false when outbound request messages contains a destination {@code Node} that
* matches the set of unreaable {@code InetSocketAddress}. Note that the {@code Node.id()}
* and {@code Node.rack()} are not compared.
*
* @param message the raft message
* @return true if the message should be delivered, otherwise false
*/
@Override
public boolean acceptOutbound(RaftMessage message) {
if (message instanceof RaftRequest.Outbound) {
RaftRequest.Outbound request = (RaftRequest.Outbound) message;
return !unreachable.contains(request.destinationId());
InetSocketAddress destination = InetSocketAddress.createUnresolved(
request.destination().host(),
request.destination().port()
);
return !unreachable.contains(destination);
}
return true;
}
@ -955,7 +995,7 @@ public class RaftEventSimulationTest {
public void verify() {
cluster.leaderHighWatermark().ifPresent(highWatermark -> {
long numReachedHighWatermark = cluster.nodes.entrySet().stream()
.filter(entry -> cluster.voters.contains(entry.getKey()))
.filter(entry -> cluster.voters.containsKey(entry.getKey()))
.filter(entry -> entry.getValue().log.endOffset().offset >= highWatermark)
.count();
assertTrue(
@ -1194,19 +1234,19 @@ public class RaftEventSimulationTest {
return;
int correlationId = outbound.correlationId();
int destinationId = outbound.destinationId();
Node destination = outbound.destination();
RaftRequest.Inbound inbound = new RaftRequest.Inbound(correlationId, outbound.data(),
cluster.time.milliseconds());
if (!filters.get(destinationId).acceptInbound(inbound))
if (!filters.get(destination.id()).acceptInbound(inbound))
return;
cluster.nodeIfRunning(destinationId).ifPresent(node -> {
inflight.put(correlationId, new InflightRequest(senderId));
cluster.nodeIfRunning(destination.id()).ifPresent(node -> {
inflight.put(correlationId, new InflightRequest(senderId, destination));
inbound.completion.whenComplete((response, exception) -> {
if (response != null && filters.get(destinationId).acceptOutbound(response)) {
deliver(destinationId, response);
if (response != null && filters.get(destination.id()).acceptOutbound(response)) {
deliver(response);
}
});
@ -1214,11 +1254,17 @@ public class RaftEventSimulationTest {
});
}
void deliver(int senderId, RaftResponse.Outbound outbound) {
void deliver(RaftResponse.Outbound outbound) {
int correlationId = outbound.correlationId();
RaftResponse.Inbound inbound = new RaftResponse.Inbound(correlationId, outbound.data(), senderId);
InflightRequest inflightRequest = inflight.remove(correlationId);
RaftResponse.Inbound inbound = new RaftResponse.Inbound(
correlationId,
outbound.data(),
// The source of the response is the destination of the request
inflightRequest.destination
);
if (!filters.get(inflightRequest.sourceId).acceptInbound(inbound))
return;

View File

@ -16,14 +16,20 @@
*/
package org.apache.kafka.raft;
import org.apache.kafka.common.Node;
import org.apache.kafka.common.utils.MockTime;
import org.apache.kafka.common.utils.Utils;
import org.junit.jupiter.api.Test;
import java.util.List;
import java.util.Optional;
import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
public class RequestManagerTest {
private final MockTime time = new MockTime();
@ -33,105 +39,247 @@ public class RequestManagerTest {
@Test
public void testResetAllConnections() {
Node node1 = new Node(1, "mock-host-1", 4321);
Node node2 = new Node(2, "mock-host-2", 4321);
RequestManager cache = new RequestManager(
Utils.mkSet(1, 2, 3),
makeBootstrapList(3),
retryBackoffMs,
requestTimeoutMs,
random);
random
);
// One host has an inflight request
RequestManager.ConnectionState connectionState1 = cache.getOrCreate(1);
connectionState1.onRequestSent(1, time.milliseconds());
assertFalse(connectionState1.isReady(time.milliseconds()));
cache.onRequestSent(node1, 1, time.milliseconds());
assertFalse(cache.isReady(node1, time.milliseconds()));
// Another is backing off
RequestManager.ConnectionState connectionState2 = cache.getOrCreate(2);
connectionState2.onRequestSent(2, time.milliseconds());
connectionState2.onResponseError(2, time.milliseconds());
assertFalse(connectionState2.isReady(time.milliseconds()));
cache.onRequestSent(node2, 2, time.milliseconds());
cache.onResponseResult(node2, 2, false, time.milliseconds());
assertFalse(cache.isReady(node2, time.milliseconds()));
cache.resetAll();
// Now both should be ready
assertTrue(connectionState1.isReady(time.milliseconds()));
assertTrue(connectionState2.isReady(time.milliseconds()));
assertTrue(cache.isReady(node1, time.milliseconds()));
assertTrue(cache.isReady(node2, time.milliseconds()));
}
@Test
public void testBackoffAfterFailure() {
Node node = new Node(1, "mock-host-1", 4321);
RequestManager cache = new RequestManager(
Utils.mkSet(1, 2, 3),
makeBootstrapList(3),
retryBackoffMs,
requestTimeoutMs,
random);
random
);
RequestManager.ConnectionState connectionState = cache.getOrCreate(1);
assertTrue(connectionState.isReady(time.milliseconds()));
assertTrue(cache.isReady(node, time.milliseconds()));
long correlationId = 1;
connectionState.onRequestSent(correlationId, time.milliseconds());
assertFalse(connectionState.isReady(time.milliseconds()));
cache.onRequestSent(node, correlationId, time.milliseconds());
assertFalse(cache.isReady(node, time.milliseconds()));
connectionState.onResponseError(correlationId, time.milliseconds());
assertFalse(connectionState.isReady(time.milliseconds()));
cache.onResponseResult(node, correlationId, false, time.milliseconds());
assertFalse(cache.isReady(node, time.milliseconds()));
time.sleep(retryBackoffMs);
assertTrue(connectionState.isReady(time.milliseconds()));
assertTrue(cache.isReady(node, time.milliseconds()));
}
@Test
public void testSuccessfulResponse() {
Node node = new Node(1, "mock-host-1", 4321);
RequestManager cache = new RequestManager(
Utils.mkSet(1, 2, 3),
makeBootstrapList(3),
retryBackoffMs,
requestTimeoutMs,
random);
RequestManager.ConnectionState connectionState = cache.getOrCreate(1);
random
);
long correlationId = 1;
connectionState.onRequestSent(correlationId, time.milliseconds());
assertFalse(connectionState.isReady(time.milliseconds()));
connectionState.onResponseReceived(correlationId);
assertTrue(connectionState.isReady(time.milliseconds()));
cache.onRequestSent(node, correlationId, time.milliseconds());
assertFalse(cache.isReady(node, time.milliseconds()));
cache.onResponseResult(node, correlationId, true, time.milliseconds());
assertTrue(cache.isReady(node, time.milliseconds()));
}
@Test
public void testIgnoreUnexpectedResponse() {
Node node = new Node(1, "mock-host-1", 4321);
RequestManager cache = new RequestManager(
Utils.mkSet(1, 2, 3),
makeBootstrapList(3),
retryBackoffMs,
requestTimeoutMs,
random);
RequestManager.ConnectionState connectionState = cache.getOrCreate(1);
random
);
long correlationId = 1;
connectionState.onRequestSent(correlationId, time.milliseconds());
assertFalse(connectionState.isReady(time.milliseconds()));
connectionState.onResponseReceived(correlationId + 1);
assertFalse(connectionState.isReady(time.milliseconds()));
cache.onRequestSent(node, correlationId, time.milliseconds());
assertFalse(cache.isReady(node, time.milliseconds()));
cache.onResponseResult(node, correlationId + 1, true, time.milliseconds());
assertFalse(cache.isReady(node, time.milliseconds()));
}
@Test
public void testRequestTimeout() {
Node node = new Node(1, "mock-host-1", 4321);
RequestManager cache = new RequestManager(
Utils.mkSet(1, 2, 3),
makeBootstrapList(3),
retryBackoffMs,
requestTimeoutMs,
random);
RequestManager.ConnectionState connectionState = cache.getOrCreate(1);
random
);
long correlationId = 1;
connectionState.onRequestSent(correlationId, time.milliseconds());
assertFalse(connectionState.isReady(time.milliseconds()));
cache.onRequestSent(node, correlationId, time.milliseconds());
assertFalse(cache.isReady(node, time.milliseconds()));
time.sleep(requestTimeoutMs - 1);
assertFalse(connectionState.isReady(time.milliseconds()));
assertFalse(cache.isReady(node, time.milliseconds()));
time.sleep(1);
assertTrue(connectionState.isReady(time.milliseconds()));
assertTrue(cache.isReady(node, time.milliseconds()));
}
@Test
public void testRequestToBootstrapList() {
List<Node> bootstrapList = makeBootstrapList(2);
RequestManager cache = new RequestManager(
bootstrapList,
retryBackoffMs,
requestTimeoutMs,
random
);
// Find a ready node with the starting state
Node bootstrapNode1 = cache.findReadyBootstrapServer(time.milliseconds()).get();
assertTrue(
bootstrapList.contains(bootstrapNode1),
String.format("%s is not in %s", bootstrapNode1, bootstrapList)
);
assertEquals(0, cache.backoffBeforeAvailableBootstrapServer(time.milliseconds()));
// Send a request and check the cache state
cache.onRequestSent(bootstrapNode1, 1, time.milliseconds());
assertEquals(
Optional.empty(),
cache.findReadyBootstrapServer(time.milliseconds())
);
assertEquals(requestTimeoutMs, cache.backoffBeforeAvailableBootstrapServer(time.milliseconds()));
// Fail the request
time.sleep(100);
cache.onResponseResult(bootstrapNode1, 1, false, time.milliseconds());
Node bootstrapNode2 = cache.findReadyBootstrapServer(time.milliseconds()).get();
assertNotEquals(bootstrapNode1, bootstrapNode2);
assertEquals(0, cache.backoffBeforeAvailableBootstrapServer(time.milliseconds()));
// Send a request to the second node and check the state
cache.onRequestSent(bootstrapNode2, 2, time.milliseconds());
assertEquals(
Optional.empty(),
cache.findReadyBootstrapServer(time.milliseconds())
);
assertEquals(requestTimeoutMs, cache.backoffBeforeAvailableBootstrapServer(time.milliseconds()));
// Fail the second request before the request timeout
time.sleep(retryBackoffMs - 1);
cache.onResponseResult(bootstrapNode2, 2, false, time.milliseconds());
assertEquals(
Optional.empty(),
cache.findReadyBootstrapServer(time.milliseconds())
);
assertEquals(1, cache.backoffBeforeAvailableBootstrapServer(time.milliseconds()));
// Timeout the first backoff and show that that node is ready
time.sleep(1);
Node bootstrapNode3 = cache.findReadyBootstrapServer(time.milliseconds()).get();
assertEquals(bootstrapNode1, bootstrapNode3);
assertEquals(0, cache.backoffBeforeAvailableBootstrapServer(time.milliseconds()));
}
@Test
public void testFindReadyWithInflightRequest() {
Node otherNode = new Node(1, "other-node", 1234);
List<Node> bootstrapList = makeBootstrapList(3);
RequestManager cache = new RequestManager(
bootstrapList,
retryBackoffMs,
requestTimeoutMs,
random
);
// Send request to a node that is not in the bootstrap list
cache.onRequestSent(otherNode, 1, time.milliseconds());
assertEquals(Optional.empty(), cache.findReadyBootstrapServer(time.milliseconds()));
}
@Test
public void testFindReadyWithRequestTimedout() {
Node otherNode = new Node(1, "other-node", 1234);
List<Node> bootstrapList = makeBootstrapList(3);
RequestManager cache = new RequestManager(
bootstrapList,
retryBackoffMs,
requestTimeoutMs,
random
);
// Send request to a node that is not in the bootstrap list
cache.onRequestSent(otherNode, 1, time.milliseconds());
assertTrue(cache.isResponseExpected(otherNode, 1));
assertEquals(Optional.empty(), cache.findReadyBootstrapServer(time.milliseconds()));
// Timeout the request
time.sleep(requestTimeoutMs);
Node bootstrapNode = cache.findReadyBootstrapServer(time.milliseconds()).get();
assertTrue(bootstrapList.contains(bootstrapNode));
assertFalse(cache.isResponseExpected(otherNode, 1));
}
@Test
public void testAnyInflightRequestWithAnyRequest() {
Node otherNode = new Node(1, "other-node", 1234);
List<Node> bootstrapList = makeBootstrapList(3);
RequestManager cache = new RequestManager(
bootstrapList,
retryBackoffMs,
requestTimeoutMs,
random
);
assertFalse(cache.hasAnyInflightRequest(time.milliseconds()));
// Send a request and check state
cache.onRequestSent(otherNode, 11, time.milliseconds());
assertTrue(cache.hasAnyInflightRequest(time.milliseconds()));
// Wait until the request times out
time.sleep(requestTimeoutMs);
assertFalse(cache.hasAnyInflightRequest(time.milliseconds()));
// Send another request and fail it
cache.onRequestSent(otherNode, 12, time.milliseconds());
cache.onResponseResult(otherNode, 12, false, time.milliseconds());
assertFalse(cache.hasAnyInflightRequest(time.milliseconds()));
// Send another request and mark it successful
cache.onRequestSent(otherNode, 12, time.milliseconds());
cache.onResponseResult(otherNode, 12, true, time.milliseconds());
assertFalse(cache.hasAnyInflightRequest(time.milliseconds()));
}
private List<Node> makeBootstrapList(int numberOfNodes) {
return IntStream.iterate(-2, id -> id - 1)
.limit(numberOfNodes)
.mapToObj(id -> new Node(id, String.format("mock-boot-host%d", id), 1234))
.collect(Collectors.toList());
}
}

View File

@ -16,8 +16,8 @@
*/
package org.apache.kafka.raft.internals;
import java.util.Arrays;
import java.util.Optional;
import java.util.stream.IntStream;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.Uuid;
import org.apache.kafka.common.message.KRaftVersionRecord;
@ -52,7 +52,7 @@ final class KRaftControlRecordStateMachineTest {
@Test
void testEmptyPartition() {
MockLog log = buildLog();
VoterSet voterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true));
VoterSet voterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(1, 2, 3), true));
KRaftControlRecordStateMachine partitionState = buildPartitionListener(log, Optional.of(voterSet));
@ -65,7 +65,7 @@ final class KRaftControlRecordStateMachineTest {
@Test
void testUpdateWithoutSnapshot() {
MockLog log = buildLog();
VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true));
VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(1, 2, 3), true));
BufferSupplier bufferSupplier = BufferSupplier.NO_CACHING;
int epoch = 1;
@ -85,7 +85,7 @@ final class KRaftControlRecordStateMachineTest {
);
// Append the voter set control record
VoterSet voterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(4, 5, 6), true));
VoterSet voterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(4, 5, 6), true));
log.appendAsLeader(
MemoryRecords.withVotersRecord(
log.endOffset().offset,
@ -108,7 +108,7 @@ final class KRaftControlRecordStateMachineTest {
@Test
void testUpdateWithEmptySnapshot() {
MockLog log = buildLog();
VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true));
VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(1, 2, 3), true));
BufferSupplier bufferSupplier = BufferSupplier.NO_CACHING;
int epoch = 1;
@ -136,7 +136,7 @@ final class KRaftControlRecordStateMachineTest {
);
// Append the voter set control record
VoterSet voterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(4, 5, 6), true));
VoterSet voterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(4, 5, 6), true));
log.appendAsLeader(
MemoryRecords.withVotersRecord(
log.endOffset().offset,
@ -159,14 +159,14 @@ final class KRaftControlRecordStateMachineTest {
@Test
void testUpdateWithSnapshot() {
MockLog log = buildLog();
VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true));
VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(1, 2, 3), true));
int epoch = 1;
KRaftControlRecordStateMachine partitionState = buildPartitionListener(log, Optional.of(staticVoterSet));
// Create a snapshot that has kraft.version and voter set control records
short kraftVersion = 1;
VoterSet voterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(4, 5, 6), true));
VoterSet voterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(4, 5, 6), true));
RecordsSnapshotWriter.Builder builder = new RecordsSnapshotWriter.Builder()
.setRawSnapshotWriter(log.createNewSnapshotUnchecked(new OffsetAndEpoch(10, epoch)).get())
@ -188,7 +188,7 @@ final class KRaftControlRecordStateMachineTest {
@Test
void testUpdateWithSnapshotAndLogOverride() {
MockLog log = buildLog();
VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true));
VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(1, 2, 3), true));
BufferSupplier bufferSupplier = BufferSupplier.NO_CACHING;
int epoch = 1;
@ -196,7 +196,7 @@ final class KRaftControlRecordStateMachineTest {
// Create a snapshot that has kraft.version and voter set control records
short kraftVersion = 1;
VoterSet snapshotVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(4, 5, 6), true));
VoterSet snapshotVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(4, 5, 6), true));
OffsetAndEpoch snapshotId = new OffsetAndEpoch(10, epoch);
RecordsSnapshotWriter.Builder builder = new RecordsSnapshotWriter.Builder()
@ -235,7 +235,7 @@ final class KRaftControlRecordStateMachineTest {
@Test
void testTruncateTo() {
MockLog log = buildLog();
VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true));
VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(1, 2, 3), true));
BufferSupplier bufferSupplier = BufferSupplier.NO_CACHING;
int epoch = 1;
@ -256,7 +256,7 @@ final class KRaftControlRecordStateMachineTest {
// Append the voter set control record
long firstVoterSetOffset = log.endOffset().offset;
VoterSet firstVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(4, 5, 6), true));
VoterSet firstVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(4, 5, 6), true));
log.appendAsLeader(
MemoryRecords.withVotersRecord(
firstVoterSetOffset,
@ -303,7 +303,7 @@ final class KRaftControlRecordStateMachineTest {
@Test
void testTrimPrefixTo() {
MockLog log = buildLog();
VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true));
VoterSet staticVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(1, 2, 3), true));
BufferSupplier bufferSupplier = BufferSupplier.NO_CACHING;
int epoch = 1;
@ -325,7 +325,7 @@ final class KRaftControlRecordStateMachineTest {
// Append the voter set control record
long firstVoterSetOffset = log.endOffset().offset;
VoterSet firstVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(Arrays.asList(4, 5, 6), true));
VoterSet firstVoterSet = VoterSetTest.voterSet(VoterSetTest.voterMap(IntStream.of(4, 5, 6), true));
log.appendAsLeader(
MemoryRecords.withVotersRecord(
firstVoterSetOffset,

View File

@ -22,7 +22,6 @@ import org.apache.kafka.common.metrics.Metrics;
import org.apache.kafka.common.utils.LogContext;
import org.apache.kafka.common.utils.MockTime;
import org.apache.kafka.common.utils.Time;
import org.apache.kafka.common.utils.Utils;
import org.apache.kafka.raft.LogOffsetMetadata;
import org.apache.kafka.raft.MockQuorumStateStore;
import org.apache.kafka.raft.OffsetAndEpoch;
@ -32,13 +31,13 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.Mockito;
import java.util.Map;
import java.util.Collections;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.OptionalLong;
import java.util.Random;
import java.util.Set;
import java.util.stream.IntStream;
import static org.junit.jupiter.api.Assertions.assertEquals;
@ -64,19 +63,11 @@ public class KafkaRaftMetricsTest {
metrics.close();
}
private QuorumState buildQuorumState(Set<Integer> voters, short kraftVersion) {
boolean withDirectoryId = kraftVersion > 0;
return buildQuorumState(
VoterSetTest.voterSet(VoterSetTest.voterMap(voters, withDirectoryId)),
kraftVersion
);
}
private QuorumState buildQuorumState(VoterSet voterSet, short kraftVersion) {
return new QuorumState(
OptionalInt.of(localId),
localDirectoryId,
VoterSetTest.DEFAULT_LISTENER_NAME,
() -> voterSet,
() -> kraftVersion,
electionTimeoutMs,
@ -88,11 +79,26 @@ public class KafkaRaftMetricsTest {
);
}
private VoterSet localStandaloneVoterSet(short kraftVersion) {
boolean withDirectoryId = kraftVersion > 0;
return VoterSetTest.voterSet(
Collections.singletonMap(
localId,
VoterSetTest.voterNode(
ReplicaKey.of(
localId,
withDirectoryId ? Optional.of(localDirectoryId) : Optional.empty()
)
)
)
);
}
@ParameterizedTest
@ValueSource(shorts = {0, 1})
public void shouldRecordVoterQuorumState(short kraftVersion) {
boolean withDirectoryId = kraftVersion > 0;
Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(Utils.mkSet(1, 2), withDirectoryId);
Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(IntStream.of(1, 2), withDirectoryId);
voterMap.put(
localId,
VoterSetTest.voterNode(
@ -102,7 +108,8 @@ public class KafkaRaftMetricsTest {
)
)
);
QuorumState state = buildQuorumState(VoterSetTest.voterSet(voterMap), kraftVersion);
VoterSet voters = VoterSetTest.voterSet(voterMap);
QuorumState state = buildQuorumState(voters, kraftVersion);
state.initialize(new OffsetAndEpoch(0L, 0));
raftMetrics = new KafkaRaftMetrics(metrics, "raft", state);
@ -144,7 +151,7 @@ public class KafkaRaftMetricsTest {
state.leaderStateOrThrow().updateReplicaState(1, 0, new LogOffsetMetadata(5L));
assertEquals((double) 5L, getMetric(metrics, "high-watermark").metricValue());
state.transitionToFollower(2, 1);
state.transitionToFollower(2, voters.voterNode(1, VoterSetTest.DEFAULT_LISTENER_NAME).get());
assertEquals("follower", getMetric(metrics, "current-state").metricValue());
assertEquals((double) 1, getMetric(metrics, "current-leader").metricValue());
assertEquals((double) -1, getMetric(metrics, "current-vote").metricValue());
@ -184,7 +191,11 @@ public class KafkaRaftMetricsTest {
@ParameterizedTest
@ValueSource(shorts = {0, 1})
public void shouldRecordNonVoterQuorumState(short kraftVersion) {
QuorumState state = buildQuorumState(Utils.mkSet(1, 2, 3), kraftVersion);
boolean withDirectoryId = kraftVersion > 0;
VoterSet voters = VoterSetTest.voterSet(
VoterSetTest.voterMap(IntStream.of(1, 2, 3), withDirectoryId)
);
QuorumState state = buildQuorumState(voters, kraftVersion);
state.initialize(new OffsetAndEpoch(0L, 0));
raftMetrics = new KafkaRaftMetrics(metrics, "raft", state);
@ -198,7 +209,7 @@ public class KafkaRaftMetricsTest {
assertEquals((double) 0, getMetric(metrics, "current-epoch").metricValue());
assertEquals((double) -1L, getMetric(metrics, "high-watermark").metricValue());
state.transitionToFollower(2, 1);
state.transitionToFollower(2, voters.voterNode(1, VoterSetTest.DEFAULT_LISTENER_NAME).get());
assertEquals("observer", getMetric(metrics, "current-state").metricValue());
assertEquals((double) 1, getMetric(metrics, "current-leader").metricValue());
assertEquals((double) -1, getMetric(metrics, "current-vote").metricValue());
@ -227,7 +238,7 @@ public class KafkaRaftMetricsTest {
@ParameterizedTest
@ValueSource(shorts = {0, 1})
public void shouldRecordLogEnd(short kraftVersion) {
QuorumState state = buildQuorumState(Collections.singleton(localId), kraftVersion);
QuorumState state = buildQuorumState(localStandaloneVoterSet(kraftVersion), kraftVersion);
state.initialize(new OffsetAndEpoch(0L, 0));
raftMetrics = new KafkaRaftMetrics(metrics, "raft", state);
@ -243,7 +254,7 @@ public class KafkaRaftMetricsTest {
@ParameterizedTest
@ValueSource(shorts = {0, 1})
public void shouldRecordNumUnknownVoterConnections(short kraftVersion) {
QuorumState state = buildQuorumState(Collections.singleton(localId), kraftVersion);
QuorumState state = buildQuorumState(localStandaloneVoterSet(kraftVersion), kraftVersion);
state.initialize(new OffsetAndEpoch(0L, 0));
raftMetrics = new KafkaRaftMetrics(metrics, "raft", state);
@ -257,7 +268,7 @@ public class KafkaRaftMetricsTest {
@ParameterizedTest
@ValueSource(shorts = {0, 1})
public void shouldRecordPollIdleRatio(short kraftVersion) {
QuorumState state = buildQuorumState(Collections.singleton(localId), kraftVersion);
QuorumState state = buildQuorumState(localStandaloneVoterSet(kraftVersion), kraftVersion);
state.initialize(new OffsetAndEpoch(0L, 0));
raftMetrics = new KafkaRaftMetrics(metrics, "raft", state);
@ -330,7 +341,7 @@ public class KafkaRaftMetricsTest {
@ParameterizedTest
@ValueSource(shorts = {0, 1})
public void shouldRecordLatency(short kraftVersion) {
QuorumState state = buildQuorumState(Collections.singleton(localId), kraftVersion);
QuorumState state = buildQuorumState(localStandaloneVoterSet(kraftVersion), kraftVersion);
state.initialize(new OffsetAndEpoch(0L, 0));
raftMetrics = new KafkaRaftMetrics(metrics, "raft", state);
@ -362,7 +373,7 @@ public class KafkaRaftMetricsTest {
@ParameterizedTest
@ValueSource(shorts = {0, 1})
public void shouldRecordRate(short kraftVersion) {
QuorumState state = buildQuorumState(Collections.singleton(localId), kraftVersion);
QuorumState state = buildQuorumState(localStandaloneVoterSet(kraftVersion), kraftVersion);
state.initialize(new OffsetAndEpoch(0L, 0));
raftMetrics = new KafkaRaftMetrics(metrics, "raft", state);

View File

@ -21,7 +21,6 @@ import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.NoSuchElementException;
@ -31,6 +30,7 @@ import java.util.Random;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import net.jqwik.api.ForAll;
import net.jqwik.api.Property;
@ -204,7 +204,7 @@ public final class RecordsIteratorTest {
public void testControlRecordIterationWithKraftVersion1() {
AtomicReference<ByteBuffer> buffer = new AtomicReference<>(null);
VoterSet voterSet = new VoterSet(
new HashMap<>(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true))
VoterSetTest.voterMap(IntStream.of(1, 2, 3), true)
);
RecordsSnapshotWriter.Builder builder = new RecordsSnapshotWriter.Builder()
.setTime(new MockTime())

View File

@ -16,10 +16,10 @@
*/
package org.apache.kafka.raft.internals;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.stream.IntStream;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
@ -27,7 +27,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows;
final public class VoterSetHistoryTest {
@Test
void testStaticVoterSet() {
VoterSet staticVoterSet = new VoterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true));
VoterSet staticVoterSet = new VoterSet(VoterSetTest.voterMap(IntStream.of(1, 2, 3), true));
VoterSetHistory votersHistory = new VoterSetHistory(Optional.of(staticVoterSet));
assertEquals(Optional.empty(), votersHistory.valueAtOrBefore(0));
@ -58,13 +58,13 @@ final public class VoterSetHistoryTest {
@Test
void testAddAt() {
Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true);
Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(IntStream.of(1, 2, 3), true);
VoterSet staticVoterSet = new VoterSet(new HashMap<>(voterMap));
VoterSetHistory votersHistory = new VoterSetHistory(Optional.of(staticVoterSet));
assertThrows(
IllegalArgumentException.class,
() -> votersHistory.addAt(-1, new VoterSet(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true)))
() -> votersHistory.addAt(-1, new VoterSet(VoterSetTest.voterMap(IntStream.of(1, 2, 3), true)))
);
assertEquals(staticVoterSet, votersHistory.lastValue());
@ -90,7 +90,7 @@ final public class VoterSetHistoryTest {
void testAddAtNonOverlapping() {
VoterSetHistory votersHistory = new VoterSetHistory(Optional.empty());
Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true);
Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(IntStream.of(1, 2, 3), true);
VoterSet voterSet = new VoterSet(new HashMap<>(voterMap));
// Add a starting voter to the history
@ -122,7 +122,7 @@ final public class VoterSetHistoryTest {
@Test
void testNonoverlappingFromStaticVoterSet() {
Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true);
Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(IntStream.of(1, 2, 3), true);
VoterSet staticVoterSet = new VoterSet(new HashMap<>(voterMap));
VoterSetHistory votersHistory = new VoterSetHistory(Optional.empty());
@ -137,7 +137,7 @@ final public class VoterSetHistoryTest {
@Test
void testTruncateTo() {
Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true);
Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(IntStream.of(1, 2, 3), true);
VoterSet staticVoterSet = new VoterSet(new HashMap<>(voterMap));
VoterSetHistory votersHistory = new VoterSetHistory(Optional.of(staticVoterSet));
@ -163,7 +163,7 @@ final public class VoterSetHistoryTest {
@Test
void testTrimPrefixTo() {
Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true);
Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(IntStream.of(1, 2, 3), true);
VoterSet staticVoterSet = new VoterSet(new HashMap<>(voterMap));
VoterSetHistory votersHistory = new VoterSetHistory(Optional.of(staticVoterSet));
@ -196,7 +196,7 @@ final public class VoterSetHistoryTest {
@Test
void testClear() {
Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true);
Map<Integer, VoterSet.VoterNode> voterMap = VoterSetTest.voterMap(IntStream.of(1, 2, 3), true);
VoterSet staticVoterSet = new VoterSet(new HashMap<>(voterMap));
VoterSetHistory votersHistory = new VoterSetHistory(Optional.of(staticVoterSet));

View File

@ -18,7 +18,6 @@ package org.apache.kafka.raft.internals;
import java.net.InetSocketAddress;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
@ -26,8 +25,13 @@ import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.kafka.common.Node;
import org.apache.kafka.common.Uuid;
import org.apache.kafka.common.feature.SupportedVersionRange;
import org.apache.kafka.common.network.ListenerName;
import org.apache.kafka.common.utils.Utils;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
@ -41,22 +45,45 @@ final public class VoterSetTest {
}
@Test
void testVoterAddress() {
VoterSet voterSet = new VoterSet(voterMap(Arrays.asList(1, 2, 3), true));
assertEquals(Optional.of(new InetSocketAddress("replica-1", 1234)), voterSet.voterAddress(1, "LISTENER"));
assertEquals(Optional.empty(), voterSet.voterAddress(1, "MISSING"));
assertEquals(Optional.empty(), voterSet.voterAddress(4, "LISTENER"));
void testVoterNode() {
VoterSet voterSet = new VoterSet(voterMap(IntStream.of(1, 2, 3), true));
assertEquals(
Optional.of(new Node(1, "replica-1", 1234)),
voterSet.voterNode(1, DEFAULT_LISTENER_NAME)
);
assertEquals(Optional.empty(), voterSet.voterNode(1, ListenerName.normalised("MISSING")));
assertEquals(Optional.empty(), voterSet.voterNode(4, DEFAULT_LISTENER_NAME));
}
@Test
void testVoterNodes() {
VoterSet voterSet = new VoterSet(voterMap(IntStream.of(1, 2, 3), true));
assertEquals(
Utils.mkSet(new Node(1, "replica-1", 1234), new Node(2, "replica-2", 1234)),
voterSet.voterNodes(IntStream.of(1, 2).boxed(), DEFAULT_LISTENER_NAME)
);
assertThrows(
IllegalArgumentException.class,
() -> voterSet.voterNodes(IntStream.of(1, 2).boxed(), ListenerName.normalised("MISSING"))
);
assertThrows(
IllegalArgumentException.class,
() -> voterSet.voterNodes(IntStream.of(1, 4).boxed(), DEFAULT_LISTENER_NAME)
);
}
@Test
void testVoterIds() {
VoterSet voterSet = new VoterSet(voterMap(Arrays.asList(1, 2, 3), true));
VoterSet voterSet = new VoterSet(voterMap(IntStream.of(1, 2, 3), true));
assertEquals(new HashSet<>(Arrays.asList(1, 2, 3)), voterSet.voterIds());
}
@Test
void testAddVoter() {
Map<Integer, VoterSet.VoterNode> aVoterMap = voterMap(Arrays.asList(1, 2, 3), true);
Map<Integer, VoterSet.VoterNode> aVoterMap = voterMap(IntStream.of(1, 2, 3), true);
VoterSet voterSet = new VoterSet(new HashMap<>(aVoterMap));
assertEquals(Optional.empty(), voterSet.addVoter(voterNode(1, true)));
@ -68,7 +95,7 @@ final public class VoterSetTest {
@Test
void testRemoveVoter() {
Map<Integer, VoterSet.VoterNode> aVoterMap = voterMap(Arrays.asList(1, 2, 3), true);
Map<Integer, VoterSet.VoterNode> aVoterMap = voterMap(IntStream.of(1, 2, 3), true);
VoterSet voterSet = new VoterSet(new HashMap<>(aVoterMap));
assertEquals(Optional.empty(), voterSet.removeVoter(ReplicaKey.of(4, Optional.empty())));
@ -83,7 +110,7 @@ final public class VoterSetTest {
@Test
void testIsVoterWithDirectoryId() {
Map<Integer, VoterSet.VoterNode> aVoterMap = voterMap(Arrays.asList(1, 2, 3), true);
Map<Integer, VoterSet.VoterNode> aVoterMap = voterMap(IntStream.of(1, 2, 3), true);
VoterSet voterSet = new VoterSet(new HashMap<>(aVoterMap));
assertTrue(voterSet.isVoter(aVoterMap.get(1).voterKey()));
@ -100,7 +127,7 @@ final public class VoterSetTest {
@Test
void testIsVoterWithoutDirectoryId() {
Map<Integer, VoterSet.VoterNode> aVoterMap = voterMap(Arrays.asList(1, 2, 3), false);
Map<Integer, VoterSet.VoterNode> aVoterMap = voterMap(IntStream.of(1, 2, 3), false);
VoterSet voterSet = new VoterSet(new HashMap<>(aVoterMap));
assertTrue(voterSet.isVoter(ReplicaKey.of(1, Optional.empty())));
@ -111,7 +138,7 @@ final public class VoterSetTest {
@Test
void testIsOnlyVoterInStandalone() {
Map<Integer, VoterSet.VoterNode> aVoterMap = voterMap(Arrays.asList(1), true);
Map<Integer, VoterSet.VoterNode> aVoterMap = voterMap(IntStream.of(1), true);
VoterSet voterSet = new VoterSet(new HashMap<>(aVoterMap));
assertTrue(voterSet.isOnlyVoter(aVoterMap.get(1).voterKey()));
@ -125,7 +152,7 @@ final public class VoterSetTest {
@Test
void testIsOnlyVoterInNotStandalone() {
Map<Integer, VoterSet.VoterNode> aVoterMap = voterMap(Arrays.asList(1, 2), true);
Map<Integer, VoterSet.VoterNode> aVoterMap = voterMap(IntStream.of(1, 2), true);
VoterSet voterSet = new VoterSet(new HashMap<>(aVoterMap));
assertFalse(voterSet.isOnlyVoter(aVoterMap.get(1).voterKey()));
@ -142,14 +169,14 @@ final public class VoterSetTest {
@Test
void testRecordRoundTrip() {
VoterSet voterSet = new VoterSet(voterMap(Arrays.asList(1, 2, 3), true));
VoterSet voterSet = new VoterSet(voterMap(IntStream.of(1, 2, 3), true));
assertEquals(voterSet, VoterSet.fromVotersRecord(voterSet.toVotersRecord((short) 0)));
}
@Test
void testOverlappingMajority() {
Map<Integer, VoterSet.VoterNode> startingVoterMap = voterMap(Arrays.asList(1, 2, 3), true);
Map<Integer, VoterSet.VoterNode> startingVoterMap = voterMap(IntStream.of(1, 2, 3), true);
VoterSet startingVoterSet = voterSet(startingVoterMap);
VoterSet biggerVoterSet = startingVoterSet
@ -172,7 +199,7 @@ final public class VoterSetTest {
@Test
void testNonoverlappingMajority() {
Map<Integer, VoterSet.VoterNode> startingVoterMap = voterMap(Arrays.asList(1, 2, 3, 4, 5), true);
Map<Integer, VoterSet.VoterNode> startingVoterMap = voterMap(IntStream.of(1, 2, 3, 4, 5), true);
VoterSet startingVoterSet = voterSet(startingVoterMap);
// Two additions don't have an overlapping majority
@ -217,20 +244,27 @@ final public class VoterSetTest {
);
}
public static final ListenerName DEFAULT_LISTENER_NAME = ListenerName.normalised("LISTENER");
public static Map<Integer, VoterSet.VoterNode> voterMap(
Collection<Integer> replicas,
IntStream replicas,
boolean withDirectoryId
) {
return replicas
.stream()
.boxed()
.collect(
Collectors.toMap(
Function.identity(),
id -> VoterSetTest.voterNode(id, withDirectoryId)
id -> voterNode(id, withDirectoryId)
)
);
}
public static Map<Integer, VoterSet.VoterNode> voterMap(Stream<ReplicaKey> replicas) {
return replicas
.collect(Collectors.toMap(ReplicaKey::id, VoterSetTest::voterNode));
}
public static VoterSet.VoterNode voterNode(int id, boolean withDirectoryId) {
return voterNode(
ReplicaKey.of(
@ -244,7 +278,7 @@ final public class VoterSetTest {
return new VoterSet.VoterNode(
replicaKey,
Collections.singletonMap(
"LISTENER",
DEFAULT_LISTENER_NAME,
InetSocketAddress.createUnresolved(
String.format("replica-%d", replicaKey.id()),
1234
@ -257,4 +291,8 @@ final public class VoterSetTest {
public static VoterSet voterSet(Map<Integer, VoterSet.VoterNode> voters) {
return new VoterSet(voters);
}
public static VoterSet voterSet(Stream<ReplicaKey> voterKeys) {
return voterSet(voterMap(voterKeys));
}
}

View File

@ -17,12 +17,11 @@
package org.apache.kafka.snapshot;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.IntStream;
import org.apache.kafka.common.message.KRaftVersionRecord;
import org.apache.kafka.common.message.SnapshotFooterRecord;
import org.apache.kafka.common.message.SnapshotHeaderRecord;
@ -97,7 +96,7 @@ final class RecordsSnapshotWriterTest {
OffsetAndEpoch snapshotId = new OffsetAndEpoch(100, 10);
int maxBatchSize = 1024;
VoterSet voterSet = VoterSetTest.voterSet(
new HashMap<>(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true))
new HashMap<>(VoterSetTest.voterMap(IntStream.of(1, 2, 3), true))
);
AtomicReference<ByteBuffer> buffer = new AtomicReference<>(null);
RecordsSnapshotWriter.Builder builder = new RecordsSnapshotWriter.Builder()
@ -117,7 +116,7 @@ final class RecordsSnapshotWriterTest {
OffsetAndEpoch snapshotId = new OffsetAndEpoch(100, 10);
int maxBatchSize = 1024;
VoterSet voterSet = VoterSetTest.voterSet(
new HashMap<>(VoterSetTest.voterMap(Arrays.asList(1, 2, 3), true))
new HashMap<>(VoterSetTest.voterMap(IntStream.of(1, 2, 3), true))
);
AtomicReference<ByteBuffer> buffer = new AtomicReference<>(null);
RecordsSnapshotWriter.Builder builder = new RecordsSnapshotWriter.Builder()