KAFKA-15661: KIP-951: Server side changes (#14444)

This is the server side changes to populate the fields in KIP-951. On NOT_LEADER_OR_FOLLOWER errors in both FETCH and PRODUCE the new leader ID and epoch are retrieved from the local cache through ReplicaManager and included in the response, falling back to the metadata cache if they are unavailable there. The endpoint for the new leader is retrieved from the metadata cache. The new fields are all optional (tagged) and an IBP bump was required.

https://cwiki.apache.org/confluence/display/KAFKA/KIP-951%3A+Leader+discovery+optimisations+for+the+client

https://issues.apache.org/jira/browse/KAFKA-15661

Protocol changes: #14627

Testing
Benchmarking described here https://cwiki.apache.org/confluence/display/KAFKA/KIP-951%3A+Leader+discovery+optimisations+for+the+client#KIP951:Leaderdiscoveryoptimisationsfortheclient-BenchmarkResults
./gradlew core:test --tests kafka.server.KafkaApisTest

Reviewers: Justine Olshan <jolshan@confluent.io>, David Jacot <djacot@confluent.io>, Jason Gustafson <jason@confluent.io>, Fred Zheng <zhengyd2014@gmail.com>, Mayank Shekhar Narula <mayanks.narula@gmail.com>,  Yang Yang <yayang@uber.com>, David Mao <dmao@confluent.io>, Kirk True <ktrue@confluent.io>
This commit is contained in:
Crispin Bernier 2023-11-10 00:07:21 -05:00 committed by GitHub
parent 809694a9f6
commit f38b0d886c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 342 additions and 7 deletions

View File

@ -16,6 +16,7 @@
*/
package org.apache.kafka.common.requests;
import org.apache.kafka.common.Node;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.message.ProduceResponseData;
import org.apache.kafka.common.message.ProduceResponseData.LeaderIdAndEpoch;
@ -72,7 +73,7 @@ public class ProduceResponse extends AbstractResponse {
*/
@Deprecated
public ProduceResponse(Map<TopicPartition, PartitionResponse> responses) {
this(responses, DEFAULT_THROTTLE_TIME);
this(responses, DEFAULT_THROTTLE_TIME, Collections.emptyList());
}
/**
@ -83,10 +84,23 @@ public class ProduceResponse extends AbstractResponse {
*/
@Deprecated
public ProduceResponse(Map<TopicPartition, PartitionResponse> responses, int throttleTimeMs) {
this(toData(responses, throttleTimeMs));
this(toData(responses, throttleTimeMs, Collections.emptyList()));
}
private static ProduceResponseData toData(Map<TopicPartition, PartitionResponse> responses, int throttleTimeMs) {
/**
* Constructor for the latest version
* This is deprecated in favor of using the ProduceResponseData constructor, KafkaApis should switch to that
* in KAFKA-10730
* @param responses Produced data grouped by topic-partition
* @param throttleTimeMs Time in milliseconds the response was throttled
* @param nodeEndpoints List of node endpoints
*/
@Deprecated
public ProduceResponse(Map<TopicPartition, PartitionResponse> responses, int throttleTimeMs, List<Node> nodeEndpoints) {
this(toData(responses, throttleTimeMs, nodeEndpoints));
}
private static ProduceResponseData toData(Map<TopicPartition, PartitionResponse> responses, int throttleTimeMs, List<Node> nodeEndpoints) {
ProduceResponseData data = new ProduceResponseData().setThrottleTimeMs(throttleTimeMs);
responses.forEach((tp, response) -> {
ProduceResponseData.TopicProduceResponse tpr = data.responses().find(tp.topic());
@ -110,6 +124,12 @@ public class ProduceResponse extends AbstractResponse {
.setBatchIndexErrorMessage(e.message))
.collect(Collectors.toList())));
});
nodeEndpoints.forEach(endpoint -> data.nodeEndpoints()
.add(new ProduceResponseData.NodeEndpoint()
.setNodeId(endpoint.id())
.setHost(endpoint.host())
.setPort(endpoint.port())
.setRack(endpoint.rack())));
return data;
}

View File

@ -562,6 +562,23 @@ class KafkaApis(val requestChannel: RequestChannel,
}
}
case class LeaderNode(leaderId: Int, leaderEpoch: Int, node: Option[Node])
private def getCurrentLeader(tp: TopicPartition, ln: ListenerName): LeaderNode = {
val partitionInfoOrError = replicaManager.getPartitionOrError(tp)
val (leaderId, leaderEpoch) = partitionInfoOrError match {
case Right(x) =>
(x.leaderReplicaIdOpt.getOrElse(-1), x.getLeaderEpoch)
case Left(x) =>
debug(s"Unable to retrieve local leaderId and Epoch with error $x, falling back to metadata cache")
metadataCache.getPartitionInfo(tp.topic, tp.partition) match {
case Some(pinfo) => (pinfo.leader(), pinfo.leaderEpoch())
case None => (-1, -1)
}
}
LeaderNode(leaderId, leaderEpoch, metadataCache.getAliveBrokerNode(leaderId, ln))
}
/**
* Handle a produce request
*/
@ -614,6 +631,7 @@ class KafkaApis(val requestChannel: RequestChannel,
val mergedResponseStatus = responseStatus ++ unauthorizedTopicResponses ++ nonExistingTopicResponses ++ invalidRequestResponses
var errorInResponse = false
val nodeEndpoints = new mutable.HashMap[Int, Node]
mergedResponseStatus.forKeyValue { (topicPartition, status) =>
if (status.error != Errors.NONE) {
errorInResponse = true
@ -622,6 +640,20 @@ class KafkaApis(val requestChannel: RequestChannel,
request.header.clientId,
topicPartition,
status.error.exceptionName))
if (request.header.apiVersion >= 10) {
status.error match {
case Errors.NOT_LEADER_OR_FOLLOWER =>
val leaderNode = getCurrentLeader(topicPartition, request.context.listenerName)
leaderNode.node.foreach { node =>
nodeEndpoints.put(node.id(), node)
}
status.currentLeader
.setLeaderId(leaderNode.leaderId)
.setLeaderEpoch(leaderNode.leaderEpoch)
case _ =>
}
}
}
}
@ -665,7 +697,7 @@ class KafkaApis(val requestChannel: RequestChannel,
requestHelper.sendNoOpResponseExemptThrottle(request)
}
} else {
requestChannel.sendResponse(request, new ProduceResponse(mergedResponseStatus.asJava, maxThrottleTimeMs), None)
requestChannel.sendResponse(request, new ProduceResponse(mergedResponseStatus.asJava, maxThrottleTimeMs, nodeEndpoints.values.toList.asJava), None)
}
}
@ -843,6 +875,7 @@ class KafkaApis(val requestChannel: RequestChannel,
.setRecords(unconvertedRecords)
.setPreferredReadReplica(partitionData.preferredReadReplica)
.setDivergingEpoch(partitionData.divergingEpoch)
.setCurrentLeader(partitionData.currentLeader())
}
}
}
@ -851,6 +884,7 @@ class KafkaApis(val requestChannel: RequestChannel,
def processResponseCallback(responsePartitionData: Seq[(TopicIdPartition, FetchPartitionData)]): Unit = {
val partitions = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData]
val reassigningPartitions = mutable.Set[TopicIdPartition]()
val nodeEndpoints = new mutable.HashMap[Int, Node]
responsePartitionData.foreach { case (tp, data) =>
val abortedTransactions = data.abortedTransactions.orElse(null)
val lastStableOffset: Long = data.lastStableOffset.orElse(FetchResponse.INVALID_LAST_STABLE_OFFSET)
@ -864,6 +898,21 @@ class KafkaApis(val requestChannel: RequestChannel,
.setAbortedTransactions(abortedTransactions)
.setRecords(data.records)
.setPreferredReadReplica(data.preferredReadReplica.orElse(FetchResponse.INVALID_PREFERRED_REPLICA_ID))
if (versionId >= 16) {
data.error match {
case Errors.NOT_LEADER_OR_FOLLOWER | Errors.FENCED_LEADER_EPOCH =>
val leaderNode = getCurrentLeader(tp.topicPartition(), request.context.listenerName)
leaderNode.node.foreach { node =>
nodeEndpoints.put(node.id(), node)
}
partitionData.currentLeader()
.setLeaderId(leaderNode.leaderId)
.setLeaderEpoch(leaderNode.leaderEpoch)
case _ =>
}
}
data.divergingEpoch.ifPresent(partitionData.setDivergingEpoch(_))
partitions.put(tp, partitionData)
}
@ -887,7 +936,7 @@ class KafkaApis(val requestChannel: RequestChannel,
// Prepare fetch response from converted data
val response =
FetchResponse.of(unconvertedFetchResponse.error, throttleTimeMs, unconvertedFetchResponse.sessionId, convertedData)
FetchResponse.of(unconvertedFetchResponse.error, throttleTimeMs, unconvertedFetchResponse.sessionId, convertedData, nodeEndpoints.values.toList.asJava)
// record the bytes out metrics only when the response is being sent
response.data.responses.forEach { topicResponse =>
topicResponse.partitions.forEach { data =>

View File

@ -24,9 +24,10 @@ import java.util.Arrays.asList
import java.util.concurrent.{CompletableFuture, TimeUnit}
import java.util.{Collections, Optional, OptionalInt, OptionalLong, Properties}
import kafka.api.LeaderAndIsr
import kafka.cluster.Broker
import kafka.cluster.{Broker, Partition}
import kafka.controller.{ControllerContext, KafkaController}
import kafka.coordinator.transaction.{InitProducerIdResult, TransactionCoordinator}
import kafka.log.UnifiedLog
import kafka.metrics.ClientMetricsTestUtils
import kafka.network.{RequestChannel, RequestMetrics}
import kafka.server.QuotaFactory.QuotaManagers
@ -98,7 +99,7 @@ import org.apache.kafka.coordinator.group.GroupCoordinator
import org.apache.kafka.server.common.{Features, MetadataVersion}
import org.apache.kafka.server.common.MetadataVersion.{IBP_0_10_2_IV0, IBP_2_2_IV1}
import org.apache.kafka.server.util.MockTime
import org.apache.kafka.storage.internals.log.{AppendOrigin, FetchParams, FetchPartitionData}
import org.apache.kafka.storage.internals.log.{AppendOrigin, FetchParams, FetchPartitionData, LogConfig}
class KafkaApisTest {
private val requestChannel: RequestChannel = mock(classOf[RequestChannel])
@ -2475,6 +2476,204 @@ class KafkaApisTest {
}
}
@Test
def testProduceResponseContainsNewLeaderOnNotLeaderOrFollower(): Unit = {
val topic = "topic"
addTopicToMetadataCache(topic, numPartitions = 2, numBrokers = 3)
for (version <- 10 to ApiKeys.PRODUCE.latestVersion) {
reset(replicaManager, clientQuotaManager, clientRequestQuotaManager, requestChannel, txnCoordinator)
val responseCallback: ArgumentCaptor[Map[TopicPartition, PartitionResponse] => Unit] = ArgumentCaptor.forClass(classOf[Map[TopicPartition, PartitionResponse] => Unit])
val tp = new TopicPartition(topic, 0)
val partition = mock(classOf[Partition])
val newLeaderId = 2
val newLeaderEpoch = 5
val produceRequest = ProduceRequest.forCurrentMagic(new ProduceRequestData()
.setTopicData(new ProduceRequestData.TopicProduceDataCollection(
Collections.singletonList(new ProduceRequestData.TopicProduceData()
.setName(tp.topic).setPartitionData(Collections.singletonList(
new ProduceRequestData.PartitionProduceData()
.setIndex(tp.partition)
.setRecords(MemoryRecords.withRecords(CompressionType.NONE, new SimpleRecord("test".getBytes))))))
.iterator))
.setAcks(1.toShort)
.setTimeoutMs(5000))
.build(version.toShort)
val request = buildRequest(produceRequest)
when(replicaManager.appendRecords(anyLong,
anyShort,
ArgumentMatchers.eq(false),
ArgumentMatchers.eq(AppendOrigin.CLIENT),
any(),
responseCallback.capture(),
any(),
any(),
any(),
any(),
any())
).thenAnswer(_ => responseCallback.getValue.apply(Map(tp -> new PartitionResponse(Errors.NOT_LEADER_OR_FOLLOWER))))
when(replicaManager.getPartitionOrError(tp)).thenAnswer(_ => Right(partition))
when(partition.leaderReplicaIdOpt).thenAnswer(_ => Some(newLeaderId))
when(partition.getLeaderEpoch).thenAnswer(_ => newLeaderEpoch)
when(clientRequestQuotaManager.maybeRecordAndGetThrottleTimeMs(any[RequestChannel.Request](),
any[Long])).thenReturn(0)
when(clientQuotaManager.maybeRecordAndGetThrottleTimeMs(
any[RequestChannel.Request](), anyDouble, anyLong)).thenReturn(0)
createKafkaApis().handleProduceRequest(request, RequestLocal.withThreadConfinedCaching)
val response = verifyNoThrottling[ProduceResponse](request)
assertEquals(1, response.data.responses.size)
val topicProduceResponse = response.data.responses.asScala.head
assertEquals(1, topicProduceResponse.partitionResponses.size)
val partitionProduceResponse = topicProduceResponse.partitionResponses.asScala.head
assertEquals(Errors.NOT_LEADER_OR_FOLLOWER, Errors.forCode(partitionProduceResponse.errorCode))
assertEquals(newLeaderId, partitionProduceResponse.currentLeader.leaderId())
assertEquals(newLeaderEpoch, partitionProduceResponse.currentLeader.leaderEpoch())
assertEquals(1, response.data.nodeEndpoints.size)
val node = response.data.nodeEndpoints.asScala.head
assertEquals(2, node.nodeId)
assertEquals("broker2", node.host)
}
}
@Test
def testProduceResponseReplicaManagerLookupErrorOnNotLeaderOrFollower(): Unit = {
val topic = "topic"
addTopicToMetadataCache(topic, numPartitions = 2, numBrokers = 3)
for (version <- 10 to ApiKeys.PRODUCE.latestVersion) {
reset(replicaManager, clientQuotaManager, clientRequestQuotaManager, requestChannel, txnCoordinator)
val responseCallback: ArgumentCaptor[Map[TopicPartition, PartitionResponse] => Unit] = ArgumentCaptor.forClass(classOf[Map[TopicPartition, PartitionResponse] => Unit])
val tp = new TopicPartition(topic, 0)
val produceRequest = ProduceRequest.forCurrentMagic(new ProduceRequestData()
.setTopicData(new ProduceRequestData.TopicProduceDataCollection(
Collections.singletonList(new ProduceRequestData.TopicProduceData()
.setName(tp.topic).setPartitionData(Collections.singletonList(
new ProduceRequestData.PartitionProduceData()
.setIndex(tp.partition)
.setRecords(MemoryRecords.withRecords(CompressionType.NONE, new SimpleRecord("test".getBytes))))))
.iterator))
.setAcks(1.toShort)
.setTimeoutMs(5000))
.build(version.toShort)
val request = buildRequest(produceRequest)
when(replicaManager.appendRecords(anyLong,
anyShort,
ArgumentMatchers.eq(false),
ArgumentMatchers.eq(AppendOrigin.CLIENT),
any(),
responseCallback.capture(),
any(),
any(),
any(),
any(),
any())
).thenAnswer(_ => responseCallback.getValue.apply(Map(tp -> new PartitionResponse(Errors.NOT_LEADER_OR_FOLLOWER))))
when(replicaManager.getPartitionOrError(tp)).thenAnswer(_ => Left(Errors.UNKNOWN_TOPIC_OR_PARTITION))
when(clientRequestQuotaManager.maybeRecordAndGetThrottleTimeMs(any[RequestChannel.Request](),
any[Long])).thenReturn(0)
when(clientQuotaManager.maybeRecordAndGetThrottleTimeMs(
any[RequestChannel.Request](), anyDouble, anyLong)).thenReturn(0)
createKafkaApis().handleProduceRequest(request, RequestLocal.withThreadConfinedCaching)
val response = verifyNoThrottling[ProduceResponse](request)
assertEquals(1, response.data.responses.size)
val topicProduceResponse = response.data.responses.asScala.head
assertEquals(1, topicProduceResponse.partitionResponses.size)
val partitionProduceResponse = topicProduceResponse.partitionResponses.asScala.head
assertEquals(Errors.NOT_LEADER_OR_FOLLOWER, Errors.forCode(partitionProduceResponse.errorCode))
// LeaderId and epoch should be the same values inserted into the metadata cache
assertEquals(0, partitionProduceResponse.currentLeader.leaderId())
assertEquals(1, partitionProduceResponse.currentLeader.leaderEpoch())
assertEquals(1, response.data.nodeEndpoints.size)
val node = response.data.nodeEndpoints.asScala.head
assertEquals(0, node.nodeId)
assertEquals("broker0", node.host)
}
}
@Test
def testProduceResponseMetadataLookupErrorOnNotLeaderOrFollower(): Unit = {
val topic = "topic"
metadataCache = mock(classOf[ZkMetadataCache])
for (version <- 10 to ApiKeys.PRODUCE.latestVersion) {
reset(replicaManager, clientQuotaManager, clientRequestQuotaManager, requestChannel, txnCoordinator)
val responseCallback: ArgumentCaptor[Map[TopicPartition, PartitionResponse] => Unit] = ArgumentCaptor.forClass(classOf[Map[TopicPartition, PartitionResponse] => Unit])
val tp = new TopicPartition(topic, 0)
val produceRequest = ProduceRequest.forCurrentMagic(new ProduceRequestData()
.setTopicData(new ProduceRequestData.TopicProduceDataCollection(
Collections.singletonList(new ProduceRequestData.TopicProduceData()
.setName(tp.topic).setPartitionData(Collections.singletonList(
new ProduceRequestData.PartitionProduceData()
.setIndex(tp.partition)
.setRecords(MemoryRecords.withRecords(CompressionType.NONE, new SimpleRecord("test".getBytes))))))
.iterator))
.setAcks(1.toShort)
.setTimeoutMs(5000))
.build(version.toShort)
val request = buildRequest(produceRequest)
when(replicaManager.appendRecords(anyLong,
anyShort,
ArgumentMatchers.eq(false),
ArgumentMatchers.eq(AppendOrigin.CLIENT),
any(),
responseCallback.capture(),
any(),
any(),
any(),
any(),
any())
).thenAnswer(_ => responseCallback.getValue.apply(Map(tp -> new PartitionResponse(Errors.NOT_LEADER_OR_FOLLOWER))))
when(replicaManager.getPartitionOrError(tp)).thenAnswer(_ => Left(Errors.UNKNOWN_TOPIC_OR_PARTITION))
when(clientRequestQuotaManager.maybeRecordAndGetThrottleTimeMs(any[RequestChannel.Request](),
any[Long])).thenReturn(0)
when(clientQuotaManager.maybeRecordAndGetThrottleTimeMs(
any[RequestChannel.Request](), anyDouble, anyLong)).thenReturn(0)
when(metadataCache.contains(tp)).thenAnswer(_ => true)
when(metadataCache.getPartitionInfo(tp.topic(), tp.partition())).thenAnswer(_ => Option.empty)
when(metadataCache.getAliveBrokerNode(any(), any())).thenReturn(Option.empty)
createKafkaApis().handleProduceRequest(request, RequestLocal.withThreadConfinedCaching)
val response = verifyNoThrottling[ProduceResponse](request)
assertEquals(1, response.data.responses.size)
val topicProduceResponse = response.data.responses.asScala.head
assertEquals(1, topicProduceResponse.partitionResponses.size)
val partitionProduceResponse = topicProduceResponse.partitionResponses.asScala.head
assertEquals(Errors.NOT_LEADER_OR_FOLLOWER, Errors.forCode(partitionProduceResponse.errorCode))
assertEquals(-1, partitionProduceResponse.currentLeader.leaderId())
assertEquals(-1, partitionProduceResponse.currentLeader.leaderEpoch())
assertEquals(0, response.data.nodeEndpoints.size)
}
}
@Test
def testTransactionalParametersSetCorrectly(): Unit = {
val topic = "topic"
@ -3786,6 +3985,73 @@ class KafkaApisTest {
assertEquals(MemoryRecords.EMPTY, FetchResponse.recordsOrFail(partitionData))
}
@Test
def testFetchResponseContainsNewLeaderOnNotLeaderOrFollower(): Unit = {
val topicId = Uuid.randomUuid()
val tidp = new TopicIdPartition(topicId, new TopicPartition("foo", 0))
val tp = tidp.topicPartition
addTopicToMetadataCache(tp.topic, numPartitions = 1, numBrokers = 3, topicId)
when(replicaManager.getLogConfig(ArgumentMatchers.eq(tp))).thenReturn(Some(LogConfig.fromProps(
Collections.emptyMap(),
new Properties()
)))
val partition = mock(classOf[Partition])
val newLeaderId = 2
val newLeaderEpoch = 5
when(replicaManager.getPartitionOrError(tp)).thenAnswer(_ => Right(partition))
when(partition.leaderReplicaIdOpt).thenAnswer(_ => Some(newLeaderId))
when(partition.getLeaderEpoch).thenAnswer(_ => newLeaderEpoch)
when(replicaManager.fetchMessages(
any[FetchParams],
any[Seq[(TopicIdPartition, FetchRequest.PartitionData)]],
any[ReplicaQuota],
any[Seq[(TopicIdPartition, FetchPartitionData)] => Unit]()
)).thenAnswer(invocation => {
val callback = invocation.getArgument(3).asInstanceOf[Seq[(TopicIdPartition, FetchPartitionData)] => Unit]
callback(Seq(tidp -> new FetchPartitionData(Errors.NOT_LEADER_OR_FOLLOWER, UnifiedLog.UnknownOffset, UnifiedLog.UnknownOffset, MemoryRecords.EMPTY,
Optional.empty(), OptionalLong.empty(), Optional.empty(), OptionalInt.empty(), false)))
})
val fetchData = Map(tidp -> new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0, 0, 1000,
Optional.empty())).asJava
val fetchDataBuilder = Map(tp -> new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0, 0, 1000,
Optional.empty())).asJava
val fetchMetadata = new JFetchMetadata(0, 0)
val fetchContext = new FullFetchContext(time, new FetchSessionCache(1000, 100),
fetchMetadata, fetchData, false, false)
when(fetchManager.newContext(
any[Short],
any[JFetchMetadata],
any[Boolean],
any[util.Map[TopicIdPartition, FetchRequest.PartitionData]],
any[util.List[TopicIdPartition]],
any[util.Map[Uuid, String]])).thenReturn(fetchContext)
when(clientQuotaManager.maybeRecordAndGetThrottleTimeMs(
any[RequestChannel.Request](), anyDouble, anyLong)).thenReturn(0)
val fetchRequest = new FetchRequest.Builder(16, 16, -1, -1, 100, 0, fetchDataBuilder)
.build()
val request = buildRequest(fetchRequest)
createKafkaApis().handleFetchRequest(request)
val response = verifyNoThrottling[FetchResponse](request)
val responseData = response.responseData(metadataCache.topicIdsToNames(), 16)
val partitionData = responseData.get(tp)
assertEquals(Errors.NOT_LEADER_OR_FOLLOWER.code, partitionData.errorCode)
assertEquals(newLeaderId, partitionData.currentLeader.leaderId())
assertEquals(newLeaderEpoch, partitionData.currentLeader.leaderEpoch())
val node = response.data.nodeEndpoints.asScala.head
assertEquals(2, node.nodeId)
assertEquals("broker2", node.host)
}
@ParameterizedTest
@ApiKeyVersionsSource(apiKey = ApiKeys.JOIN_GROUP)
def testHandleJoinGroupRequest(version: Short): Unit = {