KAFKA-18943: Kafka Streams incorrectly commits TX during task revokation (#19164)

Fixes two issues:
 - only commit TX if no revoked tasks need to be committed
 - commit revoked tasks after punctuation triggered

Reviewers: Lucas Brutschy <lbrutschy@confluent.io>, Anna Sophie Blee-Goldman <sophie@responsive.dev>, Bruno Cadonna <bruno@confluent.io>, Bill Bejeck <bill@confluent.io>
This commit is contained in:
Matthias J. Sax 2025-03-13 09:37:11 -07:00
parent bf9912a1eb
commit 90ee2d2b34
5 changed files with 502 additions and 26 deletions

View File

@ -1077,12 +1077,16 @@ public class TaskManager {
final Set<TaskId> lockedTaskIds = activeRunningTaskIterable().stream().map(Task::id).collect(Collectors.toSet()); final Set<TaskId> lockedTaskIds = activeRunningTaskIterable().stream().map(Task::id).collect(Collectors.toSet());
maybeLockTasks(lockedTaskIds); maybeLockTasks(lockedTaskIds);
boolean revokedTasksNeedCommit = false;
for (final Task task : activeRunningTaskIterable()) { for (final Task task : activeRunningTaskIterable()) {
if (remainingRevokedPartitions.containsAll(task.inputPartitions())) { if (remainingRevokedPartitions.containsAll(task.inputPartitions())) {
// when the task input partitions are included in the revoked list, // when the task input partitions are included in the revoked list,
// this is an active task and should be revoked // this is an active task and should be revoked
revokedActiveTasks.add(task); revokedActiveTasks.add(task);
remainingRevokedPartitions.removeAll(task.inputPartitions()); remainingRevokedPartitions.removeAll(task.inputPartitions());
revokedTasksNeedCommit |= task.commitNeeded();
} else if (task.commitNeeded()) { } else if (task.commitNeeded()) {
commitNeededActiveTasks.add(task); commitNeededActiveTasks.add(task);
} }
@ -1096,11 +1100,9 @@ public class TaskManager {
"have been cleaned up by the handleAssignment callback.", remainingRevokedPartitions); "have been cleaned up by the handleAssignment callback.", remainingRevokedPartitions);
} }
if (revokedTasksNeedCommit) {
prepareCommitAndAddOffsetsToMap(revokedActiveTasks, consumedOffsetsPerTask); prepareCommitAndAddOffsetsToMap(revokedActiveTasks, consumedOffsetsPerTask);
// if we need to commit any revoking task then we just commit all of those needed committing together // if we need to commit any revoking task then we just commit all of those needed committing together
final boolean shouldCommitAdditionalTasks = !consumedOffsetsPerTask.isEmpty();
if (shouldCommitAdditionalTasks) {
prepareCommitAndAddOffsetsToMap(commitNeededActiveTasks, consumedOffsetsPerTask); prepareCommitAndAddOffsetsToMap(commitNeededActiveTasks, consumedOffsetsPerTask);
} }
@ -1109,10 +1111,12 @@ public class TaskManager {
// as such we just need to skip those dirty tasks in the checkpoint // as such we just need to skip those dirty tasks in the checkpoint
final Set<Task> dirtyTasks = new HashSet<>(); final Set<Task> dirtyTasks = new HashSet<>();
try { try {
if (revokedTasksNeedCommit) {
// in handleRevocation we must call commitOffsetsOrTransaction() directly rather than // in handleRevocation we must call commitOffsetsOrTransaction() directly rather than
// commitAndFillInConsumedOffsetsAndMetadataPerTaskMap() to make sure we don't skip the // commitAndFillInConsumedOffsetsAndMetadataPerTaskMap() to make sure we don't skip the
// offset commit because we are in a rebalance // offset commit because we are in a rebalance
taskExecutor.commitOffsetsOrTransaction(consumedOffsetsPerTask); taskExecutor.commitOffsetsOrTransaction(consumedOffsetsPerTask);
}
} catch (final TaskCorruptedException e) { } catch (final TaskCorruptedException e) {
log.warn("Some tasks were corrupted when trying to commit offsets, these will be cleaned and revived: {}", log.warn("Some tasks were corrupted when trying to commit offsets, these will be cleaned and revived: {}",
e.corruptedTasks()); e.corruptedTasks());
@ -1145,7 +1149,7 @@ public class TaskManager {
} }
} }
if (shouldCommitAdditionalTasks) { if (revokedTasksNeedCommit) {
for (final Task task : commitNeededActiveTasks) { for (final Task task : commitNeededActiveTasks) {
if (!dirtyTasks.contains(task)) { if (!dirtyTasks.contains(task)) {
try { try {

View File

@ -22,10 +22,13 @@ import org.apache.kafka.clients.consumer.Consumer;
import org.apache.kafka.clients.consumer.ConsumerConfig; import org.apache.kafka.clients.consumer.ConsumerConfig;
import org.apache.kafka.clients.consumer.ConsumerRecord; import org.apache.kafka.clients.consumer.ConsumerRecord;
import org.apache.kafka.clients.consumer.KafkaConsumer; import org.apache.kafka.clients.consumer.KafkaConsumer;
import org.apache.kafka.clients.producer.KafkaProducer;
import org.apache.kafka.clients.producer.Producer;
import org.apache.kafka.clients.producer.ProducerConfig; import org.apache.kafka.clients.producer.ProducerConfig;
import org.apache.kafka.common.IsolationLevel; import org.apache.kafka.common.IsolationLevel;
import org.apache.kafka.common.TopicPartition; import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.serialization.ByteArrayDeserializer; import org.apache.kafka.common.serialization.ByteArrayDeserializer;
import org.apache.kafka.common.serialization.ByteArraySerializer;
import org.apache.kafka.common.serialization.Deserializer; import org.apache.kafka.common.serialization.Deserializer;
import org.apache.kafka.common.serialization.IntegerDeserializer; import org.apache.kafka.common.serialization.IntegerDeserializer;
import org.apache.kafka.common.serialization.IntegerSerializer; import org.apache.kafka.common.serialization.IntegerSerializer;
@ -49,8 +52,10 @@ import org.apache.kafka.streams.kstream.TransformerSupplier;
import org.apache.kafka.streams.processor.ProcessorContext; import org.apache.kafka.streams.processor.ProcessorContext;
import org.apache.kafka.streams.processor.StateRestoreListener; import org.apache.kafka.streams.processor.StateRestoreListener;
import org.apache.kafka.streams.processor.TaskId; import org.apache.kafka.streams.processor.TaskId;
import org.apache.kafka.streams.processor.api.ContextualProcessor;
import org.apache.kafka.streams.processor.api.Processor; import org.apache.kafka.streams.processor.api.Processor;
import org.apache.kafka.streams.processor.api.Record; import org.apache.kafka.streams.processor.api.Record;
import org.apache.kafka.streams.processor.internals.DefaultKafkaClientSupplier;
import org.apache.kafka.streams.processor.internals.StreamThread; import org.apache.kafka.streams.processor.internals.StreamThread;
import org.apache.kafka.streams.query.QueryResult; import org.apache.kafka.streams.query.QueryResult;
import org.apache.kafka.streams.query.RangeQuery; import org.apache.kafka.streams.query.RangeQuery;
@ -68,6 +73,7 @@ import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout; import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource; import org.junit.jupiter.params.provider.CsvSource;
@ -80,6 +86,7 @@ import java.io.IOException;
import java.nio.file.Paths; import java.nio.file.Paths;
import java.time.Duration; import java.time.Duration;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
@ -124,7 +131,10 @@ public class EosIntegrationTest {
public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster( public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(
NUM_BROKERS, NUM_BROKERS,
Utils.mkProperties(Collections.singletonMap("auto.create.topics.enable", "true")) Utils.mkProperties(mkMap(
mkEntry("auto.create.topics.enable", "true"),
mkEntry("transaction.max.timeout.ms", "" + Integer.MAX_VALUE)
))
); );
@BeforeAll @BeforeAll
@ -939,6 +949,7 @@ public class EosIntegrationTest {
final String storeName, final String storeName,
final long startingOffset, final long startingOffset,
final long endingOffset) {} final long endingOffset) {}
@Override @Override
public void onBatchRestored(final TopicPartition topicPartition, public void onBatchRestored(final TopicPartition topicPartition,
final String storeName, final String storeName,
@ -951,6 +962,7 @@ public class EosIntegrationTest {
} }
} }
} }
@Override @Override
public void onRestoreEnd(final TopicPartition topicPartition, public void onRestoreEnd(final TopicPartition topicPartition,
final String storeName, final String storeName,
@ -960,9 +972,7 @@ public class EosIntegrationTest {
ensureCommittedRecordsInTopicPartition( ensureCommittedRecordsInTopicPartition(
applicationId + "-" + stateStoreName + "-changelog", applicationId + "-" + stateStoreName + "-changelog",
partitionToVerify, partitionToVerify,
2000, 2000
IntegerDeserializer.class,
IntegerDeserializer.class
); );
throwException.set(true); throwException.set(true);
final List<KeyValue<Integer, Integer>> recordBatch2 = IntStream.range(endKey - 1000, endKey).mapToObj(i -> KeyValue.pair(i, 0)).collect(Collectors.toList()); final List<KeyValue<Integer, Integer>> recordBatch2 = IntStream.range(endKey - 1000, endKey).mapToObj(i -> KeyValue.pair(i, 0)).collect(Collectors.toList());
@ -990,6 +1000,129 @@ public class EosIntegrationTest {
); );
} }
private final AtomicReference<String> transactionalProducerId = new AtomicReference<>();
private class TestClientSupplier extends DefaultKafkaClientSupplier {
@Override
public Producer<byte[], byte[]> getProducer(final Map<String, Object> config) {
transactionalProducerId.compareAndSet(null, (String) config.get(ProducerConfig.TRANSACTIONAL_ID_CONFIG));
return new KafkaProducer<>(config, new ByteArraySerializer(), new ByteArraySerializer());
}
}
static final AtomicReference<TaskId> TASK_WITH_DATA = new AtomicReference<>();
static final AtomicBoolean DID_REVOKE_IDLE_TASK = new AtomicBoolean(false);
@Test
public void shouldNotCommitActiveTasksWithPendingInputIfRevokedTaskDidNotMakeProgress() throws Exception {
final AtomicBoolean requestCommit = new AtomicBoolean(false);
final StreamsBuilder builder = new StreamsBuilder();
builder.<Long, Long>stream(MULTI_PARTITION_INPUT_TOPIC)
.process(() -> new ContextualProcessor<Long, Long, Long, Long>() {
@Override
public void process(final Record<Long, Long> record) {
if (!requestCommit.get()) {
if (TASK_WITH_DATA.get() != null) {
throw new IllegalStateException("Should only process single record using single task");
}
TASK_WITH_DATA.set(context().taskId());
}
context().forward(record.withValue(context().recordMetadata().get().offset()));
if (requestCommit.get()) {
context().commit();
}
}
})
.to(SINGLE_PARTITION_OUTPUT_TOPIC);
final Properties properties = new Properties();
properties.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, StreamsConfig.EXACTLY_ONCE_V2);
properties.put(StreamsConfig.STATESTORE_CACHE_MAX_BYTES_CONFIG, 0);
properties.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, Integer.MAX_VALUE);
properties.put(StreamsConfig.consumerPrefix(ConsumerConfig.MAX_POLL_RECORDS_CONFIG), 1);
properties.put(StreamsConfig.consumerPrefix(ConsumerConfig.METADATA_MAX_AGE_CONFIG), "1000");
properties.put(StreamsConfig.consumerPrefix(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG), "earliest");
properties.put(StreamsConfig.consumerPrefix(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG), MAX_POLL_INTERVAL_MS - 1);
properties.put(StreamsConfig.consumerPrefix(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG), MAX_POLL_INTERVAL_MS);
properties.put(StreamsConfig.producerPrefix(ProducerConfig.TRANSACTION_TIMEOUT_CONFIG), Integer.MAX_VALUE);
properties.put(StreamsConfig.TASK_ASSIGNOR_CLASS_CONFIG, TestTaskAssignor.class.getName());
final Properties config = StreamsTestUtils.getStreamsConfig(
applicationId,
CLUSTER.bootstrapServers(),
Serdes.LongSerde.class.getName(),
Serdes.LongSerde.class.getName(),
properties
);
try (final KafkaStreams streams = new KafkaStreams(builder.build(), config, new TestClientSupplier())) {
startApplicationAndWaitUntilRunning(streams);
// PHASE 1:
// write single input record, and wait for it to get into output topic (uncommitted)
// StreamThread-1 now has a task with progress, and one task w/o progress
final List<KeyValue<Long, Long>> inputDataTask0 = Collections.singletonList(KeyValue.pair(1L, -1L));
IntegrationTestUtils.produceKeyValuesSynchronously(
MULTI_PARTITION_INPUT_TOPIC,
inputDataTask0,
TestUtils.producerConfig(CLUSTER.bootstrapServers(), LongSerializer.class, LongSerializer.class),
CLUSTER.time
);
final List<KeyValue<Long, Long>> expectedUncommittedResultTask0 = Collections.singletonList(KeyValue.pair(1L, 0L));
final List<KeyValue<Long, Long>> uncommittedRecordsBeforeRebalance = readResult(SINGLE_PARTITION_OUTPUT_TOPIC, expectedUncommittedResultTask0.size(), null);
checkResultPerKey(uncommittedRecordsBeforeRebalance, expectedUncommittedResultTask0, "The uncommitted records do not match what expected");
// PHASE 2:
// add second thread, to trigger rebalance
// expect idle task to get revoked -- this should not trigger a TX commit
streams.addStreamThread();
waitForCondition(DID_REVOKE_IDLE_TASK::get, "Idle Task was not revoked as expected.");
// best-effort sanity check (might pass and not detect issue in slow environments)
try {
readResult(SINGLE_PARTITION_OUTPUT_TOPIC, 1, "consumer", 10_000L);
throw new Exception("Should not be able to read records, as they should have not been committed.");
} catch (final AssertionError expected) {
// swallow -- we expect to not be able to read uncommitted data, but time-out
}
// PHASE 3:
// fence producer to abort pending TX of first input record
// expect rebalancing and recovery until both input record are processed
requestCommit.set(true);
// produce into input topic to fence KS producer
final List<KeyValue<Long, Long>> inputDataTask0Fencing = Collections.singletonList(KeyValue.pair(4L, -3L));
final Properties producerConfigs = new Properties();
producerConfigs.setProperty(ProducerConfig.TRANSACTIONAL_ID_CONFIG, transactionalProducerId.get());
IntegrationTestUtils.produceKeyValuesSynchronously(
MULTI_PARTITION_INPUT_TOPIC,
inputDataTask0Fencing,
TestUtils.producerConfig(CLUSTER.bootstrapServers(), LongSerializer.class, LongSerializer.class, producerConfigs),
CLUSTER.time,
true
);
final List<KeyValue<Long, Long>> expectedUncommittedResultAfterError = Arrays.asList(KeyValue.pair(1L, 0L), KeyValue.pair(1L, 0L), KeyValue.pair(4L, 1L));
final List<KeyValue<Long, Long>> uncommittedRecordsAfterError = readResult(SINGLE_PARTITION_OUTPUT_TOPIC, expectedUncommittedResultAfterError.size(), null);
checkResultPerKey(uncommittedRecordsAfterError, expectedUncommittedResultAfterError, "The committed records do not match what expected");
}
final List<KeyValue<Long, Long>> expectedFinalResult = Arrays.asList(KeyValue.pair(1L, 0L), KeyValue.pair(4L, 1L));
final List<KeyValue<Long, Long>> finalResult = readResult(SINGLE_PARTITION_OUTPUT_TOPIC, 2, "committed-only-consumer");
checkResultPerKey(finalResult, expectedFinalResult, "The committed records do not match what expected");
}
private void verifyOffsetsAreInCheckpoint(final int partition) throws IOException { private void verifyOffsetsAreInCheckpoint(final int partition) throws IOException {
final String stateStoreDir = stateTmpDir + File.separator + "appDir" + File.separator + applicationId + File.separator + "0_" + partition + File.separator; final String stateStoreDir = stateTmpDir + File.separator + "appDir" + File.separator + applicationId + File.separator + "0_" + partition + File.separator;
@ -1004,8 +1137,8 @@ public class EosIntegrationTest {
KafkaConsumer<String, String> consumer = new KafkaConsumer<>( KafkaConsumer<String, String> consumer = new KafkaConsumer<>(
consumerConfig( consumerConfig(
CLUSTER.bootstrapServers(), CLUSTER.bootstrapServers(),
Serdes.ByteArray().deserializer().getClass(), ByteArrayDeserializer.class,
Serdes.ByteArray().deserializer().getClass() ByteArrayDeserializer.class
) )
) )
) { ) {
@ -1202,14 +1335,22 @@ public class EosIntegrationTest {
private List<KeyValue<Long, Long>> readResult(final String topic, private List<KeyValue<Long, Long>> readResult(final String topic,
final int numberOfRecords, final int numberOfRecords,
final String groupId) throws Exception { final String groupId) throws Exception {
return readResult(topic, numberOfRecords, LongDeserializer.class, LongDeserializer.class, groupId); return readResult(topic, numberOfRecords, LongDeserializer.class, LongDeserializer.class, groupId, DEFAULT_TIMEOUT);
}
private List<KeyValue<Long, Long>> readResult(final String topic,
final int numberOfRecords,
final String groupId,
final long timeout) throws Exception {
return readResult(topic, numberOfRecords, LongDeserializer.class, LongDeserializer.class, groupId, timeout);
} }
private <K, V> List<KeyValue<K, V>> readResult(final String topic, private <K, V> List<KeyValue<K, V>> readResult(final String topic,
final int numberOfRecords, final int numberOfRecords,
final Class<? extends Deserializer<K>> keyDeserializer, final Class<? extends Deserializer<K>> keyDeserializer,
final Class<? extends Deserializer<V>> valueDeserializer, final Class<? extends Deserializer<V>> valueDeserializer,
final String groupId) throws Exception { final String groupId,
final long timeout) throws Exception {
if (groupId != null) { if (groupId != null) {
return IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived( return IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived(
TestUtils.consumerConfig( TestUtils.consumerConfig(
@ -1221,7 +1362,8 @@ public class EosIntegrationTest {
ConsumerConfig.ISOLATION_LEVEL_CONFIG, ConsumerConfig.ISOLATION_LEVEL_CONFIG,
IsolationLevel.READ_COMMITTED.toString()))), IsolationLevel.READ_COMMITTED.toString()))),
topic, topic,
numberOfRecords numberOfRecords,
timeout
); );
} }
@ -1229,15 +1371,14 @@ public class EosIntegrationTest {
return IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived( return IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived(
TestUtils.consumerConfig(CLUSTER.bootstrapServers(), keyDeserializer, valueDeserializer), TestUtils.consumerConfig(CLUSTER.bootstrapServers(), keyDeserializer, valueDeserializer),
topic, topic,
numberOfRecords numberOfRecords,
timeout
); );
} }
private <K, V> void ensureCommittedRecordsInTopicPartition(final String topic, private <K, V> void ensureCommittedRecordsInTopicPartition(final String topic,
final int partition, final int partition,
final int numberOfRecords, final int numberOfRecords) throws Exception {
final Class<? extends Deserializer<K>> keyDeserializer,
final Class<? extends Deserializer<V>> valueDeserializer) throws Exception {
final long timeoutMs = 2 * DEFAULT_TIMEOUT; final long timeoutMs = 2 * DEFAULT_TIMEOUT;
final int maxTries = 10; final int maxTries = 10;
final long deadline = System.currentTimeMillis() + timeoutMs; final long deadline = System.currentTimeMillis() + timeoutMs;
@ -1247,8 +1388,8 @@ public class EosIntegrationTest {
TestUtils.consumerConfig( TestUtils.consumerConfig(
CLUSTER.bootstrapServers(), CLUSTER.bootstrapServers(),
CONSUMER_GROUP_ID, CONSUMER_GROUP_ID,
keyDeserializer, IntegerDeserializer.class,
valueDeserializer, IntegerDeserializer.class,
Utils.mkProperties(Collections.singletonMap( Utils.mkProperties(Collections.singletonMap(
ConsumerConfig.ISOLATION_LEVEL_CONFIG, ConsumerConfig.ISOLATION_LEVEL_CONFIG,
IsolationLevel.READ_COMMITTED.toString()) IsolationLevel.READ_COMMITTED.toString())

View File

@ -0,0 +1,221 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.streams.integration;
import org.apache.kafka.clients.consumer.ConsumerConfig;
import org.apache.kafka.clients.producer.ProducerConfig;
import org.apache.kafka.common.serialization.LongDeserializer;
import org.apache.kafka.common.serialization.Serdes;
import org.apache.kafka.common.utils.Utils;
import org.apache.kafka.streams.KafkaStreams;
import org.apache.kafka.streams.KeyValue;
import org.apache.kafka.streams.StreamsBuilder;
import org.apache.kafka.streams.StreamsConfig;
import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster;
import org.apache.kafka.streams.integration.utils.IntegrationTestUtils;
import org.apache.kafka.streams.processor.Cancellable;
import org.apache.kafka.streams.processor.PunctuationType;
import org.apache.kafka.streams.processor.api.Processor;
import org.apache.kafka.streams.processor.api.ProcessorContext;
import org.apache.kafka.streams.processor.api.Record;
import org.apache.kafka.test.StreamsTestUtils;
import org.apache.kafka.test.TestUtils;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Properties;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import static org.apache.kafka.common.utils.Utils.mkEntry;
import static org.apache.kafka.common.utils.Utils.mkMap;
import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.startApplicationAndWaitUntilRunning;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.MatcherAssert.assertThat;
@Tag("integration")
@Timeout(600)
public class RebalanceIntegrationTest {
private static final Logger LOG = LoggerFactory.getLogger(RebalanceIntegrationTest.class);
private static final int NUM_BROKERS = 3;
private static final int MAX_POLL_INTERVAL_MS = 30_000;
public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(
NUM_BROKERS,
Utils.mkProperties(mkMap(
mkEntry("auto.create.topics.enable", "true"),
mkEntry("transaction.max.timeout.ms", "" + Integer.MAX_VALUE)
))
);
@BeforeAll
public static void startCluster() throws IOException {
CLUSTER.start();
}
@AfterAll
public static void closeCluster() {
CLUSTER.stop();
}
private String applicationId;
private static final int NUM_TOPIC_PARTITIONS = 2;
private static final String MULTI_PARTITION_INPUT_TOPIC = "multiPartitionInputTopic";
private static final String SINGLE_PARTITION_OUTPUT_TOPIC = "singlePartitionOutputTopic";
private static final AtomicInteger TEST_NUMBER = new AtomicInteger(0);
@BeforeEach
public void createTopics() throws Exception {
applicationId = "appId-" + TEST_NUMBER.getAndIncrement();
CLUSTER.deleteTopics(MULTI_PARTITION_INPUT_TOPIC, SINGLE_PARTITION_OUTPUT_TOPIC);
CLUSTER.createTopics(SINGLE_PARTITION_OUTPUT_TOPIC);
CLUSTER.createTopic(MULTI_PARTITION_INPUT_TOPIC, NUM_TOPIC_PARTITIONS, 1);
}
private void checkResultPerKey(final List<KeyValue<Long, Long>> result,
final List<KeyValue<Long, Long>> expectedResult) {
final Set<Long> allKeys = new HashSet<>();
addAllKeys(allKeys, result);
addAllKeys(allKeys, expectedResult);
for (final Long key : allKeys) {
assertThat("The records do not match what expected", getAllRecordPerKey(key, result), equalTo(getAllRecordPerKey(key, expectedResult)));
}
}
private void addAllKeys(final Set<Long> allKeys, final List<KeyValue<Long, Long>> records) {
for (final KeyValue<Long, Long> record : records) {
allKeys.add(record.key);
}
}
private List<KeyValue<Long, Long>> getAllRecordPerKey(final Long key, final List<KeyValue<Long, Long>> records) {
final List<KeyValue<Long, Long>> recordsPerKey = new ArrayList<>(records.size());
for (final KeyValue<Long, Long> record : records) {
if (record.key.equals(key)) {
recordsPerKey.add(record);
}
}
return recordsPerKey;
}
@Test
public void shouldCommitAllTasksIfRevokedTaskTriggerPunctuation() throws Exception {
final AtomicBoolean requestCommit = new AtomicBoolean(false);
final StreamsBuilder builder = new StreamsBuilder();
builder.<Long, Long>stream(MULTI_PARTITION_INPUT_TOPIC)
.process(() -> new Processor<Long, Long, Long, Long>() {
ProcessorContext<Long, Long> context;
@Override
public void init(final ProcessorContext<Long, Long> context) {
this.context = context;
final AtomicReference<Cancellable> cancellable = new AtomicReference<>();
cancellable.set(context.schedule(
Duration.ofSeconds(1),
PunctuationType.WALL_CLOCK_TIME,
time -> {
context.forward(new Record<>(
(context.taskId().partition() + 1) * 100L,
-(context.taskId().partition() + 1L),
context.currentSystemTimeMs()));
cancellable.get().cancel();
}
));
}
@Override
public void process(final Record<Long, Long> record) {
context.forward(record.withValue(context.recordMetadata().get().offset()));
if (requestCommit.get()) {
context.commit();
}
}
})
.to(SINGLE_PARTITION_OUTPUT_TOPIC);
final Properties properties = new Properties();
properties.put(StreamsConfig.STATESTORE_CACHE_MAX_BYTES_CONFIG, 0);
properties.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, Integer.MAX_VALUE);
properties.put(StreamsConfig.consumerPrefix(ConsumerConfig.MAX_POLL_RECORDS_CONFIG), 1);
properties.put(StreamsConfig.consumerPrefix(ConsumerConfig.METADATA_MAX_AGE_CONFIG), "1000");
properties.put(StreamsConfig.consumerPrefix(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG), "earliest");
properties.put(StreamsConfig.consumerPrefix(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG), MAX_POLL_INTERVAL_MS - 1);
properties.put(StreamsConfig.consumerPrefix(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG), MAX_POLL_INTERVAL_MS);
properties.put(StreamsConfig.producerPrefix(ProducerConfig.TRANSACTION_TIMEOUT_CONFIG), Integer.MAX_VALUE);
properties.put(StreamsConfig.TASK_ASSIGNOR_CLASS_CONFIG, TestTaskAssignor.class.getName());
final Properties config = StreamsTestUtils.getStreamsConfig(
applicationId,
CLUSTER.bootstrapServers(),
Serdes.LongSerde.class.getName(),
Serdes.LongSerde.class.getName(),
properties
);
try (final KafkaStreams streams = new KafkaStreams(builder.build(), config)) {
startApplicationAndWaitUntilRunning(streams);
// PHASE 1:
// produce single output record via punctuation (uncommitted) [this happens for both tasks]
// StreamThread-1 now has a task with progress, and one task w/o progress
final List<KeyValue<Long, Long>> expectedUncommittedResultBeforeRebalance = Arrays.asList(KeyValue.pair(100L, -1L), KeyValue.pair(200L, -2L));
final List<KeyValue<Long, Long>> uncommittedRecordsBeforeRebalance = IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived(
TestUtils.consumerConfig(CLUSTER.bootstrapServers(), LongDeserializer.class, LongDeserializer.class),
SINGLE_PARTITION_OUTPUT_TOPIC,
expectedUncommittedResultBeforeRebalance.size()
);
checkResultPerKey(uncommittedRecordsBeforeRebalance, expectedUncommittedResultBeforeRebalance);
// PHASE 2:
// add second thread, to trigger rebalance
// both task should get committed
streams.addStreamThread();
final List<KeyValue<Long, Long>> expectedUncommittedResultAfterRebalance = Arrays.asList(KeyValue.pair(100L, -1L), KeyValue.pair(200L, -2L));
final List<KeyValue<Long, Long>> uncommittedRecordsAfterRebalance = IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived(
TestUtils.consumerConfig(CLUSTER.bootstrapServers(), LongDeserializer.class, LongDeserializer.class),
SINGLE_PARTITION_OUTPUT_TOPIC,
expectedUncommittedResultAfterRebalance.size()
);
checkResultPerKey(uncommittedRecordsAfterRebalance, expectedUncommittedResultAfterRebalance);
}
}
}

View File

@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.streams.integration;
import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor;
import org.apache.kafka.streams.processor.TaskId;
import org.apache.kafka.streams.processor.assignment.assignors.StickyTaskAssignor;
public class TestTaskAssignor extends StickyTaskAssignor {
@Override
public void onAssignmentComputed(final ConsumerPartitionAssignor.GroupAssignment assignment,
final ConsumerPartitionAssignor.GroupSubscription subscription,
final AssignmentError error) {
if (assignment.groupAssignment().size() == 1) {
return;
}
for (final String threadName : assignment.groupAssignment().keySet()) {
if (threadName.contains("-StreamThread-1-")) {
final TaskId taskWithData = EosIntegrationTest.TASK_WITH_DATA.get();
if (taskWithData != null && taskWithData.partition() == assignment.groupAssignment().get(threadName).partitions().get(0).partition()) {
EosIntegrationTest.DID_REVOKE_IDLE_TASK.set(true);
}
}
}
}
}

View File

@ -3099,7 +3099,7 @@ public class TaskManagerTest {
assertThat(task00.commitNeeded, is(false)); assertThat(task00.commitNeeded, is(false));
assertThat(task00.commitPrepared, is(true)); assertThat(task00.commitPrepared, is(true));
assertThat(task00.commitNeeded, is(false)); assertThat(task01.commitNeeded, is(false));
assertThat(task01.commitPrepared, is(true)); assertThat(task01.commitPrepared, is(true));
assertThat(task02.commitPrepared, is(false)); assertThat(task02.commitPrepared, is(false));
assertThat(task10.commitPrepared, is(false)); assertThat(task10.commitPrepared, is(false));
@ -3107,6 +3107,74 @@ public class TaskManagerTest {
verify(consumer).commitSync(expectedCommittedOffsets); verify(consumer).commitSync(expectedCommittedOffsets);
} }
@Test
public void shouldNotCommitIfNoRevokedTasksNeedCommitting() {
final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true, stateManager);
final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, true, stateManager);
task01.setCommitNeeded();
final StateMachineTask task02 = new StateMachineTask(taskId02, taskId02Partitions, true, stateManager);
final Map<TaskId, Set<TopicPartition>> assignmentActive = mkMap(
mkEntry(taskId00, taskId00Partitions),
mkEntry(taskId01, taskId01Partitions),
mkEntry(taskId02, taskId02Partitions)
);
when(consumer.assignment()).thenReturn(assignment);
when(activeTaskCreator.createTasks(any(), eq(assignmentActive)))
.thenReturn(asList(task00, task01, task02));
taskManager.handleAssignment(assignmentActive, Collections.emptyMap());
assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true));
assertThat(task00.state(), is(Task.State.RUNNING));
assertThat(task01.state(), is(Task.State.RUNNING));
assertThat(task02.state(), is(Task.State.RUNNING));
taskManager.handleRevocation(taskId00Partitions);
assertThat(task00.commitPrepared, is(false));
assertThat(task01.commitPrepared, is(false));
assertThat(task02.commitPrepared, is(false));
}
@Test
public void shouldNotCommitIfNoRevokedTasksNeedCommittingWithEOSv2() {
final TaskManager taskManager = setUpTaskManager(ProcessingMode.EXACTLY_ONCE_V2, false);
final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true, stateManager);
final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, true, stateManager);
task01.setCommitNeeded();
final StateMachineTask task02 = new StateMachineTask(taskId02, taskId02Partitions, true, stateManager);
final Map<TaskId, Set<TopicPartition>> assignmentActive = mkMap(
mkEntry(taskId00, taskId00Partitions),
mkEntry(taskId01, taskId01Partitions),
mkEntry(taskId02, taskId02Partitions)
);
when(consumer.assignment()).thenReturn(assignment);
when(activeTaskCreator.createTasks(any(), eq(assignmentActive)))
.thenReturn(asList(task00, task01, task02));
taskManager.handleAssignment(assignmentActive, Collections.emptyMap());
assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true));
assertThat(task00.state(), is(Task.State.RUNNING));
assertThat(task01.state(), is(Task.State.RUNNING));
assertThat(task02.state(), is(Task.State.RUNNING));
taskManager.handleRevocation(taskId00Partitions);
assertThat(task00.commitPrepared, is(false));
assertThat(task01.commitPrepared, is(false));
assertThat(task02.commitPrepared, is(false));
}
@Test @Test
public void shouldNotCommitOnHandleAssignmentIfNoTaskClosed() { public void shouldNotCommitOnHandleAssignmentIfNoTaskClosed() {
final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true, stateManager); final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true, stateManager);