This is an automated email from the ASF dual-hosted git repository.

cadonna pushed a commit to branch 3.7
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/3.7 by this push:
     new 64845b9b071 KAFKA-15625: Do not flush global state store at each 
commit (#15361)
64845b9b071 is described below

commit 64845b9b071c12b2866de85b3fde096da4b056a0
Author: Ayoub Omari <ayouboma...@outlook.fr>
AuthorDate: Mon Mar 4 10:19:59 2024 +0100

    KAFKA-15625: Do not flush global state store at each commit (#15361)
    
    Global state stores are currently flushed at each commit, which may impact 
performance, especially for EOS (commit each 200ms).
    The goal of this improvement is to flush global state stores only when the 
delta between the current offset and the last checkpointed offset exceeds a 
threshold.
    This is the same logic we apply on local state store, with a threshold of 
10000 records.
    The implementation only flushes if the time interval elapsed and the 
threshold of 10000 records is exceeded.
    
    Reviewers: Jeff Kim <jeff....@confluent.io>, Bruno Cadonna 
<cado...@apache.org>
---
 .../processor/internals/GlobalStateMaintainer.java |   2 +
 .../processor/internals/GlobalStateUpdateTask.java |  20 +++-
 .../processor/internals/GlobalStreamThread.java    |  25 ++---
 .../processor/internals/GlobalStateTaskTest.java   | 111 +++++++++++++++++++--
 .../processor/internals/StateConsumerTest.java     |  31 ++----
 .../apache/kafka/test/GlobalStateManagerStub.java  |  10 +-
 .../apache/kafka/streams/TopologyTestDriver.java   |   4 +-
 7 files changed, 147 insertions(+), 56 deletions(-)

diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateMaintainer.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateMaintainer.java
index 9a8aab6eb3c..06afb6fde4f 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateMaintainer.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateMaintainer.java
@@ -34,4 +34,6 @@ interface GlobalStateMaintainer {
     void close(final boolean wipeStateStore) throws IOException;
 
     void update(ConsumerRecord<byte[], byte[]> record);
+
+    void maybeCheckpoint();
 }
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateUpdateTask.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateUpdateTask.java
index 523228542a8..da7ebba209a 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateUpdateTask.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateUpdateTask.java
@@ -19,6 +19,7 @@ package org.apache.kafka.streams.processor.internals;
 import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.utils.LogContext;
+import org.apache.kafka.common.utils.Time;
 import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.streams.errors.DeserializationExceptionHandler;
 import org.apache.kafka.streams.errors.StreamsException;
@@ -45,18 +46,26 @@ public class GlobalStateUpdateTask implements 
GlobalStateMaintainer {
     private final Map<String, RecordDeserializer> deserializers = new 
HashMap<>();
     private final GlobalStateManager stateMgr;
     private final DeserializationExceptionHandler 
deserializationExceptionHandler;
+    private final Time time;
+    private final long flushInterval;
+    private long lastFlush;
 
     public GlobalStateUpdateTask(final LogContext logContext,
                                  final ProcessorTopology topology,
                                  final InternalProcessorContext 
processorContext,
                                  final GlobalStateManager stateMgr,
-                                 final DeserializationExceptionHandler 
deserializationExceptionHandler) {
+                                 final DeserializationExceptionHandler 
deserializationExceptionHandler,
+                                 final Time time,
+                                 final long flushInterval
+                                 ) {
         this.logContext = logContext;
         this.log = logContext.logger(getClass());
         this.topology = topology;
         this.stateMgr = stateMgr;
         this.processorContext = processorContext;
         this.deserializationExceptionHandler = deserializationExceptionHandler;
+        this.time = time;
+        this.flushInterval = flushInterval;
     }
 
     /**
@@ -86,6 +95,7 @@ public class GlobalStateUpdateTask implements 
GlobalStateMaintainer {
         }
         initTopology();
         processorContext.initialize();
+        lastFlush = time.milliseconds();
         return stateMgr.changelogOffsets();
     }
 
@@ -150,5 +160,13 @@ public class GlobalStateUpdateTask implements 
GlobalStateMaintainer {
         }
     }
 
+    @Override
+    public void maybeCheckpoint() {
+        final long now = time.milliseconds();
+        if (now - flushInterval >= lastFlush && 
StateManagerUtil.checkpointNeeded(false, stateMgr.changelogOffsets(), offsets)) 
{
+            flushState();
+            lastFlush = now;
+        }
+    }
 
 }
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStreamThread.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStreamThread.java
index 82a0cc51131..1ed517b15d4 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStreamThread.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStreamThread.java
@@ -228,25 +228,17 @@ public class GlobalStreamThread extends Thread {
     static class StateConsumer {
         private final Consumer<byte[], byte[]> globalConsumer;
         private final GlobalStateMaintainer stateMaintainer;
-        private final Time time;
         private final Duration pollTime;
-        private final long flushInterval;
         private final Logger log;
 
-        private long lastFlush;
-
         StateConsumer(final LogContext logContext,
                       final Consumer<byte[], byte[]> globalConsumer,
                       final GlobalStateMaintainer stateMaintainer,
-                      final Time time,
-                      final Duration pollTime,
-                      final long flushInterval) {
+                      final Duration pollTime) {
             this.log = logContext.logger(getClass());
             this.globalConsumer = globalConsumer;
             this.stateMaintainer = stateMaintainer;
-            this.time = time;
             this.pollTime = pollTime;
-            this.flushInterval = flushInterval;
         }
 
         /**
@@ -259,7 +251,6 @@ public class GlobalStreamThread extends Thread {
             for (final Map.Entry<TopicPartition, Long> entry : 
partitionOffsets.entrySet()) {
                 globalConsumer.seek(entry.getKey(), entry.getValue());
             }
-            lastFlush = time.milliseconds();
         }
 
         void pollAndUpdate() {
@@ -267,11 +258,7 @@ public class GlobalStreamThread extends Thread {
             for (final ConsumerRecord<byte[], byte[]> record : received) {
                 stateMaintainer.update(record);
             }
-            final long now = time.milliseconds();
-            if (now - flushInterval >= lastFlush) {
-                stateMaintainer.flushState();
-                lastFlush = now;
-            }
+            stateMaintainer.maybeCheckpoint();
         }
 
         public void close(final boolean wipeStateStore) throws IOException {
@@ -418,11 +405,11 @@ public class GlobalStreamThread extends Thread {
                     topology,
                     globalProcessorContext,
                     stateMgr,
-                    config.defaultDeserializationExceptionHandler()
+                    config.defaultDeserializationExceptionHandler(),
+                    time,
+                    config.getLong(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG)
                 ),
-                time,
-                
Duration.ofMillis(config.getLong(StreamsConfig.POLL_MS_CONFIG)),
-                config.getLong(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG)
+                Duration.ofMillis(config.getLong(StreamsConfig.POLL_MS_CONFIG))
             );
 
             try {
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateTaskTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateTaskTest.java
index 31be9dc2a4d..af5dc68103c 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateTaskTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateTaskTest.java
@@ -25,6 +25,7 @@ import 
org.apache.kafka.common.serialization.IntegerSerializer;
 import org.apache.kafka.common.serialization.LongSerializer;
 import org.apache.kafka.common.serialization.StringDeserializer;
 import org.apache.kafka.common.utils.LogContext;
+import org.apache.kafka.common.utils.MockTime;
 import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.streams.errors.LogAndContinueExceptionHandler;
 import org.apache.kafka.streams.errors.LogAndFailExceptionHandler;
@@ -46,8 +47,6 @@ import java.util.Set;
 
 import static java.util.Arrays.asList;
 import static 
org.apache.kafka.streams.processor.internals.testutil.ConsumerRecordUtil.record;
-import static org.hamcrest.CoreMatchers.equalTo;
-import static org.hamcrest.MatcherAssert.assertThat;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
@@ -71,8 +70,12 @@ public class GlobalStateTaskTest {
     private final MockProcessorNode<?, ?, ?, ?> processorTwo = new 
MockProcessorNode<>();
 
     private final Map<TopicPartition, Long> offsets = new HashMap<>();
-    private File testDirectory = TestUtils.tempDirectory("global-store");
+    private final File testDirectory = TestUtils.tempDirectory("global-store");
     private final NoOpProcessorContext context = new NoOpProcessorContext();
+    private final MockTime time = new MockTime();
+    private final long flushInterval = 1000L;
+    private final long currentOffsetT1 = 50;
+    private final long currentOffsetT2 = 100;
 
     private ProcessorTopology topology;
     private GlobalStateManagerStub stateMgr;
@@ -101,7 +104,9 @@ public class GlobalStateTaskTest {
             topology,
             context,
             stateMgr,
-            new LogAndFailExceptionHandler()
+            new LogAndFailExceptionHandler(),
+            time,
+            flushInterval
         );
     }
 
@@ -188,7 +193,9 @@ public class GlobalStateTaskTest {
             topology,
             context,
             stateMgr,
-            new LogAndContinueExceptionHandler()
+            new LogAndContinueExceptionHandler(),
+            time,
+            flushInterval
         );
         final byte[] key = new LongSerializer().serialize(topic2, 1L);
         final byte[] recordValue = new IntegerSerializer().serialize(topic2, 
10);
@@ -203,7 +210,9 @@ public class GlobalStateTaskTest {
             topology,
             context,
             stateMgr,
-            new LogAndContinueExceptionHandler()
+            new LogAndContinueExceptionHandler(),
+            time,
+            flushInterval
         );
         final byte[] key = new IntegerSerializer().serialize(topic2, 1);
         final byte[] recordValue = new LongSerializer().serialize(topic2, 10L);
@@ -217,10 +226,13 @@ public class GlobalStateTaskTest {
         final Map<TopicPartition, Long> expectedOffsets = new HashMap<>();
         expectedOffsets.put(t1, 52L);
         expectedOffsets.put(t2, 100L);
+
         globalStateTask.initialize();
-        globalStateTask.update(record(topic1, 1, 51, "foo".getBytes(), 
"foo".getBytes()));
+        globalStateTask.update(record(topic1, 1, currentOffsetT1 + 1, 
"foo".getBytes(), "foo".getBytes()));
         globalStateTask.flushState();
+
         assertEquals(expectedOffsets, stateMgr.changelogOffsets());
+        assertTrue(stateMgr.flushed);
     }
 
     @Test
@@ -228,12 +240,93 @@ public class GlobalStateTaskTest {
         final Map<TopicPartition, Long> expectedOffsets = new HashMap<>();
         expectedOffsets.put(t1, 102L);
         expectedOffsets.put(t2, 100L);
+
         globalStateTask.initialize();
-        globalStateTask.update(record(topic1, 1, 101, "foo".getBytes(), 
"foo".getBytes()));
+        globalStateTask.update(record(topic1, 1, currentOffsetT1 + 51L, 
"foo".getBytes(), "foo".getBytes()));
         globalStateTask.flushState();
-        assertThat(stateMgr.changelogOffsets(), equalTo(expectedOffsets));
+
+        assertEquals(expectedOffsets, stateMgr.changelogOffsets());
+        assertTrue(stateMgr.checkpointWritten);
+    }
+
+    @Test
+    public void shouldNotCheckpointIfNotReceivedEnoughRecords() {
+        globalStateTask.initialize();
+        globalStateTask.update(record(topic1, 1, currentOffsetT1 + 9000L, 
"foo".getBytes(), "foo".getBytes()));
+        time.sleep(flushInterval); // flush interval elapsed
+        globalStateTask.maybeCheckpoint();
+
+        assertEquals(offsets, stateMgr.changelogOffsets());
+        assertFalse(stateMgr.flushed);
+        assertFalse(stateMgr.checkpointWritten);
+    }
+
+    @Test
+    public void shouldNotCheckpointWhenFlushIntervalHasNotLapsed() {
+        globalStateTask.initialize();
+
+        // offset delta exceeded
+        globalStateTask.update(record(topic1, 1, currentOffsetT1 + 10000L, 
"foo".getBytes(), "foo".getBytes()));
+
+        time.sleep(flushInterval / 2);
+        globalStateTask.maybeCheckpoint();
+
+        assertEquals(offsets, stateMgr.changelogOffsets());
+        assertFalse(stateMgr.flushed);
+        assertFalse(stateMgr.checkpointWritten);
+    }
+
+    @Test
+    public void 
shouldCheckpointIfReceivedEnoughRecordsAndFlushIntervalHasElapsed() {
+        final Map<TopicPartition, Long> expectedOffsets = new HashMap<>();
+        expectedOffsets.put(t1, 10051L); // topic1 advanced with 10001 records
+        expectedOffsets.put(t2, 100L);
+
+        globalStateTask.initialize();
+
+        time.sleep(flushInterval); // flush interval elapsed
+
+        // 10000 records received since last flush => do not flush
+        globalStateTask.update(record(topic1, 1, currentOffsetT1 + 9999L, 
"foo".getBytes(), "foo".getBytes()));
+        globalStateTask.maybeCheckpoint();
+
+        assertEquals(offsets, stateMgr.changelogOffsets());
+        assertFalse(stateMgr.flushed);
+        assertFalse(stateMgr.checkpointWritten);
+
+        // 1 more record received => triggers the flush
+        globalStateTask.update(record(topic1, 1, currentOffsetT1 + 10000L, 
"foo".getBytes(), "foo".getBytes()));
+        globalStateTask.maybeCheckpoint();
+
+        assertEquals(expectedOffsets, stateMgr.changelogOffsets());
+        assertTrue(stateMgr.flushed);
+        assertTrue(stateMgr.checkpointWritten);
     }
 
+    @Test
+    public void 
shouldCheckpointIfReceivedEnoughRecordsFromMultipleTopicsAndFlushIntervalElapsed()
 {
+        final byte[] integerBytes = new IntegerSerializer().serialize(topic2, 
1);
+
+        final Map<TopicPartition, Long> expectedOffsets = new HashMap<>();
+        expectedOffsets.put(t1, 9050L); // topic1 advanced with 9000 records
+        expectedOffsets.put(t2, 1101L); // topic2 advanced with 1001 records
+
+        globalStateTask.initialize();
+
+        time.sleep(flushInterval);
+
+        // received 9000 records in topic1
+        globalStateTask.update(record(topic1, 1, currentOffsetT1 + 8999L, 
"foo".getBytes(), "foo".getBytes()));
+        // received 1001 records in topic2
+        globalStateTask.update(record(topic2, 1, currentOffsetT2 + 1000L, 
integerBytes, integerBytes));
+        globalStateTask.maybeCheckpoint();
+
+        assertEquals(expectedOffsets, stateMgr.changelogOffsets());
+        assertTrue(stateMgr.flushed);
+        assertTrue(stateMgr.checkpointWritten);
+    }
+
+
     @Test
     public void shouldWipeGlobalStateDirectory() throws Exception {
         assertTrue(stateMgr.baseDir().exists());
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StateConsumerTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StateConsumerTest.java
index 1f98eb456d9..5e579394833 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StateConsumerTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StateConsumerTest.java
@@ -21,7 +21,6 @@ import org.apache.kafka.clients.consumer.MockConsumer;
 import org.apache.kafka.clients.consumer.OffsetResetStrategy;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.utils.LogContext;
-import org.apache.kafka.common.utils.MockTime;
 import org.apache.kafka.common.utils.Utils;
 import org.junit.Before;
 import org.junit.Test;
@@ -32,16 +31,13 @@ import java.util.HashMap;
 import java.util.Map;
 
 import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
 
 
 public class StateConsumerTest {
 
-    private static final long FLUSH_INTERVAL = 1000L;
     private final TopicPartition topicOne = new TopicPartition("topic-one", 1);
     private final TopicPartition topicTwo = new TopicPartition("topic-two", 1);
-    private final MockTime time = new MockTime();
     private final MockConsumer<byte[], byte[]> consumer = new 
MockConsumer<>(OffsetResetStrategy.EARLIEST);
     private final Map<TopicPartition, Long> partitionOffsets = new HashMap<>();
     private final LogContext logContext = new LogContext("test ");
@@ -53,7 +49,7 @@ public class StateConsumerTest {
         partitionOffsets.put(topicOne, 20L);
         partitionOffsets.put(topicTwo, 30L);
         stateMaintainer = new TaskStub(partitionOffsets);
-        stateConsumer = new GlobalStreamThread.StateConsumer(logContext, 
consumer, stateMaintainer, time, Duration.ofMillis(10L), FLUSH_INTERVAL);
+        stateConsumer = new GlobalStreamThread.StateConsumer(logContext, 
consumer, stateMaintainer, Duration.ofMillis(10L));
     }
 
     @Test
@@ -76,6 +72,7 @@ public class StateConsumerTest {
         consumer.addRecord(new ConsumerRecord<>("topic-one", 1, 21L, new 
byte[0], new byte[0]));
         stateConsumer.pollAndUpdate();
         assertEquals(2, 
stateMaintainer.updatedPartitions.get(topicOne).intValue());
+        assertTrue(stateMaintainer.flushed);
     }
 
     @Test
@@ -87,27 +84,9 @@ public class StateConsumerTest {
         stateConsumer.pollAndUpdate();
         assertEquals(1, 
stateMaintainer.updatedPartitions.get(topicOne).intValue());
         assertEquals(2, 
stateMaintainer.updatedPartitions.get(topicTwo).intValue());
-    }
-
-    @Test
-    public void shouldFlushStoreWhenFlushIntervalHasLapsed() {
-        stateConsumer.initialize();
-        consumer.addRecord(new ConsumerRecord<>("topic-one", 1, 20L, new 
byte[0], new byte[0]));
-        time.sleep(FLUSH_INTERVAL);
-
-        stateConsumer.pollAndUpdate();
         assertTrue(stateMaintainer.flushed);
     }
 
-    @Test
-    public void shouldNotFlushOffsetsWhenFlushIntervalHasNotLapsed() {
-        stateConsumer.initialize();
-        consumer.addRecord(new ConsumerRecord<>("topic-one", 1, 20L, new 
byte[0], new byte[0]));
-        time.sleep(FLUSH_INTERVAL / 2);
-        stateConsumer.pollAndUpdate();
-        assertFalse(stateMaintainer.flushed);
-    }
-
     @Test
     public void shouldCloseConsumer() throws IOException {
         stateConsumer.close(false);
@@ -161,6 +140,10 @@ public class StateConsumerTest {
             updatedPartitions.put(tp, updatedPartitions.get(tp) + 1);
         }
 
+        @Override
+        public void maybeCheckpoint() {
+            flushState();
+        }
     }
 
-}
\ No newline at end of file
+}
diff --git 
a/streams/src/test/java/org/apache/kafka/test/GlobalStateManagerStub.java 
b/streams/src/test/java/org/apache/kafka/test/GlobalStateManagerStub.java
index d34b3c8029f..30316499447 100644
--- a/streams/src/test/java/org/apache/kafka/test/GlobalStateManagerStub.java
+++ b/streams/src/test/java/org/apache/kafka/test/GlobalStateManagerStub.java
@@ -35,6 +35,8 @@ public class GlobalStateManagerStub implements 
GlobalStateManager {
     private final File baseDirectory;
     public boolean initialized;
     public boolean closed;
+    public boolean flushed;
+    public boolean checkpointWritten;
 
     public GlobalStateManagerStub(final Set<String> storeNames,
                                   final Map<TopicPartition, Long> offsets,
@@ -64,7 +66,9 @@ public class GlobalStateManagerStub implements 
GlobalStateManager {
                               final CommitCallback checkpoint) {}
 
     @Override
-    public void flush() {}
+    public void flush() {
+        flushed = true;
+    }
 
     @Override
     public void close() {
@@ -77,7 +81,9 @@ public class GlobalStateManagerStub implements 
GlobalStateManager {
     }
 
     @Override
-    public void checkpoint() {}
+    public void checkpoint() {
+        checkpointWritten = true;
+    }
 
     @Override
     public StateStore getStore(final String name) {
diff --git 
a/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java
 
b/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java
index 20abaa54072..5767ed9d20e 100644
--- 
a/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java
+++ 
b/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java
@@ -459,7 +459,9 @@ public class TopologyTestDriver implements Closeable {
                 globalTopology,
                 globalProcessorContext,
                 globalStateManager,
-                new LogAndContinueExceptionHandler()
+                new LogAndContinueExceptionHandler(),
+                mockWallClockTime,
+                streamsConfig.getLong(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG)
             );
             globalStateTask.initialize();
             globalProcessorContext.setRecordContext(null);

Reply via email to