KAFKA-15555: Ensure wakeups are handled correctly in poll() (#14746)

We need to be careful when aborting a long poll with wakeup() since the
consumer might never return records if the poll is interrupted after the
consumer position has been updated but the records have not been returned
to the caller of poll().

This PR avoid wake-ups during this critical period.

Reviewers: Philip Nee <pnee@confluent.io>, Kirk True <ktrue@confluent.io>, Lucas Brutschy <lbrutschy@confluent.io>
This commit is contained in:
Bruno Cadonna 2023-11-23 10:53:17 +01:00 committed by GitHub
parent 55017a4f68
commit 75572f904b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 271 additions and 11 deletions

View File

@ -1423,6 +1423,7 @@ project(':clients') {
testImplementation libs.junitJupiter
testImplementation libs.log4j
testImplementation libs.mockitoCore
testImplementation libs.mockitoJunitJupiter // supports MockitoExtension
testRuntimeOnly libs.slf4jlog4j
testRuntimeOnly libs.jacksonDatabind

View File

@ -407,6 +407,21 @@ public class AsyncKafkaConsumer<K, V> implements ConsumerDelegate<K, V> {
*
* @param timeout timeout of the poll loop
* @return ConsumerRecord. It can be empty if time timeout expires.
*
* @throws org.apache.kafka.common.errors.WakeupException if {@link #wakeup()} is called before or while this
* function is called
* @throws org.apache.kafka.common.errors.InterruptException if the calling thread is interrupted before or while
* this function is called
* @throws org.apache.kafka.common.errors.RecordTooLargeException if the fetched record is larger than the maximum
* allowable size
* @throws org.apache.kafka.common.KafkaException for any other unrecoverable errors
* @throws java.lang.IllegalStateException if the consumer is not subscribed to any topics or manually assigned any
* partitions to consume from or an unexpected error occurred
* @throws org.apache.kafka.clients.consumer.OffsetOutOfRangeException if the fetch position of the consumer is
* out of range and no offset reset policy is configured.
* @throws org.apache.kafka.common.errors.TopicAuthorizationException if the consumer is not authorized to read
* from a partition
* @throws org.apache.kafka.common.errors.SerializationException if the fetched records cannot be deserialized
*/
@Override
public ConsumerRecords<K, V> poll(final Duration timeout) {
@ -414,6 +429,7 @@ public class AsyncKafkaConsumer<K, V> implements ConsumerDelegate<K, V> {
acquireAndEnsureOpen();
try {
wakeupTrigger.setFetchAction(fetchBuffer);
kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());
if (subscriptions.hasNoSubscriptionOrUserAssignment()) {
@ -421,9 +437,14 @@ public class AsyncKafkaConsumer<K, V> implements ConsumerDelegate<K, V> {
}
do {
// We must not allow wake-ups between polling for fetches and returning the records.
// If the polled fetches are not empty the consumed position has already been updated in the polling
// of the fetches. A wakeup between returned fetches and returning records would lead to never
// returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.
wakeupTrigger.maybeTriggerWakeup();
updateAssignmentMetadataIfNeeded(timer);
final Fetch<K, V> fetch = pollForFetches(timer);
if (!fetch.isEmpty()) {
if (fetch.records().isEmpty()) {
log.trace("Returning empty records from `poll()` "
@ -438,6 +459,7 @@ public class AsyncKafkaConsumer<K, V> implements ConsumerDelegate<K, V> {
return ConsumerRecords.empty();
} finally {
kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());
wakeupTrigger.clearTask();
release();
}
}
@ -636,7 +658,7 @@ public class AsyncKafkaConsumer<K, V> implements ConsumerDelegate<K, V> {
try {
return applicationEventHandler.addAndGet(event, time.timer(timeout));
} finally {
wakeupTrigger.clearActiveTask();
wakeupTrigger.clearTask();
}
} finally {
release();
@ -922,7 +944,7 @@ public class AsyncKafkaConsumer<K, V> implements ConsumerDelegate<K, V> {
offsets.forEach(this::updateLastSeenEpochIfNewer);
ConsumerUtils.getResult(commitFuture, time.timer(timeout));
} finally {
wakeupTrigger.clearActiveTask();
wakeupTrigger.clearTask();
kafkaConsumerMetrics.recordCommitSync(time.nanoseconds() - commitStart);
release();
}

View File

@ -29,6 +29,7 @@ import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
@ -52,6 +53,8 @@ public class FetchBuffer implements AutoCloseable {
private final Condition notEmptyCondition;
private final IdempotentCloser idempotentCloser = new IdempotentCloser();
private final AtomicBoolean wokenup = new AtomicBoolean(false);
private CompletedFetch nextInLineFetch;
public FetchBuffer(final LogContext logContext) {
@ -166,7 +169,7 @@ public class FetchBuffer implements AutoCloseable {
try {
lock.lock();
while (isEmpty()) {
while (isEmpty() && !wokenup.compareAndSet(true, false)) {
// Update the timer before we head into the loop in case it took a while to get the lock.
timer.update();
@ -185,6 +188,16 @@ public class FetchBuffer implements AutoCloseable {
}
}
void wakeup() {
wokenup.set(true);
try {
lock.lock();
notEmptyCondition.signalAll();
} finally {
lock.unlock();
}
}
/**
* Updates the buffer to retain only the fetch data that corresponds to the given partitions. Any previously
* {@link CompletedFetch fetched data} is removed if its partition is not in the given set of partitions.

View File

@ -21,6 +21,7 @@ import org.apache.kafka.common.errors.WakeupException;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
/**
@ -44,6 +45,10 @@ public class WakeupTrigger {
ActiveFuture active = (ActiveFuture) task;
active.future().completeExceptionally(new WakeupException());
return null;
} else if (task instanceof FetchAction) {
FetchAction fetchAction = (FetchAction) task;
fetchAction.fetchBuffer().wakeup();
return new WakeupFuture();
} else {
return task;
}
@ -75,17 +80,51 @@ public class WakeupTrigger {
return currentTask;
}
public void clearActiveTask() {
public void setFetchAction(final FetchBuffer fetchBuffer) {
final AtomicBoolean throwWakeupException = new AtomicBoolean(false);
pendingTask.getAndUpdate(task -> {
if (task == null) {
return new FetchAction(fetchBuffer);
} else if (task instanceof WakeupFuture) {
throwWakeupException.set(true);
return null;
}
// last active state is still active
throw new IllegalStateException("Last active task is still active");
});
if (throwWakeupException.get()) {
throw new WakeupException();
}
}
public void clearTask() {
pendingTask.getAndUpdate(task -> {
if (task == null) {
return null;
} else if (task instanceof ActiveFuture) {
} else if (task instanceof ActiveFuture || task instanceof FetchAction) {
return null;
}
return task;
});
}
public void maybeTriggerWakeup() {
final AtomicBoolean throwWakeupException = new AtomicBoolean(false);
pendingTask.getAndUpdate(task -> {
if (task == null) {
return null;
} else if (task instanceof WakeupFuture) {
throwWakeupException.set(true);
return null;
} else {
return task;
}
});
if (throwWakeupException.get()) {
throw new WakeupException();
}
}
Wakeupable getPendingTask() {
return pendingTask.get();
}
@ -105,4 +144,17 @@ public class WakeupTrigger {
}
static class WakeupFuture implements Wakeupable { }
static class FetchAction implements Wakeupable {
private final FetchBuffer fetchBuffer;
public FetchAction(FetchBuffer fetchBuffer) {
this.fetchBuffer = fetchBuffer;
}
public FetchBuffer fetchBuffer() {
return fetchBuffer;
}
}
}

View File

@ -16,6 +16,7 @@
*/
package org.apache.kafka.clients.consumer.internals;
import org.apache.kafka.clients.consumer.ConsumerRecord;
import org.apache.kafka.clients.consumer.OffsetAndMetadata;
import org.apache.kafka.clients.consumer.OffsetAndTimestamp;
import org.apache.kafka.clients.consumer.OffsetCommitCallback;
@ -51,6 +52,7 @@ import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatchers;
import org.mockito.MockedConstruction;
import org.mockito.Mockito;
import org.mockito.stubbing.Answer;
import java.time.Duration;
@ -69,8 +71,11 @@ import java.util.concurrent.Future;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static java.util.Arrays.asList;
import static java.util.Collections.singleton;
import static java.util.Collections.singletonList;
import static org.apache.kafka.common.utils.Utils.mkEntry;
import static org.apache.kafka.common.utils.Utils.mkMap;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
@ -80,6 +85,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mockConstruction;
@ -90,6 +96,7 @@ import static org.mockito.Mockito.when;
public class AsyncKafkaConsumerTest {
private AsyncKafkaConsumer<?, ?> consumer;
private FetchCollector<?, ?> fetchCollector;
private ConsumerTestBuilder.AsyncKafkaConsumerTestBuilder testBuilder;
private ApplicationEventHandler applicationEventHandler;
@ -103,6 +110,7 @@ public class AsyncKafkaConsumerTest {
testBuilder = new ConsumerTestBuilder.AsyncKafkaConsumerTestBuilder(groupInfo);
applicationEventHandler = testBuilder.applicationEventHandler;
consumer = testBuilder.consumer;
fetchCollector = testBuilder.fetchCollector;
}
@AfterEach
@ -216,6 +224,82 @@ public class AsyncKafkaConsumerTest {
}
}
@Test
public void testWakeupBeforeCallingPoll() {
final String topicName = "foo";
final int partition = 3;
final TopicPartition tp = new TopicPartition(topicName, partition);
doReturn(Fetch.empty()).when(fetchCollector).collectFetch(any(FetchBuffer.class));
Map<TopicPartition, OffsetAndMetadata> offsets = mkMap(mkEntry(tp, new OffsetAndMetadata(1)));
doReturn(offsets).when(applicationEventHandler).addAndGet(any(OffsetFetchApplicationEvent.class), any(Timer.class));
consumer.assign(singleton(tp));
consumer.wakeup();
assertThrows(WakeupException.class, () -> consumer.poll(Duration.ZERO));
assertDoesNotThrow(() -> consumer.poll(Duration.ZERO));
}
@Test
public void testWakeupAfterEmptyFetch() {
final String topicName = "foo";
final int partition = 3;
final TopicPartition tp = new TopicPartition(topicName, partition);
doAnswer(invocation -> {
consumer.wakeup();
return Fetch.empty();
}).when(fetchCollector).collectFetch(any(FetchBuffer.class));
Map<TopicPartition, OffsetAndMetadata> offsets = mkMap(mkEntry(tp, new OffsetAndMetadata(1)));
doReturn(offsets).when(applicationEventHandler).addAndGet(any(OffsetFetchApplicationEvent.class), any(Timer.class));
consumer.assign(singleton(tp));
assertThrows(WakeupException.class, () -> consumer.poll(Duration.ofMinutes(1)));
assertDoesNotThrow(() -> consumer.poll(Duration.ZERO));
}
@Test
public void testWakeupAfterNonEmptyFetch() {
final String topicName = "foo";
final int partition = 3;
final TopicPartition tp = new TopicPartition(topicName, partition);
final List<ConsumerRecord<String, String>> records = asList(
new ConsumerRecord<>(topicName, partition, 2, "key1", "value1"),
new ConsumerRecord<>(topicName, partition, 3, "key2", "value2")
);
doAnswer(invocation -> {
consumer.wakeup();
return Fetch.forPartition(tp, records, true);
}).when(fetchCollector).collectFetch(Mockito.any(FetchBuffer.class));
Map<TopicPartition, OffsetAndMetadata> offsets = mkMap(mkEntry(tp, new OffsetAndMetadata(1)));
doReturn(offsets).when(applicationEventHandler).addAndGet(any(OffsetFetchApplicationEvent.class), any(Timer.class));
consumer.assign(singleton(tp));
// since wakeup() is called when the non-empty fetch is returned the wakeup should be ignored
assertDoesNotThrow(() -> consumer.poll(Duration.ofMinutes(1)));
// the previously ignored wake-up should not be ignored in the next call
assertThrows(WakeupException.class, () -> consumer.poll(Duration.ZERO));
}
@Test
public void testClearWakeupTriggerAfterPoll() {
final String topicName = "foo";
final int partition = 3;
final TopicPartition tp = new TopicPartition(topicName, partition);
final List<ConsumerRecord<String, String>> records = asList(
new ConsumerRecord<>(topicName, partition, 2, "key1", "value1"),
new ConsumerRecord<>(topicName, partition, 3, "key2", "value2")
);
doReturn(Fetch.forPartition(tp, records, true))
.when(fetchCollector).collectFetch(any(FetchBuffer.class));
Map<TopicPartition, OffsetAndMetadata> offsets = mkMap(mkEntry(tp, new OffsetAndMetadata(1)));
doReturn(offsets).when(applicationEventHandler).addAndGet(any(OffsetFetchApplicationEvent.class), any(Timer.class));
consumer.assign(singleton(tp));
consumer.poll(Duration.ZERO);
assertDoesNotThrow(() -> consumer.poll(Duration.ZERO));
}
@Test
public void testEnsureCallbackExecutedByApplicationThread() {
final String currentThread = Thread.currentThread().getName();

View File

@ -312,6 +312,8 @@ public class ConsumerTestBuilder implements Closeable {
final AsyncKafkaConsumer<String, String> consumer;
final FetchCollector<String, String> fetchCollector;
public AsyncKafkaConsumerTestBuilder(Optional<GroupInformation> groupInfo) {
super(groupInfo);
String clientId = config.getString(CommonClientConfigs.CLIENT_ID_CONFIG);
@ -320,13 +322,13 @@ public class ConsumerTestBuilder implements Closeable {
config.originals(Collections.singletonMap(ConsumerConfig.CLIENT_ID_CONFIG, clientId))
);
Deserializers<String, String> deserializers = new Deserializers<>(new StringDeserializer(), new StringDeserializer());
FetchCollector<String, String> fetchCollector = new FetchCollector<>(logContext,
this.fetchCollector = spy(new FetchCollector<>(logContext,
metadata,
subscriptions,
fetchConfig,
deserializers,
metricsManager,
time);
time));
this.consumer = spy(new AsyncKafkaConsumer<>(
logContext,
clientId,

View File

@ -26,9 +26,11 @@ import org.apache.kafka.common.utils.BufferSupplier;
import org.apache.kafka.common.utils.LogContext;
import org.apache.kafka.common.utils.MockTime;
import org.apache.kafka.common.utils.Time;
import org.apache.kafka.common.utils.Timer;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import java.time.Duration;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Properties;
@ -171,6 +173,20 @@ public class FetchBufferTest {
}
}
@Test
public void testWakeup() throws Exception {
try (FetchBuffer fetchBuffer = new FetchBuffer(logContext)) {
final Thread waitingThread = new Thread(() -> {
final Timer timer = time.timer(Duration.ofMinutes(1));
fetchBuffer.awaitNotEmpty(timer);
});
waitingThread.start();
fetchBuffer.wakeup();
waitingThread.join(Duration.ofSeconds(30).toMillis());
assertFalse(waitingThread.isAlive());
}
}
private CompletedFetch completedFetch(TopicPartition tp) {
FetchResponseData.PartitionData partitionData = new FetchResponseData.PartitionData();
FetchMetricsAggregator metricsAggregator = new FetchMetricsAggregator(metricsManager, allPartitions);

View File

@ -19,17 +19,26 @@ package org.apache.kafka.clients.consumer.internals;
import org.apache.kafka.common.errors.WakeupException;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.junit.jupiter.MockitoSettings;
import org.mockito.quality.Strictness;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
@MockitoSettings(strictness = Strictness.STRICT_STUBS)
public class WakeupTriggerTest {
private static long defaultTimeoutMs = 1000;
private final static long DEFAULT_TIMEOUT_MS = 1000;
private WakeupTrigger wakeupTrigger;
@BeforeEach
@ -59,14 +68,75 @@ public class WakeupTriggerTest {
public void testUnsetActiveFuture() {
CompletableFuture<Void> task = new CompletableFuture<>();
wakeupTrigger.setActiveTask(task);
wakeupTrigger.clearActiveTask();
wakeupTrigger.clearTask();
assertNull(wakeupTrigger.getPendingTask());
}
@Test
public void testSettingFetchAction() {
try (final FetchBuffer fetchBuffer = mock(FetchBuffer.class)) {
wakeupTrigger.setFetchAction(fetchBuffer);
final WakeupTrigger.Wakeupable wakeupable = wakeupTrigger.getPendingTask();
assertInstanceOf(WakeupTrigger.FetchAction.class, wakeupable);
assertEquals(fetchBuffer, ((WakeupTrigger.FetchAction) wakeupable).fetchBuffer());
}
}
@Test
public void testUnsetFetchAction() {
try (final FetchBuffer fetchBuffer = mock(FetchBuffer.class)) {
wakeupTrigger.setFetchAction(fetchBuffer);
wakeupTrigger.clearTask();
assertNull(wakeupTrigger.getPendingTask());
}
}
@Test
public void testWakeupFromFetchAction() {
try (final FetchBuffer fetchBuffer = mock(FetchBuffer.class)) {
wakeupTrigger.setFetchAction(fetchBuffer);
wakeupTrigger.wakeup();
verify(fetchBuffer).wakeup();
final WakeupTrigger.Wakeupable wakeupable = wakeupTrigger.getPendingTask();
assertInstanceOf(WakeupTrigger.WakeupFuture.class, wakeupable);
}
}
@Test
public void testManualTriggerWhenWakeupCalled() {
wakeupTrigger.wakeup();
assertThrows(WakeupException.class, () -> wakeupTrigger.maybeTriggerWakeup());
}
@Test
public void testManualTriggerWhenWakeupNotCalled() {
assertDoesNotThrow(() -> wakeupTrigger.maybeTriggerWakeup());
}
@Test
public void testManualTriggerWhenWakeupCalledAndActiveTaskSet() {
final CompletableFuture<Void> future = new CompletableFuture<>();
wakeupTrigger.setActiveTask(future);
assertDoesNotThrow(() -> wakeupTrigger.maybeTriggerWakeup());
}
@Test
public void testManualTriggerWhenWakeupCalledAndFetchActionSet() {
try (final FetchBuffer fetchBuffer = mock(FetchBuffer.class)) {
wakeupTrigger.setFetchAction(fetchBuffer);
assertDoesNotThrow(() -> wakeupTrigger.maybeTriggerWakeup());
}
}
private void assertWakeupExceptionIsThrown(final CompletableFuture<?> future) {
assertTrue(future.isCompletedExceptionally());
try {
future.get(defaultTimeoutMs, TimeUnit.MILLISECONDS);
future.get(DEFAULT_TIMEOUT_MS, TimeUnit.MILLISECONDS);
} catch (ExecutionException e) {
assertTrue(e.getCause() instanceof WakeupException);
return;