KAFKA-19110: Add missing unit test for Streams-consumer integration (#19457)

- Construct `AsyncKafkaConsumer` constructor and verify that the
`RequestManagers.supplier()` contains Streams-specific data structures.
- Verify that `RequestManagers` constructs the Streams request managers
correctly
- Test `StreamsGroupHeartbeatManager#resetPollTimer()`
- Test `StreamsOnTasksRevokedCallbackCompletedEvent`,
`StreamsOnTasksAssignedCallbackCompletedEvent`, and
`StreamsOnAllTasksLostCallbackCompletedEvent` in
`ApplicationEventProcessor`
- Test `DefaultStreamsRebalanceListener`
- Test `StreamThread`.
  - Test `handleStreamsRebalanceData`.
  - Test `StreamsRebalanceData`.

Reviewers: Lucas Brutschy <lbrutschy@confluent.io>, Bill Bejeck <bill@confluent.io>
Signed-off-by: PoAn Yang <payang@apache.org>
This commit is contained in:
PoAn Yang 2025-04-24 16:38:22 +08:00 committed by GitHub
parent 8b4560e3f0
commit 3fae785ea0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 787 additions and 6 deletions

View File

@ -52,6 +52,8 @@ import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
import static java.util.Collections.unmodifiableList;
/** /**
* Tracks the state of a single member in relationship to a group: * Tracks the state of a single member in relationship to a group:
* <p/> * <p/>
@ -1305,4 +1307,9 @@ public class StreamsMembershipManager implements RequestManager {
future.complete(null); future.complete(null);
} }
} }
// visible for testing
List<MemberStateListener> stateListeners() {
return unmodifiableList(stateUpdatesListeners);
}
} }

View File

@ -117,6 +117,7 @@ import java.util.Properties;
import java.util.Set; import java.util.Set;
import java.util.SortedSet; import java.util.SortedSet;
import java.util.TreeSet; import java.util.TreeSet;
import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Future; import java.util.concurrent.Future;
@ -205,6 +206,13 @@ public class AsyncKafkaConsumerTest {
} }
private AsyncKafkaConsumer<String, String> newConsumer(Properties props) { private AsyncKafkaConsumer<String, String> newConsumer(Properties props) {
return newConsumerWithStreamRebalanceData(props, null);
}
private AsyncKafkaConsumer<String, String> newConsumerWithStreamRebalanceData(
Properties props,
StreamsRebalanceData streamsRebalanceData
) {
// disable auto-commit by default, so we don't need to handle SyncCommitEvent for each case // disable auto-commit by default, so we don't need to handle SyncCommitEvent for each case
if (!props.containsKey(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG)) { if (!props.containsKey(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG)) {
props.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, false); props.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, false);
@ -220,7 +228,7 @@ public class AsyncKafkaConsumerTest {
(a, b, c, d, e, f, g) -> fetchCollector, (a, b, c, d, e, f, g) -> fetchCollector,
(a, b, c, d) -> metadata, (a, b, c, d) -> metadata,
backgroundEventQueue, backgroundEventQueue,
Optional.empty() Optional.ofNullable(streamsRebalanceData)
); );
} }
@ -1371,6 +1379,51 @@ public class AsyncKafkaConsumerTest {
assertEquals(groupMetadataAfterUnsubscribe, consumer.groupMetadata()); assertEquals(groupMetadataAfterUnsubscribe, consumer.groupMetadata());
} }
private Optional<StreamsRebalanceData> captureStreamRebalanceData(final MockedStatic<RequestManagers> requestManagers) {
ArgumentCaptor<Optional<StreamsRebalanceData>> streamRebalanceData = ArgumentCaptor.forClass(Optional.class);
requestManagers.verify(() -> RequestManagers.supplier(
any(),
any(),
any(),
any(),
any(),
any(),
any(),
any(),
any(),
any(),
any(),
any(),
any(),
any(),
any(),
streamRebalanceData.capture()
));
return streamRebalanceData.getValue();
}
@Test
public void testEmptyStreamRebalanceData() {
final String groupId = "consumerGroupA";
try (final MockedStatic<RequestManagers> requestManagers = mockStatic(RequestManagers.class)) {
consumer = newConsumer(requiredConsumerConfigAndGroupId(groupId));
final Optional<StreamsRebalanceData> groupMetadataUpdateListener = captureStreamRebalanceData(requestManagers);
assertTrue(groupMetadataUpdateListener.isEmpty());
}
}
@Test
public void testStreamRebalanceData() {
final String groupId = "consumerGroupA";
try (final MockedStatic<RequestManagers> requestManagers = mockStatic(RequestManagers.class)) {
StreamsRebalanceData streamsRebalanceData = new StreamsRebalanceData(UUID.randomUUID(), Optional.empty(), Map.of(), Map.of());
consumer = newConsumerWithStreamRebalanceData(requiredConsumerConfigAndGroupId(groupId), streamsRebalanceData);
final Optional<StreamsRebalanceData> groupMetadataUpdateListener = captureStreamRebalanceData(requestManagers);
assertTrue(groupMetadataUpdateListener.isPresent());
assertEquals(streamsRebalanceData, groupMetadataUpdateListener.get());
}
}
/** /**
* Tests that the consumer correctly invokes the callbacks for {@link ConsumerRebalanceListener} that was * Tests that the consumer correctly invokes the callbacks for {@link ConsumerRebalanceListener} that was
* specified. We don't go through the full effort to emulate heartbeats and correct group management here. We're * specified. We don't go through the full effort to emulate heartbeats and correct group management here. We're

View File

@ -26,10 +26,13 @@ import org.apache.kafka.common.utils.MockTime;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.Properties; import java.util.Properties;
import java.util.UUID;
import static org.apache.kafka.test.TestUtils.requiredConsumerConfig; import static org.apache.kafka.test.TestUtils.requiredConsumerConfig;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
@ -65,8 +68,53 @@ public class RequestManagersTest {
listener, listener,
Optional.empty() Optional.empty()
).get(); ).get();
requestManagers.consumerMembershipManager.ifPresent( assertTrue(requestManagers.consumerMembershipManager.isPresent());
membershipManager -> assertTrue(membershipManager.stateListeners().contains(listener)) assertTrue(requestManagers.streamsMembershipManager.isEmpty());
assertTrue(requestManagers.streamsGroupHeartbeatRequestManager.isEmpty());
assertEquals(2, requestManagers.consumerMembershipManager.get().stateListeners().size());
assertTrue(requestManagers.consumerMembershipManager.get().stateListeners().stream()
.anyMatch(m -> m instanceof CommitRequestManager));
assertTrue(requestManagers.consumerMembershipManager.get().stateListeners().contains(listener));
}
@Test
public void testStreamMemberStateListenerRegistered() {
final MemberStateListener listener = (memberEpoch, memberId) -> { };
final Properties properties = requiredConsumerConfig();
properties.setProperty(ConsumerConfig.GROUP_ID_CONFIG, "consumerGroup");
final ConsumerConfig config = new ConsumerConfig(properties);
final GroupRebalanceConfig groupRebalanceConfig = new GroupRebalanceConfig(
config,
GroupRebalanceConfig.ProtocolType.CONSUMER
); );
final RequestManagers requestManagers = RequestManagers.supplier(
new MockTime(),
new LogContext(),
mock(BackgroundEventHandler.class),
mock(ConsumerMetadata.class),
mock(SubscriptionState.class),
mock(FetchBuffer.class),
config,
groupRebalanceConfig,
mock(ApiVersions.class),
mock(FetchMetricsManager.class),
() -> mock(NetworkClientDelegate.class),
Optional.empty(),
new Metrics(),
mock(OffsetCommitCallbackInvoker.class),
listener,
Optional.of(new StreamsRebalanceData(UUID.randomUUID(), Optional.empty(), Map.of(), Map.of()))
).get();
assertTrue(requestManagers.streamsMembershipManager.isPresent());
assertTrue(requestManagers.streamsGroupHeartbeatRequestManager.isPresent());
assertTrue(requestManagers.consumerMembershipManager.isEmpty());
assertEquals(2, requestManagers.streamsMembershipManager.get().stateListeners().size());
assertTrue(requestManagers.streamsMembershipManager.get().stateListeners().stream()
.anyMatch(m -> m instanceof CommitRequestManager));
assertTrue(requestManagers.streamsMembershipManager.get().stateListeners().contains(listener));
} }
} }

View File

@ -1520,6 +1520,36 @@ class StreamsGroupHeartbeatRequestManagerTest {
} }
} }
@Test
public void testResetPollTimer() {
try (final MockedConstruction<Timer> pollTimerMockedConstruction = mockConstruction(Timer.class)) {
final StreamsGroupHeartbeatRequestManager heartbeatRequestManager = createStreamsGroupHeartbeatRequestManager();
final Timer pollTimer = pollTimerMockedConstruction.constructed().get(1);
heartbeatRequestManager.resetPollTimer(time.milliseconds());
verify(pollTimer).update(time.milliseconds());
verify(pollTimer).isExpired();
verify(pollTimer).reset(DEFAULT_MAX_POLL_INTERVAL_MS);
}
}
@Test
public void testResetPollTimerWhenExpired() {
try (final MockedConstruction<Timer> pollTimerMockedConstruction = mockConstruction(Timer.class)) {
final StreamsGroupHeartbeatRequestManager heartbeatRequestManager = createStreamsGroupHeartbeatRequestManager();
final Timer pollTimer = pollTimerMockedConstruction.constructed().get(1);
when(pollTimer.isExpired()).thenReturn(true);
heartbeatRequestManager.resetPollTimer(time.milliseconds());
verify(pollTimer).update(time.milliseconds());
verify(pollTimer).isExpired();
verify(pollTimer).isExpiredBy();
verify(membershipManager).memberId();
verify(membershipManager).maybeRejoinStaleMember();
verify(pollTimer).reset(DEFAULT_MAX_POLL_INTERVAL_MS);
}
}
private static ConsumerConfig config() { private static ConsumerConfig config() {
Properties prop = new Properties(); Properties prop = new Properties();
prop.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class); prop.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class);

View File

@ -31,15 +31,19 @@ import org.apache.kafka.clients.consumer.internals.MockRebalanceListener;
import org.apache.kafka.clients.consumer.internals.NetworkClientDelegate; import org.apache.kafka.clients.consumer.internals.NetworkClientDelegate;
import org.apache.kafka.clients.consumer.internals.OffsetsRequestManager; import org.apache.kafka.clients.consumer.internals.OffsetsRequestManager;
import org.apache.kafka.clients.consumer.internals.RequestManagers; import org.apache.kafka.clients.consumer.internals.RequestManagers;
import org.apache.kafka.clients.consumer.internals.StreamsGroupHeartbeatRequestManager;
import org.apache.kafka.clients.consumer.internals.StreamsMembershipManager;
import org.apache.kafka.clients.consumer.internals.SubscriptionState; import org.apache.kafka.clients.consumer.internals.SubscriptionState;
import org.apache.kafka.clients.consumer.internals.TopicMetadataRequestManager; import org.apache.kafka.clients.consumer.internals.TopicMetadataRequestManager;
import org.apache.kafka.common.Cluster; import org.apache.kafka.common.Cluster;
import org.apache.kafka.common.KafkaException; import org.apache.kafka.common.KafkaException;
import org.apache.kafka.common.TopicPartition; import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.utils.LogCaptureAppender;
import org.apache.kafka.common.utils.LogContext; import org.apache.kafka.common.utils.LogContext;
import org.apache.kafka.common.utils.MockTime; import org.apache.kafka.common.utils.MockTime;
import org.apache.kafka.common.utils.Time; import org.apache.kafka.common.utils.Time;
import org.apache.logging.log4j.Level;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
@ -86,6 +90,8 @@ public class ApplicationEventProcessorTest {
private final OffsetsRequestManager offsetsRequestManager = mock(OffsetsRequestManager.class); private final OffsetsRequestManager offsetsRequestManager = mock(OffsetsRequestManager.class);
private SubscriptionState subscriptionState = mock(SubscriptionState.class); private SubscriptionState subscriptionState = mock(SubscriptionState.class);
private final ConsumerMetadata metadata = mock(ConsumerMetadata.class); private final ConsumerMetadata metadata = mock(ConsumerMetadata.class);
private final StreamsGroupHeartbeatRequestManager streamsGroupHeartbeatRequestManager = mock(StreamsGroupHeartbeatRequestManager.class);
private final StreamsMembershipManager streamsMembershipManager = mock(StreamsMembershipManager.class);
private ApplicationEventProcessor processor; private ApplicationEventProcessor processor;
private void setupProcessor(boolean withGroupId) { private void setupProcessor(boolean withGroupId) {
@ -109,6 +115,27 @@ public class ApplicationEventProcessorTest {
); );
} }
private void setupStreamProcessor(boolean withGroupId) {
RequestManagers requestManagers = new RequestManagers(
new LogContext(),
offsetsRequestManager,
mock(TopicMetadataRequestManager.class),
mock(FetchRequestManager.class),
withGroupId ? Optional.of(mock(CoordinatorRequestManager.class)) : Optional.empty(),
withGroupId ? Optional.of(commitRequestManager) : Optional.empty(),
withGroupId ? Optional.of(heartbeatRequestManager) : Optional.empty(),
Optional.empty(),
withGroupId ? Optional.of(streamsGroupHeartbeatRequestManager) : Optional.empty(),
withGroupId ? Optional.of(streamsMembershipManager) : Optional.empty()
);
processor = new ApplicationEventProcessor(
new LogContext(),
requestManagers,
metadata,
subscriptionState
);
}
@Test @Test
public void testPrepClosingCommitEvents() { public void testPrepClosingCommitEvents() {
setupProcessor(true); setupProcessor(true);
@ -556,6 +583,78 @@ public class ApplicationEventProcessorTest {
assertFutureThrows(IllegalStateException.class, event.future()); assertFutureThrows(IllegalStateException.class, event.future());
} }
@Test
public void testStreamsOnTasksRevokedCallbackCompletedEvent() {
setupStreamProcessor(true);
StreamsOnTasksRevokedCallbackCompletedEvent event =
new StreamsOnTasksRevokedCallbackCompletedEvent(new CompletableFuture<>(), Optional.empty());
processor.process(event);
verify(streamsMembershipManager).onTasksRevokedCallbackCompleted(event);
}
@Test
public void testStreamsOnTasksRevokedCallbackCompletedEventWithoutStreamsMembershipManager() {
setupStreamProcessor(false);
StreamsOnTasksRevokedCallbackCompletedEvent event =
new StreamsOnTasksRevokedCallbackCompletedEvent(new CompletableFuture<>(), Optional.empty());
try (final LogCaptureAppender logAppender = LogCaptureAppender.createAndRegister()) {
logAppender.setClassLogger(ApplicationEventProcessor.class, Level.WARN);
processor.process(event);
assertTrue(logAppender.getMessages().stream().anyMatch(e ->
e.contains("An internal error occurred; the Streams membership manager was not present, so the notification " +
"of the onTasksRevoked callback execution could not be sent")));
verify(streamsMembershipManager, never()).onTasksRevokedCallbackCompleted(event);
}
}
@Test
public void testStreamsOnTasksAssignedCallbackCompletedEvent() {
setupStreamProcessor(true);
StreamsOnTasksAssignedCallbackCompletedEvent event =
new StreamsOnTasksAssignedCallbackCompletedEvent(new CompletableFuture<>(), Optional.empty());
processor.process(event);
verify(streamsMembershipManager).onTasksAssignedCallbackCompleted(event);
}
@Test
public void testStreamsOnTasksAssignedCallbackCompletedEventWithoutStreamsMembershipManager() {
setupStreamProcessor(false);
StreamsOnTasksAssignedCallbackCompletedEvent event =
new StreamsOnTasksAssignedCallbackCompletedEvent(new CompletableFuture<>(), Optional.empty());
try (final LogCaptureAppender logAppender = LogCaptureAppender.createAndRegister()) {
logAppender.setClassLogger(ApplicationEventProcessor.class, Level.WARN);
processor.process(event);
assertTrue(logAppender.getMessages().stream().anyMatch(e ->
e.contains("An internal error occurred; the Streams membership manager was not present, so the notification " +
"of the onTasksAssigned callback execution could not be sent")));
verify(streamsMembershipManager, never()).onTasksAssignedCallbackCompleted(event);
}
}
@Test
public void testStreamsOnAllTasksLostCallbackCompletedEvent() {
setupStreamProcessor(true);
StreamsOnAllTasksLostCallbackCompletedEvent event =
new StreamsOnAllTasksLostCallbackCompletedEvent(new CompletableFuture<>(), Optional.empty());
processor.process(event);
verify(streamsMembershipManager).onAllTasksLostCallbackCompleted(event);
}
@Test
public void testStreamsOnAllTasksLostCallbackCompletedEventWithoutStreamsMembershipManager() {
setupStreamProcessor(false);
StreamsOnAllTasksLostCallbackCompletedEvent event =
new StreamsOnAllTasksLostCallbackCompletedEvent(new CompletableFuture<>(), Optional.empty());
try (final LogCaptureAppender logAppender = LogCaptureAppender.createAndRegister()) {
logAppender.setClassLogger(ApplicationEventProcessor.class, Level.WARN);
processor.process(event);
assertTrue(logAppender.getMessages().stream().anyMatch(e ->
e.contains("An internal error occurred; the Streams membership manager was not present, so the notification " +
"of the onAllTasksLost callback execution could not be sent")));
verify(streamsMembershipManager, never()).onAllTasksLostCallbackCompleted(event);
}
}
private List<NetworkClientDelegate.UnsentRequest> mockCommitResults() { private List<NetworkClientDelegate.UnsentRequest> mockCommitResults() {
return Collections.singletonList(mock(NetworkClientDelegate.UnsentRequest.class)); return Collections.singletonList(mock(NetworkClientDelegate.UnsentRequest.class));
} }

View File

@ -534,7 +534,7 @@ public class StreamThread extends Thread implements ProcessingThread {
final Map<String, Object> consumerConfigs) { final Map<String, Object> consumerConfigs) {
if (config.getString(StreamsConfig.GROUP_PROTOCOL_CONFIG).equalsIgnoreCase(GroupProtocol.STREAMS.name)) { if (config.getString(StreamsConfig.GROUP_PROTOCOL_CONFIG).equalsIgnoreCase(GroupProtocol.STREAMS.name)) {
if (topologyMetadata.hasNamedTopologies()) { if (topologyMetadata.hasNamedTopologies()) {
throw new IllegalStateException("Named topologies and the CONSUMER protocol cannot be used at the same time."); throw new IllegalStateException("Named topologies and the STREAMS protocol cannot be used at the same time.");
} }
log.info("Streams rebalance protocol enabled"); log.info("Streams rebalance protocol enabled");
@ -2022,4 +2022,8 @@ public class StreamThread extends Thread implements ProcessingThread {
Admin adminClient() { Admin adminClient() {
return adminClient; return adminClient;
} }
Optional<StreamsRebalanceData> streamsRebalanceData() {
return streamsRebalanceData;
}
} }

View File

@ -0,0 +1,208 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.streams.processor.internals;
import org.apache.kafka.clients.consumer.internals.StreamsRebalanceData;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.utils.MockTime;
import org.apache.kafka.streams.processor.TaskId;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
import org.mockito.InOrder;
import org.slf4j.LoggerFactory;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class DefaultStreamsRebalanceListenerTest {
private final TaskManager taskManager = mock(TaskManager.class);
private final StreamThread streamThread = mock(StreamThread.class);
private DefaultStreamsRebalanceListener defaultStreamsRebalanceListener = new DefaultStreamsRebalanceListener(
LoggerFactory.getLogger(DefaultStreamsRebalanceListener.class),
new MockTime(),
mock(StreamsRebalanceData.class),
streamThread,
taskManager
);
private void createRebalanceListenerWithRebalanceData(final StreamsRebalanceData streamsRebalanceData) {
defaultStreamsRebalanceListener = new DefaultStreamsRebalanceListener(
LoggerFactory.getLogger(DefaultStreamsRebalanceListener.class),
new MockTime(),
streamsRebalanceData,
streamThread,
taskManager
);
}
@ParameterizedTest
@EnumSource(StreamThread.State.class)
void testOnTasksRevoked(final StreamThread.State state) {
createRebalanceListenerWithRebalanceData(new StreamsRebalanceData(
UUID.randomUUID(),
Optional.empty(),
Map.of(
"1",
new StreamsRebalanceData.Subtopology(
Set.of("source1"),
Set.of(),
Map.of("repartition1", new StreamsRebalanceData.TopicInfo(Optional.of(1), Optional.of((short) 1), Map.of())),
Map.of(),
Set.of()
)
),
Map.of()
));
when(streamThread.state()).thenReturn(state);
final Optional<Exception> result = defaultStreamsRebalanceListener.onTasksRevoked(
Set.of(new StreamsRebalanceData.TaskId("1", 0))
);
assertTrue(result.isEmpty());
final InOrder inOrder = inOrder(taskManager, streamThread);
inOrder.verify(taskManager).handleRevocation(
Set.of(new TopicPartition("source1", 0), new TopicPartition("repartition1", 0))
);
inOrder.verify(streamThread).state();
if (state != StreamThread.State.PENDING_SHUTDOWN) {
inOrder.verify(streamThread).setState(StreamThread.State.PARTITIONS_REVOKED);
} else {
inOrder.verify(streamThread, never()).setState(StreamThread.State.PARTITIONS_REVOKED);
}
}
@Test
void testOnTasksRevokedWithException() {
final Exception exception = new RuntimeException("sample exception");
doThrow(exception).when(taskManager).handleRevocation(any());
createRebalanceListenerWithRebalanceData(new StreamsRebalanceData(UUID.randomUUID(), Optional.empty(), Map.of(), Map.of()));
final Optional<Exception> result = defaultStreamsRebalanceListener.onTasksRevoked(Set.of());
assertTrue(result.isPresent());
verify(taskManager).handleRevocation(any());
verify(streamThread, never()).setState(any());
}
@Test
void testOnTasksAssigned() {
createRebalanceListenerWithRebalanceData(new StreamsRebalanceData(
UUID.randomUUID(),
Optional.empty(),
Map.of(
"1",
new StreamsRebalanceData.Subtopology(
Set.of("source1"),
Set.of(),
Map.of("repartition1", new StreamsRebalanceData.TopicInfo(Optional.of(1), Optional.of((short) 1), Map.of())),
Map.of(),
Set.of()
),
"2",
new StreamsRebalanceData.Subtopology(
Set.of("source2"),
Set.of(),
Map.of("repartition2", new StreamsRebalanceData.TopicInfo(Optional.of(1), Optional.of((short) 1), Map.of())),
Map.of(),
Set.of()
),
"3",
new StreamsRebalanceData.Subtopology(
Set.of("source3"),
Set.of(),
Map.of("repartition3", new StreamsRebalanceData.TopicInfo(Optional.of(1), Optional.of((short) 1), Map.of())),
Map.of(),
Set.of()
)
),
Map.of()
));
final Optional<Exception> result = defaultStreamsRebalanceListener.onTasksAssigned(
new StreamsRebalanceData.Assignment(
Set.of(new StreamsRebalanceData.TaskId("1", 0)),
Set.of(new StreamsRebalanceData.TaskId("2", 0)),
Set.of(new StreamsRebalanceData.TaskId("3", 0))
)
);
assertTrue(result.isEmpty());
final InOrder inOrder = inOrder(taskManager, streamThread);
inOrder.verify(taskManager).handleAssignment(
Map.of(new TaskId(1, 0), Set.of(new TopicPartition("source1", 0), new TopicPartition("repartition1", 0))),
Map.of(
new TaskId(2, 0), Set.of(new TopicPartition("source2", 0), new TopicPartition("repartition2", 0)),
new TaskId(3, 0), Set.of(new TopicPartition("source3", 0), new TopicPartition("repartition3", 0))
)
);
inOrder.verify(streamThread).setState(StreamThread.State.PARTITIONS_ASSIGNED);
inOrder.verify(taskManager).handleRebalanceComplete();
}
@Test
void testOnTasksAssignedWithException() {
final Exception exception = new RuntimeException("sample exception");
doThrow(exception).when(taskManager).handleAssignment(any(), any());
createRebalanceListenerWithRebalanceData(new StreamsRebalanceData(UUID.randomUUID(), Optional.empty(), Map.of(), Map.of()));
final Optional<Exception> result = defaultStreamsRebalanceListener.onTasksAssigned(new StreamsRebalanceData.Assignment(Set.of(), Set.of(), Set.of()));
assertTrue(defaultStreamsRebalanceListener.onAllTasksLost().isEmpty());
assertTrue(result.isPresent());
assertEquals(exception, result.get());
verify(taskManager).handleLostAll();
verify(streamThread, never()).setState(StreamThread.State.PARTITIONS_ASSIGNED);
verify(taskManager, never()).handleRebalanceComplete();
}
@Test
void testOnAllTasksLost() {
createRebalanceListenerWithRebalanceData(new StreamsRebalanceData(UUID.randomUUID(), Optional.empty(), Map.of(), Map.of()));
assertTrue(defaultStreamsRebalanceListener.onAllTasksLost().isEmpty());
verify(taskManager).handleLostAll();
}
@Test
void testOnAllTasksLostWithException() {
final Exception exception = new RuntimeException("sample exception");
doThrow(exception).when(taskManager).handleLostAll();
createRebalanceListenerWithRebalanceData(new StreamsRebalanceData(UUID.randomUUID(), Optional.empty(), Map.of(), Map.of()));
final Optional<Exception> result = defaultStreamsRebalanceListener.onAllTasksLost();
assertTrue(result.isPresent());
assertEquals(exception, result.get());
verify(taskManager).handleLostAll();
}
}

View File

@ -25,8 +25,10 @@ import org.apache.kafka.clients.consumer.ConsumerRecords;
import org.apache.kafka.clients.consumer.InvalidOffsetException; import org.apache.kafka.clients.consumer.InvalidOffsetException;
import org.apache.kafka.clients.consumer.MockConsumer; import org.apache.kafka.clients.consumer.MockConsumer;
import org.apache.kafka.clients.consumer.OffsetAndMetadata; import org.apache.kafka.clients.consumer.OffsetAndMetadata;
import org.apache.kafka.clients.consumer.internals.AsyncKafkaConsumer;
import org.apache.kafka.clients.consumer.internals.AutoOffsetResetStrategy; import org.apache.kafka.clients.consumer.internals.AutoOffsetResetStrategy;
import org.apache.kafka.clients.consumer.internals.MockRebalanceListener; import org.apache.kafka.clients.consumer.internals.MockRebalanceListener;
import org.apache.kafka.clients.consumer.internals.StreamsRebalanceData;
import org.apache.kafka.clients.producer.MockProducer; import org.apache.kafka.clients.producer.MockProducer;
import org.apache.kafka.clients.producer.Producer; import org.apache.kafka.clients.producer.Producer;
import org.apache.kafka.common.Cluster; import org.apache.kafka.common.Cluster;
@ -37,6 +39,7 @@ import org.apache.kafka.common.Node;
import org.apache.kafka.common.PartitionInfo; import org.apache.kafka.common.PartitionInfo;
import org.apache.kafka.common.TopicPartition; import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.Uuid; import org.apache.kafka.common.Uuid;
import org.apache.kafka.common.config.TopicConfig;
import org.apache.kafka.common.errors.InvalidPidMappingException; import org.apache.kafka.common.errors.InvalidPidMappingException;
import org.apache.kafka.common.errors.ProducerFencedException; import org.apache.kafka.common.errors.ProducerFencedException;
import org.apache.kafka.common.errors.TimeoutException; import org.apache.kafka.common.errors.TimeoutException;
@ -55,9 +58,11 @@ import org.apache.kafka.common.utils.LogCaptureAppender;
import org.apache.kafka.common.utils.LogContext; import org.apache.kafka.common.utils.LogContext;
import org.apache.kafka.common.utils.MockTime; import org.apache.kafka.common.utils.MockTime;
import org.apache.kafka.common.utils.Time; import org.apache.kafka.common.utils.Time;
import org.apache.kafka.streams.GroupProtocol;
import org.apache.kafka.streams.StreamsConfig; import org.apache.kafka.streams.StreamsConfig;
import org.apache.kafka.streams.StreamsConfig.InternalConfig; import org.apache.kafka.streams.StreamsConfig.InternalConfig;
import org.apache.kafka.streams.ThreadMetadata; import org.apache.kafka.streams.ThreadMetadata;
import org.apache.kafka.streams.TopologyConfig;
import org.apache.kafka.streams.errors.LogAndContinueExceptionHandler; import org.apache.kafka.streams.errors.LogAndContinueExceptionHandler;
import org.apache.kafka.streams.errors.StreamsException; import org.apache.kafka.streams.errors.StreamsException;
import org.apache.kafka.streams.errors.TaskCorruptedException; import org.apache.kafka.streams.errors.TaskCorruptedException;
@ -77,9 +82,11 @@ import org.apache.kafka.streams.processor.api.ProcessorContext;
import org.apache.kafka.streams.processor.api.ProcessorSupplier; import org.apache.kafka.streams.processor.api.ProcessorSupplier;
import org.apache.kafka.streams.processor.api.Record; import org.apache.kafka.streams.processor.api.Record;
import org.apache.kafka.streams.processor.internals.StreamThread.State; import org.apache.kafka.streams.processor.internals.StreamThread.State;
import org.apache.kafka.streams.processor.internals.assignment.AssignorError;
import org.apache.kafka.streams.processor.internals.assignment.ReferenceContainer; import org.apache.kafka.streams.processor.internals.assignment.ReferenceContainer;
import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
import org.apache.kafka.streams.processor.internals.tasks.DefaultTaskManager; import org.apache.kafka.streams.processor.internals.tasks.DefaultTaskManager;
import org.apache.kafka.streams.state.HostInfo;
import org.apache.kafka.streams.state.KeyValueStore; import org.apache.kafka.streams.state.KeyValueStore;
import org.apache.kafka.streams.state.StoreBuilder; import org.apache.kafka.streams.state.StoreBuilder;
import org.apache.kafka.streams.state.Stores; import org.apache.kafka.streams.state.Stores;
@ -96,6 +103,7 @@ import org.apache.kafka.test.TestUtils;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
@ -154,6 +162,7 @@ import static org.hamcrest.Matchers.isA;
import static org.hamcrest.core.IsInstanceOf.instanceOf; import static org.hamcrest.core.IsInstanceOf.instanceOf;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertSame;
@ -3573,6 +3582,329 @@ public class StreamThreadTest {
); );
} }
@Test
public void testNamedTopologyWithStreamsProtocol() {
final Properties props = configProps(false, false, false);
props.setProperty(StreamsConfig.GROUP_PROTOCOL_CONFIG, GroupProtocol.STREAMS.toString());
final StreamsConfig config = new StreamsConfig(props);
final InternalTopologyBuilder topologyBuilder = new InternalTopologyBuilder(
new TopologyConfig(
"my-topology",
config,
new Properties())
);
final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl(
metrics,
APPLICATION_ID,
PROCESS_ID.toString(),
mockTime
);
final TopologyMetadata topologyMetadata = new TopologyMetadata(topologyBuilder, config);
topologyMetadata.buildAndRewriteTopology();
stateDirectory = new StateDirectory(config, mockTime, true, false);
final StreamsMetadataState streamsMetadataState = new StreamsMetadataState(
new TopologyMetadata(internalTopologyBuilder, config),
StreamsMetadataState.UNKNOWN_HOST,
new LogContext(String.format("stream-client [%s] ", CLIENT_ID))
);
final IllegalStateException exception = assertThrows(IllegalStateException.class, () ->
StreamThread.create(
topologyMetadata,
config,
clientSupplier,
clientSupplier.getAdmin(config.getAdminConfigs(CLIENT_ID)),
PROCESS_ID,
CLIENT_ID,
streamsMetrics,
mockTime,
streamsMetadataState,
0,
stateDirectory,
new MockStateRestoreListener(),
new MockStandbyUpdateListener(),
threadIdx,
null,
HANDLER
)
);
assertEquals("Named topologies and the STREAMS protocol cannot be used at the same time.", exception.getMessage());
}
@Test
public void testStreamsRebalanceDataWithClassicProtocol() {
final Properties props = configProps(false, false, false);
props.setProperty(StreamsConfig.GROUP_PROTOCOL_CONFIG, GroupProtocol.CLASSIC.toString());
thread = createStreamThread(CLIENT_ID, new StreamsConfig(props));
assertTrue(thread.streamsRebalanceData().isEmpty());
}
@Test
public void testStreamsRebalanceDataWithExtraCopartition() {
final Properties props = configProps(false, false, false);
props.setProperty(StreamsConfig.GROUP_PROTOCOL_CONFIG, GroupProtocol.STREAMS.toString());
internalTopologyBuilder.addSource(null, "source1", null, null, null, topic1);
final StreamsConfig config = new StreamsConfig(props);
final InternalTopologyBuilder topologyBuilder = mock(InternalTopologyBuilder.class);
when(topologyBuilder.subtopologyToTopicsInfo()).thenReturn(Map.of());
when(topologyBuilder.copartitionGroups()).thenReturn(Set.of(Set.of("source1")));
final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl(
metrics,
APPLICATION_ID,
PROCESS_ID.toString(),
mockTime
);
final TopologyMetadata topologyMetadata = new TopologyMetadata(topologyBuilder, config);
topologyMetadata.buildAndRewriteTopology();
stateDirectory = new StateDirectory(config, mockTime, true, false);
final StreamsMetadataState streamsMetadataState = new StreamsMetadataState(
new TopologyMetadata(internalTopologyBuilder, config),
StreamsMetadataState.UNKNOWN_HOST,
new LogContext(String.format("stream-client [%s] ", CLIENT_ID))
);
final IllegalStateException exception = assertThrows(IllegalStateException.class, () ->
StreamThread.create(
topologyMetadata,
config,
clientSupplier,
clientSupplier.getAdmin(config.getAdminConfigs(CLIENT_ID)),
PROCESS_ID,
CLIENT_ID,
streamsMetrics,
mockTime,
streamsMetadataState,
0,
stateDirectory,
new MockStateRestoreListener(),
new MockStandbyUpdateListener(),
threadIdx,
null,
HANDLER
)
);
assertEquals("Not all copartition groups were converted to broker topology", exception.getMessage());
}
@Test
public void testStreamsRebalanceDataWithStreamsProtocol() {
final Properties props = configProps(false, false, false);
props.setProperty(StreamsConfig.GROUP_PROTOCOL_CONFIG, GroupProtocol.STREAMS.toString());
props.setProperty(StreamsConfig.APPLICATION_SERVER_CONFIG, "localhost:1234");
props.setProperty(StreamsConfig.REPLICATION_FACTOR_CONFIG, "1");
internalTopologyBuilder.addSource(null, "source1", null, null, null, topic1);
final StreamsConfig config = new StreamsConfig(props);
final InternalTopologyBuilder topologyBuilder = mock(InternalTopologyBuilder.class);
when(topologyBuilder.subtopologyToTopicsInfo()).thenReturn(Map.of(
new TopologyMetadata.Subtopology(1, "subTopology1"),
new InternalTopologyBuilder.TopicsInfo(
Set.of("repartitionSource1"),
Set.of("source1"),
Map.of(
"repartitionSource1",
new RepartitionTopicConfig("repartitionSource1", Map.of(), 1, false)
),
Map.of(
"stateChangeTopic1",
new RepartitionTopicConfig("stateChangeTopic1", Map.of(), 1, false)
)
)
));
when(topologyBuilder.copartitionGroups()).thenReturn(Set.of(Set.of("source1")));
final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl(
metrics,
APPLICATION_ID,
PROCESS_ID.toString(),
mockTime
);
final TopologyMetadata topologyMetadata = new TopologyMetadata(topologyBuilder, config);
topologyMetadata.buildAndRewriteTopology();
stateDirectory = new StateDirectory(config, mockTime, true, false);
final StreamsMetadataState streamsMetadataState = new StreamsMetadataState(
new TopologyMetadata(internalTopologyBuilder, config),
StreamsMetadataState.UNKNOWN_HOST,
new LogContext(String.format("stream-client [%s] ", CLIENT_ID))
);
thread = StreamThread.create(
topologyMetadata,
config,
clientSupplier,
clientSupplier.getAdmin(config.getAdminConfigs(CLIENT_ID)),
PROCESS_ID,
CLIENT_ID,
streamsMetrics,
mockTime,
streamsMetadataState,
0,
stateDirectory,
new MockStateRestoreListener(),
new MockStandbyUpdateListener(),
threadIdx,
null,
HANDLER
);
assertInstanceOf(AsyncKafkaConsumer.class, thread.mainConsumer());
assertTrue(thread.streamsRebalanceData().isPresent());
assertEquals(PROCESS_ID, thread.streamsRebalanceData().get().processId());
assertTrue(thread.streamsRebalanceData().get().endpoint().isPresent());
assertEquals(new StreamsRebalanceData.HostInfo("localhost", 1234),
thread.streamsRebalanceData().get().endpoint().get());
final Map<String, String> topicConfigs = Map.of(
TopicConfig.SEGMENT_BYTES_CONFIG, "52428800",
TopicConfig.MESSAGE_TIMESTAMP_TYPE_CONFIG, TimestampType.CREATE_TIME.name,
TopicConfig.RETENTION_MS_CONFIG, "-1",
TopicConfig.CLEANUP_POLICY_CONFIG, TopicConfig.CLEANUP_POLICY_DELETE
);
assertEquals(1, thread.streamsRebalanceData().get().subtopologies().size());
final StreamsRebalanceData.Subtopology subtopology = thread.streamsRebalanceData().get().subtopologies().get("1");
assertEquals(Set.of("source1"), subtopology.sourceTopics());
assertEquals(Set.of("repartitionSource1"), subtopology.repartitionSinkTopics());
assertEquals(1, subtopology.repartitionSourceTopics().size());
assertEquals(Optional.of(1), subtopology.repartitionSourceTopics().get("repartitionSource1").numPartitions());
assertEquals(Optional.of((short) 1), subtopology.repartitionSourceTopics().get("repartitionSource1").replicationFactor());
assertEquals(topicConfigs, subtopology.repartitionSourceTopics().get("repartitionSource1").topicConfigs());
assertEquals(1, subtopology.stateChangelogTopics().size());
assertEquals(Optional.of(1), subtopology.stateChangelogTopics().get("stateChangeTopic1").numPartitions());
assertEquals(Optional.of((short) 1), subtopology.stateChangelogTopics().get("stateChangeTopic1").replicationFactor());
assertEquals(topicConfigs, subtopology.stateChangelogTopics().get("stateChangeTopic1").topicConfigs());
assertEquals(1, subtopology.copartitionGroups().size());
assertEquals(Set.of("source1"), subtopology.copartitionGroups().stream().findFirst().get());
}
@Test
public void testStreamsProtocolRunOnceWithoutProcessingThreads() {
final ConsumerGroupMetadata consumerGroupMetadata = Mockito.mock(ConsumerGroupMetadata.class);
when(consumerGroupMetadata.groupInstanceId()).thenReturn(Optional.empty());
when(mainConsumer.poll(Mockito.any(Duration.class))).thenReturn(new ConsumerRecords<>(Map.of(), Map.of()));
when(mainConsumer.groupMetadata()).thenReturn(consumerGroupMetadata);
final StreamsRebalanceData streamsRebalanceData = new StreamsRebalanceData(
UUID.randomUUID(),
Optional.empty(),
Map.of(),
Map.of()
);
final AtomicInteger assignmentErrorCode = new AtomicInteger(0);
final Properties props = configProps(false, false, false);
final StreamsConfig config = new StreamsConfig(props);
thread = new StreamThread(
new MockTime(1),
config,
null,
mainConsumer,
consumer,
changelogReader,
null,
mock(TaskManager.class),
null,
new StreamsMetricsImpl(metrics, CLIENT_ID, PROCESS_ID.toString(), mockTime),
new TopologyMetadata(internalTopologyBuilder, config),
PROCESS_ID,
CLIENT_ID,
new LogContext(""),
assignmentErrorCode,
new AtomicLong(Long.MAX_VALUE),
new LinkedList<>(),
null,
HANDLER,
null,
Optional.of(streamsRebalanceData),
null
).updateThreadMetadata(adminClientId(CLIENT_ID));
thread.setState(State.STARTING);
thread.runOnceWithoutProcessingThreads();
assertEquals(0, assignmentErrorCode.get());
streamsRebalanceData.requestShutdown();
thread.runOnceWithoutProcessingThreads();
assertEquals(AssignorError.SHUTDOWN_REQUESTED.code(), assignmentErrorCode.get());
}
@Test
public void testStreamsProtocolRunOnceWithProcessingThreads() {
final ConsumerGroupMetadata consumerGroupMetadata = Mockito.mock(ConsumerGroupMetadata.class);
when(consumerGroupMetadata.groupInstanceId()).thenReturn(Optional.empty());
when(mainConsumer.poll(Mockito.any(Duration.class))).thenReturn(new ConsumerRecords<>(Map.of(), Map.of()));
when(mainConsumer.groupMetadata()).thenReturn(consumerGroupMetadata);
final StreamsRebalanceData streamsRebalanceData = new StreamsRebalanceData(
UUID.randomUUID(),
Optional.empty(),
Map.of(),
Map.of()
);
final AtomicInteger assignmentErrorCode = new AtomicInteger(0);
final Properties props = configProps(false, false, false);
final StreamsConfig config = new StreamsConfig(props);
thread = new StreamThread(
new MockTime(1),
config,
null,
mainConsumer,
consumer,
changelogReader,
null,
mock(TaskManager.class),
null,
new StreamsMetricsImpl(metrics, CLIENT_ID, PROCESS_ID.toString(), mockTime),
new TopologyMetadata(internalTopologyBuilder, config),
PROCESS_ID,
CLIENT_ID,
new LogContext(""),
assignmentErrorCode,
new AtomicLong(Long.MAX_VALUE),
new LinkedList<>(),
null,
HANDLER,
null,
Optional.of(streamsRebalanceData),
null
).updateThreadMetadata(adminClientId(CLIENT_ID));
thread.setState(State.STARTING);
thread.runOnceWithProcessingThreads();
assertEquals(0, assignmentErrorCode.get());
streamsRebalanceData.requestShutdown();
thread.runOnceWithProcessingThreads();
assertEquals(AssignorError.SHUTDOWN_REQUESTED.code(), assignmentErrorCode.get());
}
@Test
public void testGetTopicPartitionInfo() {
assertEquals(
Map.of(
t1p1, new PartitionInfo(t1p1.topic(), t1p1.partition(), null, new Node[0], new Node[0]),
t1p2, new PartitionInfo(t1p2.topic(), t1p2.partition(), null, new Node[0], new Node[0]),
t2p1, new PartitionInfo(t2p1.topic(), t2p1.partition(), null, new Node[0], new Node[0])
),
StreamThread.getTopicPartitionInfo(
Map.of(
new HostInfo("localhost", 9092), Set.of(t1p1, t2p1),
new HostInfo("localhost", 9094), Set.of(t1p2)
)
)
);
}
private StreamThread setUpThread(final Properties streamsConfigProps) { private StreamThread setUpThread(final Properties streamsConfigProps) {
final StreamsConfig config = new StreamsConfig(streamsConfigProps); final StreamsConfig config = new StreamsConfig(streamsConfigProps);
final ConsumerGroupMetadata consumerGroupMetadata = Mockito.mock(ConsumerGroupMetadata.class); final ConsumerGroupMetadata consumerGroupMetadata = Mockito.mock(ConsumerGroupMetadata.class);