KAFKA-18962: Fix onBatchRestored call in GlobalStateManagerImpl (#19188)

Call the StateRestoreListener#onBatchRestored with numRestored and not
the totalRestored when reprocessing state

See: https://issues.apache.org/jira/browse/KAFKA-18962

Reviewers: Anna Sophie Blee-Goldman <ableegoldman@apache.org>, Matthias
Sax <mjsax@apache.org>
This commit is contained in:
Florian Hussonnois 2025-04-09 22:17:38 +02:00 committed by A. Sophie Blee-Goldman
parent d1b381a185
commit 5c2ca4bd58
5 changed files with 89 additions and 11 deletions

View File

@ -103,7 +103,7 @@
files="(AbstractRequest|AbstractResponse|KerberosLogin|WorkerSinkTaskTest|TransactionManagerTest|SenderTest|KafkaAdminClient|ConsumerCoordinatorTest|KafkaAdminClientTest|KafkaRaftClientTest).java"/> files="(AbstractRequest|AbstractResponse|KerberosLogin|WorkerSinkTaskTest|TransactionManagerTest|SenderTest|KafkaAdminClient|ConsumerCoordinatorTest|KafkaAdminClientTest|KafkaRaftClientTest).java"/>
<suppress checks="NPathComplexity" <suppress checks="NPathComplexity"
files="(ConsumerCoordinator|BufferPool|MetricName|Node|ConfigDef|RecordBatch|SslFactory|SslTransportLayer|MetadataResponse|KerberosLogin|Selector|Sender|Serdes|TokenInformation|Agent|PluginUtils|MiniTrogdorCluster|TasksRequest|KafkaProducer|AbstractStickyAssignor|KafkaRaftClient|Authorizer|FetchSessionHandler|RecordAccumulator|Shell).java"/> files="(AbstractMembershipManager|ConsumerCoordinator|BufferPool|MetricName|Node|ConfigDef|RecordBatch|SslFactory|SslTransportLayer|MetadataResponse|KerberosLogin|Selector|Sender|Serdes|TokenInformation|Agent|PluginUtils|MiniTrogdorCluster|TasksRequest|KafkaProducer|AbstractStickyAssignor|KafkaRaftClient|Authorizer|FetchSessionHandler|RecordAccumulator|Shell|MockConsumer).java"/>
<suppress checks="(JavaNCSS|CyclomaticComplexity|MethodLength)" <suppress checks="(JavaNCSS|CyclomaticComplexity|MethodLength)"
files="CoordinatorClient.java"/> files="CoordinatorClient.java"/>

View File

@ -34,6 +34,7 @@ import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -77,6 +78,8 @@ public class MockConsumer<K, V> implements Consumer<K, V> {
private Uuid clientInstanceId; private Uuid clientInstanceId;
private int injectTimeoutExceptionCounter; private int injectTimeoutExceptionCounter;
private long maxPollRecords = Long.MAX_VALUE;
public MockConsumer(OffsetResetStrategy offsetResetStrategy) { public MockConsumer(OffsetResetStrategy offsetResetStrategy) {
this.subscriptions = new SubscriptionState(new LogContext(), offsetResetStrategy); this.subscriptions = new SubscriptionState(new LogContext(), offsetResetStrategy);
this.partitions = new HashMap<>(); this.partitions = new HashMap<>();
@ -229,14 +232,22 @@ public class MockConsumer<K, V> implements Consumer<K, V> {
// update the consumed offset // update the consumed offset
final Map<TopicPartition, List<ConsumerRecord<K, V>>> results = new HashMap<>(); final Map<TopicPartition, List<ConsumerRecord<K, V>>> results = new HashMap<>();
final List<TopicPartition> toClear = new ArrayList<>(); long numPollRecords = 0L;
final Iterator<Map.Entry<TopicPartition, List<ConsumerRecord<K, V>>>> partitionsIter = this.records.entrySet().iterator();
while (partitionsIter.hasNext() && numPollRecords < this.maxPollRecords) {
Map.Entry<TopicPartition, List<ConsumerRecord<K, V>>> entry = partitionsIter.next();
for (Map.Entry<TopicPartition, List<ConsumerRecord<K, V>>> entry : this.records.entrySet()) {
if (!subscriptions.isPaused(entry.getKey())) { if (!subscriptions.isPaused(entry.getKey())) {
final List<ConsumerRecord<K, V>> recs = entry.getValue(); final Iterator<ConsumerRecord<K, V>> recIterator = entry.getValue().iterator();
for (final ConsumerRecord<K, V> rec : recs) { while (recIterator.hasNext()) {
if (numPollRecords >= this.maxPollRecords) {
break;
}
long position = subscriptions.position(entry.getKey()).offset; long position = subscriptions.position(entry.getKey()).offset;
final ConsumerRecord<K, V> rec = recIterator.next();
if (beginningOffsets.get(entry.getKey()) != null && beginningOffsets.get(entry.getKey()) > position) { if (beginningOffsets.get(entry.getKey()) != null && beginningOffsets.get(entry.getKey()) > position) {
throw new OffsetOutOfRangeException(Collections.singletonMap(entry.getKey(), position)); throw new OffsetOutOfRangeException(Collections.singletonMap(entry.getKey(), position));
} }
@ -247,13 +258,18 @@ public class MockConsumer<K, V> implements Consumer<K, V> {
SubscriptionState.FetchPosition newPosition = new SubscriptionState.FetchPosition( SubscriptionState.FetchPosition newPosition = new SubscriptionState.FetchPosition(
rec.offset() + 1, rec.leaderEpoch(), leaderAndEpoch); rec.offset() + 1, rec.leaderEpoch(), leaderAndEpoch);
subscriptions.position(entry.getKey(), newPosition); subscriptions.position(entry.getKey(), newPosition);
}
} numPollRecords++;
toClear.add(entry.getKey()); recIterator.remove();
}
}
if (entry.getValue().isEmpty()) {
partitionsIter.remove();
}
} }
} }
toClear.forEach(records::remove);
return new ConsumerRecords<>(results); return new ConsumerRecords<>(results);
} }
@ -275,6 +291,17 @@ public class MockConsumer<K, V> implements Consumer<K, V> {
setPollException(exception); setPollException(exception);
} }
/* Sets the maximum number of records returned in a single call to {@link #poll(Duration)}.
*
* @param maxPollRecords the max.poll.records.
*/
public synchronized void setMaxPollRecords(long maxPollRecords) {
if (this.maxPollRecords < 1) {
throw new IllegalArgumentException("MaxPollRecords must be strictly superior to 0");
}
this.maxPollRecords = maxPollRecords;
}
public synchronized void setPollException(KafkaException exception) { public synchronized void setPollException(KafkaException exception) {
this.pollException = exception; this.pollException = exception;
} }

View File

@ -31,6 +31,7 @@ import java.util.HashMap;
import java.util.Iterator; import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.stream.IntStream;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
@ -190,4 +191,33 @@ public class MockConsumerTest {
assertTrue(revoked.contains(topicPartitionList.get(0))); assertTrue(revoked.contains(topicPartitionList.get(0)));
} }
@Test
public void shouldReturnMaxPollRecords() {
TopicPartition partition = new TopicPartition("test", 0);
consumer.assign(Collections.singleton(partition));
consumer.updateBeginningOffsets(Collections.singletonMap(partition, 0L));
IntStream.range(0, 10).forEach(offset -> {
consumer.addRecord(new ConsumerRecord<>("test", 0, offset, null, null));
});
consumer.setMaxPollRecords(2L);
ConsumerRecords<String, String> records;
records = consumer.poll(Duration.ofMillis(1));
assertEquals(2, records.count());
records = consumer.poll(Duration.ofMillis(1));
assertEquals(2, records.count());
consumer.setMaxPollRecords(Long.MAX_VALUE);
records = consumer.poll(Duration.ofMillis(1));
assertEquals(6, records.count());
records = consumer.poll(Duration.ofMillis(1));
assertTrue(records.isEmpty());
}
} }

