KAFKA-19297: Refactor AsyncKafkaConsumer's use of Java Streams APIs in critical sections (#19917)

Profiling has shown that using the Collections Streams API approach adds
unnecessary overhead compared to a traditional for loop. Minor revisions
to the code have been made to use simpler constructs to improve
performance.

Reviewers: Lianet Magrans <lmagrans@confluent.io>, Andrew Schofield
 <aschofield@confluent.io>
This commit is contained in:
Kirk True 2025-06-18 07:00:45 -07:00 committed by GitHub
parent 2a06335569
commit adcf10ca8b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 179 additions and 96 deletions

View File

@ -45,6 +45,7 @@ import org.slf4j.helpers.MessageFormatter;
import java.io.Closeable;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
@ -54,7 +55,6 @@ import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import static org.apache.kafka.clients.consumer.internals.FetchUtils.requestMetadataUpdate;
@ -223,10 +223,15 @@ public abstract class AbstractFetch implements Closeable {
}
if (!partitionsWithUpdatedLeaderInfo.isEmpty()) {
List<Node> leaderNodes = response.data().nodeEndpoints().stream()
.map(e -> new Node(e.nodeId(), e.host(), e.port(), e.rack()))
.filter(e -> !e.equals(Node.noNode()))
.collect(Collectors.toList());
List<Node> leaderNodes = new ArrayList<>();
for (FetchResponseData.NodeEndpoint e : response.data().nodeEndpoints()) {
Node node = new Node(e.nodeId(), e.host(), e.port(), e.rack());
if (!node.equals(Node.noNode()))
leaderNodes.add(node);
}
Set<TopicPartition> updatedPartitions = metadata.updatePartitionLeadership(partitionsWithUpdatedLeaderInfo, leaderNodes);
updatedPartitions.forEach(
tp -> {
@ -397,7 +402,7 @@ public abstract class AbstractFetch implements Closeable {
fetchable.put(fetchTarget, sessionHandler.newBuilder());
});
return fetchable.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().build()));
return convert(fetchable);
}
/**
@ -470,7 +475,21 @@ public abstract class AbstractFetch implements Closeable {
}
}
return fetchable.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().build()));
return convert(fetchable);
}
/**
* This method converts {@link FetchSessionHandler.Builder} instances to
* {@link FetchSessionHandler.FetchRequestData} instances. It intentionally forgoes use of the Java Collections
* Streams API to reduce overhead in the critical network path.
*/
private Map<Node, FetchSessionHandler.FetchRequestData> convert(Map<Node, FetchSessionHandler.Builder> fetchable) {
Map<Node, FetchSessionHandler.FetchRequestData> map = new HashMap<>(fetchable.size());
for (Map.Entry<Node, FetchSessionHandler.Builder> entry : fetchable.entrySet())
map.put(entry.getKey(), entry.getValue().build());
return map;
}
/**

View File

@ -43,7 +43,6 @@ import java.util.SortedSet;
import java.util.TreeSet;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
import static java.util.Collections.unmodifiableList;
@ -1135,7 +1134,7 @@ public abstract class AbstractMembershipManager<R extends AbstractResponse> impl
// Ensure the set of partitions to revoke are still assigned
Set<TopicPartition> revokedPartitions = new HashSet<>(partitionsToRevoke);
revokedPartitions.retainAll(subscriptions.assignedPartitions());
log.info("Revoking previously assigned partitions {}", revokedPartitions.stream().map(TopicPartition::toString).collect(Collectors.joining(", ")));
log.info("Revoking previously assigned partitions {}", revokedPartitions);
signalPartitionsBeingRevoked(revokedPartitions);

View File

@ -44,7 +44,6 @@ import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
import static org.apache.kafka.clients.consumer.CloseOptions.GroupMembershipOperation.DEFAULT;
import static org.apache.kafka.clients.consumer.CloseOptions.GroupMembershipOperation.LEAVE_GROUP;
@ -415,7 +414,7 @@ public class ConsumerMembershipManager extends AbstractMembershipManager<Consume
Set<TopicPartition> revokePausedPartitions = subscriptions.pausedPartitions();
revokePausedPartitions.retainAll(partitionsToRevoke);
if (!revokePausedPartitions.isEmpty()) {
log.info("The pause flag in partitions [{}] will be removed due to revocation.", revokePausedPartitions.stream().map(TopicPartition::toString).collect(Collectors.joining(", ")));
log.info("The pause flag in partitions {} will be removed due to revocation.", revokePausedPartitions);
}
}

View File

@ -35,13 +35,13 @@ import org.slf4j.Logger;
import java.io.Closeable;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.BlockingQueue;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import static org.apache.kafka.clients.consumer.internals.ConsumerUtils.DEFAULT_CLOSE_TIMEOUT_MS;
import static org.apache.kafka.common.utils.Utils.closeQuietly;
@ -144,6 +144,7 @@ public class ConsumerNetworkThread extends KafkaThread implements Closeable {
* </ol>
*/
void runOnce() {
// The following code avoids use of the Java Collections Streams API to reduce overhead in this loop.
processApplicationEvents();
final long currentTimeMs = time.milliseconds();
@ -152,19 +153,24 @@ public class ConsumerNetworkThread extends KafkaThread implements Closeable {
}
lastPollTimeMs = currentTimeMs;
final long pollWaitTimeMs = requestManagers.entries().stream()
.map(rm -> rm.poll(currentTimeMs))
.mapToLong(networkClientDelegate::addAll)
.filter(ms -> ms <= MAX_POLL_TIMEOUT_MS)
.min()
.orElse(MAX_POLL_TIMEOUT_MS);
long pollWaitTimeMs = MAX_POLL_TIMEOUT_MS;
for (RequestManager rm : requestManagers.entries()) {
NetworkClientDelegate.PollResult pollResult = rm.poll(currentTimeMs);
long timeoutMs = networkClientDelegate.addAll(pollResult);
pollWaitTimeMs = Math.min(pollWaitTimeMs, timeoutMs);
}
networkClientDelegate.poll(pollWaitTimeMs, currentTimeMs);
cachedMaximumTimeToWait = requestManagers.entries().stream()
.mapToLong(rm -> rm.maximumTimeToWait(currentTimeMs))
.min()
.orElse(Long.MAX_VALUE);
long maxTimeToWaitMs = Long.MAX_VALUE;
for (RequestManager rm : requestManagers.entries()) {
long waitMs = rm.maximumTimeToWait(currentTimeMs);
maxTimeToWaitMs = Math.min(maxTimeToWaitMs, waitMs);
}
cachedMaximumTimeToWait = maxTimeToWaitMs;
reapExpiredApplicationEvents(currentTimeMs);
List<CompletableEvent<?>> uncompletedEvents = applicationEventReaper.uncompletedEvents();
@ -235,10 +241,11 @@ public class ConsumerNetworkThread extends KafkaThread implements Closeable {
static void runAtClose(final Collection<RequestManager> requestManagers,
final NetworkClientDelegate networkClientDelegate,
final long currentTimeMs) {
// These are the optional outgoing requests at the
requestManagers.stream()
.map(rm -> rm.pollOnClose(currentTimeMs))
.forEach(networkClientDelegate::addAll);
// These are the optional outgoing requests at the time of closing the consumer
for (RequestManager rm : requestManagers) {
NetworkClientDelegate.PollResult pollResult = rm.pollOnClose(currentTimeMs);
networkClientDelegate.addAll(pollResult);
}
}
public boolean isRunning() {
@ -362,12 +369,13 @@ public class ConsumerNetworkThread extends KafkaThread implements Closeable {
* If there is a metadata error, complete all uncompleted events that require subscription metadata.
*/
private void maybeFailOnMetadataError(List<CompletableEvent<?>> events) {
List<? extends CompletableApplicationEvent<?>> subscriptionMetadataEvent = events.stream()
.filter(e -> e instanceof CompletableApplicationEvent<?>)
.map(e -> (CompletableApplicationEvent<?>) e)
.filter(CompletableApplicationEvent::requireSubscriptionMetadata)
.collect(Collectors.toList());
List<CompletableApplicationEvent<?>> subscriptionMetadataEvent = new ArrayList<>();
for (CompletableEvent<?> ce : events) {
if (ce instanceof CompletableApplicationEvent && ((CompletableApplicationEvent<?>) ce).requireSubscriptionMetadata())
subscriptionMetadataEvent.add((CompletableApplicationEvent<?>) ce);
}
if (subscriptionMetadataEvent.isEmpty())
return;
networkClientDelegate.getAndClearMetadataError().ifPresent(metadataError ->

View File

@ -30,7 +30,6 @@ import org.slf4j.Logger;
import java.util.Optional;
import java.util.Set;
import java.util.SortedSet;
import java.util.stream.Collectors;
/**
* This class encapsulates the invocation of the callback methods defined in the {@link ConsumerRebalanceListener}
@ -55,7 +54,7 @@ public class ConsumerRebalanceListenerInvoker {
}
public Exception invokePartitionsAssigned(final SortedSet<TopicPartition> assignedPartitions) {
log.info("Adding newly assigned partitions: {}", assignedPartitions.stream().map(TopicPartition::toString).collect(Collectors.joining(", ")));
log.info("Adding newly assigned partitions: {}", assignedPartitions);
Optional<ConsumerRebalanceListener> listener = subscriptions.rebalanceListener();
@ -67,8 +66,12 @@ public class ConsumerRebalanceListenerInvoker {
} catch (WakeupException | InterruptException e) {
throw e;
} catch (Exception e) {
log.error("User provided listener {} failed on invocation of onPartitionsAssigned for partitions {}",
listener.get().getClass().getName(), assignedPartitions, e);
log.error(
"User provided listener {} failed on invocation of onPartitionsAssigned for partitions {}",
listener.get().getClass().getName(),
assignedPartitions,
e
);
return e;
}
}
@ -77,11 +80,11 @@ public class ConsumerRebalanceListenerInvoker {
}
public Exception invokePartitionsRevoked(final SortedSet<TopicPartition> revokedPartitions) {
log.info("Revoke previously assigned partitions {}", revokedPartitions.stream().map(TopicPartition::toString).collect(Collectors.joining(", ")));
log.info("Revoke previously assigned partitions {}", revokedPartitions);
Set<TopicPartition> revokePausedPartitions = subscriptions.pausedPartitions();
revokePausedPartitions.retainAll(revokedPartitions);
if (!revokePausedPartitions.isEmpty())
log.info("The pause flag in partitions [{}] will be removed due to revocation.", revokePausedPartitions.stream().map(TopicPartition::toString).collect(Collectors.joining(", ")));
log.info("The pause flag in partitions {} will be removed due to revocation.", revokePausedPartitions);
Optional<ConsumerRebalanceListener> listener = subscriptions.rebalanceListener();
@ -93,8 +96,12 @@ public class ConsumerRebalanceListenerInvoker {
} catch (WakeupException | InterruptException e) {
throw e;
} catch (Exception e) {
log.error("User provided listener {} failed on invocation of onPartitionsRevoked for partitions {}",
listener.get().getClass().getName(), revokedPartitions, e);
log.error(
"User provided listener {} failed on invocation of onPartitionsRevoked for partitions {}",
listener.get().getClass().getName(),
revokedPartitions,
e
);
return e;
}
}
@ -103,11 +110,11 @@ public class ConsumerRebalanceListenerInvoker {
}
public Exception invokePartitionsLost(final SortedSet<TopicPartition> lostPartitions) {
log.info("Lost previously assigned partitions {}", lostPartitions.stream().map(TopicPartition::toString).collect(Collectors.joining(", ")));
log.info("Lost previously assigned partitions {}", lostPartitions);
Set<TopicPartition> lostPausedPartitions = subscriptions.pausedPartitions();
lostPausedPartitions.retainAll(lostPartitions);
if (!lostPausedPartitions.isEmpty())
log.info("The pause flag in partitions [{}] will be removed due to partition lost.", lostPartitions.stream().map(TopicPartition::toString).collect(Collectors.joining(", ")));
log.info("The pause flag in partitions {} will be removed due to partition lost.", lostPartitions);
Optional<ConsumerRebalanceListener> listener = subscriptions.rebalanceListener();
@ -119,8 +126,12 @@ public class ConsumerRebalanceListenerInvoker {
} catch (WakeupException | InterruptException e) {
throw e;
} catch (Exception e) {
log.error("User provided listener {} failed on invocation of onPartitionsLost for partitions {}",
listener.get().getClass().getName(), lostPartitions, e);
log.error(
"User provided listener {} failed on invocation of onPartitionsLost for partitions {}",
listener.get().getClass().getName(),
lostPartitions,
e
);
return e;
}
}

View File

@ -33,15 +33,16 @@ import org.apache.kafka.common.utils.Time;
import org.slf4j.Logger;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
import static org.apache.kafka.clients.consumer.internals.NetworkClientDelegate.PollResult.EMPTY;
@ -84,16 +85,23 @@ public class TopicMetadataRequestManager implements RequestManager {
@Override
public NetworkClientDelegate.PollResult poll(final long currentTimeMs) {
// Prune any requests which have timed out
List<TopicMetadataRequestState> expiredRequests = inflightRequests.stream()
.filter(TimedRequestState::isExpired)
.collect(Collectors.toList());
expiredRequests.forEach(TopicMetadataRequestState::expire);
Iterator<TopicMetadataRequestState> requestStateIterator = inflightRequests.iterator();
List<NetworkClientDelegate.UnsentRequest> requests = inflightRequests.stream()
.map(req -> req.send(currentTimeMs))
.filter(Optional::isPresent)
.map(Optional::get)
.collect(Collectors.toList());
while (requestStateIterator.hasNext()) {
TopicMetadataRequestState requestState = requestStateIterator.next();
if (requestState.isExpired()) {
requestState.expire();
requestStateIterator.remove();
}
}
List<NetworkClientDelegate.UnsentRequest> requests = new ArrayList<>();
for (TopicMetadataRequestState request : inflightRequests) {
Optional<NetworkClientDelegate.UnsentRequest> unsentRequest = request.send(currentTimeMs);
unsentRequest.ifPresent(requests::add);
}
return requests.isEmpty() ? EMPTY : new NetworkClientDelegate.PollResult(0, requests);
}
@ -181,7 +189,9 @@ public class TopicMetadataRequestManager implements RequestManager {
}
private void expire() {
completeFutureAndRemoveRequest(
// The request state is removed from inflightRequests via an iterator by the caller of this method,
// so don't remove it from inflightRequests here.
future.completeExceptionally(
new TimeoutException("Timeout expired while fetching topic metadata"));
}

View File

@ -310,7 +310,7 @@ public class ApplicationEventProcessor implements EventProcessor<ApplicationEven
manager.updateTimerAndMaybeCommit(event.currentTimeMs());
}
log.info("Assigned to partition(s): {}", event.partitions().stream().map(TopicPartition::toString).collect(Collectors.joining(", ")));
log.info("Assigned to partition(s): {}", event.partitions());
try {
if (subscriptions.assignFromUser(new HashSet<>(event.partitions())))
metadata.requestUpdateForNewTopics();

View File

@ -25,11 +25,10 @@ import org.slf4j.Logger;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.function.Consumer;
import java.util.stream.Collectors;
/**
* {@code CompletableEventReaper} is responsible for tracking {@link CompletableEvent time-bound events} and removing
@ -85,26 +84,39 @@ public class CompletableEventReaper {
* @return The number of events that were expired
*/
public long reap(long currentTimeMs) {
Consumer<CompletableEvent<?>> expireEvent = event -> {
long pastDueMs = currentTimeMs - event.deadlineMs();
TimeoutException error = new TimeoutException(String.format("%s was %s ms past its expiration of %s", event.getClass().getSimpleName(), pastDueMs, event.deadlineMs()));
int count = 0;
Iterator<CompletableEvent<?>> iterator = tracked.iterator();
while (iterator.hasNext()) {
CompletableEvent<?> event = iterator.next();
if (event.future().isDone()) {
// Remove any events that are already complete.
iterator.remove();
continue;
}
long deadlineMs = event.deadlineMs();
long pastDueMs = currentTimeMs - deadlineMs;
if (pastDueMs < 0)
continue;
TimeoutException error = new TimeoutException(String.format("%s was %s ms past its expiration of %s", event.getClass().getSimpleName(), pastDueMs, deadlineMs));
// Complete (exceptionally) any events that have passed their deadline AND aren't already complete.
if (event.future().completeExceptionally(error)) {
log.debug("Event {} completed exceptionally since its expiration of {} passed {} ms ago", event, event.deadlineMs(), pastDueMs);
log.debug("Event {} completed exceptionally since its expiration of {} passed {} ms ago", event, deadlineMs, pastDueMs);
} else {
log.trace("Event {} not completed exceptionally since it was previously completed", event);
}
};
// First, complete (exceptionally) any events that have passed their deadline AND aren't already complete.
long count = tracked.stream()
.filter(e -> !e.future().isDone())
.filter(e -> currentTimeMs >= e.deadlineMs())
.peek(expireEvent)
.count();
// Second, remove any events that are already complete, just to make sure we don't hold references. This will
// include any events that finished successfully as well as any events we just completed exceptionally above.
tracked.removeIf(e -> e.future().isDone());
count++;
// Remove the events so that we don't hold a reference to it.
iterator.remove();
}
return count;
}
@ -131,29 +143,12 @@ public class CompletableEventReaper {
public long reap(Collection<?> events) {
Objects.requireNonNull(events, "Event queue to reap must be non-null");
Consumer<CompletableEvent<?>> expireEvent = event -> {
TimeoutException error = new TimeoutException(String.format("%s could not be completed before the consumer closed", event.getClass().getSimpleName()));
if (event.future().completeExceptionally(error)) {
log.debug("Event {} completed exceptionally since the consumer is closing", event);
} else {
log.trace("Event {} not completed exceptionally since it was completed prior to the consumer closing", event);
}
};
long trackedExpiredCount = tracked.stream()
.filter(e -> !e.future().isDone())
.peek(expireEvent)
.count();
long trackedExpiredCount = completeEventsExceptionallyOnClose(tracked);
tracked.clear();
long eventExpiredCount = events.stream()
.filter(e -> e instanceof CompletableEvent<?>)
.map(e -> (CompletableEvent<?>) e)
.filter(e -> !e.future().isDone())
.peek(expireEvent)
.count();
long eventExpiredCount = completeEventsExceptionallyOnClose(events);
events.clear();
return trackedExpiredCount + eventExpiredCount;
}
@ -166,9 +161,51 @@ public class CompletableEventReaper {
}
public List<CompletableEvent<?>> uncompletedEvents() {
return tracked.stream()
.filter(e -> !e.future().isDone())
.collect(Collectors.toList());
// The following code does not use the Java Collections Streams API to reduce overhead in the critical
// path of the ConsumerNetworkThread loop.
List<CompletableEvent<?>> events = new ArrayList<>();
for (CompletableEvent<?> event : tracked) {
if (!event.future().isDone())
events.add(event);
}
return events;
}
/**
* For all the {@link CompletableEvent}s in the collection, if they're not already complete, invoke
* {@link CompletableFuture#completeExceptionally(Throwable)}.
*
* @param events Collection of objects, assumed to be subclasses of {@link ApplicationEvent} or
* {@link BackgroundEvent}, but will only perform completion for any
* unfinished {@link CompletableEvent}s
*
* @return Number of events closed
*/
private long completeEventsExceptionallyOnClose(Collection<?> events) {
long count = 0;
for (Object o : events) {
if (!(o instanceof CompletableEvent))
continue;
CompletableEvent<?> event = (CompletableEvent<?>) o;
if (event.future().isDone())
continue;
count++;
TimeoutException error = new TimeoutException(String.format("%s could not be completed before the consumer closed", event.getClass().getSimpleName()));
if (event.future().completeExceptionally(error)) {
log.debug("Event {} completed exceptionally since the consumer is closing", event);
} else {
log.trace("Event {} not completed exceptionally since it was completed prior to the consumer closing", event);
}
}
return count;
}
}