KAFKA-12500: fix memory leak in thread cache (#10355)

Need to exclude threads in PENDING_SHUTDOWN from the num live threads computation used to compute the new cache size per thread. Also adds some logging to help follow what's happening when a thread is added/removed/replaced.

Reviewers: Bruno Cadonna <cadonna@confluent.io>, Walker Carlson <wcarlson@confluent.io>, John Roesler <john@confluent.io>
This commit is contained in:
A. Sophie Blee-Goldman 2021-03-19 18:11:07 -07:00 committed by GitHub
parent 7c7e8078e4
commit 13b4ca8795
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 143 additions and 46 deletions

View File

@ -98,7 +98,6 @@ import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import java.util.function.Predicate;
import static org.apache.kafka.streams.StreamsConfig.METRICS_RECORDING_LEVEL_CONFIG;
import static org.apache.kafka.streams.errors.StreamsUncaughtExceptionHandler.StreamThreadExceptionResponse.SHUTDOWN_CLIENT;
@ -499,7 +498,7 @@ public class KafkaStreams implements AutoCloseable {
"longer in a well-defined state. Attempting to send the shutdown command anyway.", throwable);
}
if (Thread.currentThread().equals(globalStreamThread) && countStreamThread(StreamThread::isRunning) == 0) {
if (Thread.currentThread().equals(globalStreamThread) && getNumLiveStreamThreads() == 0) {
log.error("Exception in global thread caused the application to attempt to shutdown." +
" This action will succeed only if there is at least one StreamThread running on this client." +
" Currently there are no running threads so will now close the client.");
@ -838,8 +837,7 @@ public class KafkaStreams implements AutoCloseable {
ClientMetrics.addApplicationIdMetric(streamsMetrics, config.getString(StreamsConfig.APPLICATION_ID_CONFIG));
ClientMetrics.addTopologyDescriptionMetric(streamsMetrics, internalTopologyBuilder.describe().toString());
ClientMetrics.addStateMetric(streamsMetrics, (metricsConfig, now) -> state);
ClientMetrics.addNumAliveStreamThreadMetric(streamsMetrics, (metricsConfig, now) ->
Math.toIntExact(countStreamThread(thread -> thread.state().isAlive())));
ClientMetrics.addNumAliveStreamThreadMetric(streamsMetrics, (metricsConfig, now) -> getNumLiveStreamThreads());
streamsMetadataState = new StreamsMetadataState(
internalTopologyBuilder,
@ -965,12 +963,13 @@ public class KafkaStreams implements AutoCloseable {
*/
public Optional<String> addStreamThread() {
if (isRunningOrRebalancing()) {
final int threadIdx;
final long cacheSizePerThread;
final StreamThread streamThread;
synchronized (changeThreadCount) {
threadIdx = getNextThreadIndex();
cacheSizePerThread = getCacheSizePerThread(getNumLiveStreamThreads() + 1);
final int threadIdx = getNextThreadIndex();
final int numLiveThreads = getNumLiveStreamThreads();
final long cacheSizePerThread = getCacheSizePerThread(numLiveThreads + 1);
log.info("Adding StreamThread-{}, there will now be {} live threads and the new cache size per thread is {}",
threadIdx, numLiveThreads + 1, cacheSizePerThread);
resizeThreadCache(cacheSizePerThread);
// Creating thread should hold the lock in order to avoid duplicate thread index.
// If the duplicate index happen, the metadata of thread may be duplicate too.
@ -982,14 +981,19 @@ public class KafkaStreams implements AutoCloseable {
streamThread.start();
return Optional.of(streamThread.getName());
} else {
log.warn("Terminating the new thread because the Kafka Streams client is in state {}", state);
streamThread.shutdown();
threads.remove(streamThread);
resizeThreadCache(getCacheSizePerThread(getNumLiveStreamThreads()));
final long cacheSizePerThread = getCacheSizePerThread(getNumLiveStreamThreads());
log.info("Resizing thread cache due to terminating added thread, new cache size per thread is {}", cacheSizePerThread);
resizeThreadCache(cacheSizePerThread);
return Optional.empty();
}
}
} else {
log.warn("Cannot add a stream thread when Kafka Streams client is in state {}", state);
return Optional.empty();
}
log.warn("Cannot add a stream thread when Kafka Streams client is in state " + state());
return Optional.empty();
}
/**
@ -1031,7 +1035,8 @@ public class KafkaStreams implements AutoCloseable {
}
private Optional<String> removeStreamThread(final long timeoutMs) throws TimeoutException {
boolean timeout = false;
final long startMs = time.milliseconds();
if (isRunningOrRebalancing()) {
synchronized (changeThreadCount) {
// make a copy of threads to avoid holding lock
@ -1043,18 +1048,23 @@ public class KafkaStreams implements AutoCloseable {
streamThread.requestLeaveGroupDuringShutdown();
streamThread.shutdown();
if (!streamThread.getName().equals(Thread.currentThread().getName())) {
if (!streamThread.waitOnThreadState(StreamThread.State.DEAD, timeoutMs)) {
log.warn("Thread " + streamThread.getName() + " did not shutdown in the allotted time");
timeout = true;
final long remainingTimeMs = timeoutMs - (time.milliseconds() - startMs);
if (remainingTimeMs <= 0 || !streamThread.waitOnThreadState(StreamThread.State.DEAD, remainingTimeMs)) {
log.warn("{} did not shutdown in the allotted time.", streamThread.getName());
// Don't remove from threads until shutdown is complete. We will trim it from the
// list once it reaches DEAD, and if for some reason it's hanging indefinitely in the
// shutdown then we should just consider this thread.id to be burned
} else {
log.info("Successfully removed {} in {}ms", streamThread.getName(), time.milliseconds() - startMs);
threads.remove(streamThread);
}
} else {
log.info("{} is the last remaining thread and must remove itself, therefore we cannot wait "
+ "for it to complete shutdown as this will result in deadlock.", streamThread.getName());
}
final long cacheSizePerThread = getCacheSizePerThread(getNumLiveStreamThreads());
log.info("Resizing thread cache due to thread removal, new cache size per thread is {}", cacheSizePerThread);
resizeThreadCache(cacheSizePerThread);
if (groupInstanceID.isPresent() && callingThreadIsNotCurrentStreamThread) {
final MemberToRemove memberToRemove = new MemberToRemove(groupInstanceID.get());
@ -1065,7 +1075,8 @@ public class KafkaStreams implements AutoCloseable {
new RemoveMembersFromConsumerGroupOptions(membersToRemove)
);
try {
removeMembersFromConsumerGroupResult.memberResult(memberToRemove).get(timeoutMs, TimeUnit.MILLISECONDS);
final long remainingTimeMs = timeoutMs - (time.milliseconds() - startMs);
removeMembersFromConsumerGroupResult.memberResult(memberToRemove).get(remainingTimeMs, TimeUnit.MILLISECONDS);
} catch (final java.util.concurrent.TimeoutException e) {
log.error("Could not remove static member {} from consumer group {} due to a timeout: {}",
groupInstanceID.get(), config.getString(StreamsConfig.APPLICATION_ID_CONFIG), e);
@ -1083,7 +1094,8 @@ public class KafkaStreams implements AutoCloseable {
);
}
}
if (timeout) {
final long remainingTimeMs = timeoutMs - (time.milliseconds() - startMs);
if (remainingTimeMs <= 0) {
throw new TimeoutException("Thread " + streamThread.getName() + " did not stop in the allotted time");
}
return Optional.of(streamThread.getName());
@ -1097,13 +1109,25 @@ public class KafkaStreams implements AutoCloseable {
return Optional.empty();
}
// Returns the number of threads that are not in the DEAD state -- use this over threads.size()
/**
* Takes a snapshot and counts the number of stream threads which are not in PENDING_SHUTDOWN or DEAD
*
* note: iteration over SynchronizedList is not thread safe so it must be manually synchronized. However, we may
* require other locks when looping threads and it could cause deadlock. Hence, we create a copy to avoid holding
* threads lock when looping threads.
* @return number of alive stream threads
*/
private int getNumLiveStreamThreads() {
final AtomicInteger numLiveThreads = new AtomicInteger(0);
synchronized (threads) {
processStreamThread(thread -> {
if (thread.state() == StreamThread.State.DEAD) {
log.debug("Trimming thread {} from the threads list since it's state is {}", thread.getName(), StreamThread.State.DEAD);
threads.remove(thread);
} else if (thread.state() == StreamThread.State.PENDING_SHUTDOWN) {
log.debug("Skipping thread {} from num live threads computation since it's state is {}",
thread.getName(), StreamThread.State.PENDING_SHUTDOWN);
} else {
numLiveThreads.incrementAndGet();
}
@ -1617,19 +1641,6 @@ public class KafkaStreams implements AutoCloseable {
for (final StreamThread thread : copy) consumer.accept(thread);
}
/**
* count the snapshot of threads.
* noted: iteration over SynchronizedList is not thread safe so it must be manually synchronized. However, we may
* require other locks when looping threads and it could cause deadlock. Hence, we create a copy to avoid holding
* threads lock when looping threads.
* @param predicate predicate
* @return number of matched threads
*/
private long countStreamThread(final Predicate<StreamThread> predicate) {
final List<StreamThread> copy = new ArrayList<>(threads);
return copy.stream().filter(predicate).count();
}
/**
* Returns runtime information about the local threads of this {@link KafkaStreams} instance.
*

View File

@ -76,6 +76,7 @@ public class ThreadCache {
final boolean shrink = newCacheSizeBytes < maxCacheSizeBytes;
maxCacheSizeBytes = newCacheSizeBytes;
if (shrink) {
log.debug("Cache size was shrunk to {}", newCacheSizeBytes);
if (caches.values().isEmpty()) {
return;
}
@ -85,6 +86,8 @@ public class ThreadCache {
cache.evict();
numEvicts++;
}
} else {
log.debug("Cache size was expanded to {}", newCacheSizeBytes);
}
}

View File

@ -20,13 +20,22 @@ import org.apache.kafka.clients.consumer.ConsumerConfig;
import org.apache.kafka.common.errors.TimeoutException;
import org.apache.kafka.common.serialization.Serdes;
import org.apache.kafka.streams.KafkaStreams;
import org.apache.kafka.streams.KeyValue;
import org.apache.kafka.streams.StreamsBuilder;
import org.apache.kafka.streams.StreamsConfig;
import org.apache.kafka.streams.errors.StreamsUncaughtExceptionHandler.StreamThreadExceptionResponse;
import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster;
import org.apache.kafka.streams.integration.utils.IntegrationTestUtils;
import org.apache.kafka.streams.kstream.KStream;
import org.apache.kafka.streams.kstream.Transformer;
import org.apache.kafka.streams.processor.ProcessorContext;
import org.apache.kafka.streams.processor.PunctuationType;
import org.apache.kafka.streams.processor.ThreadMetadata;
import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender;
import org.apache.kafka.test.IntegrationTest;
import org.apache.kafka.test.TestUtils;
import java.util.concurrent.atomic.AtomicBoolean;
import org.junit.After;
import org.junit.Before;
import org.junit.ClassRule;
@ -34,7 +43,6 @@ import org.junit.Rule;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.junit.rules.TestName;
import java.io.IOException;
import java.time.Duration;
import java.util.ArrayList;
@ -53,6 +61,7 @@ import static org.apache.kafka.common.utils.Utils.mkObjectProperties;
import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.purgeLocalStreamsState;
import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.safeUniqueTestName;
import static org.apache.kafka.test.TestUtils.waitForCondition;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.not;
@ -61,6 +70,7 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
@Category(IntegrationTest.class)
public class AdjustStreamThreadCountTest {
@ -85,7 +95,7 @@ public class AdjustStreamThreadCountTest {
inputTopic = "input" + testId;
IntegrationTestUtils.cleanStateBeforeTest(CLUSTER, inputTopic);
builder = new StreamsBuilder();
builder = new StreamsBuilder();
builder.stream(inputTopic);
properties = mkObjectProperties(
@ -346,4 +356,86 @@ public class AdjustStreamThreadCountTest {
assertEquals(oldThreadCount, kafkaStreams.localThreadsMetadata().size());
}
}
@Test
public void shouldResizeCacheAfterThreadRemovalTimesOut() throws InterruptedException {
final long totalCacheBytes = 10L;
final Properties props = new Properties();
props.putAll(properties);
props.put(StreamsConfig.NUM_STREAM_THREADS_CONFIG, 2);
props.put(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, totalCacheBytes);
try (final KafkaStreams kafkaStreams = new KafkaStreams(builder.build(), props)) {
addStreamStateChangeListener(kafkaStreams);
startStreamsAndWaitForRunning(kafkaStreams);
try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(KafkaStreams.class)) {
assertThrows(TimeoutException.class, () -> kafkaStreams.removeStreamThread(Duration.ofSeconds(0)));
for (final String log : appender.getMessages()) {
// all 10 bytes should be available for remaining thread
if (log.endsWith("Resizing thread cache due to thread removal, new cache size per thread is 10")) {
return;
}
}
}
}
fail();
}
@Test
public void shouldResizeCacheAfterThreadReplacement() throws InterruptedException {
final long totalCacheBytes = 10L;
final Properties props = new Properties();
props.putAll(properties);
props.put(StreamsConfig.NUM_STREAM_THREADS_CONFIG, 2);
props.put(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, totalCacheBytes);
final AtomicBoolean injectError = new AtomicBoolean(false);
final StreamsBuilder builder = new StreamsBuilder();
final KStream<String, String> stream = builder.stream(inputTopic);
stream.transform(() -> new Transformer<String, String, KeyValue<String, String>>() {
@Override
public void init(final ProcessorContext context) {
context.schedule(Duration.ofSeconds(1), PunctuationType.WALL_CLOCK_TIME, timestamp -> {
if (Thread.currentThread().getName().endsWith("StreamThread-1") && injectError.get()) {
injectError.set(false);
throw new RuntimeException("BOOM");
}
});
}
@Override
public KeyValue<String, String> transform(final String key, final String value) {
return new KeyValue<>(key, value);
}
@Override
public void close() {
}
});
try (final KafkaStreams kafkaStreams = new KafkaStreams(builder.build(), props)) {
addStreamStateChangeListener(kafkaStreams);
kafkaStreams.setUncaughtExceptionHandler(e -> StreamThreadExceptionResponse.REPLACE_THREAD);
startStreamsAndWaitForRunning(kafkaStreams);
stateTransitionHistory.clear();
try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister()) {
injectError.set(true);
waitForCondition(() -> !injectError.get(), "StreamThread did not hit and reset the injected error");
waitForTransitionFromRebalancingToRunning();
for (final String log : appender.getMessages()) {
// after we replace the thread there should be two remaining threads with 5 bytes each
if (log.endsWith("Adding StreamThread-3, there will now be 2 live threads and the new cache size per thread is 5")) {
return;
}
}
}
}
fail();
}
}

View File

@ -260,7 +260,7 @@ public class MetricsIntegrationTest {
final Topology topology = builder.build();
kafkaStreams = new KafkaStreams(topology, streamsConfiguration);
verifyAliveStreamThreadsMetric(0);
verifyAliveStreamThreadsMetric();
verifyStateMetric(State.CREATED);
verifyTopologyDescriptionMetric(topology.describe().toString());
verifyApplicationIdMetric();
@ -271,7 +271,7 @@ public class MetricsIntegrationTest {
timeout,
() -> "Kafka Streams application did not reach state RUNNING in " + timeout + " ms");
verifyAliveStreamThreadsMetric(NUM_THREADS);
verifyAliveStreamThreadsMetric();
verifyStateMetric(State.RUNNING);
}
@ -463,13 +463,13 @@ public class MetricsIntegrationTest {
checkMetricsDeregistration();
}
private void verifyAliveStreamThreadsMetric(final int numThreads) {
private void verifyAliveStreamThreadsMetric() {
final List<Metric> metricsList = new ArrayList<Metric>(kafkaStreams.metrics().values()).stream()
.filter(m -> m.metricName().name().equals(ALIVE_STREAM_THREADS) &&
m.metricName().group().equals(STREAM_CLIENT_NODE_METRICS))
.collect(Collectors.toList());
assertThat(metricsList.size(), is(1));
assertThat(metricsList.get(0).metricValue(), is(numThreads));
assertThat(metricsList.get(0).metricValue(), is(NUM_THREADS));
}
private void verifyStateMetric(final State state) {

View File

@ -1352,15 +1352,6 @@ public class IntegrationTestUtils {
final long totalRestored) {
}
public boolean allStartOffsetsAtZero() {
for (final AtomicLong startOffset : changelogToStartOffset.values()) {
if (startOffset.get() != 0L) {
return false;
}
}
return true;
}
public long totalNumRestored() {
long totalNumRestored = 0L;
for (final AtomicLong numRestored : changelogToTotalNumRestored.values()) {