KAFKA-19694: Trigger StreamsRebalanceListener in Consumer.close (#20511)

In the consumer, we invoke the consumer rebalance onPartitionRevoked or
onPartitionLost callbacks, when the consumer closes. The point is that
the application may want to commit, or wipe the state if we are closing
unsuccessfully.

In the StreamsRebalanceListener, we did not implement this behavior,
which means when closing the consumer we may lose some progress, and in
the worst case also miss that we have to wipe our local state state
since we got fenced.

In this PR we implement StreamsRebalanceListenerInvoker, very similarly
to ConsumerRebalanceListenerInvoker and invoke it in Consumer.close.

Reviewers: Lianet Magrans <lmagrans@confluent.io>, Matthias J. Sax
 <matthias@confluent.io>, TengYao Chi <frankvicky@apache.org>,
 Uladzislau Blok <123193120+UladzislauBlok@users.noreply.github.com>

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Lucas Brutschy 2025-09-16 16:32:47 +02:00 committed by GitHub
parent daa7aae0c1
commit 2c347380b7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 579 additions and 103 deletions

View File

@ -187,25 +187,6 @@ public class AsyncKafkaConsumer<K, V> implements ConsumerDelegate<K, V> {
*/
private class BackgroundEventProcessor implements EventProcessor<BackgroundEvent> {
private Optional<StreamsRebalanceListener> streamsRebalanceListener = Optional.empty();
private final Optional<StreamsRebalanceData> streamsRebalanceData;
public BackgroundEventProcessor() {
this.streamsRebalanceData = Optional.empty();
}
public BackgroundEventProcessor(final Optional<StreamsRebalanceData> streamsRebalanceData) {
this.streamsRebalanceData = streamsRebalanceData;
}
private void setStreamsRebalanceListener(final StreamsRebalanceListener streamsRebalanceListener) {
if (streamsRebalanceData.isEmpty()) {
throw new IllegalStateException("Background event processor was not created to be used with Streams " +
"rebalance protocol events");
}
this.streamsRebalanceListener = Optional.of(streamsRebalanceListener);
}
@Override
public void process(final BackgroundEvent event) {
switch (event.type()) {
@ -278,44 +259,26 @@ public class AsyncKafkaConsumer<K, V> implements ConsumerDelegate<K, V> {
private StreamsOnTasksRevokedCallbackCompletedEvent invokeOnTasksRevokedCallback(final Set<StreamsRebalanceData.TaskId> activeTasksToRevoke,
final CompletableFuture<Void> future) {
final Optional<Exception> exceptionFromCallback = streamsRebalanceListener().onTasksRevoked(activeTasksToRevoke);
final Optional<Exception> exceptionFromCallback = Optional.ofNullable(streamsRebalanceListenerInvoker().invokeTasksRevoked(activeTasksToRevoke));
final Optional<KafkaException> error = exceptionFromCallback.map(e -> ConsumerUtils.maybeWrapAsKafkaException(e, "Task revocation callback throws an error"));
return new StreamsOnTasksRevokedCallbackCompletedEvent(future, error);
}
private StreamsOnTasksAssignedCallbackCompletedEvent invokeOnTasksAssignedCallback(final StreamsRebalanceData.Assignment assignment,
final CompletableFuture<Void> future) {
final Optional<KafkaException> error;
final Optional<Exception> exceptionFromCallback = streamsRebalanceListener().onTasksAssigned(assignment);
if (exceptionFromCallback.isPresent()) {
error = Optional.of(ConsumerUtils.maybeWrapAsKafkaException(exceptionFromCallback.get(), "Task assignment callback throws an error"));
} else {
error = Optional.empty();
streamsRebalanceData().setReconciledAssignment(assignment);
}
final Optional<Exception> exceptionFromCallback = Optional.ofNullable(streamsRebalanceListenerInvoker().invokeTasksAssigned(assignment));
final Optional<KafkaException> error = exceptionFromCallback.map(e -> ConsumerUtils.maybeWrapAsKafkaException(e, "Task assignment callback throws an error"));
return new StreamsOnTasksAssignedCallbackCompletedEvent(future, error);
}
private StreamsOnAllTasksLostCallbackCompletedEvent invokeOnAllTasksLostCallback(final CompletableFuture<Void> future) {
final Optional<KafkaException> error;
final Optional<Exception> exceptionFromCallback = streamsRebalanceListener().onAllTasksLost();
if (exceptionFromCallback.isPresent()) {
error = Optional.of(ConsumerUtils.maybeWrapAsKafkaException(exceptionFromCallback.get(), "All tasks lost callback throws an error"));
} else {
error = Optional.empty();
streamsRebalanceData().setReconciledAssignment(StreamsRebalanceData.Assignment.EMPTY);
}
final Optional<Exception> exceptionFromCallback = Optional.ofNullable(streamsRebalanceListenerInvoker().invokeAllTasksLost());
final Optional<KafkaException> error = exceptionFromCallback.map(e -> ConsumerUtils.maybeWrapAsKafkaException(e, "All tasks lost callback throws an error"));
return new StreamsOnAllTasksLostCallbackCompletedEvent(future, error);
}
private StreamsRebalanceData streamsRebalanceData() {
return streamsRebalanceData.orElseThrow(
() -> new IllegalStateException("Background event processor was not created to be used with Streams " +
"rebalance protocol events"));
}
private StreamsRebalanceListener streamsRebalanceListener() {
return streamsRebalanceListener.orElseThrow(
private StreamsRebalanceListenerInvoker streamsRebalanceListenerInvoker() {
return streamsRebalanceListenerInvoker.orElseThrow(
() -> new IllegalStateException("Background event processor was not created to be used with Streams " +
"rebalance protocol events"));
}
@ -367,6 +330,7 @@ public class AsyncKafkaConsumer<K, V> implements ConsumerDelegate<K, V> {
private final WakeupTrigger wakeupTrigger = new WakeupTrigger();
private final OffsetCommitCallbackInvoker offsetCommitCallbackInvoker;
private final ConsumerRebalanceListenerInvoker rebalanceListenerInvoker;
private final Optional<StreamsRebalanceListenerInvoker> streamsRebalanceListenerInvoker;
// Last triggered async commit future. Used to wait until all previous async commits are completed.
// We only need to keep track of the last one, since they are guaranteed to complete in order.
private CompletableFuture<Map<TopicPartition, OffsetAndMetadata>> lastPendingAsyncCommit = null;
@ -517,7 +481,9 @@ public class AsyncKafkaConsumer<K, V> implements ConsumerDelegate<K, V> {
time,
new RebalanceCallbackMetricsManager(metrics)
);
this.backgroundEventProcessor = new BackgroundEventProcessor(streamsRebalanceData);
this.streamsRebalanceListenerInvoker = streamsRebalanceData.map(s ->
new StreamsRebalanceListenerInvoker(logContext, s));
this.backgroundEventProcessor = new BackgroundEventProcessor();
this.backgroundEventReaper = backgroundEventReaperFactory.build(logContext);
// The FetchCollector is only used on the application thread.
@ -577,6 +543,7 @@ public class AsyncKafkaConsumer<K, V> implements ConsumerDelegate<K, V> {
this.time = time;
this.backgroundEventQueue = backgroundEventQueue;
this.rebalanceListenerInvoker = rebalanceListenerInvoker;
this.streamsRebalanceListenerInvoker = Optional.empty();
this.backgroundEventProcessor = new BackgroundEventProcessor();
this.backgroundEventReaper = backgroundEventReaper;
this.metrics = metrics;
@ -699,6 +666,7 @@ public class AsyncKafkaConsumer<K, V> implements ConsumerDelegate<K, V> {
networkClientDelegateSupplier,
requestManagersSupplier,
asyncConsumerMetrics);
this.streamsRebalanceListenerInvoker = Optional.empty();
this.backgroundEventProcessor = new BackgroundEventProcessor();
this.backgroundEventReaper = new CompletableEventReaper(logContext);
}
@ -1477,7 +1445,7 @@ public class AsyncKafkaConsumer<K, V> implements ConsumerDelegate<K, V> {
() -> autoCommitOnClose(closeTimer), firstException);
swallow(log, Level.ERROR, "Failed to stop finding coordinator",
this::stopFindCoordinatorOnClose, firstException);
swallow(log, Level.ERROR, "Failed to release group assignment",
swallow(log, Level.ERROR, "Failed to run rebalance callbacks",
this::runRebalanceCallbacksOnClose, firstException);
swallow(log, Level.ERROR, "Failed to leave group while closing consumer",
() -> leaveGroupOnClose(closeTimer, membershipOperation), firstException);
@ -1527,26 +1495,39 @@ public class AsyncKafkaConsumer<K, V> implements ConsumerDelegate<K, V> {
}
private void runRebalanceCallbacksOnClose() {
if (groupMetadata.get().isEmpty() || rebalanceListenerInvoker == null)
if (groupMetadata.get().isEmpty())
return;
int memberEpoch = groupMetadata.get().get().generationId();
Set<TopicPartition> assignedPartitions = groupAssignmentSnapshot.get();
Exception error = null;
if (assignedPartitions.isEmpty())
// Nothing to revoke.
return;
if (streamsRebalanceListenerInvoker != null && streamsRebalanceListenerInvoker.isPresent()) {
SortedSet<TopicPartition> droppedPartitions = new TreeSet<>(TOPIC_PARTITION_COMPARATOR);
droppedPartitions.addAll(assignedPartitions);
if (memberEpoch > 0) {
error = streamsRebalanceListenerInvoker.get().invokeAllTasksRevoked();
} else {
error = streamsRebalanceListenerInvoker.get().invokeAllTasksLost();
}
final Exception error;
} else if (rebalanceListenerInvoker != null) {
if (memberEpoch > 0)
error = rebalanceListenerInvoker.invokePartitionsRevoked(droppedPartitions);
else
error = rebalanceListenerInvoker.invokePartitionsLost(droppedPartitions);
Set<TopicPartition> assignedPartitions = groupAssignmentSnapshot.get();
if (assignedPartitions.isEmpty())
// Nothing to revoke.
return;
SortedSet<TopicPartition> droppedPartitions = new TreeSet<>(TOPIC_PARTITION_COMPARATOR);
droppedPartitions.addAll(assignedPartitions);
if (memberEpoch > 0) {
error = rebalanceListenerInvoker.invokePartitionsRevoked(droppedPartitions);
} else {
error = rebalanceListenerInvoker.invokePartitionsLost(droppedPartitions);
}
}
if (error != null)
throw ConsumerUtils.maybeWrapAsKafkaException(error);
@ -1963,8 +1944,12 @@ public class AsyncKafkaConsumer<K, V> implements ConsumerDelegate<K, V> {
}
public void subscribe(Collection<String> topics, StreamsRebalanceListener streamsRebalanceListener) {
streamsRebalanceListenerInvoker
.orElseThrow(() -> new IllegalStateException("Consumer was not created to be used with Streams rebalance protocol events"))
.setRebalanceListener(streamsRebalanceListener);
subscribeInternal(topics, Optional.empty());
backgroundEventProcessor.setStreamsRebalanceListener(streamsRebalanceListener);
}
@Override

View File

@ -0,0 +1,117 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.clients.consumer.internals;
import org.apache.kafka.common.errors.InterruptException;
import org.apache.kafka.common.errors.WakeupException;
import org.apache.kafka.common.utils.LogContext;
import org.slf4j.Logger;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
/**
* This class encapsulates the invocation of the callback methods defined in the {@link StreamsRebalanceListener}
* interface. When streams group task assignment changes, these methods are invoked. This class wraps those
* callback calls with some logging and error handling.
*/
public class StreamsRebalanceListenerInvoker {
private final Logger log;
private final StreamsRebalanceData streamsRebalanceData;
private Optional<StreamsRebalanceListener> listener;
StreamsRebalanceListenerInvoker(LogContext logContext, StreamsRebalanceData streamsRebalanceData) {
this.log = logContext.logger(getClass());
this.listener = Optional.empty();
this.streamsRebalanceData = streamsRebalanceData;
}
public void setRebalanceListener(StreamsRebalanceListener streamsRebalanceListener) {
Objects.requireNonNull(streamsRebalanceListener, "StreamsRebalanceListener cannot be null");
this.listener = Optional.of(streamsRebalanceListener);
}
public Exception invokeAllTasksRevoked() {
if (listener.isEmpty()) {
throw new IllegalStateException("StreamsRebalanceListener is not defined");
}
return invokeTasksRevoked(streamsRebalanceData.reconciledAssignment().activeTasks());
}
public Exception invokeTasksAssigned(final StreamsRebalanceData.Assignment assignment) {
if (listener.isEmpty()) {
throw new IllegalStateException("StreamsRebalanceListener is not defined");
}
log.info("Invoking tasks assigned callback for new assignment: {}", assignment);
try {
listener.get().onTasksAssigned(assignment);
} catch (WakeupException | InterruptException e) {
throw e;
} catch (Exception e) {
log.error(
"Streams rebalance listener failed on invocation of onTasksAssigned for tasks {}",
assignment,
e
);
return e;
}
return null;
}
public Exception invokeTasksRevoked(final Set<StreamsRebalanceData.TaskId> tasks) {
if (listener.isEmpty()) {
throw new IllegalStateException("StreamsRebalanceListener is not defined");
}
log.info("Invoking task revoked callback for revoked active tasks {}", tasks);
try {
listener.get().onTasksRevoked(tasks);
} catch (WakeupException | InterruptException e) {
throw e;
} catch (Exception e) {
log.error(
"Streams rebalance listener failed on invocation of onTasksRevoked for tasks {}",
tasks,
e
);
return e;
}
return null;
}
public Exception invokeAllTasksLost() {
if (listener.isEmpty()) {
throw new IllegalStateException("StreamsRebalanceListener is not defined");
}
log.info("Invoking tasks lost callback for all tasks");
try {
listener.get().onAllTasksLost();
} catch (WakeupException | InterruptException e) {
throw e;
} catch (Exception e) {
log.error(
"Streams rebalance listener failed on invocation of onTasksLost.",
e
);
return e;
}
return null;
}
}

View File

@ -2210,6 +2210,73 @@ public class AsyncKafkaConsumerTest {
}).when(applicationEventHandler).add(ArgumentMatchers.isA(CommitEvent.class));
}
@Test
public void testCloseInvokesStreamsRebalanceListenerOnTasksRevokedWhenMemberEpochPositive() {
final String groupId = "streamsGroup";
final StreamsRebalanceData streamsRebalanceData = new StreamsRebalanceData(UUID.randomUUID(), Optional.empty(), Map.of(), Map.of());
try (final MockedStatic<RequestManagers> requestManagers = mockStatic(RequestManagers.class)) {
consumer = newConsumerWithStreamRebalanceData(requiredConsumerConfigAndGroupId(groupId), streamsRebalanceData);
StreamsRebalanceListener mockStreamsListener = mock(StreamsRebalanceListener.class);
when(mockStreamsListener.onTasksRevoked(any())).thenReturn(Optional.empty());
consumer.subscribe(singletonList("topic"), mockStreamsListener);
final MemberStateListener groupMetadataUpdateListener = captureGroupMetadataUpdateListener(requestManagers);
final int memberEpoch = 42;
final String memberId = "memberId";
groupMetadataUpdateListener.onMemberEpochUpdated(Optional.of(memberEpoch), memberId);
consumer.close(CloseOptions.timeout(Duration.ZERO));
verify(mockStreamsListener).onTasksRevoked(any());
}
}
@Test
public void testCloseInvokesStreamsRebalanceListenerOnAllTasksLostWhenMemberEpochZeroOrNegative() {
final String groupId = "streamsGroup";
final StreamsRebalanceData streamsRebalanceData = new StreamsRebalanceData(UUID.randomUUID(), Optional.empty(), Map.of(), Map.of());
try (final MockedStatic<RequestManagers> requestManagers = mockStatic(RequestManagers.class)) {
consumer = newConsumerWithStreamRebalanceData(requiredConsumerConfigAndGroupId(groupId), streamsRebalanceData);
StreamsRebalanceListener mockStreamsListener = mock(StreamsRebalanceListener.class);
when(mockStreamsListener.onAllTasksLost()).thenReturn(Optional.empty());
consumer.subscribe(singletonList("topic"), mockStreamsListener);
final MemberStateListener groupMetadataUpdateListener = captureGroupMetadataUpdateListener(requestManagers);
final int memberEpoch = 0;
final String memberId = "memberId";
groupMetadataUpdateListener.onMemberEpochUpdated(Optional.of(memberEpoch), memberId);
consumer.close(CloseOptions.timeout(Duration.ZERO));
verify(mockStreamsListener).onAllTasksLost();
}
}
@Test
public void testCloseWrapsStreamsRebalanceListenerException() {
final String groupId = "streamsGroup";
final StreamsRebalanceData streamsRebalanceData = new StreamsRebalanceData(UUID.randomUUID(), Optional.empty(), Map.of(), Map.of());
try (final MockedStatic<RequestManagers> requestManagers = mockStatic(RequestManagers.class)) {
consumer = newConsumerWithStreamRebalanceData(requiredConsumerConfigAndGroupId(groupId), streamsRebalanceData);
StreamsRebalanceListener mockStreamsListener = mock(StreamsRebalanceListener.class);
RuntimeException testException = new RuntimeException("Test streams listener exception");
doThrow(testException).when(mockStreamsListener).onTasksRevoked(any());
consumer.subscribe(singletonList("topic"), mockStreamsListener);
final MemberStateListener groupMetadataUpdateListener = captureGroupMetadataUpdateListener(requestManagers);
final int memberEpoch = 1;
final String memberId = "memberId";
groupMetadataUpdateListener.onMemberEpochUpdated(Optional.of(memberEpoch), memberId);
KafkaException thrownException = assertThrows(KafkaException.class,
() -> consumer.close(CloseOptions.timeout(Duration.ZERO)));
assertInstanceOf(RuntimeException.class, thrownException.getCause());
assertTrue(thrownException.getCause().getMessage().contains("Test streams listener exception"));
verify(mockStreamsListener).onTasksRevoked(any());
}
}
private void markReconcileAndAutoCommitCompleteForPollEvent() {
doAnswer(invocation -> {
PollEvent event = invocation.getArgument(0);

View File

@ -0,0 +1,293 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.clients.consumer.internals;
import org.apache.kafka.common.errors.InterruptException;
import org.apache.kafka.common.errors.WakeupException;
import org.apache.kafka.common.utils.LogContext;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.mockito.junit.jupiter.MockitoSettings;
import org.mockito.quality.Strictness;
import java.util.Optional;
import java.util.Set;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
@MockitoSettings(strictness = Strictness.STRICT_STUBS)
public class StreamsRebalanceListenerInvokerTest {
@Mock
private StreamsRebalanceListener mockListener;
@Mock
private StreamsRebalanceData streamsRebalanceData;
private StreamsRebalanceListenerInvoker invoker;
private final LogContext logContext = new LogContext();
@BeforeEach
public void setup() {
invoker = new StreamsRebalanceListenerInvoker(logContext, streamsRebalanceData);
}
@Test
public void testSetRebalanceListenerWithNull() {
NullPointerException exception = assertThrows(NullPointerException.class,
() -> invoker.setRebalanceListener(null));
assertEquals("StreamsRebalanceListener cannot be null", exception.getMessage());
}
@Test
public void testSetRebalanceListenerOverwritesExisting() {
StreamsRebalanceListener firstListener = org.mockito.Mockito.mock(StreamsRebalanceListener.class);
StreamsRebalanceListener secondListener = org.mockito.Mockito.mock(StreamsRebalanceListener.class);
StreamsRebalanceData.Assignment mockAssignment = createMockAssignment();
when(streamsRebalanceData.reconciledAssignment()).thenReturn(mockAssignment);
when(secondListener.onTasksRevoked(any())).thenReturn(Optional.empty());
// Set first listener
invoker.setRebalanceListener(firstListener);
// Overwrite with second listener
invoker.setRebalanceListener(secondListener);
// Should use second listener
invoker.invokeAllTasksRevoked();
verify(firstListener, never()).onTasksRevoked(any());
verify(secondListener).onTasksRevoked(eq(mockAssignment.activeTasks()));
}
@Test
public void testInvokeMethodsWithNoListener() {
IllegalStateException exception1 = assertThrows(IllegalStateException.class,
() -> invoker.invokeAllTasksRevoked());
assertEquals("StreamsRebalanceListener is not defined", exception1.getMessage());
IllegalStateException exception2 = assertThrows(IllegalStateException.class,
() -> invoker.invokeTasksAssigned(createMockAssignment()));
assertEquals("StreamsRebalanceListener is not defined", exception2.getMessage());
IllegalStateException exception3 = assertThrows(IllegalStateException.class,
() -> invoker.invokeTasksRevoked(createMockTasks()));
assertEquals("StreamsRebalanceListener is not defined", exception3.getMessage());
IllegalStateException exception4 = assertThrows(IllegalStateException.class,
() -> invoker.invokeAllTasksLost());
assertEquals("StreamsRebalanceListener is not defined", exception4.getMessage());
}
@Test
public void testInvokeAllTasksRevokedWithListener() {
invoker.setRebalanceListener(mockListener);
StreamsRebalanceData.Assignment mockAssignment = createMockAssignment();
when(streamsRebalanceData.reconciledAssignment()).thenReturn(mockAssignment);
when(mockListener.onTasksRevoked(any())).thenReturn(Optional.empty());
Exception result = invoker.invokeAllTasksRevoked();
assertNull(result);
verify(mockListener).onTasksRevoked(eq(mockAssignment.activeTasks()));
}
@Test
public void testInvokeTasksAssignedWithListener() {
invoker.setRebalanceListener(mockListener);
StreamsRebalanceData.Assignment assignment = createMockAssignment();
when(mockListener.onTasksAssigned(assignment)).thenReturn(Optional.empty());
Exception result = invoker.invokeTasksAssigned(assignment);
assertNull(result);
verify(mockListener).onTasksAssigned(eq(assignment));
}
@Test
public void testInvokeTasksAssignedWithWakeupException() {
invoker.setRebalanceListener(mockListener);
StreamsRebalanceData.Assignment assignment = createMockAssignment();
WakeupException wakeupException = new WakeupException();
doThrow(wakeupException).when(mockListener).onTasksAssigned(assignment);
WakeupException thrownException = assertThrows(WakeupException.class,
() -> invoker.invokeTasksAssigned(assignment));
assertEquals(wakeupException, thrownException);
verify(mockListener).onTasksAssigned(eq(assignment));
}
@Test
public void testInvokeTasksAssignedWithInterruptException() {
invoker.setRebalanceListener(mockListener);
StreamsRebalanceData.Assignment assignment = createMockAssignment();
InterruptException interruptException = new InterruptException("Test interrupt");
doThrow(interruptException).when(mockListener).onTasksAssigned(assignment);
InterruptException thrownException = assertThrows(InterruptException.class,
() -> invoker.invokeTasksAssigned(assignment));
assertEquals(interruptException, thrownException);
verify(mockListener).onTasksAssigned(eq(assignment));
}
@Test
public void testInvokeTasksAssignedWithOtherException() {
invoker.setRebalanceListener(mockListener);
StreamsRebalanceData.Assignment assignment = createMockAssignment();
RuntimeException runtimeException = new RuntimeException("Test exception");
doThrow(runtimeException).when(mockListener).onTasksAssigned(assignment);
Exception result = invoker.invokeTasksAssigned(assignment);
assertEquals(runtimeException, result);
verify(mockListener).onTasksAssigned(eq(assignment));
}
@Test
public void testInvokeTasksRevokedWithListener() {
invoker.setRebalanceListener(mockListener);
Set<StreamsRebalanceData.TaskId> tasks = createMockTasks();
when(mockListener.onTasksRevoked(tasks)).thenReturn(Optional.empty());
Exception result = invoker.invokeTasksRevoked(tasks);
assertNull(result);
verify(mockListener).onTasksRevoked(eq(tasks));
}
@Test
public void testInvokeTasksRevokedWithWakeupException() {
invoker.setRebalanceListener(mockListener);
Set<StreamsRebalanceData.TaskId> tasks = createMockTasks();
WakeupException wakeupException = new WakeupException();
doThrow(wakeupException).when(mockListener).onTasksRevoked(tasks);
WakeupException thrownException = assertThrows(WakeupException.class,
() -> invoker.invokeTasksRevoked(tasks));
assertEquals(wakeupException, thrownException);
verify(mockListener).onTasksRevoked(eq(tasks));
}
@Test
public void testInvokeTasksRevokedWithInterruptException() {
invoker.setRebalanceListener(mockListener);
Set<StreamsRebalanceData.TaskId> tasks = createMockTasks();
InterruptException interruptException = new InterruptException("Test interrupt");
doThrow(interruptException).when(mockListener).onTasksRevoked(tasks);
InterruptException thrownException = assertThrows(InterruptException.class,
() -> invoker.invokeTasksRevoked(tasks));
assertEquals(interruptException, thrownException);
verify(mockListener).onTasksRevoked(eq(tasks));
}
@Test
public void testInvokeTasksRevokedWithOtherException() {
invoker.setRebalanceListener(mockListener);
Set<StreamsRebalanceData.TaskId> tasks = createMockTasks();
RuntimeException runtimeException = new RuntimeException("Test exception");
doThrow(runtimeException).when(mockListener).onTasksRevoked(tasks);
Exception result = invoker.invokeTasksRevoked(tasks);
assertEquals(runtimeException, result);
verify(mockListener).onTasksRevoked(eq(tasks));
}
@Test
public void testInvokeAllTasksLostWithListener() {
invoker.setRebalanceListener(mockListener);
when(mockListener.onAllTasksLost()).thenReturn(Optional.empty());
Exception result = invoker.invokeAllTasksLost();
assertNull(result);
verify(mockListener).onAllTasksLost();
}
@Test
public void testInvokeAllTasksLostWithWakeupException() {
invoker.setRebalanceListener(mockListener);
WakeupException wakeupException = new WakeupException();
doThrow(wakeupException).when(mockListener).onAllTasksLost();
WakeupException thrownException = assertThrows(WakeupException.class,
() -> invoker.invokeAllTasksLost());
assertEquals(wakeupException, thrownException);
verify(mockListener).onAllTasksLost();
}
@Test
public void testInvokeAllTasksLostWithInterruptException() {
invoker.setRebalanceListener(mockListener);
InterruptException interruptException = new InterruptException("Test interrupt");
doThrow(interruptException).when(mockListener).onAllTasksLost();
InterruptException thrownException = assertThrows(InterruptException.class,
() -> invoker.invokeAllTasksLost());
assertEquals(interruptException, thrownException);
verify(mockListener).onAllTasksLost();
}
@Test
public void testInvokeAllTasksLostWithOtherException() {
invoker.setRebalanceListener(mockListener);
RuntimeException runtimeException = new RuntimeException("Test exception");
doThrow(runtimeException).when(mockListener).onAllTasksLost();
Exception result = invoker.invokeAllTasksLost();
assertEquals(runtimeException, result);
verify(mockListener).onAllTasksLost();
}
private StreamsRebalanceData.Assignment createMockAssignment() {
Set<StreamsRebalanceData.TaskId> activeTasks = createMockTasks();
Set<StreamsRebalanceData.TaskId> standbyTasks = Set.of();
Set<StreamsRebalanceData.TaskId> warmupTasks = Set.of();
return new StreamsRebalanceData.Assignment(activeTasks, standbyTasks, warmupTasks);
}
private Set<StreamsRebalanceData.TaskId> createMockTasks() {
return Set.of(
new StreamsRebalanceData.TaskId("subtopology1", 0),
new StreamsRebalanceData.TaskId("subtopology1", 1)
);
}
}

View File

@ -89,6 +89,7 @@ public class DefaultStreamsRebalanceListener implements StreamsRebalanceListener
taskManager.handleAssignment(activeTasksWithPartitions, standbyTasksWithPartitions);
streamThread.setState(StreamThread.State.PARTITIONS_ASSIGNED);
taskManager.handleRebalanceComplete();
streamsRebalanceData.setReconciledAssignment(assignment);
} catch (final Exception exception) {
return Optional.of(exception);
}
@ -99,6 +100,7 @@ public class DefaultStreamsRebalanceListener implements StreamsRebalanceListener
public Optional<Exception> onAllTasksLost() {
try {
taskManager.handleLostAll();
streamsRebalanceData.setReconciledAssignment(StreamsRebalanceData.Assignment.EMPTY);
} catch (final Exception exception) {
return Optional.of(exception);
}

View File

@ -118,49 +118,46 @@ public class DefaultStreamsRebalanceListenerTest {
@Test
void testOnTasksAssigned() {
createRebalanceListenerWithRebalanceData(new StreamsRebalanceData(
UUID.randomUUID(),
Optional.empty(),
Map.of(
"1",
new StreamsRebalanceData.Subtopology(
Set.of("source1"),
Set.of(),
Map.of("repartition1", new StreamsRebalanceData.TopicInfo(Optional.of(1), Optional.of((short) 1), Map.of())),
Map.of(),
Set.of()
),
"2",
new StreamsRebalanceData.Subtopology(
Set.of("source2"),
Set.of(),
Map.of("repartition2", new StreamsRebalanceData.TopicInfo(Optional.of(1), Optional.of((short) 1), Map.of())),
Map.of(),
Set.of()
),
"3",
new StreamsRebalanceData.Subtopology(
Set.of("source3"),
Set.of(),
Map.of("repartition3", new StreamsRebalanceData.TopicInfo(Optional.of(1), Optional.of((short) 1), Map.of())),
Map.of(),
Set.of()
)
final StreamsRebalanceData streamsRebalanceData = mock(StreamsRebalanceData.class);
when(streamsRebalanceData.subtopologies()).thenReturn(Map.of(
"1",
new StreamsRebalanceData.Subtopology(
Set.of("source1"),
Set.of(),
Map.of("repartition1", new StreamsRebalanceData.TopicInfo(Optional.of(1), Optional.of((short) 1), Map.of())),
Map.of(),
Set.of()
),
Map.of()
));
final Optional<Exception> result = defaultStreamsRebalanceListener.onTasksAssigned(
new StreamsRebalanceData.Assignment(
Set.of(new StreamsRebalanceData.TaskId("1", 0)),
Set.of(new StreamsRebalanceData.TaskId("2", 0)),
Set.of(new StreamsRebalanceData.TaskId("3", 0))
"2",
new StreamsRebalanceData.Subtopology(
Set.of("source2"),
Set.of(),
Map.of("repartition2", new StreamsRebalanceData.TopicInfo(Optional.of(1), Optional.of((short) 1), Map.of())),
Map.of(),
Set.of()
),
"3",
new StreamsRebalanceData.Subtopology(
Set.of("source3"),
Set.of(),
Map.of("repartition3", new StreamsRebalanceData.TopicInfo(Optional.of(1), Optional.of((short) 1), Map.of())),
Map.of(),
Set.of()
)
));
createRebalanceListenerWithRebalanceData(streamsRebalanceData);
final StreamsRebalanceData.Assignment assignment = new StreamsRebalanceData.Assignment(
Set.of(new StreamsRebalanceData.TaskId("1", 0)),
Set.of(new StreamsRebalanceData.TaskId("2", 0)),
Set.of(new StreamsRebalanceData.TaskId("3", 0))
);
final Optional<Exception> result = defaultStreamsRebalanceListener.onTasksAssigned(assignment);
assertTrue(result.isEmpty());
final InOrder inOrder = inOrder(taskManager, streamThread);
final InOrder inOrder = inOrder(taskManager, streamThread, streamsRebalanceData);
inOrder.verify(taskManager).handleAssignment(
Map.of(new TaskId(1, 0), Set.of(new TopicPartition("source1", 0), new TopicPartition("repartition1", 0))),
Map.of(
@ -170,6 +167,7 @@ public class DefaultStreamsRebalanceListenerTest {
);
inOrder.verify(streamThread).setState(StreamThread.State.PARTITIONS_ASSIGNED);
inOrder.verify(taskManager).handleRebalanceComplete();
inOrder.verify(streamsRebalanceData).setReconciledAssignment(assignment);
}
@Test
@ -177,21 +175,32 @@ public class DefaultStreamsRebalanceListenerTest {
final Exception exception = new RuntimeException("sample exception");
doThrow(exception).when(taskManager).handleAssignment(any(), any());
createRebalanceListenerWithRebalanceData(new StreamsRebalanceData(UUID.randomUUID(), Optional.empty(), Map.of(), Map.of()));
final Optional<Exception> result = defaultStreamsRebalanceListener.onTasksAssigned(new StreamsRebalanceData.Assignment(Set.of(), Set.of(), Set.of()));
assertTrue(defaultStreamsRebalanceListener.onAllTasksLost().isEmpty());
final StreamsRebalanceData streamsRebalanceData = mock(StreamsRebalanceData.class);
when(streamsRebalanceData.subtopologies()).thenReturn(Map.of());
createRebalanceListenerWithRebalanceData(streamsRebalanceData);
final Optional<Exception> result = defaultStreamsRebalanceListener.onTasksAssigned(
new StreamsRebalanceData.Assignment(Set.of(), Set.of(), Set.of())
);
assertTrue(result.isPresent());
assertEquals(exception, result.get());
verify(taskManager).handleLostAll();
verify(taskManager).handleAssignment(any(), any());
verify(streamThread, never()).setState(StreamThread.State.PARTITIONS_ASSIGNED);
verify(taskManager, never()).handleRebalanceComplete();
verify(streamsRebalanceData, never()).setReconciledAssignment(any());
}
@Test
void testOnAllTasksLost() {
createRebalanceListenerWithRebalanceData(new StreamsRebalanceData(UUID.randomUUID(), Optional.empty(), Map.of(), Map.of()));
final StreamsRebalanceData streamsRebalanceData = mock(StreamsRebalanceData.class);
when(streamsRebalanceData.subtopologies()).thenReturn(Map.of());
createRebalanceListenerWithRebalanceData(streamsRebalanceData);
assertTrue(defaultStreamsRebalanceListener.onAllTasksLost().isEmpty());
verify(taskManager).handleLostAll();
final InOrder inOrder = inOrder(taskManager, streamsRebalanceData);
inOrder.verify(taskManager).handleLostAll();
inOrder.verify(streamsRebalanceData).setReconciledAssignment(StreamsRebalanceData.Assignment.EMPTY);
}
@Test
@ -199,10 +208,13 @@ public class DefaultStreamsRebalanceListenerTest {
final Exception exception = new RuntimeException("sample exception");
doThrow(exception).when(taskManager).handleLostAll();
createRebalanceListenerWithRebalanceData(new StreamsRebalanceData(UUID.randomUUID(), Optional.empty(), Map.of(), Map.of()));
final StreamsRebalanceData streamsRebalanceData = mock(StreamsRebalanceData.class);
when(streamsRebalanceData.subtopologies()).thenReturn(Map.of());
createRebalanceListenerWithRebalanceData(streamsRebalanceData);
final Optional<Exception> result = defaultStreamsRebalanceListener.onAllTasksLost();
assertTrue(result.isPresent());
assertEquals(exception, result.get());
verify(taskManager).handleLostAll();
verify(streamsRebalanceData, never()).setReconciledAssignment(any());
}
}