diff --git a/checkstyle/suppressions.xml b/checkstyle/suppressions.xml
index 3195493d9b1..a0dadd66382 100644
--- a/checkstyle/suppressions.xml
+++ b/checkstyle/suppressions.xml
@@ -103,7 +103,7 @@
files="(AbstractRequest|AbstractResponse|KerberosLogin|WorkerSinkTaskTest|TransactionManagerTest|SenderTest|KafkaAdminClient|ConsumerCoordinatorTest|KafkaAdminClientTest|KafkaRaftClientTest).java"/>
+ files="(AbstractMembershipManager|ConsumerCoordinator|BufferPool|MetricName|Node|ConfigDef|RecordBatch|SslFactory|SslTransportLayer|MetadataResponse|KerberosLogin|Selector|Sender|Serdes|TokenInformation|Agent|PluginUtils|MiniTrogdorCluster|TasksRequest|KafkaProducer|AbstractStickyAssignor|KafkaRaftClient|Authorizer|FetchSessionHandler|RecordAccumulator|Shell|MockConsumer).java"/>
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/MockConsumer.java b/clients/src/main/java/org/apache/kafka/clients/consumer/MockConsumer.java
index 600f8bbd07e..56e684e94c4 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/MockConsumer.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/MockConsumer.java
@@ -34,6 +34,7 @@ import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
+import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
@@ -77,6 +78,8 @@ public class MockConsumer implements Consumer {
private Uuid clientInstanceId;
private int injectTimeoutExceptionCounter;
+ private long maxPollRecords = Long.MAX_VALUE;
+
public MockConsumer(OffsetResetStrategy offsetResetStrategy) {
this.subscriptions = new SubscriptionState(new LogContext(), offsetResetStrategy);
this.partitions = new HashMap<>();
@@ -229,14 +232,22 @@ public class MockConsumer implements Consumer {
// update the consumed offset
final Map>> results = new HashMap<>();
- final List toClear = new ArrayList<>();
+ long numPollRecords = 0L;
+
+ final Iterator>>> partitionsIter = this.records.entrySet().iterator();
+ while (partitionsIter.hasNext() && numPollRecords < this.maxPollRecords) {
+ Map.Entry>> entry = partitionsIter.next();
- for (Map.Entry>> entry : this.records.entrySet()) {
if (!subscriptions.isPaused(entry.getKey())) {
- final List> recs = entry.getValue();
- for (final ConsumerRecord rec : recs) {
+ final Iterator> recIterator = entry.getValue().iterator();
+ while (recIterator.hasNext()) {
+ if (numPollRecords >= this.maxPollRecords) {
+ break;
+ }
long position = subscriptions.position(entry.getKey()).offset;
+ final ConsumerRecord rec = recIterator.next();
+
if (beginningOffsets.get(entry.getKey()) != null && beginningOffsets.get(entry.getKey()) > position) {
throw new OffsetOutOfRangeException(Collections.singletonMap(entry.getKey(), position));
}
@@ -247,13 +258,18 @@ public class MockConsumer implements Consumer {
SubscriptionState.FetchPosition newPosition = new SubscriptionState.FetchPosition(
rec.offset() + 1, rec.leaderEpoch(), leaderAndEpoch);
subscriptions.position(entry.getKey(), newPosition);
+
+ numPollRecords++;
+ recIterator.remove();
}
}
- toClear.add(entry.getKey());
+
+ if (entry.getValue().isEmpty()) {
+ partitionsIter.remove();
+ }
}
}
- toClear.forEach(records::remove);
return new ConsumerRecords<>(results);
}
@@ -275,6 +291,17 @@ public class MockConsumer implements Consumer {
setPollException(exception);
}
+ /* Sets the maximum number of records returned in a single call to {@link #poll(Duration)}.
+ *
+ * @param maxPollRecords the max.poll.records.
+ */
+ public synchronized void setMaxPollRecords(long maxPollRecords) {
+ if (this.maxPollRecords < 1) {
+ throw new IllegalArgumentException("MaxPollRecords must be strictly superior to 0");
+ }
+ this.maxPollRecords = maxPollRecords;
+ }
+
public synchronized void setPollException(KafkaException exception) {
this.pollException = exception;
}
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/MockConsumerTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/MockConsumerTest.java
index a9b0c2843d9..9e2ca5a2ae4 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/MockConsumerTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/MockConsumerTest.java
@@ -31,6 +31,7 @@ import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
+import java.util.stream.IntStream;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
@@ -190,4 +191,33 @@ public class MockConsumerTest {
assertTrue(revoked.contains(topicPartitionList.get(0)));
}
+ @Test
+ public void shouldReturnMaxPollRecords() {
+ TopicPartition partition = new TopicPartition("test", 0);
+ consumer.assign(Collections.singleton(partition));
+ consumer.updateBeginningOffsets(Collections.singletonMap(partition, 0L));
+
+ IntStream.range(0, 10).forEach(offset -> {
+ consumer.addRecord(new ConsumerRecord<>("test", 0, offset, null, null));
+ });
+
+ consumer.setMaxPollRecords(2L);
+
+ ConsumerRecords records;
+
+ records = consumer.poll(Duration.ofMillis(1));
+ assertEquals(2, records.count());
+
+ records = consumer.poll(Duration.ofMillis(1));
+ assertEquals(2, records.count());
+
+ consumer.setMaxPollRecords(Long.MAX_VALUE);
+
+ records = consumer.poll(Duration.ofMillis(1));
+ assertEquals(6, records.count());
+
+ records = consumer.poll(Duration.ofMillis(1));
+ assertTrue(records.isEmpty());
+ }
+
}
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java
index 24cf51a67be..97ec387f0dc 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java
@@ -300,6 +300,7 @@ public class GlobalStateManagerImpl implements GlobalStateManager {
currentDeadline = NO_DEADLINE;
}
+ long batchRestoreCount = 0;
for (final ConsumerRecord record : records.records(topicPartition)) {
final ProcessorRecordContext recordContext =
new ProcessorRecordContext(
@@ -318,6 +319,7 @@ public class GlobalStateManagerImpl implements GlobalStateManager {
record.timestamp(),
record.headers()));
restoreCount++;
+ batchRestoreCount++;
}
} catch (final Exception deserializationException) {
// while Java distinguishes checked vs unchecked exceptions, other languages
@@ -341,7 +343,7 @@ public class GlobalStateManagerImpl implements GlobalStateManager {
offset = getGlobalConsumerOffset(topicPartition);
- stateRestoreListener.onBatchRestored(topicPartition, storeName, offset, restoreCount);
+ stateRestoreListener.onBatchRestored(topicPartition, storeName, offset, batchRestoreCount);
}
stateRestoreListener.onRestoreEnd(topicPartition, storeName, restoreCount);
checkpointFileCache.put(topicPartition, offset);
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImplTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImplTest.java
index 88709ed9186..63889742d08 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImplTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImplTest.java
@@ -356,16 +356,35 @@ public class GlobalStateManagerImplTest {
assertEquals(2, stateRestoreCallback.restored.size());
}
+ @Test
+ public void shouldListenForRestoreEventsWhenReprocessing() {
+ setUpReprocessing();
+
+ initializeConsumer(6, 1, t1);
+ consumer.setMaxPollRecords(2L);
+
+ stateManager.initialize();
+ stateManager.registerStore(store1, stateRestoreCallback, null);
+
+ assertThat(stateRestoreListener.numBatchRestored, equalTo(2L));
+ assertThat(stateRestoreListener.restoreStartOffset, equalTo(1L));
+ assertThat(stateRestoreListener.restoreEndOffset, equalTo(7L));
+ assertThat(stateRestoreListener.totalNumRestored, equalTo(6L));
+ }
+
@Test
public void shouldListenForRestoreEvents() {
- initializeConsumer(5, 1, t1);
+ initializeConsumer(6, 1, t1);
+ consumer.setMaxPollRecords(2L);
+
stateManager.initialize();
stateManager.registerStore(store1, stateRestoreCallback, null);
+ assertThat(stateRestoreListener.numBatchRestored, equalTo(2L));
assertThat(stateRestoreListener.restoreStartOffset, equalTo(1L));
- assertThat(stateRestoreListener.restoreEndOffset, equalTo(6L));
- assertThat(stateRestoreListener.totalNumRestored, equalTo(5L));
+ assertThat(stateRestoreListener.restoreEndOffset, equalTo(7L));
+ assertThat(stateRestoreListener.totalNumRestored, equalTo(6L));
assertThat(stateRestoreListener.storeNameCalledStates.get(RESTORE_START), equalTo(store1.name()));