View File

@ -300,6 +300,7 @@ public class GlobalStateManagerImpl implements GlobalStateManager {
currentDeadline = NO_DEADLINE; currentDeadline = NO_DEADLINE;
} }
long batchRestoreCount = 0;
for (final ConsumerRecord<byte[], byte[]> record : records.records(topicPartition)) { for (final ConsumerRecord<byte[], byte[]> record : records.records(topicPartition)) {
final ProcessorRecordContext recordContext = final ProcessorRecordContext recordContext =
new ProcessorRecordContext( new ProcessorRecordContext(
@ -318,6 +319,7 @@ public class GlobalStateManagerImpl implements GlobalStateManager {
record.timestamp(), record.timestamp(),
record.headers())); record.headers()));
restoreCount++; restoreCount++;
batchRestoreCount++;
} }
} catch (final Exception deserializationException) { } catch (final Exception deserializationException) {
// while Java distinguishes checked vs unchecked exceptions, other languages // while Java distinguishes checked vs unchecked exceptions, other languages
@ -341,7 +343,7 @@ public class GlobalStateManagerImpl implements GlobalStateManager {
offset = getGlobalConsumerOffset(topicPartition); offset = getGlobalConsumerOffset(topicPartition);
stateRestoreListener.onBatchRestored(topicPartition, storeName, offset, restoreCount); stateRestoreListener.onBatchRestored(topicPartition, storeName, offset, batchRestoreCount);
} }
stateRestoreListener.onRestoreEnd(topicPartition, storeName, restoreCount); stateRestoreListener.onRestoreEnd(topicPartition, storeName, restoreCount);
checkpointFileCache.put(topicPartition, offset); checkpointFileCache.put(topicPartition, offset);

