This is an automated email from the ASF dual-hosted git repository.

ableegoldman pushed a commit to branch 3.9
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/3.9 by this push:
     new 5c2ca4bd58c KAFKA-18962: Fix onBatchRestored call in 
GlobalStateManagerImpl (#19188)
5c2ca4bd58c is described below

commit 5c2ca4bd58ccce9b8611634192adfba8e3737d94
Author: Florian Hussonnois <florian.hussonn...@gmail.com>
AuthorDate: Wed Apr 9 22:17:38 2025 +0200

    KAFKA-18962: Fix onBatchRestored call in GlobalStateManagerImpl (#19188)
    
    Call the StateRestoreListener#onBatchRestored with numRestored and not
    the totalRestored when reprocessing state
    
    See: https://issues.apache.org/jira/browse/KAFKA-18962
    
    Reviewers: Anna Sophie Blee-Goldman <ableegold...@apache.org>, Matthias
    Sax <mj...@apache.org>
---
 checkstyle/suppressions.xml                        |  2 +-
 .../kafka/clients/consumer/MockConsumer.java       | 39 ++++++++++++++++++----
 .../kafka/clients/consumer/MockConsumerTest.java   | 30 +++++++++++++++++
 .../internals/GlobalStateManagerImpl.java          |  4 ++-
 .../internals/GlobalStateManagerImplTest.java      | 25 ++++++++++++--
 5 files changed, 89 insertions(+), 11 deletions(-)

diff --git a/checkstyle/suppressions.xml b/checkstyle/suppressions.xml
index 3195493d9b1..a0dadd66382 100644
--- a/checkstyle/suppressions.xml
+++ b/checkstyle/suppressions.xml
@@ -103,7 +103,7 @@
               
files="(AbstractRequest|AbstractResponse|KerberosLogin|WorkerSinkTaskTest|TransactionManagerTest|SenderTest|KafkaAdminClient|ConsumerCoordinatorTest|KafkaAdminClientTest|KafkaRaftClientTest).java"/>
 
     <suppress checks="NPathComplexity"
-              
files="(ConsumerCoordinator|BufferPool|MetricName|Node|ConfigDef|RecordBatch|SslFactory|SslTransportLayer|MetadataResponse|KerberosLogin|Selector|Sender|Serdes|TokenInformation|Agent|PluginUtils|MiniTrogdorCluster|TasksRequest|KafkaProducer|AbstractStickyAssignor|KafkaRaftClient|Authorizer|FetchSessionHandler|RecordAccumulator|Shell).java"/>
+              
files="(AbstractMembershipManager|ConsumerCoordinator|BufferPool|MetricName|Node|ConfigDef|RecordBatch|SslFactory|SslTransportLayer|MetadataResponse|KerberosLogin|Selector|Sender|Serdes|TokenInformation|Agent|PluginUtils|MiniTrogdorCluster|TasksRequest|KafkaProducer|AbstractStickyAssignor|KafkaRaftClient|Authorizer|FetchSessionHandler|RecordAccumulator|Shell|MockConsumer).java"/>
 
     <suppress checks="(JavaNCSS|CyclomaticComplexity|MethodLength)"
               files="CoordinatorClient.java"/>
diff --git 
a/clients/src/main/java/org/apache/kafka/clients/consumer/MockConsumer.java 
b/clients/src/main/java/org/apache/kafka/clients/consumer/MockConsumer.java
index 600f8bbd07e..56e684e94c4 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/MockConsumer.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/MockConsumer.java
@@ -34,6 +34,7 @@ import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
+import java.util.Iterator;
 import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
@@ -77,6 +78,8 @@ public class MockConsumer<K, V> implements Consumer<K, V> {
     private Uuid clientInstanceId;
     private int injectTimeoutExceptionCounter;
 
+    private long maxPollRecords = Long.MAX_VALUE;
+
     public MockConsumer(OffsetResetStrategy offsetResetStrategy) {
         this.subscriptions = new SubscriptionState(new LogContext(), 
offsetResetStrategy);
         this.partitions = new HashMap<>();
@@ -229,14 +232,22 @@ public class MockConsumer<K, V> implements Consumer<K, V> 
{
 
         // update the consumed offset
         final Map<TopicPartition, List<ConsumerRecord<K, V>>> results = new 
HashMap<>();
-        final List<TopicPartition> toClear = new ArrayList<>();
+        long numPollRecords = 0L;
+
+        final Iterator<Map.Entry<TopicPartition, List<ConsumerRecord<K, V>>>> 
partitionsIter = this.records.entrySet().iterator();
+        while (partitionsIter.hasNext() && numPollRecords < 
this.maxPollRecords) {
+            Map.Entry<TopicPartition, List<ConsumerRecord<K, V>>> entry = 
partitionsIter.next();
 
-        for (Map.Entry<TopicPartition, List<ConsumerRecord<K, V>>> entry : 
this.records.entrySet()) {
             if (!subscriptions.isPaused(entry.getKey())) {
-                final List<ConsumerRecord<K, V>> recs = entry.getValue();
-                for (final ConsumerRecord<K, V> rec : recs) {
+                final Iterator<ConsumerRecord<K, V>> recIterator = 
entry.getValue().iterator();
+                while (recIterator.hasNext()) {
+                    if (numPollRecords >= this.maxPollRecords) {
+                        break;
+                    }
                     long position = 
subscriptions.position(entry.getKey()).offset;
 
+                    final ConsumerRecord<K, V> rec = recIterator.next();
+
                     if (beginningOffsets.get(entry.getKey()) != null && 
beginningOffsets.get(entry.getKey()) > position) {
                         throw new 
OffsetOutOfRangeException(Collections.singletonMap(entry.getKey(), position));
                     }
@@ -247,13 +258,18 @@ public class MockConsumer<K, V> implements Consumer<K, V> 
{
                         SubscriptionState.FetchPosition newPosition = new 
SubscriptionState.FetchPosition(
                                 rec.offset() + 1, rec.leaderEpoch(), 
leaderAndEpoch);
                         subscriptions.position(entry.getKey(), newPosition);
+
+                        numPollRecords++;
+                        recIterator.remove();
                     }
                 }
-                toClear.add(entry.getKey());
+
+                if (entry.getValue().isEmpty()) {
+                    partitionsIter.remove();
+                }
             }
         }
 
-        toClear.forEach(records::remove);
         return new ConsumerRecords<>(results);
     }
 
@@ -275,6 +291,17 @@ public class MockConsumer<K, V> implements Consumer<K, V> {
         setPollException(exception);
     }
 
+    /* Sets the maximum number of records returned in a single call to {@link 
#poll(Duration)}.
+     *
+     * @param maxPollRecords the max.poll.records.
+     */
+    public synchronized void setMaxPollRecords(long maxPollRecords) {
+        if (this.maxPollRecords < 1) {
+            throw new IllegalArgumentException("MaxPollRecords must be 
strictly superior to 0");
+        }
+        this.maxPollRecords = maxPollRecords;
+    }
+
     public synchronized void setPollException(KafkaException exception) {
         this.pollException = exception;
     }
diff --git 
a/clients/src/test/java/org/apache/kafka/clients/consumer/MockConsumerTest.java 
b/clients/src/test/java/org/apache/kafka/clients/consumer/MockConsumerTest.java
index a9b0c2843d9..9e2ca5a2ae4 100644
--- 
a/clients/src/test/java/org/apache/kafka/clients/consumer/MockConsumerTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/clients/consumer/MockConsumerTest.java
@@ -31,6 +31,7 @@ import java.util.HashMap;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Optional;
+import java.util.stream.IntStream;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertFalse;
@@ -190,4 +191,33 @@ public class MockConsumerTest {
         assertTrue(revoked.contains(topicPartitionList.get(0)));
     }
 
+    @Test
+    public void shouldReturnMaxPollRecords() {
+        TopicPartition partition = new TopicPartition("test", 0);
+        consumer.assign(Collections.singleton(partition));
+        consumer.updateBeginningOffsets(Collections.singletonMap(partition, 
0L));
+
+        IntStream.range(0, 10).forEach(offset -> {
+            consumer.addRecord(new ConsumerRecord<>("test", 0, offset, null, 
null));
+        });
+
+        consumer.setMaxPollRecords(2L);
+
+        ConsumerRecords<String, String> records;
+
+        records = consumer.poll(Duration.ofMillis(1));
+        assertEquals(2, records.count());
+
+        records = consumer.poll(Duration.ofMillis(1));
+        assertEquals(2, records.count());
+
+        consumer.setMaxPollRecords(Long.MAX_VALUE);
+
+        records = consumer.poll(Duration.ofMillis(1));
+        assertEquals(6, records.count());
+
+        records = consumer.poll(Duration.ofMillis(1));
+        assertTrue(records.isEmpty());
+    }
+
 }
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java
index 24cf51a67be..97ec387f0dc 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java
@@ -300,6 +300,7 @@ public class GlobalStateManagerImpl implements 
GlobalStateManager {
                     currentDeadline = NO_DEADLINE;
                 }
 
+                long batchRestoreCount = 0;
                 for (final ConsumerRecord<byte[], byte[]> record : 
records.records(topicPartition)) {
                     final ProcessorRecordContext recordContext =
                         new ProcessorRecordContext(
@@ -318,6 +319,7 @@ public class GlobalStateManagerImpl implements 
GlobalStateManager {
                                 record.timestamp(),
                                 record.headers()));
                             restoreCount++;
+                            batchRestoreCount++;
                         }
                     } catch (final Exception deserializationException) {
                         // while Java distinguishes checked vs unchecked 
exceptions, other languages
@@ -341,7 +343,7 @@ public class GlobalStateManagerImpl implements 
GlobalStateManager {
 
                 offset = getGlobalConsumerOffset(topicPartition);
 
-                stateRestoreListener.onBatchRestored(topicPartition, 
storeName, offset, restoreCount);
+                stateRestoreListener.onBatchRestored(topicPartition, 
storeName, offset, batchRestoreCount);
             }
             stateRestoreListener.onRestoreEnd(topicPartition, storeName, 
restoreCount);
             checkpointFileCache.put(topicPartition, offset);
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImplTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImplTest.java
index 88709ed9186..63889742d08 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImplTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImplTest.java
@@ -356,16 +356,35 @@ public class GlobalStateManagerImplTest {
         assertEquals(2, stateRestoreCallback.restored.size());
     }
 
+    @Test
+    public void shouldListenForRestoreEventsWhenReprocessing() {
+        setUpReprocessing();
+
+        initializeConsumer(6, 1, t1);
+        consumer.setMaxPollRecords(2L);
+
+        stateManager.initialize();
+        stateManager.registerStore(store1, stateRestoreCallback, null);
+
+        assertThat(stateRestoreListener.numBatchRestored, equalTo(2L));
+        assertThat(stateRestoreListener.restoreStartOffset, equalTo(1L));
+        assertThat(stateRestoreListener.restoreEndOffset, equalTo(7L));
+        assertThat(stateRestoreListener.totalNumRestored, equalTo(6L));
+    }
+
     @Test
     public void shouldListenForRestoreEvents() {
-        initializeConsumer(5, 1, t1);
+        initializeConsumer(6, 1, t1);
+        consumer.setMaxPollRecords(2L);
+
         stateManager.initialize();
 
         stateManager.registerStore(store1, stateRestoreCallback, null);
 
+        assertThat(stateRestoreListener.numBatchRestored, equalTo(2L));
         assertThat(stateRestoreListener.restoreStartOffset, equalTo(1L));
-        assertThat(stateRestoreListener.restoreEndOffset, equalTo(6L));
-        assertThat(stateRestoreListener.totalNumRestored, equalTo(5L));
+        assertThat(stateRestoreListener.restoreEndOffset, equalTo(7L));
+        assertThat(stateRestoreListener.totalNumRestored, equalTo(6L));
 
 
         
assertThat(stateRestoreListener.storeNameCalledStates.get(RESTORE_START), 
equalTo(store1.name()));

Reply via email to