diff --git a/checkstyle/checkstyle.xml b/checkstyle/checkstyle.xml index 91045adc608..7f912dc428a 100644 --- a/checkstyle/checkstyle.xml +++ b/checkstyle/checkstyle.xml @@ -120,6 +120,7 @@ + diff --git a/streams/src/main/java/org/apache/kafka/streams/KafkaStreams.java b/streams/src/main/java/org/apache/kafka/streams/KafkaStreams.java index de048d19940..f62427d5e12 100644 --- a/streams/src/main/java/org/apache/kafka/streams/KafkaStreams.java +++ b/streams/src/main/java/org/apache/kafka/streams/KafkaStreams.java @@ -92,6 +92,7 @@ import java.util.Properties; import java.util.Set; import java.util.TreeMap; import java.util.UUID; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; @@ -463,9 +464,8 @@ public class KafkaStreams implements AutoCloseable { closeToError(); } final StreamThread deadThread = (StreamThread) Thread.currentThread(); - threads.remove(deadThread); - addStreamThread(); deadThread.shutdown(); + addStreamThread(); if (throwable instanceof RuntimeException) { throw (RuntimeException) throwable; } else if (throwable instanceof Error) { @@ -970,7 +970,7 @@ public class KafkaStreams implements AutoCloseable { final StreamThread streamThread; synchronized (changeThreadCount) { threadIdx = getNextThreadIndex(); - cacheSizePerThread = getCacheSizePerThread(threads.size() + 1); + cacheSizePerThread = getCacheSizePerThread(getNumLiveStreamThreads() + 1); resizeThreadCache(cacheSizePerThread); // Creating thread should hold the lock in order to avoid duplicate thread index. // If the duplicate index happen, the metadata of thread may be duplicate too. @@ -984,7 +984,7 @@ public class KafkaStreams implements AutoCloseable { } else { streamThread.shutdown(); threads.remove(streamThread); - resizeThreadCache(getCacheSizePerThread(threads.size())); + resizeThreadCache(getCacheSizePerThread(getNumLiveStreamThreads())); } } } @@ -1038,7 +1038,7 @@ public class KafkaStreams implements AutoCloseable { // make a copy of threads to avoid holding lock for (final StreamThread streamThread : new ArrayList<>(threads)) { final boolean callingThreadIsNotCurrentStreamThread = !streamThread.getName().equals(Thread.currentThread().getName()); - if (streamThread.isAlive() && (callingThreadIsNotCurrentStreamThread || threads.size() == 1)) { + if (streamThread.isAlive() && (callingThreadIsNotCurrentStreamThread || getNumLiveStreamThreads() == 1)) { log.info("Removing StreamThread " + streamThread.getName()); final Optional groupInstanceID = streamThread.getGroupInstanceID(); streamThread.requestLeaveGroupDuringShutdown(); @@ -1047,10 +1047,15 @@ public class KafkaStreams implements AutoCloseable { if (!streamThread.waitOnThreadState(StreamThread.State.DEAD, timeoutMs - begin)) { log.warn("Thread " + streamThread.getName() + " did not shutdown in the allotted time"); timeout = true; + // Don't remove from threads until shutdown is complete. We will trim it from the + // list once it reaches DEAD, and if for some reason it's hanging indefinitely in the + // shutdown then we should just consider this thread.id to be burned + } else { + threads.remove(streamThread); } } - threads.remove(streamThread); - final long cacheSizePerThread = getCacheSizePerThread(threads.size()); + + final long cacheSizePerThread = getCacheSizePerThread(getNumLiveStreamThreads()); resizeThreadCache(cacheSizePerThread); if (groupInstanceID.isPresent() && callingThreadIsNotCurrentStreamThread) { final MemberToRemove memberToRemove = new MemberToRemove(groupInstanceID.get()); @@ -1093,17 +1098,51 @@ public class KafkaStreams implements AutoCloseable { return Optional.empty(); } - private int getNextThreadIndex() { - final HashSet names = new HashSet<>(); - processStreamThread(thread -> names.add(thread.getName())); - final String baseName = clientId + "-StreamThread-"; - for (int i = 1; i <= threads.size(); i++) { - final String name = baseName + i; - if (!names.contains(name)) { - return i; - } + // Returns the number of threads that are not in the DEAD state -- use this over threads.size() + private int getNumLiveStreamThreads() { + final AtomicInteger numLiveThreads = new AtomicInteger(0); + synchronized (threads) { + processStreamThread(thread -> { + if (thread.state() == StreamThread.State.DEAD) { + threads.remove(thread); + } else { + numLiveThreads.incrementAndGet(); + } + }); + return numLiveThreads.get(); + } + } + + private int getNextThreadIndex() { + final HashSet allLiveThreadNames = new HashSet<>(); + final AtomicInteger maxThreadId = new AtomicInteger(1); + synchronized (threads) { + processStreamThread(thread -> { + // trim any DEAD threads from the list so we can reuse the thread.id + // this is only safe to do once the thread has fully completed shutdown + if (thread.state() == StreamThread.State.DEAD) { + threads.remove(thread); + } else { + allLiveThreadNames.add(thread.getName()); + // Assume threads are always named with the "-StreamThread-" suffix + final int threadId = Integer.parseInt(thread.getName().substring(thread.getName().lastIndexOf("-") + 1)); + if (threadId > maxThreadId.get()) { + maxThreadId.set(threadId); + } + } + }); + + final String baseName = clientId + "-StreamThread-"; + for (int i = 1; i <= maxThreadId.get(); i++) { + final String name = baseName + i; + if (!allLiveThreadNames.contains(name)) { + return i; + } + } + // It's safe to use threads.size() rather than getNumLiveStreamThreads() to infer the number of threads + // here since we trimmed any DEAD threads earlier in this method while holding the lock + return threads.size() + 1; } - return threads.size() + 1; } private long getCacheSizePerThread(final int numStreamThreads) { diff --git a/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java b/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java index a4cd8bf22c2..b3dd559d88d 100644 --- a/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java @@ -232,8 +232,8 @@ public class KafkaStreamsTest { EasyMock.expect(StreamThread.processingMode(anyObject(StreamsConfig.class))).andReturn(StreamThread.ProcessingMode.AT_LEAST_ONCE).anyTimes(); EasyMock.expect(streamThreadOne.getId()).andReturn(0L).anyTimes(); EasyMock.expect(streamThreadTwo.getId()).andReturn(1L).anyTimes(); - prepareStreamThread(streamThreadOne, true); - prepareStreamThread(streamThreadTwo, false); + prepareStreamThread(streamThreadOne, 1, true); + prepareStreamThread(streamThreadTwo, 2, false); // setup global threads final AtomicReference globalThreadState = new AtomicReference<>(GlobalStreamThread.State.CREATED); @@ -293,7 +293,7 @@ public class KafkaStreamsTest { ); } - private void prepareStreamThread(final StreamThread thread, final boolean terminable) throws Exception { + private void prepareStreamThread(final StreamThread thread, final int threadId, final boolean terminable) throws Exception { final AtomicReference state = new AtomicReference<>(StreamThread.State.CREATED); EasyMock.expect(thread.state()).andAnswer(state::get).anyTimes(); @@ -321,7 +321,7 @@ public class KafkaStreamsTest { }).anyTimes(); EasyMock.expect(thread.getGroupInstanceID()).andStubReturn(Optional.empty()); EasyMock.expect(thread.threadMetadata()).andReturn(new ThreadMetadata( - "newThead", + "processId-StreamThread-" + threadId, "DEAD", "", "", @@ -337,7 +337,7 @@ public class KafkaStreamsTest { EasyMock.expectLastCall().anyTimes(); thread.requestLeaveGroupDuringShutdown(); EasyMock.expectLastCall().anyTimes(); - EasyMock.expect(thread.getName()).andStubReturn("newThread"); + EasyMock.expect(thread.getName()).andStubReturn("processId-StreamThread-" + threadId); thread.shutdown(); EasyMock.expectLastCall().andAnswer(() -> { supplier.consumer.close(); @@ -564,7 +564,7 @@ public class KafkaStreamsTest { streams.start(); final int oldSize = streams.threads.size(); TestUtils.waitForCondition(() -> streams.state() == KafkaStreams.State.RUNNING, 15L, "wait until running"); - assertThat(streams.addStreamThread(), equalTo(Optional.of("newThread"))); + assertThat(streams.addStreamThread(), equalTo(Optional.of("processId-StreamThread-" + 2))); assertThat(streams.threads.size(), equalTo(oldSize + 1)); } @@ -613,7 +613,7 @@ public class KafkaStreamsTest { final int oldSize = streams.threads.size(); TestUtils.waitForCondition(() -> streams.state() == KafkaStreams.State.RUNNING, 15L, "Kafka Streams client did not reach state RUNNING"); - assertThat(streams.removeStreamThread(), equalTo(Optional.of("newThread"))); + assertThat(streams.removeStreamThread(), equalTo(Optional.of("processId-StreamThread-" + 1))); assertThat(streams.threads.size(), equalTo(oldSize - 1)); }