diff --git a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala index 9ed20fd7ac5..d7b35de63e5 100644 --- a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala +++ b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala @@ -2427,6 +2427,332 @@ class KafkaApisTest extends Logging { } } + @Test + def testHandleShareFetchRequestQuotaTagsVerification(): Unit = { + val topicName = "foo" + val topicId = Uuid.randomUuid() + val partitionIndex = 0 + metadataCache = initializeMetadataCacheWithShareGroupsEnabled() + addTopicToMetadataCache(topicName, 1, topicId = topicId) + val memberId: Uuid = Uuid.randomUuid() + val groupId = "group" + + // Create test principal and client address to verify quota tags + val testPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "test-user") + val testClientAddress = InetAddress.getByName("192.168.1.100") + val testClientId = "test-client-id" + + // Mock share partition manager responses + val records = memoryRecords(10, 0) + when(sharePartitionManager.fetchMessages(any(), any(), any(), anyInt(), anyInt(), anyInt(), any())).thenReturn( + CompletableFuture.completedFuture(util.Map.of[TopicIdPartition, ShareFetchResponseData.PartitionData]( + new TopicIdPartition(topicId, new TopicPartition(topicName, partitionIndex)), + new ShareFetchResponseData.PartitionData() + .setErrorCode(Errors.NONE.code) + .setAcknowledgeErrorCode(Errors.NONE.code) + .setRecords(records) + .setAcquiredRecords(new util.ArrayList(util.List.of( + new ShareFetchResponseData.AcquiredRecords() + .setFirstOffset(0) + .setLastOffset(9) + .setDeliveryCount(1) + )))))) + + when(sharePartitionManager.newContext(any(), any(), any(), any(), any(), any())).thenReturn( + new ShareSessionContext(new ShareRequestMetadata(memberId, 0), util.List.of( + new TopicIdPartition(topicId, partitionIndex, topicName))) + ) + + // Create argument captors to verify session information passed to quota managers + val sessionCaptorFetch = ArgumentCaptor.forClass(classOf[Session]) + val clientIdCaptor = ArgumentCaptor.forClass(classOf[String]) + val requestCaptor = ArgumentCaptor.forClass(classOf[RequestChannel.Request]) + + // Mock quota manager responses and capture arguments + when(quotas.fetch.maybeRecordAndGetThrottleTimeMs( + sessionCaptorFetch.capture(), clientIdCaptor.capture(), anyDouble, anyLong)).thenReturn(0) + when(quotas.request.maybeRecordAndGetThrottleTimeMs( + requestCaptor.capture(), anyLong)).thenReturn(0) + + // Create ShareFetch request + val shareFetchRequestData = new ShareFetchRequestData() + .setGroupId(groupId) + .setMemberId(memberId.toString) + .setShareSessionEpoch(0) + .setTopics(new ShareFetchRequestData.FetchTopicCollection(util.List.of(new ShareFetchRequestData.FetchTopic() + .setTopicId(topicId) + .setPartitions(new ShareFetchRequestData.FetchPartitionCollection(util.List.of( + new ShareFetchRequestData.FetchPartition() + .setPartitionIndex(partitionIndex) + ).iterator)) + ).iterator)) + + val shareFetchRequest = new ShareFetchRequest.Builder(shareFetchRequestData).build(ApiKeys.SHARE_FETCH.latestVersion) + + // Create request with custom principal and client address to test quota tags + val requestHeader = new RequestHeader(shareFetchRequest.apiKey, shareFetchRequest.version, testClientId, 0) + val request = buildRequest(shareFetchRequest, testPrincipal, testClientAddress, + ListenerName.forSecurityProtocol(SecurityProtocol.SSL), fromPrivilegedListener = false, Some(requestHeader), requestChannelMetrics) + + // Test that the request itself contains the proper tags and information + assertEquals(testClientId, request.header.clientId) + assertEquals(testPrincipal, request.context.principal) + assertEquals(testClientAddress, request.context.clientAddress) + assertEquals(ApiKeys.SHARE_FETCH, request.header.apiKey) + assertEquals("1", request.context.connectionId) + + kafkaApis = createKafkaApis() + kafkaApis.handleShareFetchRequest(request) + val response = verifyNoThrottling[ShareFetchResponse](request) + + // Verify response is successful + val responseData = response.data() + assertEquals(Errors.NONE.code, responseData.errorCode) + + // Verify that quota methods were called and captured session information + verify(quotas.fetch, times(1)).maybeRecordAndGetThrottleTimeMs( + any[Session](), anyString, anyDouble, anyLong) + verify(quotas.request, times(1)).maybeRecordAndGetThrottleTimeMs( + any[RequestChannel.Request](), anyLong) + + // Verify the Session data passed to fetch quota manager is exactly what was defined in the test + val capturedSession = sessionCaptorFetch.getValue + assertNotNull(capturedSession) + assertNotNull(capturedSession.principal) + assertEquals(KafkaPrincipal.USER_TYPE, capturedSession.principal.getPrincipalType) + assertEquals("test-user", capturedSession.principal.getName) + assertEquals(testClientAddress, capturedSession.clientAddress) + assertEquals("test-user", capturedSession.sanitizedUser) + + // Verify client ID passed to fetch quota manager matches what was defined + val capturedClientId = clientIdCaptor.getValue + assertEquals(testClientId, capturedClientId) + + // Verify the Request data passed to request quota manager is exactly what was defined + val capturedRequest = requestCaptor.getValue + assertNotNull(capturedRequest) + assertEquals(testClientId, capturedRequest.header.clientId) + assertEquals(testPrincipal, capturedRequest.context.principal) + assertEquals(testClientAddress, capturedRequest.context.clientAddress) + assertEquals(ApiKeys.SHARE_FETCH, capturedRequest.header.apiKey) + } + + @Test + def testHandleShareAcknowledgeRequestQuotaTagsVerification(): Unit = { + val topicName = "foo" + val topicId = Uuid.randomUuid() + val partitionIndex = 0 + metadataCache = initializeMetadataCacheWithShareGroupsEnabled() + addTopicToMetadataCache(topicName, 1, topicId = topicId) + val memberId: Uuid = Uuid.randomUuid() + val groupId = "group" + + // Create test principal and client address to verify quota tags + val testPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "test-user") + val testClientAddress = InetAddress.getByName("192.168.1.100") + val testClientId = "test-client-id" + + // Mock share partition manager acknowledge response + when(sharePartitionManager.acknowledge(any(), any(), any())).thenReturn( + CompletableFuture.completedFuture(util.Map.of[TopicIdPartition, ShareAcknowledgeResponseData.PartitionData]( + new TopicIdPartition(topicId, new TopicPartition(topicName, partitionIndex)), + new ShareAcknowledgeResponseData.PartitionData() + .setPartitionIndex(partitionIndex) + .setErrorCode(Errors.NONE.code)))) + + // Create argument captors to verify session information passed to quota managers + val requestCaptor = ArgumentCaptor.forClass(classOf[RequestChannel.Request]) + + // Mock quota manager responses and capture arguments + // For ShareAcknowledge, we only verify Request quota (not fetch quota) + when(quotas.request.maybeRecordAndGetThrottleTimeMs( + requestCaptor.capture(), anyLong)).thenReturn(0) + + // Create ShareAcknowledge request + val shareAcknowledgeRequestData = new ShareAcknowledgeRequestData() + .setGroupId(groupId) + .setMemberId(memberId.toString) + .setShareSessionEpoch(1) + .setTopics(new ShareAcknowledgeRequestData.AcknowledgeTopicCollection( + util.List.of(new ShareAcknowledgeRequestData.AcknowledgeTopic() + .setTopicId(topicId) + .setPartitions(new ShareAcknowledgeRequestData.AcknowledgePartitionCollection( + util.List.of(new ShareAcknowledgeRequestData.AcknowledgePartition() + .setPartitionIndex(partitionIndex) + .setAcknowledgementBatches(util.List.of( + new ShareAcknowledgeRequestData.AcknowledgementBatch() + .setFirstOffset(0) + .setLastOffset(9) + .setAcknowledgeTypes(util.List.of(1.toByte)) + )) + ).iterator)) + ).iterator)) + + val shareAcknowledgeRequest = new ShareAcknowledgeRequest.Builder(shareAcknowledgeRequestData).build(ApiKeys.SHARE_ACKNOWLEDGE.latestVersion) + + // Create request with custom principal and client address to test quota tags + val requestHeader = new RequestHeader(shareAcknowledgeRequest.apiKey, shareAcknowledgeRequest.version, testClientId, 0) + val request = buildRequest(shareAcknowledgeRequest, testPrincipal, testClientAddress, + ListenerName.forSecurityProtocol(SecurityProtocol.SSL), fromPrivilegedListener = false, Some(requestHeader), requestChannelMetrics) + + // Test that the request itself contains the proper tags and information + assertEquals(testClientId, request.header.clientId) + assertEquals(testPrincipal, request.context.principal) + assertEquals(testClientAddress, request.context.clientAddress) + assertEquals(ApiKeys.SHARE_ACKNOWLEDGE, request.header.apiKey) + assertEquals("1", request.context.connectionId) + + kafkaApis = createKafkaApis() + kafkaApis.handleShareAcknowledgeRequest(request) + val response = verifyNoThrottling[ShareAcknowledgeResponse](request) + + // Verify response is successful + val responseData = response.data() + assertEquals(Errors.NONE.code, responseData.errorCode) + + // Verify that request quota method was called + verify(quotas.request, times(1)).maybeRecordAndGetThrottleTimeMs( + any[RequestChannel.Request](), anyLong) + + // Verify that fetch quota method was NOT called (ShareAcknowledge only uses request quota) + verify(quotas.fetch, times(0)).maybeRecordAndGetThrottleTimeMs( + any[Session](), anyString, anyDouble, anyLong) + + // Verify the Request data passed to request quota manager is exactly what was defined + val capturedRequest = requestCaptor.getValue + assertNotNull(capturedRequest) + assertEquals(testClientId, capturedRequest.header.clientId) + assertEquals(testPrincipal, capturedRequest.context.principal) + assertEquals(testClientAddress, capturedRequest.context.clientAddress) + assertEquals(ApiKeys.SHARE_ACKNOWLEDGE, capturedRequest.header.apiKey) + } + + @Test + def testHandleShareFetchWithAcknowledgementQuotaTagsVerification(): Unit = { + val topicName = "foo" + val topicId = Uuid.randomUuid() + val partitionIndex = 0 + metadataCache = initializeMetadataCacheWithShareGroupsEnabled() + addTopicToMetadataCache(topicName, 1, topicId = topicId) + val memberId: Uuid = Uuid.randomUuid() + val groupId = "group" + + // Create test principal and client address to verify quota tags + val testPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "test-user") + val testClientAddress = InetAddress.getByName("192.168.1.100") + val testClientId = "test-client-id" + + // Mock share partition manager responses for both fetch and acknowledge + val records = memoryRecords(10, 0) + when(sharePartitionManager.fetchMessages(any(), any(), any(), anyInt(), anyInt(), anyInt(), any())).thenReturn( + CompletableFuture.completedFuture(util.Map.of[TopicIdPartition, ShareFetchResponseData.PartitionData]( + new TopicIdPartition(topicId, new TopicPartition(topicName, partitionIndex)), + new ShareFetchResponseData.PartitionData() + .setErrorCode(Errors.NONE.code) + .setAcknowledgeErrorCode(Errors.NONE.code) + .setRecords(records) + .setAcquiredRecords(new util.ArrayList(util.List.of( + new ShareFetchResponseData.AcquiredRecords() + .setFirstOffset(0) + .setLastOffset(9) + .setDeliveryCount(1) + )))))) + + when(sharePartitionManager.acknowledge(any(), any(), any())).thenReturn( + CompletableFuture.completedFuture(util.Map.of[TopicIdPartition, ShareAcknowledgeResponseData.PartitionData]( + new TopicIdPartition(topicId, new TopicPartition(topicName, partitionIndex)), + new ShareAcknowledgeResponseData.PartitionData() + .setPartitionIndex(partitionIndex) + .setErrorCode(Errors.NONE.code)))) + + when(sharePartitionManager.newContext(any(), any(), any(), any(), any(), any())).thenReturn( + new ShareSessionContext(new ShareRequestMetadata(memberId, 1), util.List.of( + new TopicIdPartition(topicId, partitionIndex, topicName))) + ) + + // Create argument captors to verify session information passed to quota managers + val sessionCaptorFetch = ArgumentCaptor.forClass(classOf[Session]) + val clientIdCaptor = ArgumentCaptor.forClass(classOf[String]) + val requestCaptor = ArgumentCaptor.forClass(classOf[RequestChannel.Request]) + + // Mock quota manager responses and capture arguments + when(quotas.fetch.maybeRecordAndGetThrottleTimeMs( + sessionCaptorFetch.capture(), clientIdCaptor.capture(), anyDouble, anyLong)).thenReturn(0) + when(quotas.request.maybeRecordAndGetThrottleTimeMs( + requestCaptor.capture(), anyLong)).thenReturn(0) + + // Create ShareFetch request with acknowledgement data + val shareFetchRequestData = new ShareFetchRequestData() + .setGroupId(groupId) + .setMemberId(memberId.toString) + .setShareSessionEpoch(1) + .setMaxWaitMs(100) + .setMinBytes(1) + .setMaxBytes(1000000) + .setTopics(new ShareFetchRequestData.FetchTopicCollection(util.List.of(new ShareFetchRequestData.FetchTopic() + .setTopicId(topicId) + .setPartitions(new ShareFetchRequestData.FetchPartitionCollection(util.List.of( + new ShareFetchRequestData.FetchPartition() + .setPartitionIndex(partitionIndex) + .setAcknowledgementBatches(util.List.of( + new ShareFetchRequestData.AcknowledgementBatch() + .setFirstOffset(0) + .setLastOffset(9) + .setAcknowledgeTypes(util.List.of(1.toByte)) + )) + ).iterator)) + ).iterator)) + + val shareFetchRequest = new ShareFetchRequest.Builder(shareFetchRequestData).build(ApiKeys.SHARE_FETCH.latestVersion) + + // Create request with custom principal and client address to test quota tags + val requestHeader = new RequestHeader(shareFetchRequest.apiKey, shareFetchRequest.version, testClientId, 0) + val request = buildRequest(shareFetchRequest, testPrincipal, testClientAddress, + ListenerName.forSecurityProtocol(SecurityProtocol.SSL), fromPrivilegedListener = false, Some(requestHeader), requestChannelMetrics) + + // Test that the request itself contains the proper tags and information + assertEquals(testClientId, request.header.clientId) + assertEquals(testPrincipal, request.context.principal) + assertEquals(testClientAddress, request.context.clientAddress) + assertEquals(ApiKeys.SHARE_FETCH, request.header.apiKey) + assertEquals("1", request.context.connectionId) + + kafkaApis = createKafkaApis() + kafkaApis.handleShareFetchRequest(request) + val response = verifyNoThrottling[ShareFetchResponse](request) + + // Verify response is successful + val responseData = response.data() + assertEquals(Errors.NONE.code, responseData.errorCode) + + // Verify that quota methods were called exactly once each (not twice despite having acknowledgements) + verify(quotas.fetch, times(1)).maybeRecordAndGetThrottleTimeMs( + any[Session](), anyString, anyDouble, anyLong) + verify(quotas.request, times(1)).maybeRecordAndGetThrottleTimeMs( + any[RequestChannel.Request](), anyLong) + + // Verify the Session data passed to fetch quota manager is exactly what was defined in the test + val capturedSession = sessionCaptorFetch.getValue + assertNotNull(capturedSession) + assertNotNull(capturedSession.principal) + assertEquals(KafkaPrincipal.USER_TYPE, capturedSession.principal.getPrincipalType) + assertEquals("test-user", capturedSession.principal.getName) + assertEquals(testClientAddress, capturedSession.clientAddress) + assertEquals("test-user", capturedSession.sanitizedUser) + + // Verify client ID passed to fetch quota manager matches what was defined + val capturedClientId = clientIdCaptor.getValue + assertEquals(testClientId, capturedClientId) + + // Verify the Request data passed to request quota manager is exactly what was defined + val capturedRequest = requestCaptor.getValue + assertNotNull(capturedRequest) + assertEquals(testClientId, capturedRequest.header.clientId) + assertEquals(testPrincipal, capturedRequest.context.principal) + assertEquals(testClientAddress, capturedRequest.context.clientAddress) + assertEquals(ApiKeys.SHARE_FETCH, capturedRequest.header.apiKey) + } + @Test def testProduceResponseContainsNewLeaderOnNotLeaderOrFollower(): Unit = { val topic = "topic" @@ -9494,6 +9820,17 @@ class KafkaApisTest extends Logging { fromPrivilegedListener: Boolean = false, requestHeader: Option[RequestHeader] = None, requestMetrics: RequestChannelMetrics = requestChannelMetrics): RequestChannel.Request = { + buildRequest(request, new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "Alice"), InetAddress.getLocalHost, listenerName, + fromPrivilegedListener, requestHeader, requestMetrics) + } + + private def buildRequest(request: AbstractRequest, + principal: KafkaPrincipal, + clientAddress: InetAddress, + listenerName: ListenerName, + fromPrivilegedListener: Boolean, + requestHeader: Option[RequestHeader], + requestMetrics: RequestChannelMetrics): RequestChannel.Request = { val buffer = request.serializeWithHeader( requestHeader.getOrElse(new RequestHeader(request.apiKey, request.version, clientId, 0))) @@ -9503,8 +9840,8 @@ class KafkaApisTest extends Logging { // and have a non KafkaPrincipal.ANONYMOUS principal. This test is done before the check // for forwarding because after forwarding the context will have a different context. // We validate the context authenticated failure case in other integration tests. - val context = new RequestContext(header, "1", InetAddress.getLocalHost, Optional.empty(), - new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "Alice"), listenerName, SecurityProtocol.SSL, + val context = new RequestContext(header, "1", clientAddress, Optional.empty(), + principal, listenerName, SecurityProtocol.SSL, ClientInformation.EMPTY, fromPrivilegedListener, Optional.of(kafkaPrincipalSerde)) new RequestChannel.Request(processor = 1, context = context, startTimeNanos = 0, MemoryPool.NONE, buffer, requestMetrics, envelope = None)