From df93571f503033c6b21034a1c815f7b21da96959 Mon Sep 17 00:00:00 2001 From: Sushant Mahajan Date: Tue, 3 Jun 2025 15:56:38 +0530 Subject: [PATCH] KAFKA-19338: Error on read/write of uninitialized share part. (#19861) - Currently, read and write share state requests were allowed on uninitialized share partitions (share partitions on which initializeState has NOT been called). This should not be the case. - This PR addresses the concern by adding error checks on read and write. Other requests are allowed (initialize, readSummary, alter). - Refactored `ShareCoordinatorShardTest` to reduce redundancy and added some new tests. - Some request/response classes have also been reformatted. Reviewers: Andrew Schofield --- .../DeleteShareGroupStateRequest.java | 20 +- .../DeleteShareGroupStateResponse.java | 3 +- .../requests/ReadShareGroupStateRequest.java | 22 +- .../requests/ReadShareGroupStateResponse.java | 63 ++-- .../requests/WriteShareGroupStateRequest.java | 22 +- .../WriteShareGroupStateResponse.java | 45 ++- .../ShareFetchAcknowledgeRequestTest.scala | 112 +++++- .../share/ShareCoordinatorShard.java | 88 ++--- .../share/ShareCoordinatorShardTest.java | 324 +++++++++--------- 9 files changed, 400 insertions(+), 299 deletions(-) diff --git a/clients/src/main/java/org/apache/kafka/common/requests/DeleteShareGroupStateRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/DeleteShareGroupStateRequest.java index 0e399c3303e..c15e76328e1 100644 --- a/clients/src/main/java/org/apache/kafka/common/requests/DeleteShareGroupStateRequest.java +++ b/clients/src/main/java/org/apache/kafka/common/requests/DeleteShareGroupStateRequest.java @@ -59,15 +59,15 @@ public class DeleteShareGroupStateRequest extends AbstractRequest { public DeleteShareGroupStateResponse getErrorResponse(int throttleTimeMs, Throwable e) { List results = new ArrayList<>(); data.topics().forEach( - topicResult -> results.add(new DeleteShareGroupStateResponseData.DeleteStateResult() - .setTopicId(topicResult.topicId()) - .setPartitions(topicResult.partitions().stream() - .map(partitionData -> new DeleteShareGroupStateResponseData.PartitionResult() - .setPartition(partitionData.partition()) - .setErrorCode(Errors.forException(e).code())) - .collect(Collectors.toList())))); + topicResult -> results.add(new DeleteShareGroupStateResponseData.DeleteStateResult() + .setTopicId(topicResult.topicId()) + .setPartitions(topicResult.partitions().stream() + .map(partitionData -> new DeleteShareGroupStateResponseData.PartitionResult() + .setPartition(partitionData.partition()) + .setErrorCode(Errors.forException(e).code())) + .collect(Collectors.toList())))); return new DeleteShareGroupStateResponse(new DeleteShareGroupStateResponseData() - .setResults(results)); + .setResults(results)); } @Override @@ -77,8 +77,8 @@ public class DeleteShareGroupStateRequest extends AbstractRequest { public static DeleteShareGroupStateRequest parse(Readable readable, short version) { return new DeleteShareGroupStateRequest( - new DeleteShareGroupStateRequestData(readable, version), - version + new DeleteShareGroupStateRequestData(readable, version), + version ); } } diff --git a/clients/src/main/java/org/apache/kafka/common/requests/DeleteShareGroupStateResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/DeleteShareGroupStateResponse.java index 45cda17697c..e7da3e048c4 100644 --- a/clients/src/main/java/org/apache/kafka/common/requests/DeleteShareGroupStateResponse.java +++ b/clients/src/main/java/org/apache/kafka/common/requests/DeleteShareGroupStateResponse.java @@ -25,7 +25,6 @@ import org.apache.kafka.common.protocol.Errors; import org.apache.kafka.common.protocol.Readable; import java.util.ArrayList; -import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -104,7 +103,7 @@ public class DeleteShareGroupStateResponse extends AbstractResponse { public static DeleteShareGroupStateResponseData toErrorResponseData(Uuid topicId, int partitionId, Errors error, String errorMessage) { return new DeleteShareGroupStateResponseData().setResults( - Collections.singletonList(new DeleteShareGroupStateResponseData.DeleteStateResult() + List.of(new DeleteShareGroupStateResponseData.DeleteStateResult() .setTopicId(topicId) .setPartitions(List.of(new DeleteShareGroupStateResponseData.PartitionResult() .setPartition(partitionId) diff --git a/clients/src/main/java/org/apache/kafka/common/requests/ReadShareGroupStateRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/ReadShareGroupStateRequest.java index 920f189ce79..3637da2ca1b 100644 --- a/clients/src/main/java/org/apache/kafka/common/requests/ReadShareGroupStateRequest.java +++ b/clients/src/main/java/org/apache/kafka/common/requests/ReadShareGroupStateRequest.java @@ -59,16 +59,16 @@ public class ReadShareGroupStateRequest extends AbstractRequest { public ReadShareGroupStateResponse getErrorResponse(int throttleTimeMs, Throwable e) { List results = new ArrayList<>(); data.topics().forEach( - topicResult -> results.add(new ReadShareGroupStateResponseData.ReadStateResult() - .setTopicId(topicResult.topicId()) - .setPartitions(topicResult.partitions().stream() - .map(partitionData -> new ReadShareGroupStateResponseData.PartitionResult() - .setPartition(partitionData.partition()) - .setErrorCode(Errors.forException(e).code()) - .setErrorMessage(Errors.forException(e).message())) - .collect(Collectors.toList())))); + topicResult -> results.add(new ReadShareGroupStateResponseData.ReadStateResult() + .setTopicId(topicResult.topicId()) + .setPartitions(topicResult.partitions().stream() + .map(partitionData -> new ReadShareGroupStateResponseData.PartitionResult() + .setPartition(partitionData.partition()) + .setErrorCode(Errors.forException(e).code()) + .setErrorMessage(Errors.forException(e).message())) + .collect(Collectors.toList())))); return new ReadShareGroupStateResponse(new ReadShareGroupStateResponseData() - .setResults(results)); + .setResults(results)); } @Override @@ -78,8 +78,8 @@ public class ReadShareGroupStateRequest extends AbstractRequest { public static ReadShareGroupStateRequest parse(Readable readable, short version) { return new ReadShareGroupStateRequest( - new ReadShareGroupStateRequestData(readable, version), - version + new ReadShareGroupStateRequestData(readable, version), + version ); } } diff --git a/clients/src/main/java/org/apache/kafka/common/requests/ReadShareGroupStateResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/ReadShareGroupStateResponse.java index 6ee84c01992..2ab20e52e95 100644 --- a/clients/src/main/java/org/apache/kafka/common/requests/ReadShareGroupStateResponse.java +++ b/clients/src/main/java/org/apache/kafka/common/requests/ReadShareGroupStateResponse.java @@ -25,7 +25,6 @@ import org.apache.kafka.common.protocol.Errors; import org.apache.kafka.common.protocol.Readable; import java.util.ArrayList; -import java.util.Collections; import java.util.EnumMap; import java.util.List; import java.util.Map; @@ -47,9 +46,9 @@ public class ReadShareGroupStateResponse extends AbstractResponse { public Map errorCounts() { Map counts = new EnumMap<>(Errors.class); data.results().forEach( - result -> result.partitions().forEach( - partitionResult -> updateErrorCounts(counts, Errors.forCode(partitionResult.errorCode())) - ) + result -> result.partitions().forEach( + partitionResult -> updateErrorCounts(counts, Errors.forCode(partitionResult.errorCode())) + ) ); return counts; } @@ -66,52 +65,52 @@ public class ReadShareGroupStateResponse extends AbstractResponse { public static ReadShareGroupStateResponse parse(Readable readable, short version) { return new ReadShareGroupStateResponse( - new ReadShareGroupStateResponseData(readable, version) + new ReadShareGroupStateResponseData(readable, version) ); } public static ReadShareGroupStateResponseData toResponseData( - Uuid topicId, - int partition, - long startOffset, - int stateEpoch, - List stateBatches + Uuid topicId, + int partition, + long startOffset, + int stateEpoch, + List stateBatches ) { return new ReadShareGroupStateResponseData() - .setResults(Collections.singletonList( - new ReadShareGroupStateResponseData.ReadStateResult() - .setTopicId(topicId) - .setPartitions(Collections.singletonList( - new ReadShareGroupStateResponseData.PartitionResult() - .setPartition(partition) - .setStartOffset(startOffset) - .setStateEpoch(stateEpoch) - .setStateBatches(stateBatches) - )) - )); + .setResults(List.of( + new ReadShareGroupStateResponseData.ReadStateResult() + .setTopicId(topicId) + .setPartitions(List.of( + new ReadShareGroupStateResponseData.PartitionResult() + .setPartition(partition) + .setStartOffset(startOffset) + .setStateEpoch(stateEpoch) + .setStateBatches(stateBatches) + )) + )); } public static ReadShareGroupStateResponseData toErrorResponseData(Uuid topicId, int partitionId, Errors error, String errorMessage) { return new ReadShareGroupStateResponseData().setResults( - Collections.singletonList(new ReadShareGroupStateResponseData.ReadStateResult() - .setTopicId(topicId) - .setPartitions(Collections.singletonList(new ReadShareGroupStateResponseData.PartitionResult() - .setPartition(partitionId) - .setErrorCode(error.code()) - .setErrorMessage(errorMessage))))); + List.of(new ReadShareGroupStateResponseData.ReadStateResult() + .setTopicId(topicId) + .setPartitions(List.of(new ReadShareGroupStateResponseData.PartitionResult() + .setPartition(partitionId) + .setErrorCode(error.code()) + .setErrorMessage(errorMessage))))); } public static ReadShareGroupStateResponseData.PartitionResult toErrorResponsePartitionResult(int partitionId, Errors error, String errorMessage) { return new ReadShareGroupStateResponseData.PartitionResult() - .setPartition(partitionId) - .setErrorCode(error.code()) - .setErrorMessage(errorMessage); + .setPartition(partitionId) + .setErrorCode(error.code()) + .setErrorMessage(errorMessage); } public static ReadShareGroupStateResponseData.ReadStateResult toResponseReadStateResult(Uuid topicId, List partitionResults) { return new ReadShareGroupStateResponseData.ReadStateResult() - .setTopicId(topicId) - .setPartitions(partitionResults); + .setTopicId(topicId) + .setPartitions(partitionResults); } public static ReadShareGroupStateResponseData toGlobalErrorResponse(ReadShareGroupStateRequestData request, Errors error) { diff --git a/clients/src/main/java/org/apache/kafka/common/requests/WriteShareGroupStateRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/WriteShareGroupStateRequest.java index 4d8417c135e..35619791540 100644 --- a/clients/src/main/java/org/apache/kafka/common/requests/WriteShareGroupStateRequest.java +++ b/clients/src/main/java/org/apache/kafka/common/requests/WriteShareGroupStateRequest.java @@ -59,16 +59,16 @@ public class WriteShareGroupStateRequest extends AbstractRequest { public WriteShareGroupStateResponse getErrorResponse(int throttleTimeMs, Throwable e) { List results = new ArrayList<>(); data.topics().forEach( - topicResult -> results.add(new WriteShareGroupStateResponseData.WriteStateResult() - .setTopicId(topicResult.topicId()) - .setPartitions(topicResult.partitions().stream() - .map(partitionData -> new WriteShareGroupStateResponseData.PartitionResult() - .setPartition(partitionData.partition()) - .setErrorCode(Errors.forException(e).code()) - .setErrorMessage(Errors.forException(e).message())) - .collect(Collectors.toList())))); + topicResult -> results.add(new WriteShareGroupStateResponseData.WriteStateResult() + .setTopicId(topicResult.topicId()) + .setPartitions(topicResult.partitions().stream() + .map(partitionData -> new WriteShareGroupStateResponseData.PartitionResult() + .setPartition(partitionData.partition()) + .setErrorCode(Errors.forException(e).code()) + .setErrorMessage(Errors.forException(e).message())) + .collect(Collectors.toList())))); return new WriteShareGroupStateResponse(new WriteShareGroupStateResponseData() - .setResults(results)); + .setResults(results)); } @Override @@ -78,8 +78,8 @@ public class WriteShareGroupStateRequest extends AbstractRequest { public static WriteShareGroupStateRequest parse(Readable readable, short version) { return new WriteShareGroupStateRequest( - new WriteShareGroupStateRequestData(readable, version), - version + new WriteShareGroupStateRequestData(readable, version), + version ); } } diff --git a/clients/src/main/java/org/apache/kafka/common/requests/WriteShareGroupStateResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/WriteShareGroupStateResponse.java index f83992d6e63..799ec80d228 100644 --- a/clients/src/main/java/org/apache/kafka/common/requests/WriteShareGroupStateResponse.java +++ b/clients/src/main/java/org/apache/kafka/common/requests/WriteShareGroupStateResponse.java @@ -25,7 +25,6 @@ import org.apache.kafka.common.protocol.Errors; import org.apache.kafka.common.protocol.Readable; import java.util.ArrayList; -import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -47,9 +46,9 @@ public class WriteShareGroupStateResponse extends AbstractResponse { public Map errorCounts() { Map counts = new HashMap<>(); data.results().forEach( - result -> result.partitions().forEach( - partitionResult -> updateErrorCounts(counts, Errors.forCode(partitionResult.errorCode())) - ) + result -> result.partitions().forEach( + partitionResult -> updateErrorCounts(counts, Errors.forCode(partitionResult.errorCode())) + ) ); return counts; } @@ -66,47 +65,47 @@ public class WriteShareGroupStateResponse extends AbstractResponse { public static WriteShareGroupStateResponse parse(Readable readable, short version) { return new WriteShareGroupStateResponse( - new WriteShareGroupStateResponseData(readable, version) + new WriteShareGroupStateResponseData(readable, version) ); } public static WriteShareGroupStateResponseData toResponseData(Uuid topicId, int partitionId) { return new WriteShareGroupStateResponseData() - .setResults(Collections.singletonList( - new WriteShareGroupStateResponseData.WriteStateResult() - .setTopicId(topicId) - .setPartitions(Collections.singletonList( - new WriteShareGroupStateResponseData.PartitionResult() - .setPartition(partitionId))))); + .setResults(List.of( + new WriteShareGroupStateResponseData.WriteStateResult() + .setTopicId(topicId) + .setPartitions(List.of( + new WriteShareGroupStateResponseData.PartitionResult() + .setPartition(partitionId))))); } public static WriteShareGroupStateResponseData toErrorResponseData(Uuid topicId, int partitionId, Errors error, String errorMessage) { WriteShareGroupStateResponseData responseData = new WriteShareGroupStateResponseData(); - responseData.setResults(Collections.singletonList(new WriteShareGroupStateResponseData.WriteStateResult() - .setTopicId(topicId) - .setPartitions(Collections.singletonList(new WriteShareGroupStateResponseData.PartitionResult() - .setPartition(partitionId) - .setErrorCode(error.code()) - .setErrorMessage(errorMessage))))); + responseData.setResults(List.of(new WriteShareGroupStateResponseData.WriteStateResult() + .setTopicId(topicId) + .setPartitions(List.of(new WriteShareGroupStateResponseData.PartitionResult() + .setPartition(partitionId) + .setErrorCode(error.code()) + .setErrorMessage(errorMessage))))); return responseData; } public static WriteShareGroupStateResponseData.PartitionResult toErrorResponsePartitionResult(int partitionId, Errors error, String errorMessage) { return new WriteShareGroupStateResponseData.PartitionResult() - .setPartition(partitionId) - .setErrorCode(error.code()) - .setErrorMessage(errorMessage); + .setPartition(partitionId) + .setErrorCode(error.code()) + .setErrorMessage(errorMessage); } public static WriteShareGroupStateResponseData.WriteStateResult toResponseWriteStateResult(Uuid topicId, List partitionResults) { return new WriteShareGroupStateResponseData.WriteStateResult() - .setTopicId(topicId) - .setPartitions(partitionResults); + .setTopicId(topicId) + .setPartitions(partitionResults); } public static WriteShareGroupStateResponseData.PartitionResult toResponsePartitionResult(int partitionId) { return new WriteShareGroupStateResponseData.PartitionResult() - .setPartition(partitionId); + .setPartition(partitionId); } public static WriteShareGroupStateResponseData toGlobalErrorResponse(WriteShareGroupStateRequestData request, Errors error) { diff --git a/core/src/test/scala/unit/kafka/server/ShareFetchAcknowledgeRequestTest.scala b/core/src/test/scala/unit/kafka/server/ShareFetchAcknowledgeRequestTest.scala index 8e185f7cd7f..1d0c022d2bf 100644 --- a/core/src/test/scala/unit/kafka/server/ShareFetchAcknowledgeRequestTest.scala +++ b/core/src/test/scala/unit/kafka/server/ShareFetchAcknowledgeRequestTest.scala @@ -17,12 +17,13 @@ package kafka.server import kafka.utils.TestUtils +import org.apache.kafka.clients.admin.DescribeShareGroupsOptions import org.apache.kafka.common.test.api.{ClusterConfigProperty, ClusterFeature, ClusterTest, ClusterTestDefaults, ClusterTests, Type} import org.apache.kafka.common.message.ShareFetchResponseData.AcquiredRecords -import org.apache.kafka.common.message.{ShareAcknowledgeRequestData, ShareAcknowledgeResponseData, ShareFetchRequestData, ShareFetchResponseData} +import org.apache.kafka.common.message.{FindCoordinatorRequestData, ShareAcknowledgeRequestData, ShareAcknowledgeResponseData, ShareFetchRequestData, ShareFetchResponseData, ShareGroupHeartbeatRequestData} import org.apache.kafka.common.protocol.Errors import org.apache.kafka.common.{TopicIdPartition, TopicPartition, Uuid} -import org.apache.kafka.common.requests.{ShareAcknowledgeRequest, ShareAcknowledgeResponse, ShareFetchRequest, ShareFetchResponse, ShareRequestMetadata} +import org.apache.kafka.common.requests.{FindCoordinatorRequest, FindCoordinatorResponse, ShareAcknowledgeRequest, ShareAcknowledgeResponse, ShareFetchRequest, ShareFetchResponse, ShareGroupHeartbeatRequest, ShareGroupHeartbeatResponse, ShareRequestMetadata} import org.apache.kafka.common.test.ClusterInstance import org.apache.kafka.server.common.Feature import org.junit.jupiter.api.Assertions.{assertEquals, assertTrue} @@ -109,7 +110,8 @@ class ShareFetchAcknowledgeRequestTest(cluster: ClusterInstance) extends GroupCo ) def testShareFetchRequestToNonLeaderReplica(): Unit = { val groupId: String = "group" - val metadata: ShareRequestMetadata = new ShareRequestMetadata(Uuid.randomUuid(), ShareRequestMetadata.INITIAL_EPOCH) + val memberId: Uuid = Uuid.randomUuid() + val metadata: ShareRequestMetadata = new ShareRequestMetadata(memberId, ShareRequestMetadata.INITIAL_EPOCH) val topic = "topic" val partition = 0 @@ -129,6 +131,9 @@ class ShareFetchAcknowledgeRequestTest(cluster: ClusterInstance) extends GroupCo val socket: Socket = connect(nonReplicaId) + createOffsetsTopic() + shareHeartbeat(memberId, groupId, Map[String, Int](topic -> 1)) + // Send the share fetch request to the non-replica and verify the error code val shareFetchRequest = createShareFetchRequest(groupId, metadata, send, Seq.empty, Map.empty) val shareFetchResponse = IntegrationTestUtils.sendAndReceive[ShareFetchResponse](shareFetchRequest, socket) @@ -172,6 +177,9 @@ class ShareFetchAcknowledgeRequestTest(cluster: ClusterInstance) extends GroupCo val socket: Socket = connectAny() + createOffsetsTopic() + shareHeartbeat(memberId, groupId, Map[String, Int](topic -> 3)) + // Send the first share fetch request to initialize the share partition sendFirstShareFetchRequest(memberId, groupId, send, socket) @@ -238,6 +246,9 @@ class ShareFetchAcknowledgeRequestTest(cluster: ClusterInstance) extends GroupCo val socket: Socket = connectAny() + createOffsetsTopic() + shareHeartbeat(memberId, groupId, Map[String, Int](topic -> 3)) + // Send the first share fetch request to initialize the share partitions sendFirstShareFetchRequest(memberId, groupId, send, socket) @@ -349,6 +360,9 @@ class ShareFetchAcknowledgeRequestTest(cluster: ClusterInstance) extends GroupCo val socket2: Socket = connect(leader2) val socket3: Socket = connect(leader3) + createOffsetsTopic() + shareHeartbeat(memberId, groupId, Map[String, Int](topic -> 3)) + // Send the first share fetch request to initialize the share partitions // Create different share fetch requests for different partitions as they may have leaders on separate brokers var shareFetchRequest1 = createShareFetchRequest(groupId, metadata, send1, Seq.empty, acknowledgementsMap) @@ -456,6 +470,9 @@ class ShareFetchAcknowledgeRequestTest(cluster: ClusterInstance) extends GroupCo val socket: Socket = connectAny() + createOffsetsTopic() + shareHeartbeat(memberId, groupId, Map[String, Int](topic -> 3)) + // Send the first share fetch request to initialize share partitions sendFirstShareFetchRequest(memberId, groupId, send, socket) @@ -573,6 +590,9 @@ class ShareFetchAcknowledgeRequestTest(cluster: ClusterInstance) extends GroupCo val socket: Socket = connectAny() + createOffsetsTopic() + shareHeartbeat(memberId, groupId, Map[String, Int](topic -> 3)) + // Send the first share fetch request to initialize the share partition sendFirstShareFetchRequest(memberId, groupId, send, socket, 15000) @@ -693,6 +713,9 @@ class ShareFetchAcknowledgeRequestTest(cluster: ClusterInstance) extends GroupCo val socket: Socket = connectAny() + createOffsetsTopic() + shareHeartbeat(memberId, groupId, Map[String, Int](topic -> 3)) + // Send the first share fetch request to initialize the share partiion sendFirstShareFetchRequest(memberId, groupId, send, socket) @@ -805,6 +828,9 @@ class ShareFetchAcknowledgeRequestTest(cluster: ClusterInstance) extends GroupCo val socket: Socket = connectAny() + createOffsetsTopic() + shareHeartbeat(memberId, groupId, Map[String, Int](topic -> 3)) + // Send the first share fetch request to initialize the share partition sendFirstShareFetchRequest(memberId, groupId, send, socket) @@ -921,6 +947,9 @@ class ShareFetchAcknowledgeRequestTest(cluster: ClusterInstance) extends GroupCo val socket: Socket = connectAny() + createOffsetsTopic() + shareHeartbeat(memberId, groupId, Map[String, Int](topic -> 3)) + // Send the first share fetch request to initialize the share partition sendFirstShareFetchRequest(memberId, groupId, send, socket) @@ -1036,6 +1065,9 @@ class ShareFetchAcknowledgeRequestTest(cluster: ClusterInstance) extends GroupCo val socket: Socket = connectAny() + createOffsetsTopic() + shareHeartbeat(memberId, groupId, Map[String, Int](topic -> 3)) + // Send the first share fetch request to initialize the share partition sendFirstShareFetchRequest(memberId, groupId, send, socket) @@ -1158,6 +1190,9 @@ class ShareFetchAcknowledgeRequestTest(cluster: ClusterInstance) extends GroupCo val socket: Socket = connectAny() + createOffsetsTopic() + shareHeartbeat(memberId, groupId, Map[String, Int](topic -> 3)) + // Send the first share fetch request to initialize the shar partition sendFirstShareFetchRequest(memberId, groupId, send, socket) @@ -1323,6 +1358,11 @@ class ShareFetchAcknowledgeRequestTest(cluster: ClusterInstance) extends GroupCo val socket2: Socket = connectAny() val socket3: Socket = connectAny() + createOffsetsTopic() + shareHeartbeat(memberId1, groupId, Map[String, Int](topic -> 3)) + shareHeartbeat(memberId2, groupId, Map[String, Int](topic -> 3)) + shareHeartbeat(memberId3, groupId, Map[String, Int](topic -> 3)) + // Sending a dummy share fetch request to initialize the share partition sendFirstShareFetchRequest(memberId1, groupId, send, socket1) @@ -1420,6 +1460,11 @@ class ShareFetchAcknowledgeRequestTest(cluster: ClusterInstance) extends GroupCo val socket2: Socket = connectAny() val socket3: Socket = connectAny() + createOffsetsTopic() + shareHeartbeat(memberId1, groupId1, Map[String, Int](topic -> 3)) + shareHeartbeat(memberId2, groupId2, Map[String, Int](topic -> 3)) + shareHeartbeat(memberId3, groupId3, Map[String, Int](topic -> 3)) + // Sending 3 dummy share Fetch Requests with to inititlaize the share partitions for each share group\ sendFirstShareFetchRequest(memberId1, groupId1, send, socket1) sendFirstShareFetchRequest(memberId2, groupId2, send, socket2) @@ -1513,6 +1558,9 @@ class ShareFetchAcknowledgeRequestTest(cluster: ClusterInstance) extends GroupCo val socket: Socket = connectAny() + createOffsetsTopic() + shareHeartbeat(memberId, groupId, Map[String, Int](topic -> 3)) + // Send the first share fetch request to initialize the share partition sendFirstShareFetchRequest(memberId, groupId, send, socket) @@ -1623,6 +1671,9 @@ class ShareFetchAcknowledgeRequestTest(cluster: ClusterInstance) extends GroupCo val socket: Socket = connectAny() + createOffsetsTopic() + shareHeartbeat(memberId, groupId, Map[String, Int](topic -> 3)) + // Send the first share fetch request to initialize the share partition sendFirstShareFetchRequest(memberId, groupId, send, socket) @@ -1844,6 +1895,9 @@ class ShareFetchAcknowledgeRequestTest(cluster: ClusterInstance) extends GroupCo val socket: Socket = connectAny() + createOffsetsTopic() + shareHeartbeat(memberId, groupId, Map[String, Int](topic -> 3)) + // Send the first share fetch request to initialize the share partition sendFirstShareFetchRequest(memberId, groupId, send, socket) @@ -1918,6 +1972,9 @@ class ShareFetchAcknowledgeRequestTest(cluster: ClusterInstance) extends GroupCo val socket: Socket = connectAny() + createOffsetsTopic() + shareHeartbeat(memberId, groupId, Map[String, Int](topic -> 3)) + // Send the first share fetch request to initialize the share partition sendFirstShareFetchRequest(memberId, groupId, send, socket) @@ -1998,6 +2055,9 @@ class ShareFetchAcknowledgeRequestTest(cluster: ClusterInstance) extends GroupCo val socket: Socket = connectAny() + createOffsetsTopic() + shareHeartbeat(memberId, groupId, Map[String, Int](topic -> 3)) + // Send the first share fetch request to initialize the share partition sendFirstShareFetchRequest(memberId, groupId, send, socket) @@ -2161,6 +2221,9 @@ class ShareFetchAcknowledgeRequestTest(cluster: ClusterInstance) extends GroupCo val socket: Socket = connectAny() + createOffsetsTopic() + shareHeartbeat(memberId, groupId, Map[String, Int](topic -> 3)) + // Send the first share fetch request to initialize the share partition sendFirstShareFetchRequest(memberId, groupId, send, socket) @@ -2242,6 +2305,9 @@ class ShareFetchAcknowledgeRequestTest(cluster: ClusterInstance) extends GroupCo val socket: Socket = connectAny() + createOffsetsTopic() + shareHeartbeat(memberId, groupId, Map[String, Int](topic -> 3)) + // Send the first share fetch request to initialize the share partition sendFirstShareFetchRequest(memberId, groupId, send, socket) @@ -2339,6 +2405,9 @@ class ShareFetchAcknowledgeRequestTest(cluster: ClusterInstance) extends GroupCo val socket: Socket = connectAny() + createOffsetsTopic() + shareHeartbeat(memberId, groupId, Map[String, Int](topic -> 3)) + // Send the first share fetch request to initialize the share partition sendFirstShareFetchRequest(memberId, groupId, send, socket) @@ -2403,6 +2472,9 @@ class ShareFetchAcknowledgeRequestTest(cluster: ClusterInstance) extends GroupCo val socket: Socket = connectAny() + createOffsetsTopic() + shareHeartbeat(memberId, groupId, Map[String, Int](topic -> 3)) + // Send the first share fetch request to initialize the share partition sendFirstShareFetchRequest(memberId, groupId, send, socket) @@ -2455,6 +2527,40 @@ class ShareFetchAcknowledgeRequestTest(cluster: ClusterInstance) extends GroupCo }, "Share fetch request failed", 5000) } + private def shareHeartbeat(memberId: Uuid, groupId: String, topics: Map[String, Int]): Unit = { + val coordResp = connectAndReceive[FindCoordinatorResponse](new FindCoordinatorRequest.Builder(new FindCoordinatorRequestData() + .setKey(groupId) + .setKeyType(0.toByte) + ).build(0) + ) + + val shareGroupHeartbeatRequest = new ShareGroupHeartbeatRequest.Builder( + new ShareGroupHeartbeatRequestData() + .setMemberId(memberId.toString) + .setGroupId(groupId) + .setMemberEpoch(0) + .setSubscribedTopicNames(topics.keys.toList.asJava) + ).build() + + TestUtils.waitUntilTrue(() => { + val resp = connectAndReceive[ShareGroupHeartbeatResponse](shareGroupHeartbeatRequest, coordResp.node().id()) + resp.data().errorCode() == Errors.NONE.code() && assignment(memberId.toString, groupId) + }, "Heartbeat failed") + } + + private def assignment(memberId: String, groupId: String): Boolean = { + val admin = cluster.admin() + + val isAssigned = admin.describeShareGroups(List(groupId).asJava, new DescribeShareGroupsOptions().includeAuthorizedOperations(true)) + .describedGroups() + .get(groupId) + .get() + .members() + .asScala.count(desc => desc.consumerId() == memberId && !desc.assignment().topicPartitions().isEmpty) > 0 + admin.close() + isAssigned + } + private def expectedAcquiredRecords(firstOffsets: util.List[Long], lastOffsets: util.List[Long], deliveryCounts: util.List[Int]): util.List[AcquiredRecords] = { val acquiredRecordsList: util.List[AcquiredRecords] = new util.ArrayList() for (i <- firstOffsets.asScala.indices) { diff --git a/share-coordinator/src/main/java/org/apache/kafka/coordinator/share/ShareCoordinatorShard.java b/share-coordinator/src/main/java/org/apache/kafka/coordinator/share/ShareCoordinatorShard.java index 928192edc4f..f5d9a3cfb30 100644 --- a/share-coordinator/src/main/java/org/apache/kafka/coordinator/share/ShareCoordinatorShard.java +++ b/share-coordinator/src/main/java/org/apache/kafka/coordinator/share/ShareCoordinatorShard.java @@ -89,6 +89,8 @@ public class ShareCoordinatorShard implements CoordinatorShard { private final ShareCoordinatorConfig config; @@ -363,38 +365,24 @@ public class ShareCoordinatorShard implements CoordinatorShard stateBatches = (offsetValue.stateBatches() != null && !offsetValue.stateBatches().isEmpty()) ? + offsetValue.stateBatches().stream() + .map( + stateBatch -> new ReadShareGroupStateResponseData.StateBatch() + .setFirstOffset(stateBatch.firstOffset()) + .setLastOffset(stateBatch.lastOffset()) + .setDeliveryState(stateBatch.deliveryState()) + .setDeliveryCount(stateBatch.deliveryCount()) + ).toList() : List.of(); - if (!shareStateMap.containsKey(key)) { - // Leader epoch update might be needed - responseData = ReadShareGroupStateResponse.toResponseData( - topicId, - partitionId, - PartitionFactory.UNINITIALIZED_START_OFFSET, - PartitionFactory.DEFAULT_STATE_EPOCH, - List.of() - ); - } else { - // Leader epoch update might be needed - ShareGroupOffset offsetValue = shareStateMap.get(key); - List stateBatches = (offsetValue.stateBatches() != null && !offsetValue.stateBatches().isEmpty()) ? - offsetValue.stateBatches().stream() - .map( - stateBatch -> new ReadShareGroupStateResponseData.StateBatch() - .setFirstOffset(stateBatch.firstOffset()) - .setLastOffset(stateBatch.lastOffset()) - .setDeliveryState(stateBatch.deliveryState()) - .setDeliveryCount(stateBatch.deliveryCount()) - ).toList() : List.of(); - - responseData = ReadShareGroupStateResponse.toResponseData( - topicId, - partitionId, - offsetValue.startOffset(), - offsetValue.stateEpoch(), - stateBatches - ); - } + ReadShareGroupStateResponseData responseData = ReadShareGroupStateResponse.toResponseData( + topicId, + partitionId, + offsetValue.startOffset(), + offsetValue.stateEpoch(), + stateBatches + ); // Optimization in case leaderEpoch update is not required. if (leaderEpoch == -1 || @@ -644,9 +632,9 @@ public class ShareCoordinatorShard implements CoordinatorShard - * If no snapshot has been created for the key => create a new ShareSnapshot record - * else if number of ShareUpdate records for key >= max allowed per snapshot per key => create a new ShareSnapshot record - * else create a new ShareUpdate record + * If number of ShareUpdate records for key >= max allowed per snapshot per key or stateEpoch is highest + * seen so far => create a new ShareSnapshot record else create a new ShareUpdate record. This method assumes + * that share partition key is present in shareStateMap since it should be called on initialized share partitions. * * @param partitionData - Represents the data which should be written into the share state record. * @param key - The {@link SharePartitionKey} object. @@ -658,28 +646,14 @@ public class ShareCoordinatorShard implements CoordinatorShard= updatesPerSnapshotLimit || partitionData.stateEpoch() > shareStateMap.get(key).stateEpoch()) { + if (snapshotUpdateCount.getOrDefault(key, 0) >= updatesPerSnapshotLimit || partitionData.stateEpoch() > shareStateMap.get(key).stateEpoch()) { ShareGroupOffset currentState = shareStateMap.get(key); // shareStateMap will have the entry as containsKey is true int newLeaderEpoch = partitionData.leaderEpoch() == -1 ? currentState.leaderEpoch() : partitionData.leaderEpoch(); int newStateEpoch = partitionData.stateEpoch() == -1 ? currentState.stateEpoch() : partitionData.stateEpoch(); long newStartOffset = partitionData.startOffset() == -1 ? currentState.startOffset() : partitionData.startOffset(); - // Since the number of update records for this share part key exceeds snapshotUpdateRecordsPerSnapshot, - // we should be creating a share snapshot record. + // Since the number of update records for this share part key exceeds snapshotUpdateRecordsPerSnapshot + // or state epoch has incremented, we should be creating a share snapshot record. // The incoming partition data could have overlapping state batches, we must merge them. return ShareCoordinatorRecordHelpers.newShareSnapshotRecord( key.groupId(), key.topicId(), partitionData.partition(), @@ -772,6 +746,11 @@ public class ShareCoordinatorShard implements CoordinatorShard partitionData.leaderEpoch()) { log.error("Write request leader epoch is smaller than last recorded current: {}, requested: {}.", leaderEpochMap.get(mapKey), partitionData.leaderEpoch()); return Optional.of(getWriteErrorCoordinatorResult(Errors.FENCED_LEADER_EPOCH, null, topicId, partitionId)); @@ -814,6 +793,13 @@ public class ShareCoordinatorShard implements CoordinatorShard partitionData.leaderEpoch()) { log.error("Read request leader epoch is smaller than last recorded current: {}, requested: {}.", leaderEpochMap.get(mapKey), partitionData.leaderEpoch()); return Optional.of(ReadShareGroupStateResponse.toErrorResponseData(topicId, partitionId, Errors.FENCED_LEADER_EPOCH, Errors.FENCED_LEADER_EPOCH.message())); diff --git a/share-coordinator/src/test/java/org/apache/kafka/coordinator/share/ShareCoordinatorShardTest.java b/share-coordinator/src/test/java/org/apache/kafka/coordinator/share/ShareCoordinatorShardTest.java index 0f6dde9259d..9aed6583f1d 100644 --- a/share-coordinator/src/test/java/org/apache/kafka/coordinator/share/ShareCoordinatorShardTest.java +++ b/share-coordinator/src/test/java/org/apache/kafka/coordinator/share/ShareCoordinatorShardTest.java @@ -57,6 +57,7 @@ import org.apache.kafka.server.share.persister.PartitionFactory; import org.apache.kafka.server.share.persister.PersisterStateBatch; import org.apache.kafka.timeline.SnapshotRegistry; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import java.util.HashMap; @@ -86,6 +87,7 @@ class ShareCoordinatorShardTest { private static final Uuid TOPIC_ID = Uuid.randomUuid(); private static final Uuid TOPIC_ID_2 = Uuid.randomUuid(); private static final int PARTITION = 0; + private static final SharePartitionKey SHARE_PARTITION_KEY = SharePartitionKey.getInstance(GROUP_ID, TOPIC_ID, PARTITION); private static final Time TIME = new MockTime(); public static class ShareCoordinatorShardBuilder { @@ -141,7 +143,6 @@ class ShareCoordinatorShardTest { } private void writeAndReplayRecord(ShareCoordinatorShard shard, int leaderEpoch) { - WriteShareGroupStateRequestData request = new WriteShareGroupStateRequestData() .setGroupId(GROUP_ID) .setTopics(List.of(new WriteShareGroupStateRequestData.WriteStateData() @@ -162,10 +163,15 @@ class ShareCoordinatorShardTest { shard.replay(0L, 0L, (short) 0, result.records().get(0)); } + private ShareCoordinatorShard shard; + + @BeforeEach + public void setUp() { + shard = new ShareCoordinatorShardBuilder().build(); + } + @Test public void testReplayWithShareSnapshot() { - ShareCoordinatorShard shard = new ShareCoordinatorShardBuilder().build(); - long offset = 0; long producerId = 0; short producerEpoch = 0; @@ -230,10 +236,39 @@ class ShareCoordinatorShardTest { } @Test - public void testWriteStateSuccess() { - ShareCoordinatorShard shard = new ShareCoordinatorShardBuilder().build(); + public void testWriteFailsOnUninitializedPartition() { + WriteShareGroupStateRequestData request = new WriteShareGroupStateRequestData() + .setGroupId(GROUP_ID) + .setTopics(List.of(new WriteShareGroupStateRequestData.WriteStateData() + .setTopicId(TOPIC_ID) + .setPartitions(List.of(new WriteShareGroupStateRequestData.PartitionData() + .setPartition(PARTITION) + .setStartOffset(0) + .setStateEpoch(0) + .setLeaderEpoch(0) + .setStateBatches(List.of(new WriteShareGroupStateRequestData.StateBatch() + .setFirstOffset(0) + .setLastOffset(10) + .setDeliveryCount((short) 1) + .setDeliveryState((byte) 0))))))); - SharePartitionKey shareCoordinatorKey = SharePartitionKey.getInstance(GROUP_ID, TOPIC_ID, PARTITION); + CoordinatorResult result = shard.writeState(request); + + WriteShareGroupStateResponseData expectedData = WriteShareGroupStateResponse.toErrorResponseData( + TOPIC_ID, PARTITION, + Errors.INVALID_REQUEST, + ShareCoordinatorShard.WRITE_UNINITIALIZED_SHARE_PARTITION.getMessage() + ); + List expectedRecords = List.of(); + + assertEquals(expectedData, result.response()); + assertEquals(expectedRecords, result.records()); + assertNull(shard.getShareStateMapValue(SHARE_PARTITION_KEY)); + } + + @Test + public void testWriteStateSuccess() { + initSharePartition(shard, SHARE_PARTITION_KEY); WriteShareGroupStateRequestData request = new WriteShareGroupStateRequestData() .setGroupId(GROUP_ID) @@ -255,17 +290,17 @@ class ShareCoordinatorShardTest { shard.replay(0L, 0L, (short) 0, result.records().get(0)); WriteShareGroupStateResponseData expectedData = WriteShareGroupStateResponse.toResponseData(TOPIC_ID, PARTITION); - List expectedRecords = List.of(ShareCoordinatorRecordHelpers.newShareSnapshotRecord( + List expectedRecords = List.of(ShareCoordinatorRecordHelpers.newShareUpdateRecord( GROUP_ID, TOPIC_ID, PARTITION, ShareGroupOffset.fromRequest(request.topics().get(0).partitions().get(0), TIME.milliseconds()) )); assertEquals(expectedData, result.response()); assertEquals(expectedRecords, result.records()); - assertEquals(groupOffset(ShareCoordinatorRecordHelpers.newShareSnapshotRecord( + assertEquals(groupOffset(ShareCoordinatorRecordHelpers.newShareUpdateRecord( GROUP_ID, TOPIC_ID, PARTITION, ShareGroupOffset.fromRequest(request.topics().get(0).partitions().get(0), TIME.milliseconds()) - ).value().message()), shard.getShareStateMapValue(shareCoordinatorKey)); - assertEquals(0, shard.getLeaderMapValue(shareCoordinatorKey)); + ).value().message()), shard.getShareStateMapValue(SHARE_PARTITION_KEY)); + assertEquals(0, shard.getLeaderMapValue(SHARE_PARTITION_KEY)); verify(shard.getMetricsShard()).record(ShareCoordinatorMetrics.SHARE_COORDINATOR_WRITE_SENSOR_NAME); } @@ -276,9 +311,10 @@ class ShareCoordinatorShardTest { // a higher state epoch in a request forces snapshot creation, even if number of share updates // have not breached the updates/snapshots limit. - ShareCoordinatorShard shard = new ShareCoordinatorShardBuilder().build(); - int stateEpoch = 0; - SharePartitionKey shareCoordinatorKey = SharePartitionKey.getInstance(GROUP_ID, TOPIC_ID, PARTITION); + int stateEpoch = 1; + int snapshotEpoch = 0; + + initSharePartition(shard, SHARE_PARTITION_KEY); WriteShareGroupStateRequestData request = new WriteShareGroupStateRequestData() .setGroupId(GROUP_ID) @@ -299,19 +335,20 @@ class ShareCoordinatorShardTest { shard.replay(0L, 0L, (short) 0, result.records().get(0)); + snapshotEpoch++; // Since state epoch increased. WriteShareGroupStateResponseData expectedData = WriteShareGroupStateResponse.toResponseData(TOPIC_ID, PARTITION); List expectedRecords = List.of(ShareCoordinatorRecordHelpers.newShareSnapshotRecord( - GROUP_ID, TOPIC_ID, PARTITION, ShareGroupOffset.fromRequest(request.topics().get(0).partitions().get(0), TIME.milliseconds()) + GROUP_ID, TOPIC_ID, PARTITION, ShareGroupOffset.fromRequest(request.topics().get(0).partitions().get(0), snapshotEpoch, TIME.milliseconds()) )); - assertEquals(0, shard.getShareStateMapValue(shareCoordinatorKey).snapshotEpoch()); + assertEquals(1, shard.getShareStateMapValue(SHARE_PARTITION_KEY).snapshotEpoch()); assertEquals(expectedData, result.response()); assertEquals(expectedRecords, result.records()); assertEquals(groupOffset(ShareCoordinatorRecordHelpers.newShareSnapshotRecord( - GROUP_ID, TOPIC_ID, PARTITION, ShareGroupOffset.fromRequest(request.topics().get(0).partitions().get(0), TIME.milliseconds()) - ).value().message()), shard.getShareStateMapValue(shareCoordinatorKey)); - assertEquals(0, shard.getLeaderMapValue(shareCoordinatorKey)); + GROUP_ID, TOPIC_ID, PARTITION, ShareGroupOffset.fromRequest(request.topics().get(0).partitions().get(0), snapshotEpoch, TIME.milliseconds()) + ).value().message()), shard.getShareStateMapValue(SHARE_PARTITION_KEY)); + assertEquals(0, shard.getLeaderMapValue(SHARE_PARTITION_KEY)); verify(shard.getMetricsShard()).record(ShareCoordinatorMetrics.SHARE_COORDINATOR_WRITE_SENSOR_NAME); // State epoch stays same so share update. @@ -336,18 +373,18 @@ class ShareCoordinatorShardTest { expectedData = WriteShareGroupStateResponse.toResponseData(TOPIC_ID, PARTITION); expectedRecords = List.of(ShareCoordinatorRecordHelpers.newShareUpdateRecord( - GROUP_ID, TOPIC_ID, PARTITION, ShareGroupOffset.fromRequest(request.topics().get(0).partitions().get(0), TIME.milliseconds()) + GROUP_ID, TOPIC_ID, PARTITION, ShareGroupOffset.fromRequest(request.topics().get(0).partitions().get(0), snapshotEpoch, TIME.milliseconds()) )); // Snapshot epoch did not increase - assertEquals(0, shard.getShareStateMapValue(shareCoordinatorKey).snapshotEpoch()); + assertEquals(1, shard.getShareStateMapValue(SHARE_PARTITION_KEY).snapshotEpoch()); assertEquals(expectedData, result.response()); assertEquals(expectedRecords, result.records()); assertEquals(groupOffset(ShareCoordinatorRecordHelpers.newShareSnapshotRecord( - GROUP_ID, TOPIC_ID, PARTITION, ShareGroupOffset.fromRequest(request.topics().get(0).partitions().get(0), TIME.milliseconds()) - ).value().message()), shard.getShareStateMapValue(shareCoordinatorKey)); - assertEquals(0, shard.getLeaderMapValue(shareCoordinatorKey)); + GROUP_ID, TOPIC_ID, PARTITION, ShareGroupOffset.fromRequest(request.topics().get(0).partitions().get(0), snapshotEpoch, TIME.milliseconds()) + ).value().message()), shard.getShareStateMapValue(SHARE_PARTITION_KEY)); + assertEquals(0, shard.getLeaderMapValue(SHARE_PARTITION_KEY)); verify(shard.getMetricsShard(), times(2)).record(ShareCoordinatorMetrics.SHARE_COORDINATOR_WRITE_SENSOR_NAME); // State epoch incremented so share snapshot. @@ -370,28 +407,27 @@ class ShareCoordinatorShardTest { shard.replay(0L, 0L, (short) 0, result.records().get(0)); + snapshotEpoch++; // Since state epoch increased expectedData = WriteShareGroupStateResponse.toResponseData(TOPIC_ID, PARTITION); expectedRecords = List.of(ShareCoordinatorRecordHelpers.newShareSnapshotRecord( - GROUP_ID, TOPIC_ID, PARTITION, ShareGroupOffset.fromRequest(request.topics().get(0).partitions().get(0), 1, TIME.milliseconds()) + GROUP_ID, TOPIC_ID, PARTITION, ShareGroupOffset.fromRequest(request.topics().get(0).partitions().get(0), snapshotEpoch, TIME.milliseconds()) )); // Snapshot epoch increased. - assertEquals(1, shard.getShareStateMapValue(shareCoordinatorKey).snapshotEpoch()); + assertEquals(2, shard.getShareStateMapValue(SHARE_PARTITION_KEY).snapshotEpoch()); assertEquals(expectedData, result.response()); assertEquals(expectedRecords, result.records()); assertEquals(groupOffset(ShareCoordinatorRecordHelpers.newShareSnapshotRecord( - GROUP_ID, TOPIC_ID, PARTITION, ShareGroupOffset.fromRequest(request.topics().get(0).partitions().get(0), 1, TIME.milliseconds()) - ).value().message()), shard.getShareStateMapValue(shareCoordinatorKey)); - assertEquals(0, shard.getLeaderMapValue(shareCoordinatorKey)); + GROUP_ID, TOPIC_ID, PARTITION, ShareGroupOffset.fromRequest(request.topics().get(0).partitions().get(0), snapshotEpoch, TIME.milliseconds()) + ).value().message()), shard.getShareStateMapValue(SHARE_PARTITION_KEY)); + assertEquals(0, shard.getLeaderMapValue(SHARE_PARTITION_KEY)); verify(shard.getMetricsShard(), times(3)).record(ShareCoordinatorMetrics.SHARE_COORDINATOR_WRITE_SENSOR_NAME); } @Test public void testSubsequentWriteStateSnapshotEpochUpdatesSuccessfully() { - ShareCoordinatorShard shard = new ShareCoordinatorShardBuilder().build(); - - SharePartitionKey shareCoordinatorKey = SharePartitionKey.getInstance(GROUP_ID, TOPIC_ID, PARTITION); + initSharePartition(shard, SHARE_PARTITION_KEY); WriteShareGroupStateRequestData request1 = new WriteShareGroupStateRequestData() .setGroupId(GROUP_ID) @@ -428,7 +464,7 @@ class ShareCoordinatorShardTest { shard.replay(0L, 0L, (short) 0, result.records().get(0)); WriteShareGroupStateResponseData expectedData = WriteShareGroupStateResponse.toResponseData(TOPIC_ID, PARTITION); - List expectedRecords = List.of(ShareCoordinatorRecordHelpers.newShareSnapshotRecord( + List expectedRecords = List.of(ShareCoordinatorRecordHelpers.newShareUpdateRecord( GROUP_ID, TOPIC_ID, PARTITION, ShareGroupOffset.fromRequest(request1.topics().get(0).partitions().get(0), TIME.milliseconds()) )); @@ -436,8 +472,8 @@ class ShareCoordinatorShardTest { assertEquals(expectedRecords, result.records()); assertEquals(groupOffset(expectedRecords.get(0).value().message()), - shard.getShareStateMapValue(shareCoordinatorKey)); - assertEquals(0, shard.getLeaderMapValue(shareCoordinatorKey)); + shard.getShareStateMapValue(SHARE_PARTITION_KEY)); + assertEquals(0, shard.getLeaderMapValue(SHARE_PARTITION_KEY)); result = shard.writeState(request2); @@ -454,24 +490,20 @@ class ShareCoordinatorShardTest { assertEquals(expectedRecords, result.records()); ShareGroupOffset incrementalUpdate = groupOffset(expectedRecords.get(0).value().message()); - ShareGroupOffset combinedState = shard.getShareStateMapValue(shareCoordinatorKey); + ShareGroupOffset combinedState = shard.getShareStateMapValue(SHARE_PARTITION_KEY); assertEquals(incrementalUpdate.snapshotEpoch(), combinedState.snapshotEpoch()); assertEquals(incrementalUpdate.leaderEpoch(), combinedState.leaderEpoch()); assertEquals(incrementalUpdate.startOffset(), combinedState.startOffset()); // The batches should have combined to 1 since same state. assertEquals(List.of(new PersisterStateBatch(0, 20, (byte) 0, (short) 1)), combinedState.stateBatches()); - assertEquals(0, shard.getLeaderMapValue(shareCoordinatorKey)); + assertEquals(0, shard.getLeaderMapValue(SHARE_PARTITION_KEY)); } @Test public void testWriteStateInvalidRequestData() { - ShareCoordinatorShard shard = new ShareCoordinatorShardBuilder().build(); - int partition = -1; - SharePartitionKey shareCoordinatorKey = SharePartitionKey.getInstance(GROUP_ID, TOPIC_ID, PARTITION); - WriteShareGroupStateRequestData request = new WriteShareGroupStateRequestData() .setGroupId(GROUP_ID) .setTopics(List.of(new WriteShareGroupStateRequestData.WriteStateData() @@ -496,16 +528,15 @@ class ShareCoordinatorShardTest { assertEquals(expectedData, result.response()); assertEquals(expectedRecords, result.records()); - assertNull(shard.getShareStateMapValue(shareCoordinatorKey)); - assertNull(shard.getLeaderMapValue(shareCoordinatorKey)); + assertNull(shard.getShareStateMapValue(SHARE_PARTITION_KEY)); + assertNull(shard.getLeaderMapValue(SHARE_PARTITION_KEY)); } @Test public void testWriteNullMetadataImage() { - ShareCoordinatorShard shard = new ShareCoordinatorShardBuilder().build(); - shard.onNewMetadataImage(null, null); + initSharePartition(shard, SHARE_PARTITION_KEY); - SharePartitionKey shareCoordinatorKey = SharePartitionKey.getInstance(GROUP_ID, TOPIC_ID, PARTITION); + shard.onNewMetadataImage(null, null); WriteShareGroupStateRequestData request = new WriteShareGroupStateRequestData() .setGroupId(GROUP_ID) @@ -530,16 +561,12 @@ class ShareCoordinatorShardTest { assertEquals(expectedData, result.response()); assertEquals(expectedRecords, result.records()); - - assertNull(shard.getShareStateMapValue(shareCoordinatorKey)); - assertNull(shard.getLeaderMapValue(shareCoordinatorKey)); + assertEquals(-1, shard.getLeaderMapValue(SHARE_PARTITION_KEY)); } @Test public void testWriteStateFencedLeaderEpochError() { - ShareCoordinatorShard shard = new ShareCoordinatorShardBuilder().build(); - - SharePartitionKey shareCoordinatorKey = SharePartitionKey.getInstance(GROUP_ID, TOPIC_ID, PARTITION); + initSharePartition(shard, SHARE_PARTITION_KEY); WriteShareGroupStateRequestData request1 = new WriteShareGroupStateRequestData() .setGroupId(GROUP_ID) @@ -576,7 +603,7 @@ class ShareCoordinatorShardTest { shard.replay(0L, 0L, (short) 0, result.records().get(0)); WriteShareGroupStateResponseData expectedData = WriteShareGroupStateResponse.toResponseData(TOPIC_ID, PARTITION); - List expectedRecords = List.of(ShareCoordinatorRecordHelpers.newShareSnapshotRecord( + List expectedRecords = List.of(ShareCoordinatorRecordHelpers.newShareUpdateRecord( GROUP_ID, TOPIC_ID, PARTITION, ShareGroupOffset.fromRequest(request1.topics().get(0).partitions().get(0), TIME.milliseconds()) )); @@ -584,8 +611,8 @@ class ShareCoordinatorShardTest { assertEquals(expectedRecords, result.records()); assertEquals(groupOffset(expectedRecords.get(0).value().message()), - shard.getShareStateMapValue(shareCoordinatorKey)); - assertEquals(5, shard.getLeaderMapValue(shareCoordinatorKey)); + shard.getShareStateMapValue(SHARE_PARTITION_KEY)); + assertEquals(5, shard.getLeaderMapValue(SHARE_PARTITION_KEY)); result = shard.writeState(request2); @@ -598,14 +625,12 @@ class ShareCoordinatorShardTest { assertEquals(expectedRecords, result.records()); // No changes to the leaderMap. - assertEquals(5, shard.getLeaderMapValue(shareCoordinatorKey)); + assertEquals(5, shard.getLeaderMapValue(SHARE_PARTITION_KEY)); } @Test public void testWriteStateFencedStateEpochError() { - ShareCoordinatorShard shard = new ShareCoordinatorShardBuilder().build(); - - SharePartitionKey shareCoordinatorKey = SharePartitionKey.getInstance(GROUP_ID, TOPIC_ID, PARTITION); + initSharePartition(shard, SHARE_PARTITION_KEY); WriteShareGroupStateRequestData request1 = new WriteShareGroupStateRequestData() .setGroupId(GROUP_ID) @@ -643,15 +668,15 @@ class ShareCoordinatorShardTest { WriteShareGroupStateResponseData expectedData = WriteShareGroupStateResponse.toResponseData(TOPIC_ID, PARTITION); List expectedRecords = List.of(ShareCoordinatorRecordHelpers.newShareSnapshotRecord( - GROUP_ID, TOPIC_ID, PARTITION, ShareGroupOffset.fromRequest(request1.topics().get(0).partitions().get(0), TIME.milliseconds()) + GROUP_ID, TOPIC_ID, PARTITION, ShareGroupOffset.fromRequest(request1.topics().get(0).partitions().get(0), 1, TIME.milliseconds()) )); assertEquals(expectedData, result.response()); assertEquals(expectedRecords, result.records()); assertEquals(groupOffset(expectedRecords.get(0).value().message()), - shard.getShareStateMapValue(shareCoordinatorKey)); - assertEquals(5, shard.getLeaderMapValue(shareCoordinatorKey)); + shard.getShareStateMapValue(SHARE_PARTITION_KEY)); + assertEquals(5, shard.getLeaderMapValue(SHARE_PARTITION_KEY)); result = shard.writeState(request2); @@ -664,15 +689,34 @@ class ShareCoordinatorShardTest { assertEquals(expectedRecords, result.records()); // No changes to the stateEpochMap. - assertEquals(1, shard.getStateEpochMapValue(shareCoordinatorKey)); + assertEquals(1, shard.getStateEpochMapValue(SHARE_PARTITION_KEY)); + } + + @Test + public void testReadFailsOnUninitializedPartition() { + ReadShareGroupStateRequestData request = new ReadShareGroupStateRequestData() + .setGroupId(GROUP_ID) + .setTopics(List.of(new ReadShareGroupStateRequestData.ReadStateData() + .setTopicId(TOPIC_ID) + .setPartitions(List.of(new ReadShareGroupStateRequestData.PartitionData() + .setPartition(PARTITION) + .setLeaderEpoch(1))))); + + CoordinatorResult result = shard.readStateAndMaybeUpdateLeaderEpoch(request); + + assertEquals(ReadShareGroupStateResponse.toErrorResponseData( + TOPIC_ID, + PARTITION, + Errors.INVALID_REQUEST, + ShareCoordinatorShard.READ_UNINITIALIZED_SHARE_PARTITION.getMessage() + ), result.response()); + + assertNull(shard.getLeaderMapValue(SHARE_PARTITION_KEY)); } @Test public void testReadStateSuccess() { - ShareCoordinatorShard shard = new ShareCoordinatorShardBuilder().build(); - - SharePartitionKey coordinatorKey = SharePartitionKey.getInstance(GROUP_ID, TOPIC_ID, PARTITION); - + initSharePartition(shard, SHARE_PARTITION_KEY); writeAndReplayDefaultRecord(shard); ReadShareGroupStateRequestData request = new ReadShareGroupStateRequestData() @@ -698,15 +742,12 @@ class ShareCoordinatorShardTest { ) ), result.response()); - assertEquals(0, shard.getLeaderMapValue(coordinatorKey)); + assertEquals(0, shard.getLeaderMapValue(SHARE_PARTITION_KEY)); } @Test public void testReadStateSummarySuccess() { - ShareCoordinatorShard shard = new ShareCoordinatorShardBuilder().build(); - - SharePartitionKey coordinatorKey = SharePartitionKey.getInstance(GROUP_ID, TOPIC_ID, PARTITION); - + initSharePartition(shard, SHARE_PARTITION_KEY); writeAndReplayDefaultRecord(shard); ReadShareGroupStateSummaryRequestData request = new ReadShareGroupStateSummaryRequestData() @@ -727,19 +768,15 @@ class ShareCoordinatorShardTest { 0 ), result.response()); - assertEquals(0, shard.getLeaderMapValue(coordinatorKey)); + assertEquals(0, shard.getLeaderMapValue(SHARE_PARTITION_KEY)); } @Test public void testReadStateInvalidRequestData() { - ShareCoordinatorShard shard = new ShareCoordinatorShardBuilder().build(); - - int partition = -1; - + initSharePartition(shard, SHARE_PARTITION_KEY); writeAndReplayDefaultRecord(shard); - SharePartitionKey shareCoordinatorKey = SharePartitionKey.getInstance(GROUP_ID, TOPIC_ID, PARTITION); - + int partition = -1; ReadShareGroupStateRequestData request = new ReadShareGroupStateRequestData() .setGroupId(GROUP_ID) .setTopics(List.of(new ReadShareGroupStateRequestData.ReadStateData() @@ -756,19 +793,15 @@ class ShareCoordinatorShardTest { assertEquals(expectedData, result.response()); // Leader epoch should not be changed because the request failed. - assertEquals(0, shard.getLeaderMapValue(shareCoordinatorKey)); + assertEquals(0, shard.getLeaderMapValue(SHARE_PARTITION_KEY)); } @Test public void testReadStateSummaryInvalidRequestData() { - ShareCoordinatorShard shard = new ShareCoordinatorShardBuilder().build(); - - int partition = -1; - + initSharePartition(shard, SHARE_PARTITION_KEY); writeAndReplayDefaultRecord(shard); - SharePartitionKey shareCoordinatorKey = SharePartitionKey.getInstance(GROUP_ID, TOPIC_ID, PARTITION); - + int partition = -1; ReadShareGroupStateSummaryRequestData request = new ReadShareGroupStateSummaryRequestData() .setGroupId(GROUP_ID) .setTopics(List.of(new ReadShareGroupStateSummaryRequestData.ReadStateSummaryData() @@ -785,19 +818,16 @@ class ShareCoordinatorShardTest { assertEquals(expectedData, result.response()); // Leader epoch should not be changed because the request failed. - assertEquals(0, shard.getLeaderMapValue(shareCoordinatorKey)); + assertEquals(0, shard.getLeaderMapValue(SHARE_PARTITION_KEY)); } @Test public void testReadNullMetadataImage() { - ShareCoordinatorShard shard = new ShareCoordinatorShardBuilder().build(); - + initSharePartition(shard, SHARE_PARTITION_KEY); writeAndReplayDefaultRecord(shard); shard.onNewMetadataImage(null, null); - SharePartitionKey shareCoordinatorKey = SharePartitionKey.getInstance(GROUP_ID, TOPIC_ID, PARTITION); - ReadShareGroupStateRequestData request = new ReadShareGroupStateRequestData() .setGroupId(GROUP_ID) .setTopics(List.of(new ReadShareGroupStateRequestData.ReadStateData() @@ -814,19 +844,16 @@ class ShareCoordinatorShardTest { assertEquals(expectedData, result.response()); // Leader epoch should not be changed because the request failed. - assertEquals(0, shard.getLeaderMapValue(shareCoordinatorKey)); + assertEquals(0, shard.getLeaderMapValue(SHARE_PARTITION_KEY)); } @Test public void testReadStateFencedLeaderEpochError() { - ShareCoordinatorShard shard = new ShareCoordinatorShardBuilder().build(); + initSharePartition(shard, SHARE_PARTITION_KEY); int leaderEpoch = 5; - writeAndReplayRecord(shard, leaderEpoch); // leaderEpoch in the leaderMap will be 5. - SharePartitionKey shareCoordinatorKey = SharePartitionKey.getInstance(GROUP_ID, TOPIC_ID, PARTITION); - ReadShareGroupStateRequestData request = new ReadShareGroupStateRequestData() .setGroupId(GROUP_ID) .setTopics(List.of(new ReadShareGroupStateRequestData.ReadStateData() @@ -845,7 +872,7 @@ class ShareCoordinatorShardTest { assertEquals(expectedData, result.response()); - assertEquals(leaderEpoch, shard.getLeaderMapValue(shareCoordinatorKey)); + assertEquals(leaderEpoch, shard.getLeaderMapValue(SHARE_PARTITION_KEY)); } @Test @@ -875,11 +902,11 @@ class ShareCoordinatorShardTest { // -Share leader acks batch 3 and sends the new startOffset and the state of batch 3 to share coordinator. // -Share coordinator writes the snapshot with startOffset 110 and batch 3. // -batch2 should NOT be lost - ShareCoordinatorShard shard = new ShareCoordinatorShardBuilder() + shard = new ShareCoordinatorShardBuilder() .setConfigOverrides(Map.of(ShareCoordinatorConfig.SNAPSHOT_UPDATE_RECORDS_PER_SNAPSHOT_CONFIG, "0")) .build(); - SharePartitionKey shareCoordinatorKey = SharePartitionKey.getInstance(GROUP_ID, TOPIC_ID, PARTITION); + initSharePartition(shard, SHARE_PARTITION_KEY); // Set initial state. WriteShareGroupStateRequestData request = new WriteShareGroupStateRequestData() @@ -916,16 +943,16 @@ class ShareCoordinatorShardTest { WriteShareGroupStateResponseData expectedData = WriteShareGroupStateResponse.toResponseData(TOPIC_ID, PARTITION); List expectedRecords = List.of(ShareCoordinatorRecordHelpers.newShareSnapshotRecord( - GROUP_ID, TOPIC_ID, PARTITION, ShareGroupOffset.fromRequest(request.topics().get(0).partitions().get(0), TIME.milliseconds()) + GROUP_ID, TOPIC_ID, PARTITION, ShareGroupOffset.fromRequest(request.topics().get(0).partitions().get(0), 1, TIME.milliseconds()) )); assertEquals(expectedData, result.response()); assertEquals(expectedRecords, result.records()); assertEquals(groupOffset(ShareCoordinatorRecordHelpers.newShareSnapshotRecord( - GROUP_ID, TOPIC_ID, PARTITION, ShareGroupOffset.fromRequest(request.topics().get(0).partitions().get(0), TIME.milliseconds()) - ).value().message()), shard.getShareStateMapValue(shareCoordinatorKey)); - assertEquals(0, shard.getLeaderMapValue(shareCoordinatorKey)); + GROUP_ID, TOPIC_ID, PARTITION, ShareGroupOffset.fromRequest(request.topics().get(0).partitions().get(0), 1, TIME.milliseconds()) + ).value().message()), shard.getShareStateMapValue(SHARE_PARTITION_KEY)); + assertEquals(0, shard.getLeaderMapValue(SHARE_PARTITION_KEY)); verify(shard.getMetricsShard()).record(ShareCoordinatorMetrics.SHARE_COORDINATOR_WRITE_SENSOR_NAME); // Acknowledge b1. @@ -977,7 +1004,7 @@ class ShareCoordinatorShardTest { .setStartOffset(110) .setLeaderEpoch(0) .setStateEpoch(0) - .setSnapshotEpoch(2) // since 2nd share snapshot + .setSnapshotEpoch(3) // since 2nd share snapshot .setStateBatches(List.of( new PersisterStateBatch(110, 119, (byte) 1, (short) 2), // b2 not lost new PersisterStateBatch(120, 129, (byte) 2, (short) 1) @@ -994,15 +1021,15 @@ class ShareCoordinatorShardTest { assertEquals(groupOffset(ShareCoordinatorRecordHelpers.newShareSnapshotRecord( GROUP_ID, TOPIC_ID, PARTITION, offsetFinal - ).value().message()), shard.getShareStateMapValue(shareCoordinatorKey)); - assertEquals(0, shard.getLeaderMapValue(shareCoordinatorKey)); + ).value().message()), shard.getShareStateMapValue(SHARE_PARTITION_KEY)); + assertEquals(0, shard.getLeaderMapValue(SHARE_PARTITION_KEY)); verify(shard.getMetricsShard(), times(3)).record(ShareCoordinatorMetrics.SHARE_COORDINATOR_WRITE_SENSOR_NAME); } @Test public void testLastRedundantOffset() { ShareCoordinatorOffsetsManager manager = mock(ShareCoordinatorOffsetsManager.class); - ShareCoordinatorShard shard = new ShareCoordinatorShardBuilder() + shard = new ShareCoordinatorShardBuilder() .setOffsetsManager(manager) .build(); @@ -1012,9 +1039,7 @@ class ShareCoordinatorShardTest { @Test public void testReadStateLeaderEpochUpdateSuccess() { - ShareCoordinatorShard shard = new ShareCoordinatorShardBuilder().build(); - - SharePartitionKey shareCoordinatorKey = SharePartitionKey.getInstance(GROUP_ID, TOPIC_ID, PARTITION); + initSharePartition(shard, SHARE_PARTITION_KEY); ReadShareGroupStateRequestData request = new ReadShareGroupStateRequestData() .setGroupId(GROUP_ID) @@ -1034,7 +1059,7 @@ class ShareCoordinatorShardTest { PartitionFactory.UNINITIALIZED_START_OFFSET, PartitionFactory.DEFAULT_STATE_EPOCH, List.of()); - List expectedRecords = List.of(ShareCoordinatorRecordHelpers.newShareSnapshotRecord( + List expectedRecords = List.of(ShareCoordinatorRecordHelpers.newShareUpdateRecord( GROUP_ID, TOPIC_ID, PARTITION, new ShareGroupOffset.Builder() .setStartOffset(PartitionFactory.UNINITIALIZED_START_OFFSET) .setLeaderEpoch(2) @@ -1049,14 +1074,14 @@ class ShareCoordinatorShardTest { assertEquals(expectedData, result.response()); assertEquals(expectedRecords, result.records()); - assertEquals(groupOffset(expectedRecords.get(0).value().message()), shard.getShareStateMapValue(shareCoordinatorKey)); - assertEquals(2, shard.getLeaderMapValue(shareCoordinatorKey)); + assertEquals(groupOffset(expectedRecords.get(0).value().message()), shard.getShareStateMapValue(SHARE_PARTITION_KEY)); + assertEquals(2, shard.getLeaderMapValue(SHARE_PARTITION_KEY)); verify(shard.getMetricsShard()).record(ShareCoordinatorMetrics.SHARE_COORDINATOR_WRITE_SENSOR_NAME); } @Test public void testReadStateLeaderEpochUpdateNoUpdate() { - ShareCoordinatorShard shard = new ShareCoordinatorShardBuilder().build(); + initSharePartition(shard, SHARE_PARTITION_KEY); ReadShareGroupStateRequestData request1 = new ReadShareGroupStateRequestData() .setGroupId(GROUP_ID) @@ -1103,10 +1128,6 @@ class ShareCoordinatorShardTest { @Test public void testDeleteStateSuccess() { - ShareCoordinatorShard shard = new ShareCoordinatorShardBuilder().build(); - - SharePartitionKey shareCoordinatorKey = SharePartitionKey.getInstance(GROUP_ID, TOPIC_ID, PARTITION); - DeleteShareGroupStateRequestData request = new DeleteShareGroupStateRequestData() .setGroupId(GROUP_ID) .setTopics(List.of(new DeleteShareGroupStateRequestData.DeleteStateData() @@ -1135,9 +1156,9 @@ class ShareCoordinatorShardTest { .build() ); shard.replay(0L, 0L, (short) 0, record); - assertNotNull(shard.getShareStateMapValue(shareCoordinatorKey)); - assertNotNull(shard.getLeaderMapValue(shareCoordinatorKey)); - assertNotNull(shard.getStateEpochMapValue(shareCoordinatorKey)); + assertNotNull(shard.getShareStateMapValue(SHARE_PARTITION_KEY)); + assertNotNull(shard.getLeaderMapValue(SHARE_PARTITION_KEY)); + assertNotNull(shard.getStateEpochMapValue(SHARE_PARTITION_KEY)); CoordinatorResult result = shard.deleteState(request); @@ -1153,17 +1174,13 @@ class ShareCoordinatorShardTest { assertEquals(expectedData, result.response()); assertEquals(expectedRecords, result.records()); - assertNull(shard.getShareStateMapValue(shareCoordinatorKey)); - assertNull(shard.getLeaderMapValue(shareCoordinatorKey)); - assertNull(shard.getStateEpochMapValue(shareCoordinatorKey)); + assertNull(shard.getShareStateMapValue(SHARE_PARTITION_KEY)); + assertNull(shard.getLeaderMapValue(SHARE_PARTITION_KEY)); + assertNull(shard.getStateEpochMapValue(SHARE_PARTITION_KEY)); } @Test public void testDeleteStateUnintializedRecord() { - ShareCoordinatorShard shard = new ShareCoordinatorShardBuilder().build(); - - SharePartitionKey shareCoordinatorKey = SharePartitionKey.getInstance(GROUP_ID, TOPIC_ID, PARTITION); - DeleteShareGroupStateRequestData request = new DeleteShareGroupStateRequestData() .setGroupId(GROUP_ID) .setTopics(List.of(new DeleteShareGroupStateRequestData.DeleteStateData() @@ -1172,10 +1189,10 @@ class ShareCoordinatorShardTest { .setPartition(PARTITION))))); CoordinatorResult result = shard.deleteState(request); - - assertNull(shard.getShareStateMapValue(shareCoordinatorKey)); - assertNull(shard.getLeaderMapValue(shareCoordinatorKey)); - assertNull(shard.getStateEpochMapValue(shareCoordinatorKey)); + + assertNull(shard.getShareStateMapValue(SHARE_PARTITION_KEY)); + assertNull(shard.getLeaderMapValue(SHARE_PARTITION_KEY)); + assertNull(shard.getStateEpochMapValue(SHARE_PARTITION_KEY)); DeleteShareGroupStateResponseData expectedData = DeleteShareGroupStateResponse.toResponseData(TOPIC_ID, PARTITION); @@ -1185,8 +1202,6 @@ class ShareCoordinatorShardTest { @Test public void testDeleteStateInvalidRequestData() { - ShareCoordinatorShard shard = new ShareCoordinatorShardBuilder().build(); - // invalid partition int partition = -1; @@ -1210,7 +1225,6 @@ class ShareCoordinatorShardTest { @Test public void testDeleteNullMetadataImage() { - ShareCoordinatorShard shard = new ShareCoordinatorShardBuilder().build(); shard.onNewMetadataImage(null, null); DeleteShareGroupStateRequestData request = new DeleteShareGroupStateRequestData() @@ -1232,7 +1246,6 @@ class ShareCoordinatorShardTest { @Test public void testDeleteTopicIdNonExistentInMetadataImage() { - ShareCoordinatorShard shard = new ShareCoordinatorShardBuilder().build(); MetadataImage image = mock(MetadataImage.class); shard.onNewMetadataImage(image, null); @@ -1264,7 +1277,6 @@ class ShareCoordinatorShardTest { @Test public void testDeletePartitionIdNonExistentInMetadataImage() { - ShareCoordinatorShard shard = new ShareCoordinatorShardBuilder().build(); MetadataImage image = mock(MetadataImage.class); shard.onNewMetadataImage(image, null); @@ -1302,10 +1314,6 @@ class ShareCoordinatorShardTest { @Test public void testInitializeStateSuccess() { - ShareCoordinatorShard shard = new ShareCoordinatorShardBuilder().build(); - - SharePartitionKey shareCoordinatorKey = SharePartitionKey.getInstance(GROUP_ID, TOPIC_ID, PARTITION); - InitializeShareGroupStateRequestData request = new InitializeShareGroupStateRequestData() .setGroupId(GROUP_ID) .setTopics(List.of(new InitializeShareGroupStateRequestData.InitializeStateData() @@ -1316,8 +1324,8 @@ class ShareCoordinatorShardTest { .setStateEpoch(5))) )); - assertNull(shard.getShareStateMapValue(shareCoordinatorKey)); - assertNull(shard.getStateEpochMapValue(shareCoordinatorKey)); + assertNull(shard.getShareStateMapValue(SHARE_PARTITION_KEY)); + assertNull(shard.getStateEpochMapValue(SHARE_PARTITION_KEY)); CoordinatorResult result = shard.initializeState(request); result.records().forEach(record -> shard.replay(0L, 0L, (short) 0, record)); @@ -1331,14 +1339,12 @@ class ShareCoordinatorShardTest { assertEquals(expectedData, result.response()); assertEquals(expectedRecords, result.records()); - assertNotNull(shard.getShareStateMapValue(shareCoordinatorKey)); - assertNotNull(shard.getStateEpochMapValue(shareCoordinatorKey)); + assertNotNull(shard.getShareStateMapValue(SHARE_PARTITION_KEY)); + assertNotNull(shard.getStateEpochMapValue(SHARE_PARTITION_KEY)); } @Test public void testInitializeStateInvalidRequestData() { - ShareCoordinatorShard shard = new ShareCoordinatorShardBuilder().build(); - // invalid partition int partition = -1; @@ -1391,7 +1397,6 @@ class ShareCoordinatorShardTest { @Test public void testInitializeNullMetadataImage() { - ShareCoordinatorShard shard = new ShareCoordinatorShardBuilder().build(); shard.onNewMetadataImage(null, null); InitializeShareGroupStateRequestData request = new InitializeShareGroupStateRequestData() @@ -1415,7 +1420,6 @@ class ShareCoordinatorShardTest { @Test public void testInitializeTopicIdNonExistentInMetadataImage() { - ShareCoordinatorShard shard = new ShareCoordinatorShardBuilder().build(); MetadataImage image = mock(MetadataImage.class); shard.onNewMetadataImage(image, null); @@ -1445,7 +1449,6 @@ class ShareCoordinatorShardTest { @Test public void testInitializePartitionIdNonExistentInMetadataImage() { - ShareCoordinatorShard shard = new ShareCoordinatorShardBuilder().build(); MetadataImage image = mock(MetadataImage.class); shard.onNewMetadataImage(image, null); @@ -1479,7 +1482,6 @@ class ShareCoordinatorShardTest { @Test public void testSnapshotColdPartitionsNoEligiblePartitions() { - ShareCoordinatorShard shard = new ShareCoordinatorShardBuilder().build(); MetadataImage image = mock(MetadataImage.class); shard.onNewMetadataImage(image, null); int offset = 0; @@ -1546,7 +1548,6 @@ class ShareCoordinatorShardTest { @Test public void testSnapshotColdPartitionsSnapshotUpdateNotConsidered() { - ShareCoordinatorShard shard = new ShareCoordinatorShardBuilder().build(); MetadataImage image = mock(MetadataImage.class); shard.onNewMetadataImage(image, null); int offset = 0; @@ -1644,7 +1645,6 @@ class ShareCoordinatorShardTest { @Test public void testSnapshotColdPartitionsDoesNotPerpetuallySnapshot() { - ShareCoordinatorShard shard = new ShareCoordinatorShardBuilder().build(); MetadataImage image = mock(MetadataImage.class); shard.onNewMetadataImage(image, null); int offset = 0; @@ -1719,7 +1719,6 @@ class ShareCoordinatorShardTest { @Test public void testSnapshotColdPartitionsPartialEligiblePartitions() { - ShareCoordinatorShard shard = new ShareCoordinatorShardBuilder().build(); MetadataImage image = mock(MetadataImage.class); shard.onNewMetadataImage(image, null); int offset = 0; @@ -1824,8 +1823,6 @@ class ShareCoordinatorShardTest { @Test public void testOnTopicsDeletedEmptyTopicIds() { - ShareCoordinatorShard shard = new ShareCoordinatorShardBuilder().build(); - CoordinatorResult expectedResult = new CoordinatorResult<>(List.of()); assertEquals(expectedResult, shard.maybeCleanupShareState(Set.of())); @@ -1836,7 +1833,6 @@ class ShareCoordinatorShardTest { @Test public void testOnTopicsDeletedTopicIds() { - ShareCoordinatorShard shard = new ShareCoordinatorShardBuilder().build(); MetadataImage image = mock(MetadataImage.class); shard.onNewMetadataImage(image, null); @@ -1909,4 +1905,20 @@ class ShareCoordinatorShardTest { } return ShareGroupOffset.fromRecord((ShareUpdateValue) record); } + + private void initSharePartition(ShareCoordinatorShard shard, SharePartitionKey key) { + shard.replay(0L, 0L, (short) 0, CoordinatorRecord.record( + new ShareSnapshotKey() + .setGroupId(key.groupId()) + .setTopicId(key.topicId()) + .setPartition(key.partition()), + new ApiMessageAndVersion( + new ShareSnapshotValue() + .setStateEpoch(0) + .setLeaderEpoch(-1) + .setStartOffset(-1), + (short) 0 + ) + )); + } }