mirror of https://github.com/apache/kafka.git
MINOR: Add ShareFetch quota session verification test (#20164)
CI / build (push) Waiting to run
Details
CI / build (push) Waiting to run
Details
### Background As part of KIP-932 implementation, ShareFetch requests need to properly integrate with Kafka's quota system. This requires that ShareFetch requests extract and pass the correct session information (Principal, client address, client ID) to quota managers, ensuring consistent quota enforcement between ShareFetch and traditional Fetch requests. ### Changes This PR adds `testHandleShareFetchRequestQuotaTagsVerification()`, `testHandleShareAcknowledgeRequestQuotaTagsVerification` and `testHandleShareFetchWithAcknowledgementQuotaTagsVerification` to `KafkaApisTest`, which provides verification of quota tag extraction and session handling for ShareFetch and ShareAcknowledge requests. - Ensures ShareFetch/ShareAck requests are properly constructed with the correct client ID, principal, client address, and API key - Verifies the request context contains the expected session information - Uses `ArgumentCaptor` to capture the exact `Session` and `RequestChannel.Request` objects passed to quota managers - Verifies both `quotas.fetch.maybeRecordAndGetThrottleTimeMs()` and `quotas.request.maybeRecordAndGetThrottleTimeMs()` are called with correct parameters as and when needed. - Validates that the captured `RequestChannel.Request` object maintains the correct request context information - Ensures the client ID passed to quota managers matches the test-defined value - Verifies that in case of Acks being piggybacked on the fetch requests, the quotas are applied only once and not twice. Reviewers: Apoorv Mittal <apoorvmittal10@gmail.com>
This commit is contained in:
parent
2658f25238
commit
65a9337739
|
@ -2427,6 +2427,332 @@ class KafkaApisTest extends Logging {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
def testHandleShareFetchRequestQuotaTagsVerification(): Unit = {
|
||||
val topicName = "foo"
|
||||
val topicId = Uuid.randomUuid()
|
||||
val partitionIndex = 0
|
||||
metadataCache = initializeMetadataCacheWithShareGroupsEnabled()
|
||||
addTopicToMetadataCache(topicName, 1, topicId = topicId)
|
||||
val memberId: Uuid = Uuid.randomUuid()
|
||||
val groupId = "group"
|
||||
|
||||
// Create test principal and client address to verify quota tags
|
||||
val testPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "test-user")
|
||||
val testClientAddress = InetAddress.getByName("192.168.1.100")
|
||||
val testClientId = "test-client-id"
|
||||
|
||||
// Mock share partition manager responses
|
||||
val records = memoryRecords(10, 0)
|
||||
when(sharePartitionManager.fetchMessages(any(), any(), any(), anyInt(), anyInt(), anyInt(), any())).thenReturn(
|
||||
CompletableFuture.completedFuture(util.Map.of[TopicIdPartition, ShareFetchResponseData.PartitionData](
|
||||
new TopicIdPartition(topicId, new TopicPartition(topicName, partitionIndex)),
|
||||
new ShareFetchResponseData.PartitionData()
|
||||
.setErrorCode(Errors.NONE.code)
|
||||
.setAcknowledgeErrorCode(Errors.NONE.code)
|
||||
.setRecords(records)
|
||||
.setAcquiredRecords(new util.ArrayList(util.List.of(
|
||||
new ShareFetchResponseData.AcquiredRecords()
|
||||
.setFirstOffset(0)
|
||||
.setLastOffset(9)
|
||||
.setDeliveryCount(1)
|
||||
))))))
|
||||
|
||||
when(sharePartitionManager.newContext(any(), any(), any(), any(), any(), any())).thenReturn(
|
||||
new ShareSessionContext(new ShareRequestMetadata(memberId, 0), util.List.of(
|
||||
new TopicIdPartition(topicId, partitionIndex, topicName)))
|
||||
)
|
||||
|
||||
// Create argument captors to verify session information passed to quota managers
|
||||
val sessionCaptorFetch = ArgumentCaptor.forClass(classOf[Session])
|
||||
val clientIdCaptor = ArgumentCaptor.forClass(classOf[String])
|
||||
val requestCaptor = ArgumentCaptor.forClass(classOf[RequestChannel.Request])
|
||||
|
||||
// Mock quota manager responses and capture arguments
|
||||
when(quotas.fetch.maybeRecordAndGetThrottleTimeMs(
|
||||
sessionCaptorFetch.capture(), clientIdCaptor.capture(), anyDouble, anyLong)).thenReturn(0)
|
||||
when(quotas.request.maybeRecordAndGetThrottleTimeMs(
|
||||
requestCaptor.capture(), anyLong)).thenReturn(0)
|
||||
|
||||
// Create ShareFetch request
|
||||
val shareFetchRequestData = new ShareFetchRequestData()
|
||||
.setGroupId(groupId)
|
||||
.setMemberId(memberId.toString)
|
||||
.setShareSessionEpoch(0)
|
||||
.setTopics(new ShareFetchRequestData.FetchTopicCollection(util.List.of(new ShareFetchRequestData.FetchTopic()
|
||||
.setTopicId(topicId)
|
||||
.setPartitions(new ShareFetchRequestData.FetchPartitionCollection(util.List.of(
|
||||
new ShareFetchRequestData.FetchPartition()
|
||||
.setPartitionIndex(partitionIndex)
|
||||
).iterator))
|
||||
).iterator))
|
||||
|
||||
val shareFetchRequest = new ShareFetchRequest.Builder(shareFetchRequestData).build(ApiKeys.SHARE_FETCH.latestVersion)
|
||||
|
||||
// Create request with custom principal and client address to test quota tags
|
||||
val requestHeader = new RequestHeader(shareFetchRequest.apiKey, shareFetchRequest.version, testClientId, 0)
|
||||
val request = buildRequest(shareFetchRequest, testPrincipal, testClientAddress,
|
||||
ListenerName.forSecurityProtocol(SecurityProtocol.SSL), fromPrivilegedListener = false, Some(requestHeader), requestChannelMetrics)
|
||||
|
||||
// Test that the request itself contains the proper tags and information
|
||||
assertEquals(testClientId, request.header.clientId)
|
||||
assertEquals(testPrincipal, request.context.principal)
|
||||
assertEquals(testClientAddress, request.context.clientAddress)
|
||||
assertEquals(ApiKeys.SHARE_FETCH, request.header.apiKey)
|
||||
assertEquals("1", request.context.connectionId)
|
||||
|
||||
kafkaApis = createKafkaApis()
|
||||
kafkaApis.handleShareFetchRequest(request)
|
||||
val response = verifyNoThrottling[ShareFetchResponse](request)
|
||||
|
||||
// Verify response is successful
|
||||
val responseData = response.data()
|
||||
assertEquals(Errors.NONE.code, responseData.errorCode)
|
||||
|
||||
// Verify that quota methods were called and captured session information
|
||||
verify(quotas.fetch, times(1)).maybeRecordAndGetThrottleTimeMs(
|
||||
any[Session](), anyString, anyDouble, anyLong)
|
||||
verify(quotas.request, times(1)).maybeRecordAndGetThrottleTimeMs(
|
||||
any[RequestChannel.Request](), anyLong)
|
||||
|
||||
// Verify the Session data passed to fetch quota manager is exactly what was defined in the test
|
||||
val capturedSession = sessionCaptorFetch.getValue
|
||||
assertNotNull(capturedSession)
|
||||
assertNotNull(capturedSession.principal)
|
||||
assertEquals(KafkaPrincipal.USER_TYPE, capturedSession.principal.getPrincipalType)
|
||||
assertEquals("test-user", capturedSession.principal.getName)
|
||||
assertEquals(testClientAddress, capturedSession.clientAddress)
|
||||
assertEquals("test-user", capturedSession.sanitizedUser)
|
||||
|
||||
// Verify client ID passed to fetch quota manager matches what was defined
|
||||
val capturedClientId = clientIdCaptor.getValue
|
||||
assertEquals(testClientId, capturedClientId)
|
||||
|
||||
// Verify the Request data passed to request quota manager is exactly what was defined
|
||||
val capturedRequest = requestCaptor.getValue
|
||||
assertNotNull(capturedRequest)
|
||||
assertEquals(testClientId, capturedRequest.header.clientId)
|
||||
assertEquals(testPrincipal, capturedRequest.context.principal)
|
||||
assertEquals(testClientAddress, capturedRequest.context.clientAddress)
|
||||
assertEquals(ApiKeys.SHARE_FETCH, capturedRequest.header.apiKey)
|
||||
}
|
||||
|
||||
@Test
|
||||
def testHandleShareAcknowledgeRequestQuotaTagsVerification(): Unit = {
|
||||
val topicName = "foo"
|
||||
val topicId = Uuid.randomUuid()
|
||||
val partitionIndex = 0
|
||||
metadataCache = initializeMetadataCacheWithShareGroupsEnabled()
|
||||
addTopicToMetadataCache(topicName, 1, topicId = topicId)
|
||||
val memberId: Uuid = Uuid.randomUuid()
|
||||
val groupId = "group"
|
||||
|
||||
// Create test principal and client address to verify quota tags
|
||||
val testPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "test-user")
|
||||
val testClientAddress = InetAddress.getByName("192.168.1.100")
|
||||
val testClientId = "test-client-id"
|
||||
|
||||
// Mock share partition manager acknowledge response
|
||||
when(sharePartitionManager.acknowledge(any(), any(), any())).thenReturn(
|
||||
CompletableFuture.completedFuture(util.Map.of[TopicIdPartition, ShareAcknowledgeResponseData.PartitionData](
|
||||
new TopicIdPartition(topicId, new TopicPartition(topicName, partitionIndex)),
|
||||
new ShareAcknowledgeResponseData.PartitionData()
|
||||
.setPartitionIndex(partitionIndex)
|
||||
.setErrorCode(Errors.NONE.code))))
|
||||
|
||||
// Create argument captors to verify session information passed to quota managers
|
||||
val requestCaptor = ArgumentCaptor.forClass(classOf[RequestChannel.Request])
|
||||
|
||||
// Mock quota manager responses and capture arguments
|
||||
// For ShareAcknowledge, we only verify Request quota (not fetch quota)
|
||||
when(quotas.request.maybeRecordAndGetThrottleTimeMs(
|
||||
requestCaptor.capture(), anyLong)).thenReturn(0)
|
||||
|
||||
// Create ShareAcknowledge request
|
||||
val shareAcknowledgeRequestData = new ShareAcknowledgeRequestData()
|
||||
.setGroupId(groupId)
|
||||
.setMemberId(memberId.toString)
|
||||
.setShareSessionEpoch(1)
|
||||
.setTopics(new ShareAcknowledgeRequestData.AcknowledgeTopicCollection(
|
||||
util.List.of(new ShareAcknowledgeRequestData.AcknowledgeTopic()
|
||||
.setTopicId(topicId)
|
||||
.setPartitions(new ShareAcknowledgeRequestData.AcknowledgePartitionCollection(
|
||||
util.List.of(new ShareAcknowledgeRequestData.AcknowledgePartition()
|
||||
.setPartitionIndex(partitionIndex)
|
||||
.setAcknowledgementBatches(util.List.of(
|
||||
new ShareAcknowledgeRequestData.AcknowledgementBatch()
|
||||
.setFirstOffset(0)
|
||||
.setLastOffset(9)
|
||||
.setAcknowledgeTypes(util.List.of(1.toByte))
|
||||
))
|
||||
).iterator))
|
||||
).iterator))
|
||||
|
||||
val shareAcknowledgeRequest = new ShareAcknowledgeRequest.Builder(shareAcknowledgeRequestData).build(ApiKeys.SHARE_ACKNOWLEDGE.latestVersion)
|
||||
|
||||
// Create request with custom principal and client address to test quota tags
|
||||
val requestHeader = new RequestHeader(shareAcknowledgeRequest.apiKey, shareAcknowledgeRequest.version, testClientId, 0)
|
||||
val request = buildRequest(shareAcknowledgeRequest, testPrincipal, testClientAddress,
|
||||
ListenerName.forSecurityProtocol(SecurityProtocol.SSL), fromPrivilegedListener = false, Some(requestHeader), requestChannelMetrics)
|
||||
|
||||
// Test that the request itself contains the proper tags and information
|
||||
assertEquals(testClientId, request.header.clientId)
|
||||
assertEquals(testPrincipal, request.context.principal)
|
||||
assertEquals(testClientAddress, request.context.clientAddress)
|
||||
assertEquals(ApiKeys.SHARE_ACKNOWLEDGE, request.header.apiKey)
|
||||
assertEquals("1", request.context.connectionId)
|
||||
|
||||
kafkaApis = createKafkaApis()
|
||||
kafkaApis.handleShareAcknowledgeRequest(request)
|
||||
val response = verifyNoThrottling[ShareAcknowledgeResponse](request)
|
||||
|
||||
// Verify response is successful
|
||||
val responseData = response.data()
|
||||
assertEquals(Errors.NONE.code, responseData.errorCode)
|
||||
|
||||
// Verify that request quota method was called
|
||||
verify(quotas.request, times(1)).maybeRecordAndGetThrottleTimeMs(
|
||||
any[RequestChannel.Request](), anyLong)
|
||||
|
||||
// Verify that fetch quota method was NOT called (ShareAcknowledge only uses request quota)
|
||||
verify(quotas.fetch, times(0)).maybeRecordAndGetThrottleTimeMs(
|
||||
any[Session](), anyString, anyDouble, anyLong)
|
||||
|
||||
// Verify the Request data passed to request quota manager is exactly what was defined
|
||||
val capturedRequest = requestCaptor.getValue
|
||||
assertNotNull(capturedRequest)
|
||||
assertEquals(testClientId, capturedRequest.header.clientId)
|
||||
assertEquals(testPrincipal, capturedRequest.context.principal)
|
||||
assertEquals(testClientAddress, capturedRequest.context.clientAddress)
|
||||
assertEquals(ApiKeys.SHARE_ACKNOWLEDGE, capturedRequest.header.apiKey)
|
||||
}
|
||||
|
||||
@Test
|
||||
def testHandleShareFetchWithAcknowledgementQuotaTagsVerification(): Unit = {
|
||||
val topicName = "foo"
|
||||
val topicId = Uuid.randomUuid()
|
||||
val partitionIndex = 0
|
||||
metadataCache = initializeMetadataCacheWithShareGroupsEnabled()
|
||||
addTopicToMetadataCache(topicName, 1, topicId = topicId)
|
||||
val memberId: Uuid = Uuid.randomUuid()
|
||||
val groupId = "group"
|
||||
|
||||
// Create test principal and client address to verify quota tags
|
||||
val testPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "test-user")
|
||||
val testClientAddress = InetAddress.getByName("192.168.1.100")
|
||||
val testClientId = "test-client-id"
|
||||
|
||||
// Mock share partition manager responses for both fetch and acknowledge
|
||||
val records = memoryRecords(10, 0)
|
||||
when(sharePartitionManager.fetchMessages(any(), any(), any(), anyInt(), anyInt(), anyInt(), any())).thenReturn(
|
||||
CompletableFuture.completedFuture(util.Map.of[TopicIdPartition, ShareFetchResponseData.PartitionData](
|
||||
new TopicIdPartition(topicId, new TopicPartition(topicName, partitionIndex)),
|
||||
new ShareFetchResponseData.PartitionData()
|
||||
.setErrorCode(Errors.NONE.code)
|
||||
.setAcknowledgeErrorCode(Errors.NONE.code)
|
||||
.setRecords(records)
|
||||
.setAcquiredRecords(new util.ArrayList(util.List.of(
|
||||
new ShareFetchResponseData.AcquiredRecords()
|
||||
.setFirstOffset(0)
|
||||
.setLastOffset(9)
|
||||
.setDeliveryCount(1)
|
||||
))))))
|
||||
|
||||
when(sharePartitionManager.acknowledge(any(), any(), any())).thenReturn(
|
||||
CompletableFuture.completedFuture(util.Map.of[TopicIdPartition, ShareAcknowledgeResponseData.PartitionData](
|
||||
new TopicIdPartition(topicId, new TopicPartition(topicName, partitionIndex)),
|
||||
new ShareAcknowledgeResponseData.PartitionData()
|
||||
.setPartitionIndex(partitionIndex)
|
||||
.setErrorCode(Errors.NONE.code))))
|
||||
|
||||
when(sharePartitionManager.newContext(any(), any(), any(), any(), any(), any())).thenReturn(
|
||||
new ShareSessionContext(new ShareRequestMetadata(memberId, 1), util.List.of(
|
||||
new TopicIdPartition(topicId, partitionIndex, topicName)))
|
||||
)
|
||||
|
||||
// Create argument captors to verify session information passed to quota managers
|
||||
val sessionCaptorFetch = ArgumentCaptor.forClass(classOf[Session])
|
||||
val clientIdCaptor = ArgumentCaptor.forClass(classOf[String])
|
||||
val requestCaptor = ArgumentCaptor.forClass(classOf[RequestChannel.Request])
|
||||
|
||||
// Mock quota manager responses and capture arguments
|
||||
when(quotas.fetch.maybeRecordAndGetThrottleTimeMs(
|
||||
sessionCaptorFetch.capture(), clientIdCaptor.capture(), anyDouble, anyLong)).thenReturn(0)
|
||||
when(quotas.request.maybeRecordAndGetThrottleTimeMs(
|
||||
requestCaptor.capture(), anyLong)).thenReturn(0)
|
||||
|
||||
// Create ShareFetch request with acknowledgement data
|
||||
val shareFetchRequestData = new ShareFetchRequestData()
|
||||
.setGroupId(groupId)
|
||||
.setMemberId(memberId.toString)
|
||||
.setShareSessionEpoch(1)
|
||||
.setMaxWaitMs(100)
|
||||
.setMinBytes(1)
|
||||
.setMaxBytes(1000000)
|
||||
.setTopics(new ShareFetchRequestData.FetchTopicCollection(util.List.of(new ShareFetchRequestData.FetchTopic()
|
||||
.setTopicId(topicId)
|
||||
.setPartitions(new ShareFetchRequestData.FetchPartitionCollection(util.List.of(
|
||||
new ShareFetchRequestData.FetchPartition()
|
||||
.setPartitionIndex(partitionIndex)
|
||||
.setAcknowledgementBatches(util.List.of(
|
||||
new ShareFetchRequestData.AcknowledgementBatch()
|
||||
.setFirstOffset(0)
|
||||
.setLastOffset(9)
|
||||
.setAcknowledgeTypes(util.List.of(1.toByte))
|
||||
))
|
||||
).iterator))
|
||||
).iterator))
|
||||
|
||||
val shareFetchRequest = new ShareFetchRequest.Builder(shareFetchRequestData).build(ApiKeys.SHARE_FETCH.latestVersion)
|
||||
|
||||
// Create request with custom principal and client address to test quota tags
|
||||
val requestHeader = new RequestHeader(shareFetchRequest.apiKey, shareFetchRequest.version, testClientId, 0)
|
||||
val request = buildRequest(shareFetchRequest, testPrincipal, testClientAddress,
|
||||
ListenerName.forSecurityProtocol(SecurityProtocol.SSL), fromPrivilegedListener = false, Some(requestHeader), requestChannelMetrics)
|
||||
|
||||
// Test that the request itself contains the proper tags and information
|
||||
assertEquals(testClientId, request.header.clientId)
|
||||
assertEquals(testPrincipal, request.context.principal)
|
||||
assertEquals(testClientAddress, request.context.clientAddress)
|
||||
assertEquals(ApiKeys.SHARE_FETCH, request.header.apiKey)
|
||||
assertEquals("1", request.context.connectionId)
|
||||
|
||||
kafkaApis = createKafkaApis()
|
||||
kafkaApis.handleShareFetchRequest(request)
|
||||
val response = verifyNoThrottling[ShareFetchResponse](request)
|
||||
|
||||
// Verify response is successful
|
||||
val responseData = response.data()
|
||||
assertEquals(Errors.NONE.code, responseData.errorCode)
|
||||
|
||||
// Verify that quota methods were called exactly once each (not twice despite having acknowledgements)
|
||||
verify(quotas.fetch, times(1)).maybeRecordAndGetThrottleTimeMs(
|
||||
any[Session](), anyString, anyDouble, anyLong)
|
||||
verify(quotas.request, times(1)).maybeRecordAndGetThrottleTimeMs(
|
||||
any[RequestChannel.Request](), anyLong)
|
||||
|
||||
// Verify the Session data passed to fetch quota manager is exactly what was defined in the test
|
||||
val capturedSession = sessionCaptorFetch.getValue
|
||||
assertNotNull(capturedSession)
|
||||
assertNotNull(capturedSession.principal)
|
||||
assertEquals(KafkaPrincipal.USER_TYPE, capturedSession.principal.getPrincipalType)
|
||||
assertEquals("test-user", capturedSession.principal.getName)
|
||||
assertEquals(testClientAddress, capturedSession.clientAddress)
|
||||
assertEquals("test-user", capturedSession.sanitizedUser)
|
||||
|
||||
// Verify client ID passed to fetch quota manager matches what was defined
|
||||
val capturedClientId = clientIdCaptor.getValue
|
||||
assertEquals(testClientId, capturedClientId)
|
||||
|
||||
// Verify the Request data passed to request quota manager is exactly what was defined
|
||||
val capturedRequest = requestCaptor.getValue
|
||||
assertNotNull(capturedRequest)
|
||||
assertEquals(testClientId, capturedRequest.header.clientId)
|
||||
assertEquals(testPrincipal, capturedRequest.context.principal)
|
||||
assertEquals(testClientAddress, capturedRequest.context.clientAddress)
|
||||
assertEquals(ApiKeys.SHARE_FETCH, capturedRequest.header.apiKey)
|
||||
}
|
||||
|
||||
@Test
|
||||
def testProduceResponseContainsNewLeaderOnNotLeaderOrFollower(): Unit = {
|
||||
val topic = "topic"
|
||||
|
@ -9494,6 +9820,17 @@ class KafkaApisTest extends Logging {
|
|||
fromPrivilegedListener: Boolean = false,
|
||||
requestHeader: Option[RequestHeader] = None,
|
||||
requestMetrics: RequestChannelMetrics = requestChannelMetrics): RequestChannel.Request = {
|
||||
buildRequest(request, new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "Alice"), InetAddress.getLocalHost, listenerName,
|
||||
fromPrivilegedListener, requestHeader, requestMetrics)
|
||||
}
|
||||
|
||||
private def buildRequest(request: AbstractRequest,
|
||||
principal: KafkaPrincipal,
|
||||
clientAddress: InetAddress,
|
||||
listenerName: ListenerName,
|
||||
fromPrivilegedListener: Boolean,
|
||||
requestHeader: Option[RequestHeader],
|
||||
requestMetrics: RequestChannelMetrics): RequestChannel.Request = {
|
||||
val buffer = request.serializeWithHeader(
|
||||
requestHeader.getOrElse(new RequestHeader(request.apiKey, request.version, clientId, 0)))
|
||||
|
||||
|
@ -9503,8 +9840,8 @@ class KafkaApisTest extends Logging {
|
|||
// and have a non KafkaPrincipal.ANONYMOUS principal. This test is done before the check
|
||||
// for forwarding because after forwarding the context will have a different context.
|
||||
// We validate the context authenticated failure case in other integration tests.
|
||||
val context = new RequestContext(header, "1", InetAddress.getLocalHost, Optional.empty(),
|
||||
new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "Alice"), listenerName, SecurityProtocol.SSL,
|
||||
val context = new RequestContext(header, "1", clientAddress, Optional.empty(),
|
||||
principal, listenerName, SecurityProtocol.SSL,
|
||||
ClientInformation.EMPTY, fromPrivilegedListener, Optional.of(kafkaPrincipalSerde))
|
||||
new RequestChannel.Request(processor = 1, context = context, startTimeNanos = 0, MemoryPool.NONE, buffer,
|
||||
requestMetrics, envelope = None)
|
||||
|
|
Loading…
Reference in New Issue