View File

@ -356,16 +356,35 @@ public class GlobalStateManagerImplTest {
assertEquals(2, stateRestoreCallback.restored.size()); assertEquals(2, stateRestoreCallback.restored.size());
} }
@Test
public void shouldListenForRestoreEventsWhenReprocessing() {
setUpReprocessing();
initializeConsumer(6, 1, t1);
consumer.setMaxPollRecords(2L);
stateManager.initialize();
stateManager.registerStore(store1, stateRestoreCallback, null);
assertThat(stateRestoreListener.numBatchRestored, equalTo(2L));
assertThat(stateRestoreListener.restoreStartOffset, equalTo(1L));
assertThat(stateRestoreListener.restoreEndOffset, equalTo(7L));
assertThat(stateRestoreListener.totalNumRestored, equalTo(6L));
}
@Test @Test
public void shouldListenForRestoreEvents() { public void shouldListenForRestoreEvents() {
initializeConsumer(5, 1, t1); initializeConsumer(6, 1, t1);
consumer.setMaxPollRecords(2L);
stateManager.initialize(); stateManager.initialize();
stateManager.registerStore(store1, stateRestoreCallback, null); stateManager.registerStore(store1, stateRestoreCallback, null);
assertThat(stateRestoreListener.numBatchRestored, equalTo(2L));
assertThat(stateRestoreListener.restoreStartOffset, equalTo(1L)); assertThat(stateRestoreListener.restoreStartOffset, equalTo(1L));
assertThat(stateRestoreListener.restoreEndOffset, equalTo(6L)); assertThat(stateRestoreListener.restoreEndOffset, equalTo(7L));
assertThat(stateRestoreListener.totalNumRestored, equalTo(5L)); assertThat(stateRestoreListener.totalNumRestored, equalTo(6L));
assertThat(stateRestoreListener.storeNameCalledStates.get(RESTORE_START), equalTo(store1.name())); assertThat(stateRestoreListener.storeNameCalledStates.get(RESTORE_START), equalTo(store1.name()));