Refactor CompositePollEvent to use Blocker for state management

Replaces CompletableFuture-based state handling in CompositePollEvent with a new Blocker class for improved synchronization and exception handling. Updates AsyncKafkaConsumer, WakeupTrigger, ApplicationEventProcessor, and related tests to use Blocker, simplifying event completion and error propagation.
This commit is contained in:
Kirk True 2025-09-17 20:49:53 -07:00
parent 40f6754810
commit 3e0b920399
6 changed files with 260 additions and 27 deletions

View File

@ -882,7 +882,7 @@ public class AsyncKafkaConsumer<K, V> implements ConsumerDelegate<K, V> {
// returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches. // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.
wakeupTrigger.maybeTriggerWakeup(); wakeupTrigger.maybeTriggerWakeup();
prepareFetch(timer); prepareFetch(timer);
final Fetch<K, V> fetch = collectFetch(); final Fetch<K, V> fetch = pollForFetches(timer);
if (!fetch.isEmpty()) { if (!fetch.isEmpty()) {
// before returning the fetched records, we can send off the next round of fetches // before returning the fetched records, we can send off the next round of fetches
// and avoid block waiting for their responses to enable pipelining while the user // and avoid block waiting for their responses to enable pipelining while the user
@ -914,31 +914,42 @@ public class AsyncKafkaConsumer<K, V> implements ConsumerDelegate<K, V> {
long deadlineMs = calculateDeadlineMs(timer); long deadlineMs = calculateDeadlineMs(timer);
ApplicationEvent.Type nextEventType = ApplicationEvent.Type.POLL; ApplicationEvent.Type nextEventType = ApplicationEvent.Type.POLL;
log.debug("prepareFetch - timer: {}", timer.remainingMs());
Timer blockerTimer = time.timer(defaultApiTimeoutMs.toMillis());
while (true) { while (true) {
CompositePollEvent event = new CompositePollEvent(deadlineMs, pollTimeMs, nextEventType); CompositePollEvent event = new CompositePollEvent(deadlineMs, pollTimeMs, nextEventType);
applicationEventHandler.add(event);
CompositePollEvent.State state; CompositePollEvent.State state;
wakeupTrigger.setFetchAction(event);
try { try {
state = applicationEventHandler.addAndGet(event); state = event.blocker().await(blockerTimer);
} catch (TimeoutException e) { } catch (TimeoutException e) {
// Timeouts are OK, there's just no data to return on this pass. // Timeouts are OK, there's just no data to return on this pass.
break; return;
} catch (InterruptException e) {
log.trace("Interrupt during composite poll", e);
throw e;
} finally {
timer.update(blockerTimer.currentTimeMs());
wakeupTrigger.clearTask();
} }
if (state == CompositePollEvent.State.OFFSET_COMMIT_CALLBACKS_REQUIRED) { if (state == null || state == CompositePollEvent.State.COMPLETE) {
break;
} else if (state == CompositePollEvent.State.OFFSET_COMMIT_CALLBACKS_REQUIRED) {
offsetCommitCallbackInvoker.executeCallbacks(); offsetCommitCallbackInvoker.executeCallbacks();
nextEventType = ApplicationEvent.Type.UPDATE_SUBSCRIPTION_METADATA; nextEventType = ApplicationEvent.Type.UPDATE_SUBSCRIPTION_METADATA;
} else if (state == CompositePollEvent.State.BACKGROUND_EVENT_PROCESSING_REQUIRED) { } else if (state == CompositePollEvent.State.BACKGROUND_EVENT_PROCESSING_REQUIRED) {
processBackgroundEvents(); processBackgroundEvents();
nextEventType = ApplicationEvent.Type.CHECK_AND_UPDATE_POSITIONS; nextEventType = ApplicationEvent.Type.CHECK_AND_UPDATE_POSITIONS;
} else if (state == CompositePollEvent.State.COMPLETE) {
break;
} else { } else {
throw new IllegalStateException("Unexpected state: " + state); throw new IllegalStateException("Unexpected state: " + state);
} }
} }
timer.update();
} }
/** /**

View File

@ -0,0 +1,157 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.clients.consumer.internals;
import org.apache.kafka.common.KafkaException;
import org.apache.kafka.common.errors.InterruptException;
import org.apache.kafka.common.errors.TimeoutException;
import org.apache.kafka.common.utils.Timer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
public class Blocker<T> {
private final Logger log;
private final Lock lock;
private final Condition condition;
private final AtomicBoolean wokenup = new AtomicBoolean(false);
private T value;
private KafkaException error;
public Blocker() {
this.log = LoggerFactory.getLogger(getClass());
this.lock = new ReentrantLock();
this.condition = lock.newCondition();
}
public boolean complete(T value) {
Objects.requireNonNull(value);
try {
lock.lock();
if (isSet())
return false;
log.debug("Setting value to {}", value);
this.value = value;
wokenup.set(true);
condition.signalAll();
return true;
} finally {
lock.unlock();
}
}
public boolean completeExceptionally(KafkaException error) {
Objects.requireNonNull(error);
try {
lock.lock();
if (isSet())
return false;
log.debug("Setting exception to {}", String.valueOf(error));
this.error = error;
wokenup.set(true);
condition.signalAll();
return true;
} finally {
lock.unlock();
}
}
private boolean isSet() {
return error != null || value != null;
}
/**
* Allows the caller to await a response from the broker for requested data. The method will block, returning only
* under one of the following conditions:
*
* <ol>
* <li>The buffer was already woken</li>
* <li>The buffer was woken during the wait</li>
* <li>The remaining time on the {@link Timer timer} elapsed</li>
* <li>The thread was interrupted</li>
* </ol>
*
* @param timer Timer that provides time to wait
*/
public T await(Timer timer) {
try {
lock.lock();
log.debug("At start of method, error: {}, value: {}", error, value);
if (error != null)
throw error;
else if (value != null)
return value;
while (!wokenup.compareAndSet(true, false)) {
// Update the timer before we head into the loop in case it took a while to get the lock.
timer.update();
if (timer.isExpired()) {
// If the thread was interrupted before we start waiting, it still counts as
// interrupted from the point of view of the KafkaConsumer.poll(Duration) contract.
// We only need to check this when we are not going to wait because waiting
// already checks whether the thread is interrupted.
if (Thread.interrupted())
throw error = new InterruptException("Interrupted waiting for completion");
break;
}
if (!condition.await(timer.remainingMs(), TimeUnit.MILLISECONDS)) {
break;
}
}
log.debug("At end of method, error: {}, value: {}", error, value);
if (error != null)
throw error;
else if (value != null)
return value;
throw error = new TimeoutException("Timed out waiting for completion");
} catch (InterruptedException e) {
throw new InterruptException("Interrupted waiting for completion", e);
} finally {
lock.unlock();
timer.update();
}
}
@Override
public String toString() {
return "Blocker{" +
"value=" + value +
", error=" + error +
'}';
}
}

View File

@ -16,6 +16,7 @@
*/ */
package org.apache.kafka.clients.consumer.internals; package org.apache.kafka.clients.consumer.internals;
import org.apache.kafka.clients.consumer.internals.events.CompositePollEvent;
import org.apache.kafka.common.KafkaException; import org.apache.kafka.common.KafkaException;
import org.apache.kafka.common.errors.WakeupException; import org.apache.kafka.common.errors.WakeupException;
@ -49,6 +50,10 @@ public class WakeupTrigger {
// will be ignored. If it was already completed, we then need to return a new WakeupFuture so that the // will be ignored. If it was already completed, we then need to return a new WakeupFuture so that the
// next call to setActiveTask will throw the WakeupException. // next call to setActiveTask will throw the WakeupException.
return wasTriggered ? null : new WakeupFuture(); return wasTriggered ? null : new WakeupFuture();
} else if (task instanceof CompositePollEventAction) {
CompositePollEventAction compositePollEventAction = (CompositePollEventAction) task;
compositePollEventAction.event().blocker().completeExceptionally(new WakeupException());
return new WakeupFuture();
} else if (task instanceof FetchAction) { } else if (task instanceof FetchAction) {
FetchAction fetchAction = (FetchAction) task; FetchAction fetchAction = (FetchAction) task;
fetchAction.fetchBuffer().wakeup(); fetchAction.fetchBuffer().wakeup();
@ -89,6 +94,25 @@ public class WakeupTrigger {
return currentTask; return currentTask;
} }
public void setFetchAction(final CompositePollEvent event) {
final AtomicBoolean throwWakeupException = new AtomicBoolean(false);
pendingTask.getAndUpdate(task -> {
if (task == null) {
return new CompositePollEventAction(event);
} else if (task instanceof WakeupFuture) {
throwWakeupException.set(true);
return null;
} else if (task instanceof DisabledWakeups) {
return task;
}
// last active state is still active
throw new IllegalStateException("Last active task is still active");
});
if (throwWakeupException.get()) {
throw new WakeupException();
}
}
public void setFetchAction(final FetchBuffer fetchBuffer) { public void setFetchAction(final FetchBuffer fetchBuffer) {
final AtomicBoolean throwWakeupException = new AtomicBoolean(false); final AtomicBoolean throwWakeupException = new AtomicBoolean(false);
pendingTask.getAndUpdate(task -> { pendingTask.getAndUpdate(task -> {
@ -135,7 +159,7 @@ public class WakeupTrigger {
pendingTask.getAndUpdate(task -> { pendingTask.getAndUpdate(task -> {
if (task == null) { if (task == null) {
return null; return null;
} else if (task instanceof ActiveFuture || task instanceof FetchAction || task instanceof ShareFetchAction) { } else if (task instanceof ActiveFuture || task instanceof CompositePollEventAction || task instanceof FetchAction || task instanceof ShareFetchAction) {
return null; return null;
} }
return task; return task;
@ -182,6 +206,19 @@ public class WakeupTrigger {
static class WakeupFuture implements Wakeupable { } static class WakeupFuture implements Wakeupable { }
static class CompositePollEventAction implements Wakeupable {
private final CompositePollEvent event;
public CompositePollEventAction(CompositePollEvent event) {
this.event = event;
}
public CompositePollEvent event() {
return event;
}
}
static class FetchAction implements Wakeupable { static class FetchAction implements Wakeupable {
private final FetchBuffer fetchBuffer; private final FetchBuffer fetchBuffer;

View File

@ -22,6 +22,7 @@ import org.apache.kafka.clients.consumer.internals.CachedSupplier;
import org.apache.kafka.clients.consumer.internals.CommitRequestManager; import org.apache.kafka.clients.consumer.internals.CommitRequestManager;
import org.apache.kafka.clients.consumer.internals.ConsumerMetadata; import org.apache.kafka.clients.consumer.internals.ConsumerMetadata;
import org.apache.kafka.clients.consumer.internals.ConsumerNetworkThread; import org.apache.kafka.clients.consumer.internals.ConsumerNetworkThread;
import org.apache.kafka.clients.consumer.internals.ConsumerUtils;
import org.apache.kafka.clients.consumer.internals.NetworkClientDelegate; import org.apache.kafka.clients.consumer.internals.NetworkClientDelegate;
import org.apache.kafka.clients.consumer.internals.OffsetAndTimestampInternal; import org.apache.kafka.clients.consumer.internals.OffsetAndTimestampInternal;
import org.apache.kafka.clients.consumer.internals.OffsetCommitCallbackInvoker; import org.apache.kafka.clients.consumer.internals.OffsetCommitCallbackInvoker;
@ -48,6 +49,7 @@ import java.util.Optional;
import java.util.OptionalLong; import java.util.OptionalLong;
import java.util.Set; import java.util.Set;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.function.BiConsumer; import java.util.function.BiConsumer;
import java.util.function.Supplier; import java.util.function.Supplier;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -292,14 +294,12 @@ public class ApplicationEventProcessor implements EventProcessor<ApplicationEven
if (nextEventType == ApplicationEvent.Type.CHECK_AND_UPDATE_POSITIONS) { if (nextEventType == ApplicationEvent.Type.CHECK_AND_UPDATE_POSITIONS) {
// This is a bit tricky... The CompositePollEvent should be "paused" from being reaped while the code // This is a bit tricky... The CompositePollEvent should be "paused" from being reaped while the code
// for new CheckAndUpdatePositionsEvent is in flight. // for new CheckAndUpdatePositionsEvent is in flight.
applicationEventReaper.pause(event);
CompletableFuture<Boolean> updatePositionsFuture = processCheckAndUpdatePositionsEvent(event.deadlineMs()); CompletableFuture<Boolean> updatePositionsFuture = processCheckAndUpdatePositionsEvent(event.deadlineMs());
applicationEventReaper.add(new CompositePollPsuedoEvent<>(updatePositionsFuture, event.deadlineMs())); applicationEventReaper.add(new CompositePollPsuedoEvent<>(updatePositionsFuture, event.deadlineMs()));
updatePositionsFuture.whenComplete((__, updatePositionsError) -> { updatePositionsFuture.whenComplete((__, updatePositionsError) -> {
// Make sure to resume the CompositePollEvent *before* checking for failure so that it is assured // Make sure to resume the CompositePollEvent *before* checking for failure so that it is assured
// to be resumed. // to be resumed.
applicationEventReaper.resume(event);
if (maybeFailCompositePoll(event, updatePositionsError)) if (maybeFailCompositePoll(event, updatePositionsError))
return; return;
@ -309,15 +309,15 @@ public class ApplicationEventProcessor implements EventProcessor<ApplicationEven
if (maybeFailCompositePoll(event, fetchError)) if (maybeFailCompositePoll(event, fetchError))
return; return;
log.trace("Completing CompositePollEvent {}", event); event.blocker().complete(CompositePollEvent.State.COMPLETE);
event.future().complete(CompositePollEvent.State.COMPLETE); log.trace("Completed CompositePollEvent {}", event);
}); });
}); });
return; return;
} }
event.future().completeExceptionally(new IllegalArgumentException("Unknown next step for composite poll: " + nextEventType)); event.blocker().completeExceptionally(new KafkaException("Unknown next step for composite poll: " + nextEventType));
} }
private boolean maybePauseCompositePoll(CompositePollEvent event, RequiresApplicationThreadExecution test) { private boolean maybePauseCompositePoll(CompositePollEvent event, RequiresApplicationThreadExecution test) {
@ -325,8 +325,8 @@ public class ApplicationEventProcessor implements EventProcessor<ApplicationEven
return false; return false;
CompositePollEvent.State targetState = test.targetState(); CompositePollEvent.State targetState = test.targetState();
event.blocker().complete(targetState);
log.trace("Pausing CompositePollEvent {} to process logic for target state {}", event, targetState); log.trace("Pausing CompositePollEvent {} to process logic for target state {}", event, targetState);
event.future().complete(targetState);
return true; return true;
} }
@ -342,8 +342,12 @@ public class ApplicationEventProcessor implements EventProcessor<ApplicationEven
return false; return false;
} }
if (t instanceof CompletionException) {
t = t.getCause();
}
event.blocker().completeExceptionally(ConsumerUtils.maybeWrapAsKafkaException(t));
log.trace("Failing CompositePollEvent {}", event, t); log.trace("Failing CompositePollEvent {}", event, t);
event.future().completeExceptionally(t);
return true; return true;
} }
@ -353,7 +357,7 @@ public class ApplicationEventProcessor implements EventProcessor<ApplicationEven
if (exception.isPresent()) { if (exception.isPresent()) {
Exception e = exception.get(); Exception e = exception.get();
log.trace("Failing CompositePollEvent {} with error from NetworkClient", event, e); log.trace("Failing CompositePollEvent {} with error from NetworkClient", event, e);
event.future().completeExceptionally(e); event.blocker().completeExceptionally(ConsumerUtils.maybeWrapAsKafkaException(e));
return true; return true;
} }

View File

@ -16,7 +16,9 @@
*/ */
package org.apache.kafka.clients.consumer.internals.events; package org.apache.kafka.clients.consumer.internals.events;
public class CompositePollEvent extends CompletableApplicationEvent<CompositePollEvent.State> { import org.apache.kafka.clients.consumer.internals.Blocker;
public class CompositePollEvent extends ApplicationEvent {
public enum State { public enum State {
@ -25,13 +27,21 @@ public class CompositePollEvent extends CompletableApplicationEvent<CompositePol
COMPLETE COMPLETE
} }
private final long deadlineMs;
private final long pollTimeMs; private final long pollTimeMs;
private final Type nextEventType; private final Type nextEventType;
private final Blocker<State> blocker;
public CompositePollEvent(long deadlineMs, long pollTimeMs, Type nextEventType) { public CompositePollEvent(long deadlineMs, long pollTimeMs, Type nextEventType) {
super(Type.COMPOSITE_POLL, deadlineMs); super(Type.COMPOSITE_POLL);
this.deadlineMs = deadlineMs;
this.pollTimeMs = pollTimeMs; this.pollTimeMs = pollTimeMs;
this.nextEventType = nextEventType; this.nextEventType = nextEventType;
this.blocker = new Blocker<>();
}
public long deadlineMs() {
return deadlineMs;
} }
public long pollTimeMs() { public long pollTimeMs() {
@ -42,8 +52,12 @@ public class CompositePollEvent extends CompletableApplicationEvent<CompositePol
return nextEventType; return nextEventType;
} }
public Blocker<State> blocker() {
return blocker;
}
@Override @Override
protected String toStringBase() { protected String toStringBase() {
return super.toStringBase() + ", pollTimeMs=" + pollTimeMs + ", nextEventType=" + nextEventType; return super.toStringBase() + ", deadlineMs=" + deadlineMs + ", pollTimeMs=" + pollTimeMs + ", nextEventType=" + nextEventType + ", blocker=" + blocker;
} }
} }

View File

@ -102,6 +102,8 @@ import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatchers; import org.mockito.ArgumentMatchers;
import org.mockito.MockedStatic; import org.mockito.MockedStatic;
import org.mockito.Mockito; import org.mockito.Mockito;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.time.Duration; import java.time.Duration;
import java.util.Arrays; import java.util.Arrays;
@ -169,6 +171,7 @@ import static org.mockito.Mockito.when;
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public class AsyncKafkaConsumerTest { public class AsyncKafkaConsumerTest {
private static final Logger log = LoggerFactory.getLogger(AsyncKafkaConsumerTest.class);
private AsyncKafkaConsumer<String, String> consumer = null; private AsyncKafkaConsumer<String, String> consumer = null;
private Time time = new MockTime(0); private Time time = new MockTime(0);
private final Metrics metrics = new Metrics(); private final Metrics metrics = new Metrics();
@ -1677,7 +1680,7 @@ public class AsyncKafkaConsumerTest {
markReconcileAndAutoCommitCompleteForPollEvent(); markReconcileAndAutoCommitCompleteForPollEvent();
markResultForCompositePollEvent(CompositePollEvent.State.COMPLETE); markResultForCompositePollEvent(CompositePollEvent.State.COMPLETE);
consumer.poll(Duration.ofMillis(100)); consumer.poll(Duration.ofMillis(100));
verify(applicationEventHandler).addAndGet(any(CompositePollEvent.class)); verify(applicationEventHandler).add(any(CompositePollEvent.class));
} }
private Properties requiredConsumerConfigAndGroupId(final String groupId) { private Properties requiredConsumerConfigAndGroupId(final String groupId) {
@ -2255,26 +2258,33 @@ public class AsyncKafkaConsumerTest {
private void markResultForCompositePollEvent(CompositePollEvent.State state) { private void markResultForCompositePollEvent(CompositePollEvent.State state) {
doAnswer(invocation -> { doAnswer(invocation -> {
if (Thread.currentThread().isInterrupted()) CompositePollEvent event = invocation.getArgument(0);
throw new InterruptException("Test interrupt"); log.error("Am I invoked: {}", event);
if (Thread.currentThread().isInterrupted())
event.blocker().completeExceptionally(new InterruptException("Test interrupt"));
event.blocker().complete(state);
return state; return state;
}).when(applicationEventHandler).addAndGet(ArgumentMatchers.isA(CompositePollEvent.class)); }).when(applicationEventHandler).add(ArgumentMatchers.isA(CompositePollEvent.class));
} }
private void markResultForCompositePollEvent(Collection<CompositePollEvent.State> states) { private void markResultForCompositePollEvent(Collection<CompositePollEvent.State> states) {
LinkedList<CompositePollEvent.State> statesQueue = new LinkedList<>(states); LinkedList<CompositePollEvent.State> statesQueue = new LinkedList<>(states);
doAnswer(invocation -> { doAnswer(invocation -> {
CompositePollEvent event = invocation.getArgument(0);
log.error("Am I invoked: {}", event);
CompositePollEvent.State state = statesQueue.poll(); CompositePollEvent.State state = statesQueue.poll();
if (state == null) if (state == null)
throw new IllegalStateException("The array of " + CompositePollEvent.State.class.getSimpleName() + " did not provide enough values"); event.blocker().completeExceptionally(new KafkaException("The array of " + CompositePollEvent.State.class.getSimpleName() + " did not provide enough values"));
if (Thread.currentThread().isInterrupted()) if (Thread.currentThread().isInterrupted())
throw new InterruptException("Test interrupt"); event.blocker().completeExceptionally(new InterruptException("Test interrupt"));
event.blocker().complete(state);
return state; return state;
}).when(applicationEventHandler).addAndGet(ArgumentMatchers.isA(CompositePollEvent.class)); }).when(applicationEventHandler).add(ArgumentMatchers.isA(CompositePollEvent.class));
} }
} }