diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java index 78ff15cee5f..08d3ce12e81 100644 --- a/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java @@ -102,6 +102,7 @@ import org.apache.kafka.common.utils.LogCaptureAppender; import org.apache.kafka.common.utils.LogContext; import org.apache.kafka.common.utils.MockTime; import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Timer; import org.apache.kafka.common.utils.Utils; import org.apache.kafka.test.MockConsumerInterceptor; import org.apache.kafka.test.MockDeserializer; @@ -935,7 +936,6 @@ public class KafkaConsumerTest { @ParameterizedTest @EnumSource(GroupProtocol.class) - @SuppressWarnings("unchecked") public void verifyNoCoordinatorLookupForManualAssignmentWithSeek(GroupProtocol groupProtocol) { ConsumerMetadata metadata = createMetadata(subscription); MockClient client = new MockClient(time, metadata); @@ -951,7 +951,7 @@ public class KafkaConsumerTest { client.prepareResponse(listOffsetsResponse(Map.of(tp0, 50L))); client.prepareResponse(fetchResponse(tp0, 50L, 5)); - ConsumerRecords records = (ConsumerRecords) consumer.poll(Duration.ofMillis(1)); + ConsumerRecords records = pollForRecords(); assertEquals(5, records.count()); assertEquals(55L, consumer.position(tp0)); assertEquals(1, records.nextOffsets().size()); @@ -1045,8 +1045,7 @@ public class KafkaConsumerTest { }, fetchResponse(tp0, 50L, 5)); - @SuppressWarnings("unchecked") - ConsumerRecords records = (ConsumerRecords) consumer.poll(Duration.ofMillis(1)); + ConsumerRecords records = pollForRecords(); assertEquals(5, records.count()); assertEquals(Set.of(tp0), records.partitions()); assertEquals(1, records.nextOffsets().size()); @@ -1731,7 +1730,6 @@ public class KafkaConsumerTest { @ParameterizedTest @EnumSource(GroupProtocol.class) - @SuppressWarnings("unchecked") public void testManualAssignmentChangeWithAutoCommitEnabled(GroupProtocol groupProtocol) { ConsumerMetadata metadata = createMetadata(subscription); MockClient client = new MockClient(time, metadata); @@ -1766,7 +1764,7 @@ public class KafkaConsumerTest { client.prepareResponse(listOffsetsResponse(Map.of(tp0, 10L))); client.prepareResponse(fetchResponse(tp0, 10L, 1)); - ConsumerRecords records = (ConsumerRecords) consumer.poll(Duration.ofMillis(100)); + ConsumerRecords records = pollForRecords(); assertEquals(1, records.count()); assertEquals(11L, consumer.position(tp0)); @@ -1825,8 +1823,7 @@ public class KafkaConsumerTest { client.prepareResponse(listOffsetsResponse(Map.of(tp0, 10L))); client.prepareResponse(fetchResponse(tp0, 10L, 1)); - @SuppressWarnings("unchecked") - ConsumerRecords records = (ConsumerRecords) consumer.poll(Duration.ofMillis(1)); + ConsumerRecords records = pollForRecords(); assertEquals(1, records.count()); assertEquals(11L, consumer.position(tp0)); assertEquals(1, records.nextOffsets().size()); @@ -2655,7 +2652,6 @@ public class KafkaConsumerTest { @ParameterizedTest @EnumSource(GroupProtocol.class) - @SuppressWarnings("unchecked") public void testCurrentLag(GroupProtocol groupProtocol) throws InterruptedException { final ConsumerMetadata metadata = createMetadata(subscription); final MockClient client = new MockClient(time, metadata); @@ -2715,7 +2711,7 @@ public class KafkaConsumerTest { final FetchInfo fetchInfo = new FetchInfo(1L, 99L, 50L, 5); client.respondToRequest(fetchRequest, fetchResponse(Map.of(tp0, fetchInfo))); - final ConsumerRecords records = (ConsumerRecords) consumer.poll(Duration.ofMillis(1)); + final ConsumerRecords records = pollForRecords(); assertEquals(5, records.count()); assertEquals(55L, consumer.position(tp0)); assertEquals(1, records.nextOffsets().size()); @@ -2725,6 +2721,22 @@ public class KafkaConsumerTest { assertEquals(OptionalLong.of(45L), consumer.currentLag(tp0)); } + @SuppressWarnings("unchecked") + private ConsumerRecords pollForRecords() { + Timer timer = time.timer(15000); + + while (timer.notExpired()) { + ConsumerRecords records = (ConsumerRecords) consumer.poll(Duration.ofMillis(1000)); + + if (!records.isEmpty()) + return records; + } + + throw new org.apache.kafka.common.errors.TimeoutException("no records to return"); + } + + + @ParameterizedTest @EnumSource(GroupProtocol.class) public void testListOffsetShouldUpdateSubscriptions(GroupProtocol groupProtocol) {