MINOR: Add ShareFetch quota session verification test (#20164)
CI / build (push) Waiting to run Details

### Background
As part of KIP-932 implementation, ShareFetch requests need to properly
integrate with Kafka's quota system. This requires that ShareFetch
requests extract and pass the correct session information (Principal,
client address, client ID) to quota managers, ensuring consistent quota
enforcement between ShareFetch and traditional Fetch requests.

### Changes
This PR adds `testHandleShareFetchRequestQuotaTagsVerification()`,
`testHandleShareAcknowledgeRequestQuotaTagsVerification` and
`testHandleShareFetchWithAcknowledgementQuotaTagsVerification` to
`KafkaApisTest`, which provides verification of quota tag extraction and
session handling for ShareFetch and ShareAcknowledge requests.
   - Ensures ShareFetch/ShareAck requests are properly constructed with
the correct client ID, principal, client address, and API key
   - Verifies the request context contains the expected session
information
   - Uses `ArgumentCaptor` to capture the exact `Session` and
`RequestChannel.Request` objects passed to quota managers
   - Verifies both `quotas.fetch.maybeRecordAndGetThrottleTimeMs()` and
`quotas.request.maybeRecordAndGetThrottleTimeMs()` are called with
correct parameters as and when needed.
   - Validates that the captured `RequestChannel.Request` object
maintains the correct request context information
   - Ensures the client ID passed to quota managers matches the
test-defined value
   - Verifies that in case of Acks being piggybacked on the fetch
requests, the quotas are applied only once and not twice.

Reviewers: Apoorv Mittal <apoorvmittal10@gmail.com>
This commit is contained in:
Sanskar Jhajharia 2025-07-16 14:26:01 +05:30 committed by GitHub
parent 2658f25238
commit 65a9337739
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 339 additions and 2 deletions

View File

@ -2427,6 +2427,332 @@ class KafkaApisTest extends Logging {
}
}
@Test
def testHandleShareFetchRequestQuotaTagsVerification(): Unit = {
val topicName = "foo"
val topicId = Uuid.randomUuid()
val partitionIndex = 0
metadataCache = initializeMetadataCacheWithShareGroupsEnabled()
addTopicToMetadataCache(topicName, 1, topicId = topicId)
val memberId: Uuid = Uuid.randomUuid()
val groupId = "group"
// Create test principal and client address to verify quota tags
val testPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "test-user")
val testClientAddress = InetAddress.getByName("192.168.1.100")
val testClientId = "test-client-id"
// Mock share partition manager responses
val records = memoryRecords(10, 0)
when(sharePartitionManager.fetchMessages(any(), any(), any(), anyInt(), anyInt(), anyInt(), any())).thenReturn(
CompletableFuture.completedFuture(util.Map.of[TopicIdPartition, ShareFetchResponseData.PartitionData](
new TopicIdPartition(topicId, new TopicPartition(topicName, partitionIndex)),
new ShareFetchResponseData.PartitionData()
.setErrorCode(Errors.NONE.code)
.setAcknowledgeErrorCode(Errors.NONE.code)
.setRecords(records)
.setAcquiredRecords(new util.ArrayList(util.List.of(
new ShareFetchResponseData.AcquiredRecords()
.setFirstOffset(0)
.setLastOffset(9)
.setDeliveryCount(1)
))))))
when(sharePartitionManager.newContext(any(), any(), any(), any(), any(), any())).thenReturn(
new ShareSessionContext(new ShareRequestMetadata(memberId, 0), util.List.of(
new TopicIdPartition(topicId, partitionIndex, topicName)))
)
// Create argument captors to verify session information passed to quota managers
val sessionCaptorFetch = ArgumentCaptor.forClass(classOf[Session])
val clientIdCaptor = ArgumentCaptor.forClass(classOf[String])
val requestCaptor = ArgumentCaptor.forClass(classOf[RequestChannel.Request])
// Mock quota manager responses and capture arguments
when(quotas.fetch.maybeRecordAndGetThrottleTimeMs(
sessionCaptorFetch.capture(), clientIdCaptor.capture(), anyDouble, anyLong)).thenReturn(0)
when(quotas.request.maybeRecordAndGetThrottleTimeMs(
requestCaptor.capture(), anyLong)).thenReturn(0)
// Create ShareFetch request
val shareFetchRequestData = new ShareFetchRequestData()
.setGroupId(groupId)
.setMemberId(memberId.toString)
.setShareSessionEpoch(0)
.setTopics(new ShareFetchRequestData.FetchTopicCollection(util.List.of(new ShareFetchRequestData.FetchTopic()
.setTopicId(topicId)
.setPartitions(new ShareFetchRequestData.FetchPartitionCollection(util.List.of(
new ShareFetchRequestData.FetchPartition()
.setPartitionIndex(partitionIndex)
).iterator))
).iterator))
val shareFetchRequest = new ShareFetchRequest.Builder(shareFetchRequestData).build(ApiKeys.SHARE_FETCH.latestVersion)
// Create request with custom principal and client address to test quota tags
val requestHeader = new RequestHeader(shareFetchRequest.apiKey, shareFetchRequest.version, testClientId, 0)
val request = buildRequest(shareFetchRequest, testPrincipal, testClientAddress,
ListenerName.forSecurityProtocol(SecurityProtocol.SSL), fromPrivilegedListener = false, Some(requestHeader), requestChannelMetrics)
// Test that the request itself contains the proper tags and information
assertEquals(testClientId, request.header.clientId)
assertEquals(testPrincipal, request.context.principal)
assertEquals(testClientAddress, request.context.clientAddress)
assertEquals(ApiKeys.SHARE_FETCH, request.header.apiKey)
assertEquals("1", request.context.connectionId)
kafkaApis = createKafkaApis()
kafkaApis.handleShareFetchRequest(request)
val response = verifyNoThrottling[ShareFetchResponse](request)
// Verify response is successful
val responseData = response.data()
assertEquals(Errors.NONE.code, responseData.errorCode)
// Verify that quota methods were called and captured session information
verify(quotas.fetch, times(1)).maybeRecordAndGetThrottleTimeMs(
any[Session](), anyString, anyDouble, anyLong)
verify(quotas.request, times(1)).maybeRecordAndGetThrottleTimeMs(
any[RequestChannel.Request](), anyLong)
// Verify the Session data passed to fetch quota manager is exactly what was defined in the test
val capturedSession = sessionCaptorFetch.getValue
assertNotNull(capturedSession)
assertNotNull(capturedSession.principal)
assertEquals(KafkaPrincipal.USER_TYPE, capturedSession.principal.getPrincipalType)
assertEquals("test-user", capturedSession.principal.getName)
assertEquals(testClientAddress, capturedSession.clientAddress)
assertEquals("test-user", capturedSession.sanitizedUser)
// Verify client ID passed to fetch quota manager matches what was defined
val capturedClientId = clientIdCaptor.getValue
assertEquals(testClientId, capturedClientId)
// Verify the Request data passed to request quota manager is exactly what was defined
val capturedRequest = requestCaptor.getValue
assertNotNull(capturedRequest)
assertEquals(testClientId, capturedRequest.header.clientId)
assertEquals(testPrincipal, capturedRequest.context.principal)
assertEquals(testClientAddress, capturedRequest.context.clientAddress)
assertEquals(ApiKeys.SHARE_FETCH, capturedRequest.header.apiKey)
}
@Test
def testHandleShareAcknowledgeRequestQuotaTagsVerification(): Unit = {
val topicName = "foo"
val topicId = Uuid.randomUuid()
val partitionIndex = 0
metadataCache = initializeMetadataCacheWithShareGroupsEnabled()
addTopicToMetadataCache(topicName, 1, topicId = topicId)
val memberId: Uuid = Uuid.randomUuid()
val groupId = "group"
// Create test principal and client address to verify quota tags
val testPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "test-user")
val testClientAddress = InetAddress.getByName("192.168.1.100")
val testClientId = "test-client-id"
// Mock share partition manager acknowledge response
when(sharePartitionManager.acknowledge(any(), any(), any())).thenReturn(
CompletableFuture.completedFuture(util.Map.of[TopicIdPartition, ShareAcknowledgeResponseData.PartitionData](
new TopicIdPartition(topicId, new TopicPartition(topicName, partitionIndex)),
new ShareAcknowledgeResponseData.PartitionData()
.setPartitionIndex(partitionIndex)
.setErrorCode(Errors.NONE.code))))
// Create argument captors to verify session information passed to quota managers
val requestCaptor = ArgumentCaptor.forClass(classOf[RequestChannel.Request])
// Mock quota manager responses and capture arguments
// For ShareAcknowledge, we only verify Request quota (not fetch quota)
when(quotas.request.maybeRecordAndGetThrottleTimeMs(
requestCaptor.capture(), anyLong)).thenReturn(0)
// Create ShareAcknowledge request
val shareAcknowledgeRequestData = new ShareAcknowledgeRequestData()
.setGroupId(groupId)
.setMemberId(memberId.toString)
.setShareSessionEpoch(1)
.setTopics(new ShareAcknowledgeRequestData.AcknowledgeTopicCollection(
util.List.of(new ShareAcknowledgeRequestData.AcknowledgeTopic()
.setTopicId(topicId)
.setPartitions(new ShareAcknowledgeRequestData.AcknowledgePartitionCollection(
util.List.of(new ShareAcknowledgeRequestData.AcknowledgePartition()
.setPartitionIndex(partitionIndex)
.setAcknowledgementBatches(util.List.of(
new ShareAcknowledgeRequestData.AcknowledgementBatch()
.setFirstOffset(0)
.setLastOffset(9)
.setAcknowledgeTypes(util.List.of(1.toByte))
))
).iterator))
).iterator))
val shareAcknowledgeRequest = new ShareAcknowledgeRequest.Builder(shareAcknowledgeRequestData).build(ApiKeys.SHARE_ACKNOWLEDGE.latestVersion)
// Create request with custom principal and client address to test quota tags
val requestHeader = new RequestHeader(shareAcknowledgeRequest.apiKey, shareAcknowledgeRequest.version, testClientId, 0)
val request = buildRequest(shareAcknowledgeRequest, testPrincipal, testClientAddress,
ListenerName.forSecurityProtocol(SecurityProtocol.SSL), fromPrivilegedListener = false, Some(requestHeader), requestChannelMetrics)
// Test that the request itself contains the proper tags and information
assertEquals(testClientId, request.header.clientId)
assertEquals(testPrincipal, request.context.principal)
assertEquals(testClientAddress, request.context.clientAddress)
assertEquals(ApiKeys.SHARE_ACKNOWLEDGE, request.header.apiKey)
assertEquals("1", request.context.connectionId)
kafkaApis = createKafkaApis()
kafkaApis.handleShareAcknowledgeRequest(request)
val response = verifyNoThrottling[ShareAcknowledgeResponse](request)
// Verify response is successful
val responseData = response.data()
assertEquals(Errors.NONE.code, responseData.errorCode)
// Verify that request quota method was called
verify(quotas.request, times(1)).maybeRecordAndGetThrottleTimeMs(
any[RequestChannel.Request](), anyLong)
// Verify that fetch quota method was NOT called (ShareAcknowledge only uses request quota)
verify(quotas.fetch, times(0)).maybeRecordAndGetThrottleTimeMs(
any[Session](), anyString, anyDouble, anyLong)
// Verify the Request data passed to request quota manager is exactly what was defined
val capturedRequest = requestCaptor.getValue
assertNotNull(capturedRequest)
assertEquals(testClientId, capturedRequest.header.clientId)
assertEquals(testPrincipal, capturedRequest.context.principal)
assertEquals(testClientAddress, capturedRequest.context.clientAddress)
assertEquals(ApiKeys.SHARE_ACKNOWLEDGE, capturedRequest.header.apiKey)
}
@Test
def testHandleShareFetchWithAcknowledgementQuotaTagsVerification(): Unit = {
val topicName = "foo"
val topicId = Uuid.randomUuid()
val partitionIndex = 0
metadataCache = initializeMetadataCacheWithShareGroupsEnabled()
addTopicToMetadataCache(topicName, 1, topicId = topicId)
val memberId: Uuid = Uuid.randomUuid()
val groupId = "group"
// Create test principal and client address to verify quota tags
val testPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "test-user")
val testClientAddress = InetAddress.getByName("192.168.1.100")
val testClientId = "test-client-id"
// Mock share partition manager responses for both fetch and acknowledge
val records = memoryRecords(10, 0)
when(sharePartitionManager.fetchMessages(any(), any(), any(), anyInt(), anyInt(), anyInt(), any())).thenReturn(
CompletableFuture.completedFuture(util.Map.of[TopicIdPartition, ShareFetchResponseData.PartitionData](
new TopicIdPartition(topicId, new TopicPartition(topicName, partitionIndex)),
new ShareFetchResponseData.PartitionData()
.setErrorCode(Errors.NONE.code)
.setAcknowledgeErrorCode(Errors.NONE.code)
.setRecords(records)
.setAcquiredRecords(new util.ArrayList(util.List.of(
new ShareFetchResponseData.AcquiredRecords()
.setFirstOffset(0)
.setLastOffset(9)
.setDeliveryCount(1)
))))))
when(sharePartitionManager.acknowledge(any(), any(), any())).thenReturn(
CompletableFuture.completedFuture(util.Map.of[TopicIdPartition, ShareAcknowledgeResponseData.PartitionData](
new TopicIdPartition(topicId, new TopicPartition(topicName, partitionIndex)),
new ShareAcknowledgeResponseData.PartitionData()
.setPartitionIndex(partitionIndex)
.setErrorCode(Errors.NONE.code))))
when(sharePartitionManager.newContext(any(), any(), any(), any(), any(), any())).thenReturn(
new ShareSessionContext(new ShareRequestMetadata(memberId, 1), util.List.of(
new TopicIdPartition(topicId, partitionIndex, topicName)))
)
// Create argument captors to verify session information passed to quota managers
val sessionCaptorFetch = ArgumentCaptor.forClass(classOf[Session])
val clientIdCaptor = ArgumentCaptor.forClass(classOf[String])
val requestCaptor = ArgumentCaptor.forClass(classOf[RequestChannel.Request])
// Mock quota manager responses and capture arguments
when(quotas.fetch.maybeRecordAndGetThrottleTimeMs(
sessionCaptorFetch.capture(), clientIdCaptor.capture(), anyDouble, anyLong)).thenReturn(0)
when(quotas.request.maybeRecordAndGetThrottleTimeMs(
requestCaptor.capture(), anyLong)).thenReturn(0)
// Create ShareFetch request with acknowledgement data
val shareFetchRequestData = new ShareFetchRequestData()
.setGroupId(groupId)
.setMemberId(memberId.toString)
.setShareSessionEpoch(1)
.setMaxWaitMs(100)
.setMinBytes(1)
.setMaxBytes(1000000)
.setTopics(new ShareFetchRequestData.FetchTopicCollection(util.List.of(new ShareFetchRequestData.FetchTopic()
.setTopicId(topicId)
.setPartitions(new ShareFetchRequestData.FetchPartitionCollection(util.List.of(
new ShareFetchRequestData.FetchPartition()
.setPartitionIndex(partitionIndex)
.setAcknowledgementBatches(util.List.of(
new ShareFetchRequestData.AcknowledgementBatch()
.setFirstOffset(0)
.setLastOffset(9)
.setAcknowledgeTypes(util.List.of(1.toByte))
))
).iterator))
).iterator))
val shareFetchRequest = new ShareFetchRequest.Builder(shareFetchRequestData).build(ApiKeys.SHARE_FETCH.latestVersion)
// Create request with custom principal and client address to test quota tags
val requestHeader = new RequestHeader(shareFetchRequest.apiKey, shareFetchRequest.version, testClientId, 0)
val request = buildRequest(shareFetchRequest, testPrincipal, testClientAddress,
ListenerName.forSecurityProtocol(SecurityProtocol.SSL), fromPrivilegedListener = false, Some(requestHeader), requestChannelMetrics)
// Test that the request itself contains the proper tags and information
assertEquals(testClientId, request.header.clientId)
assertEquals(testPrincipal, request.context.principal)
assertEquals(testClientAddress, request.context.clientAddress)
assertEquals(ApiKeys.SHARE_FETCH, request.header.apiKey)
assertEquals("1", request.context.connectionId)
kafkaApis = createKafkaApis()
kafkaApis.handleShareFetchRequest(request)
val response = verifyNoThrottling[ShareFetchResponse](request)
// Verify response is successful
val responseData = response.data()
assertEquals(Errors.NONE.code, responseData.errorCode)
// Verify that quota methods were called exactly once each (not twice despite having acknowledgements)
verify(quotas.fetch, times(1)).maybeRecordAndGetThrottleTimeMs(
any[Session](), anyString, anyDouble, anyLong)
verify(quotas.request, times(1)).maybeRecordAndGetThrottleTimeMs(
any[RequestChannel.Request](), anyLong)
// Verify the Session data passed to fetch quota manager is exactly what was defined in the test
val capturedSession = sessionCaptorFetch.getValue
assertNotNull(capturedSession)
assertNotNull(capturedSession.principal)
assertEquals(KafkaPrincipal.USER_TYPE, capturedSession.principal.getPrincipalType)
assertEquals("test-user", capturedSession.principal.getName)
assertEquals(testClientAddress, capturedSession.clientAddress)
assertEquals("test-user", capturedSession.sanitizedUser)
// Verify client ID passed to fetch quota manager matches what was defined
val capturedClientId = clientIdCaptor.getValue
assertEquals(testClientId, capturedClientId)
// Verify the Request data passed to request quota manager is exactly what was defined
val capturedRequest = requestCaptor.getValue
assertNotNull(capturedRequest)
assertEquals(testClientId, capturedRequest.header.clientId)
assertEquals(testPrincipal, capturedRequest.context.principal)
assertEquals(testClientAddress, capturedRequest.context.clientAddress)
assertEquals(ApiKeys.SHARE_FETCH, capturedRequest.header.apiKey)
}
@Test
def testProduceResponseContainsNewLeaderOnNotLeaderOrFollower(): Unit = {
val topic = "topic"
@ -9494,6 +9820,17 @@ class KafkaApisTest extends Logging {
fromPrivilegedListener: Boolean = false,
requestHeader: Option[RequestHeader] = None,
requestMetrics: RequestChannelMetrics = requestChannelMetrics): RequestChannel.Request = {
buildRequest(request, new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "Alice"), InetAddress.getLocalHost, listenerName,
fromPrivilegedListener, requestHeader, requestMetrics)
}
private def buildRequest(request: AbstractRequest,
principal: KafkaPrincipal,
clientAddress: InetAddress,
listenerName: ListenerName,
fromPrivilegedListener: Boolean,
requestHeader: Option[RequestHeader],
requestMetrics: RequestChannelMetrics): RequestChannel.Request = {
val buffer = request.serializeWithHeader(
requestHeader.getOrElse(new RequestHeader(request.apiKey, request.version, clientId, 0)))
@ -9503,8 +9840,8 @@ class KafkaApisTest extends Logging {
// and have a non KafkaPrincipal.ANONYMOUS principal. This test is done before the check
// for forwarding because after forwarding the context will have a different context.
// We validate the context authenticated failure case in other integration tests.
val context = new RequestContext(header, "1", InetAddress.getLocalHost, Optional.empty(),
new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "Alice"), listenerName, SecurityProtocol.SSL,
val context = new RequestContext(header, "1", clientAddress, Optional.empty(),
principal, listenerName, SecurityProtocol.SSL,
ClientInformation.EMPTY, fromPrivilegedListener, Optional.of(kafkaPrincipalSerde))
new RequestChannel.Request(processor = 1, context = context, startTimeNanos = 0, MemoryPool.NONE, buffer,
requestMetrics, envelope = None)