This is an automated email from the ASF dual-hosted git repository.

leiyanfei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new d8b3243eaa5 [FLINK-35667][state/forst] Implement Reducing Async State 
API for ForStStateBackend (#25308)
d8b3243eaa5 is described below

commit d8b3243eaa5475eafb813c23419bbb55a4bebd23
Author: Yanfei Lei <fredia...@gmail.com>
AuthorDate: Sat Sep 14 10:10:20 2024 +0800

    [FLINK-35667][state/forst] Implement Reducing Async State API for 
ForStStateBackend (#25308)
---
 .../runtime/asyncprocessing/StateRequestType.java  |   3 +
 .../runtime/state/v2/InternalMergingState.java     |  41 ++++++
 .../runtime/state/v2/InternalReducingState.java    | 117 +++++++++++++++-
 .../state/v2/InternalAggregatingStateTest.java     |   6 +-
 .../state/v2/InternalKeyedStateTestBase.java       |   7 +-
 .../runtime/state/v2/InternalListStateTest.java    |  20 +--
 .../runtime/state/v2/InternalMapStateTest.java     |  42 +++---
 .../state/v2/InternalReducingStateTest.java        | 142 +++++++++++++++++++-
 .../runtime/state/v2/InternalValueStateTest.java   |  12 +-
 .../flink/state/forst/ForStKeyedStateBackend.java  |  12 ++
 .../flink/state/forst/ForStReducingState.java      | 149 +++++++++++++++++++++
 .../state/forst/ForStStateRequestClassifier.java   |   2 +
 12 files changed, 502 insertions(+), 51 deletions(-)

diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/StateRequestType.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/StateRequestType.java
index 504115a48fa..a3a95cf9f69 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/StateRequestType.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/StateRequestType.java
@@ -108,6 +108,9 @@ public enum StateRequestType {
     /** Add element into reducing state, {@link 
ReducingState#asyncAdd(Object)}. */
     REDUCING_ADD,
 
+    /** Remove element from reducing state. */
+    REDUCING_REMOVE,
+
     /** Get value from aggregating state by {@link 
AggregatingState#asyncGet()}. */
     AGGREGATING_GET,
 
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/InternalMergingState.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/InternalMergingState.java
new file mode 100644
index 00000000000..3f5713bdb58
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/InternalMergingState.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state.v2;
+
+import org.apache.flink.api.common.state.v2.StateFuture;
+
+import java.util.Collection;
+
+/**
+ * This class defines the internal interface for merging state.
+ *
+ * @param <N> The type of the namespace
+ */
+public interface InternalMergingState<N> extends InternalPartitionedState<N> {
+
+    /**
+     * Merges the state of the current key for the given source namespaces 
into the state of the
+     * target namespace.
+     *
+     * @param target The target namespace where the merged state should be 
stored.
+     * @param sources The source namespaces whose state should be merged.
+     */
+    StateFuture<Void> asyncMergeNamespaces(N target, Collection<N> sources);
+
+    void mergeNamespaces(N target, Collection<N> sources);
+}
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/InternalReducingState.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/InternalReducingState.java
index b49962f9f5d..0c3b124a778 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/InternalReducingState.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/InternalReducingState.java
@@ -20,9 +20,15 @@ package org.apache.flink.runtime.state.v2;
 import org.apache.flink.api.common.functions.ReduceFunction;
 import org.apache.flink.api.common.state.v2.ReducingState;
 import org.apache.flink.api.common.state.v2.StateFuture;
+import org.apache.flink.core.state.StateFutureUtils;
 import org.apache.flink.runtime.asyncprocessing.StateRequestHandler;
 import org.apache.flink.runtime.asyncprocessing.StateRequestType;
 
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.List;
+
 /**
  * A default implementation of {@link ReducingState} which delegates all async 
requests to {@link
  * StateRequestHandler}.
@@ -31,7 +37,7 @@ import 
org.apache.flink.runtime.asyncprocessing.StateRequestType;
  * @param <V> The type of values kept internally in state.
  */
 public class InternalReducingState<K, N, V> extends InternalKeyedState<K, N, V>
-        implements ReducingState<V> {
+        implements ReducingState<V>, InternalMergingState<N> {
 
     protected final ReduceFunction<V> reduceFunction;
 
@@ -48,7 +54,15 @@ public class InternalReducingState<K, N, V> extends 
InternalKeyedState<K, N, V>
 
     @Override
     public StateFuture<Void> asyncAdd(V value) {
-        return handleRequest(StateRequestType.REDUCING_ADD, value);
+        return handleRequest(StateRequestType.REDUCING_GET, null)
+                .thenAccept(
+                        oldValue -> {
+                            V newValue =
+                                    oldValue == null
+                                            ? value
+                                            : reduceFunction.reduce((V) 
oldValue, value);
+                            handleRequest(StateRequestType.REDUCING_ADD, 
newValue);
+                        });
     }
 
     @Override
@@ -58,6 +72,103 @@ public class InternalReducingState<K, N, V> extends 
InternalKeyedState<K, N, V>
 
     @Override
     public void add(V value) {
-        handleRequestSync(StateRequestType.REDUCING_ADD, value);
+        V oldValue = handleRequestSync(StateRequestType.REDUCING_GET, null);
+        try {
+            V newValue = oldValue == null ? value : 
reduceFunction.reduce(oldValue, value);
+            handleRequestSync(StateRequestType.REDUCING_ADD, newValue);
+        } catch (Exception e) {
+            throw new RuntimeException(e);
+        }
+    }
+
+    @Override
+    public StateFuture<Void> asyncMergeNamespaces(N target, Collection<N> 
sources) {
+        if (sources == null || sources.isEmpty()) {
+            return StateFutureUtils.completedVoidFuture();
+        }
+        // phase 1: read from the sources and target
+        List<StateFuture<V>> futures = new ArrayList<>(sources.size() + 1);
+        for (N source : sources) {
+            if (source != null) {
+                setCurrentNamespace(source);
+                futures.add(handleRequest(StateRequestType.REDUCING_GET, 
null));
+            }
+        }
+        setCurrentNamespace(target);
+        futures.add(handleRequest(StateRequestType.REDUCING_GET, null));
+        // phase 2: merge the sources to the target
+        return StateFutureUtils.combineAll(futures)
+                .thenCompose(
+                        values -> {
+                            List<StateFuture<V>> updateFutures =
+                                    new ArrayList<>(sources.size() + 1);
+                            V current = null;
+                            Iterator<V> valueIterator = values.iterator();
+                            for (N source : sources) {
+                                V value = valueIterator.next();
+                                if (value != null) {
+                                    setCurrentNamespace(source);
+                                    updateFutures.add(
+                                            
handleRequest(StateRequestType.REDUCING_REMOVE, null));
+                                    if (current != null) {
+                                        current = 
reduceFunction.reduce(current, value);
+                                    } else {
+                                        current = value;
+                                    }
+                                }
+                            }
+                            V targetValue = valueIterator.next();
+                            if (current != null) {
+                                if (targetValue != null) {
+                                    current = reduceFunction.reduce(current, 
targetValue);
+                                }
+                                setCurrentNamespace(target);
+                                updateFutures.add(
+                                        
handleRequest(StateRequestType.REDUCING_ADD, current));
+                            }
+                            return StateFutureUtils.combineAll(updateFutures)
+                                    .thenAccept(ignores -> {});
+                        });
+    }
+
+    @Override
+    public void mergeNamespaces(N target, Collection<N> sources) {
+        if (sources == null || sources.isEmpty()) {
+            return;
+        }
+        try {
+            V current = null;
+            // merge the sources to the target
+            for (N source : sources) {
+                if (source != null) {
+                    setCurrentNamespace(source);
+                    V oldValue = 
handleRequestSync(StateRequestType.REDUCING_GET, null);
+
+                    if (oldValue != null) {
+                        handleRequestSync(StateRequestType.REDUCING_REMOVE, 
null);
+
+                        if (current != null) {
+                            current = reduceFunction.reduce(current, oldValue);
+                        } else {
+                            current = oldValue;
+                        }
+                    }
+                }
+            }
+
+            // if something came out of merging the sources, merge it or write 
it to the target
+            if (current != null) {
+                // create the target full-binary-key
+                setCurrentNamespace(target);
+                V targetValue = 
handleRequestSync(StateRequestType.REDUCING_GET, null);
+
+                if (targetValue != null) {
+                    current = reduceFunction.reduce(current, targetValue);
+                }
+                handleRequestSync(StateRequestType.REDUCING_ADD, current);
+            }
+        } catch (Exception e) {
+            throw new RuntimeException("merge namespace fail.", e);
+        }
     }
 }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalAggregatingStateTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalAggregatingStateTest.java
index 97410748189..da671c78a83 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalAggregatingStateTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalAggregatingStateTest.java
@@ -61,12 +61,12 @@ class InternalAggregatingStateTest extends 
InternalKeyedStateTestBase {
         aec.setCurrentContext(aec.buildContext("test", "test"));
 
         state.asyncClear();
-        validateRequestRun(state, StateRequestType.CLEAR, null);
+        validateRequestRun(state, StateRequestType.CLEAR, null, 0);
 
         state.asyncGet();
-        validateRequestRun(state, StateRequestType.AGGREGATING_GET, null);
+        validateRequestRun(state, StateRequestType.AGGREGATING_GET, null, 0);
 
         state.asyncAdd(1);
-        validateRequestRun(state, StateRequestType.AGGREGATING_ADD, 1);
+        validateRequestRun(state, StateRequestType.AGGREGATING_ADD, 1, 0);
     }
 }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalKeyedStateTestBase.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalKeyedStateTestBase.java
index 91923ff7705..9e645cf5a6b 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalKeyedStateTestBase.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalKeyedStateTestBase.java
@@ -86,10 +86,13 @@ public class InternalKeyedStateTestBase {
     }
 
     <IN> void validateRequestRun(
-            @Nullable State state, StateRequestType type, @Nullable IN 
payload) {
+            @Nullable State state,
+            StateRequestType type,
+            @Nullable IN payload,
+            int remainingRequests) {
         aec.triggerIfNeeded(true);
         testStateExecutor.validate(state, type, payload);
-        assertThat(testStateExecutor.receivedRequest.isEmpty()).isTrue();
+        
assertThat(testStateExecutor.receivedRequest.size()).isEqualTo(remainingRequests);
     }
 
     /**
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalListStateTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalListStateTest.java
index ef0de969fcf..7f6679d2399 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalListStateTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalListStateTest.java
@@ -39,37 +39,37 @@ public class InternalListStateTest extends 
InternalKeyedStateTestBase {
         aec.setCurrentContext(aec.buildContext("test", "test"));
 
         listState.asyncClear();
-        validateRequestRun(listState, StateRequestType.CLEAR, null);
+        validateRequestRun(listState, StateRequestType.CLEAR, null, 0);
 
         listState.asyncGet();
-        validateRequestRun(listState, StateRequestType.LIST_GET, null);
+        validateRequestRun(listState, StateRequestType.LIST_GET, null, 0);
 
         listState.asyncAdd(1);
-        validateRequestRun(listState, StateRequestType.LIST_ADD, 1);
+        validateRequestRun(listState, StateRequestType.LIST_ADD, 1, 0);
 
         List<Integer> list = new ArrayList<>();
         listState.asyncUpdate(list);
-        validateRequestRun(listState, StateRequestType.LIST_UPDATE, list);
+        validateRequestRun(listState, StateRequestType.LIST_UPDATE, list, 0);
 
         list = new ArrayList<>();
         listState.asyncAddAll(list);
-        validateRequestRun(listState, StateRequestType.LIST_ADD_ALL, list);
+        validateRequestRun(listState, StateRequestType.LIST_ADD_ALL, list, 0);
 
         listState.clear();
-        validateRequestRun(listState, StateRequestType.CLEAR, null);
+        validateRequestRun(listState, StateRequestType.CLEAR, null, 0);
 
         listState.get().iterator();
-        validateRequestRun(listState, StateRequestType.LIST_GET, null);
+        validateRequestRun(listState, StateRequestType.LIST_GET, null, 0);
 
         listState.add(1);
-        validateRequestRun(listState, StateRequestType.LIST_ADD, 1);
+        validateRequestRun(listState, StateRequestType.LIST_ADD, 1, 0);
 
         list = new ArrayList<>();
         listState.update(list);
-        validateRequestRun(listState, StateRequestType.LIST_UPDATE, list);
+        validateRequestRun(listState, StateRequestType.LIST_UPDATE, list, 0);
 
         list = new ArrayList<>();
         listState.addAll(list);
-        validateRequestRun(listState, StateRequestType.LIST_ADD_ALL, list);
+        validateRequestRun(listState, StateRequestType.LIST_ADD_ALL, list, 0);
     }
 }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalMapStateTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalMapStateTest.java
index ae988326caa..52dd9b1c3ef 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalMapStateTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalMapStateTest.java
@@ -41,67 +41,67 @@ public class InternalMapStateTest extends 
InternalKeyedStateTestBase {
         aec.setCurrentContext(aec.buildContext("test", "test"));
 
         mapState.asyncClear();
-        validateRequestRun(mapState, StateRequestType.CLEAR, null);
+        validateRequestRun(mapState, StateRequestType.CLEAR, null, 0);
 
         mapState.asyncGet("key1");
-        validateRequestRun(mapState, StateRequestType.MAP_GET, "key1");
+        validateRequestRun(mapState, StateRequestType.MAP_GET, "key1", 0);
 
         mapState.asyncPut("key2", 2);
-        validateRequestRun(mapState, StateRequestType.MAP_PUT, 
Tuple2.of("key2", 2));
+        validateRequestRun(mapState, StateRequestType.MAP_PUT, 
Tuple2.of("key2", 2), 0);
 
         Map<String, Integer> map = new HashMap<>();
         mapState.asyncPutAll(map);
-        validateRequestRun(mapState, StateRequestType.MAP_PUT_ALL, map);
+        validateRequestRun(mapState, StateRequestType.MAP_PUT_ALL, map, 0);
 
         mapState.asyncRemove("key3");
-        validateRequestRun(mapState, StateRequestType.MAP_REMOVE, "key3");
+        validateRequestRun(mapState, StateRequestType.MAP_REMOVE, "key3", 0);
 
         mapState.asyncContains("key4");
-        validateRequestRun(mapState, StateRequestType.MAP_CONTAINS, "key4");
+        validateRequestRun(mapState, StateRequestType.MAP_CONTAINS, "key4", 0);
 
         mapState.asyncEntries();
-        validateRequestRun(mapState, StateRequestType.MAP_ITER, null);
+        validateRequestRun(mapState, StateRequestType.MAP_ITER, null, 0);
 
         mapState.asyncKeys();
-        validateRequestRun(mapState, StateRequestType.MAP_ITER_KEY, null);
+        validateRequestRun(mapState, StateRequestType.MAP_ITER_KEY, null, 0);
 
         mapState.asyncValues();
-        validateRequestRun(mapState, StateRequestType.MAP_ITER_VALUE, null);
+        validateRequestRun(mapState, StateRequestType.MAP_ITER_VALUE, null, 0);
 
         mapState.asyncIsEmpty();
-        validateRequestRun(mapState, StateRequestType.MAP_IS_EMPTY, null);
+        validateRequestRun(mapState, StateRequestType.MAP_IS_EMPTY, null, 0);
 
         mapState.clear();
-        validateRequestRun(mapState, StateRequestType.CLEAR, null);
+        validateRequestRun(mapState, StateRequestType.CLEAR, null, 0);
 
         mapState.get("key1");
-        validateRequestRun(mapState, StateRequestType.MAP_GET, "key1");
+        validateRequestRun(mapState, StateRequestType.MAP_GET, "key1", 0);
 
         mapState.put("key2", 2);
-        validateRequestRun(mapState, StateRequestType.MAP_PUT, 
Tuple2.of("key2", 2));
+        validateRequestRun(mapState, StateRequestType.MAP_PUT, 
Tuple2.of("key2", 2), 0);
 
         mapState.putAll(map);
-        validateRequestRun(mapState, StateRequestType.MAP_PUT_ALL, map);
+        validateRequestRun(mapState, StateRequestType.MAP_PUT_ALL, map, 0);
 
         mapState.remove("key3");
-        validateRequestRun(mapState, StateRequestType.MAP_REMOVE, "key3");
+        validateRequestRun(mapState, StateRequestType.MAP_REMOVE, "key3", 0);
 
         mapState.contains("key4");
-        validateRequestRun(mapState, StateRequestType.MAP_CONTAINS, "key4");
+        validateRequestRun(mapState, StateRequestType.MAP_CONTAINS, "key4", 0);
 
         mapState.iterator();
-        validateRequestRun(mapState, StateRequestType.MAP_ITER, null);
+        validateRequestRun(mapState, StateRequestType.MAP_ITER, null, 0);
 
         mapState.entries().iterator();
-        validateRequestRun(mapState, StateRequestType.MAP_ITER, null);
+        validateRequestRun(mapState, StateRequestType.MAP_ITER, null, 0);
 
         mapState.keys().iterator();
-        validateRequestRun(mapState, StateRequestType.MAP_ITER_KEY, null);
+        validateRequestRun(mapState, StateRequestType.MAP_ITER_KEY, null, 0);
 
         mapState.values().iterator();
-        validateRequestRun(mapState, StateRequestType.MAP_ITER_VALUE, null);
+        validateRequestRun(mapState, StateRequestType.MAP_ITER_VALUE, null, 0);
 
         mapState.isEmpty();
-        validateRequestRun(mapState, StateRequestType.MAP_IS_EMPTY, null);
+        validateRequestRun(mapState, StateRequestType.MAP_IS_EMPTY, null, 0);
     }
 }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalReducingStateTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalReducingStateTest.java
index cf4657ef626..f923dd52128 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalReducingStateTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalReducingStateTest.java
@@ -20,10 +20,26 @@ package org.apache.flink.runtime.state.v2;
 
 import org.apache.flink.api.common.functions.ReduceFunction;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.runtime.asyncprocessing.AsyncExecutionController;
+import org.apache.flink.runtime.asyncprocessing.MockStateRequestContainer;
+import org.apache.flink.runtime.asyncprocessing.StateExecutor;
+import org.apache.flink.runtime.asyncprocessing.StateRequest;
+import org.apache.flink.runtime.asyncprocessing.StateRequestContainer;
 import org.apache.flink.runtime.asyncprocessing.StateRequestType;
+import org.apache.flink.runtime.mailbox.SyncMailboxExecutor;
+import org.apache.flink.util.Preconditions;
 
 import org.junit.jupiter.api.Test;
 
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+
+import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
+
 /** Tests for {@link InternalReducingState}. */
 public class InternalReducingStateTest extends InternalKeyedStateTestBase {
 
@@ -38,21 +54,135 @@ public class InternalReducingStateTest extends 
InternalKeyedStateTestBase {
         aec.setCurrentContext(aec.buildContext("test", "test"));
 
         reducingState.asyncClear();
-        validateRequestRun(reducingState, StateRequestType.CLEAR, null);
+        validateRequestRun(reducingState, StateRequestType.CLEAR, null, 0);
 
         reducingState.asyncGet();
-        validateRequestRun(reducingState, StateRequestType.REDUCING_GET, null);
+        validateRequestRun(reducingState, StateRequestType.REDUCING_GET, null, 
0);
 
         reducingState.asyncAdd(1);
-        validateRequestRun(reducingState, StateRequestType.REDUCING_ADD, 1);
+        validateRequestRun(reducingState, StateRequestType.REDUCING_GET, null, 
1);
+        validateRequestRun(reducingState, StateRequestType.REDUCING_ADD, 1, 0);
 
         reducingState.clear();
-        validateRequestRun(reducingState, StateRequestType.CLEAR, null);
+        validateRequestRun(reducingState, StateRequestType.CLEAR, null, 0);
 
         reducingState.get();
-        validateRequestRun(reducingState, StateRequestType.REDUCING_GET, null);
+        validateRequestRun(reducingState, StateRequestType.REDUCING_GET, null, 
0);
 
         reducingState.add(1);
-        validateRequestRun(reducingState, StateRequestType.REDUCING_ADD, 1);
+        validateRequestRun(reducingState, StateRequestType.REDUCING_GET, null, 
1);
+        validateRequestRun(reducingState, StateRequestType.REDUCING_ADD, 1, 0);
+    }
+
+    @Test
+    public void testMergeNamespace() throws Exception {
+        ReduceFunction<Integer> reducer = Integer::sum;
+        ReducingStateDescriptor<Integer> descriptor =
+                new ReducingStateDescriptor<>("testState", reducer, 
BasicTypeInfo.INT_TYPE_INFO);
+        AsyncExecutionController aec =
+                new AsyncExecutionController(
+                        new SyncMailboxExecutor(),
+                        (a, b) -> {},
+                        new ReducingStateExecutor(),
+                        1,
+                        100,
+                        10000,
+                        1);
+        InternalReducingState<String, String, Integer> reducingState =
+                new InternalReducingState<>(aec, descriptor);
+        aec.setCurrentContext(aec.buildContext("test", "test"));
+        aec.setCurrentNamespaceForState(reducingState, "1");
+        reducingState.asyncAdd(1);
+        aec.drainInflightRecords(0);
+        assertThat(ReducingStateExecutor.hashMap.size()).isEqualTo(1);
+        assertThat(ReducingStateExecutor.hashMap.get(Tuple2.of("test", 
"1"))).isEqualTo(1);
+        aec.setCurrentNamespaceForState(reducingState, "2");
+        reducingState.asyncAdd(2);
+        aec.drainInflightRecords(0);
+        assertThat(ReducingStateExecutor.hashMap.size()).isEqualTo(2);
+        assertThat(ReducingStateExecutor.hashMap.get(Tuple2.of("test", 
"1"))).isEqualTo(1);
+        assertThat(ReducingStateExecutor.hashMap.get(Tuple2.of("test", 
"2"))).isEqualTo(2);
+        aec.setCurrentNamespaceForState(reducingState, "3");
+        reducingState.asyncAdd(3);
+        aec.drainInflightRecords(0);
+        assertThat(ReducingStateExecutor.hashMap.size()).isEqualTo(3);
+        assertThat(ReducingStateExecutor.hashMap.get(Tuple2.of("test", 
"1"))).isEqualTo(1);
+        assertThat(ReducingStateExecutor.hashMap.get(Tuple2.of("test", 
"2"))).isEqualTo(2);
+        assertThat(ReducingStateExecutor.hashMap.get(Tuple2.of("test", 
"3"))).isEqualTo(3);
+
+        List<String> sources = new ArrayList<>(Arrays.asList("1", "2", "3"));
+        reducingState.asyncMergeNamespaces("0", sources);
+        aec.drainInflightRecords(0);
+        assertThat(ReducingStateExecutor.hashMap.size()).isEqualTo(1);
+        assertThat(ReducingStateExecutor.hashMap.get(Tuple2.of("test", 
"0"))).isEqualTo(6);
+        assertThat(ReducingStateExecutor.hashMap.get(Tuple2.of("test", 
"1"))).isNull();
+        assertThat(ReducingStateExecutor.hashMap.get(Tuple2.of("test", 
"2"))).isNull();
+        assertThat(ReducingStateExecutor.hashMap.get(Tuple2.of("test", 
"3"))).isNull();
+
+        aec.setCurrentNamespaceForState(reducingState, "4");
+        reducingState.asyncAdd(4);
+        aec.drainInflightRecords(0);
+        assertThat(ReducingStateExecutor.hashMap.size()).isEqualTo(2);
+        assertThat(ReducingStateExecutor.hashMap.get(Tuple2.of("test", 
"0"))).isEqualTo(6);
+        assertThat(ReducingStateExecutor.hashMap.get(Tuple2.of("test", 
"4"))).isEqualTo(4);
+
+        List<String> sources1 = new ArrayList<>(Arrays.asList("4"));
+        reducingState.asyncMergeNamespaces("0", sources1);
+        aec.drainInflightRecords(0);
+
+        assertThat(ReducingStateExecutor.hashMap.size()).isEqualTo(1);
+        assertThat(ReducingStateExecutor.hashMap.get(Tuple2.of("test", 
"0"))).isEqualTo(10);
+        assertThat(ReducingStateExecutor.hashMap.get(Tuple2.of("test", 
"1"))).isNull();
+        assertThat(ReducingStateExecutor.hashMap.get(Tuple2.of("test", 
"2"))).isNull();
+        assertThat(ReducingStateExecutor.hashMap.get(Tuple2.of("test", 
"3"))).isNull();
+        assertThat(ReducingStateExecutor.hashMap.get(Tuple2.of("test", 
"4"))).isNull();
+    }
+
+    static class ReducingStateExecutor implements StateExecutor {
+
+        private static final HashMap<Tuple2<String, String>, Integer> hashMap 
= new HashMap<>();
+
+        public ReducingStateExecutor() {
+            hashMap.clear();
+        }
+
+        @Override
+        @SuppressWarnings({"unchecked", "rawtypes"})
+        public CompletableFuture<Void> executeBatchRequests(
+                StateRequestContainer stateRequestContainer) {
+            Preconditions.checkArgument(stateRequestContainer instanceof 
MockStateRequestContainer);
+            CompletableFuture<Void> future = new CompletableFuture<>();
+            for (StateRequest request :
+                    ((MockStateRequestContainer) 
stateRequestContainer).getStateRequestList()) {
+                if (request.getRequestType() == StateRequestType.REDUCING_GET) 
{
+                    String key = (String) request.getRecordContext().getKey();
+                    String namespace = (String) request.getNamespace();
+                    Integer val = hashMap.get(Tuple2.of(key, namespace));
+                    request.getFuture().complete(val);
+                } else if (request.getRequestType() == 
StateRequestType.REDUCING_ADD) {
+                    String key = (String) request.getRecordContext().getKey();
+                    String namespace = (String) request.getNamespace();
+                    hashMap.put(Tuple2.of(key, namespace), (Integer) 
request.getPayload());
+                    request.getFuture().complete(null);
+                } else if (request.getRequestType() == 
StateRequestType.REDUCING_REMOVE) {
+                    String key = (String) request.getRecordContext().getKey();
+                    String namespace = (String) request.getNamespace();
+                    hashMap.remove(Tuple2.of(key, namespace));
+                    request.getFuture().complete(null);
+                } else {
+                    throw new UnsupportedOperationException("Unsupported 
request type");
+                }
+            }
+            future.complete(null);
+            return future;
+        }
+
+        @Override
+        public StateRequestContainer createStateRequestContainer() {
+            return new MockStateRequestContainer();
+        }
+
+        @Override
+        public void shutdown() {}
     }
 }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalValueStateTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalValueStateTest.java
index cb79b4ab7b2..58278ecbd39 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalValueStateTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalValueStateTest.java
@@ -36,21 +36,21 @@ public class InternalValueStateTest extends 
InternalKeyedStateTestBase {
         aec.setCurrentContext(aec.buildContext("test", "test"));
 
         valueState.asyncClear();
-        validateRequestRun(valueState, StateRequestType.CLEAR, null);
+        validateRequestRun(valueState, StateRequestType.CLEAR, null, 0);
 
         valueState.asyncValue();
-        validateRequestRun(valueState, StateRequestType.VALUE_GET, null);
+        validateRequestRun(valueState, StateRequestType.VALUE_GET, null, 0);
 
         valueState.asyncUpdate(1);
-        validateRequestRun(valueState, StateRequestType.VALUE_UPDATE, 1);
+        validateRequestRun(valueState, StateRequestType.VALUE_UPDATE, 1, 0);
 
         valueState.clear();
-        validateRequestRun(valueState, StateRequestType.CLEAR, null);
+        validateRequestRun(valueState, StateRequestType.CLEAR, null, 0);
 
         valueState.value();
-        validateRequestRun(valueState, StateRequestType.VALUE_GET, null);
+        validateRequestRun(valueState, StateRequestType.VALUE_GET, null, 0);
 
         valueState.update(1);
-        validateRequestRun(valueState, StateRequestType.VALUE_UPDATE, 1);
+        validateRequestRun(valueState, StateRequestType.VALUE_UPDATE, 1, 0);
     }
 }
diff --git 
a/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStKeyedStateBackend.java
 
b/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStKeyedStateBackend.java
index 7ec2f60c854..a04c7f2b5f4 100644
--- 
a/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStKeyedStateBackend.java
+++ 
b/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStKeyedStateBackend.java
@@ -28,6 +28,7 @@ import 
org.apache.flink.runtime.asyncprocessing.StateRequestHandler;
 import org.apache.flink.runtime.state.AsyncKeyedStateBackend;
 import org.apache.flink.runtime.state.SerializedCompositeKeyBuilder;
 import org.apache.flink.runtime.state.v2.ListStateDescriptor;
+import org.apache.flink.runtime.state.v2.ReducingStateDescriptor;
 import org.apache.flink.runtime.state.v2.StateDescriptor;
 import org.apache.flink.runtime.state.v2.ValueStateDescriptor;
 import org.apache.flink.util.FlinkRuntimeException;
@@ -191,6 +192,17 @@ public class ForStKeyedStateBackend<K> implements 
AsyncKeyedStateBackend {
                         keyDeserializerView,
                         valueDeserializerView,
                         keyGroupPrefixBytes);
+            case REDUCING:
+                return (S)
+                        new ForStReducingState<>(
+                                stateRequestHandler,
+                                columnFamilyHandle,
+                                (ReducingStateDescriptor<SV>) stateDesc,
+                                serializedKeyBuilder,
+                                defaultNamespace,
+                                namespaceSerializer::duplicate,
+                                valueSerializerView,
+                                valueDeserializerView);
             default:
                 throw new UnsupportedOperationException(
                         String.format("Unsupported state type: %s", 
stateDesc.getType()));
diff --git 
a/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStReducingState.java
 
b/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStReducingState.java
new file mode 100644
index 00000000000..3fe2329bd7b
--- /dev/null
+++ 
b/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStReducingState.java
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.state.forst;
+
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.memory.DataInputDeserializer;
+import org.apache.flink.core.memory.DataOutputSerializer;
+import org.apache.flink.core.state.InternalStateFuture;
+import org.apache.flink.runtime.asyncprocessing.RecordContext;
+import org.apache.flink.runtime.asyncprocessing.StateRequest;
+import org.apache.flink.runtime.asyncprocessing.StateRequestHandler;
+import org.apache.flink.runtime.asyncprocessing.StateRequestType;
+import org.apache.flink.runtime.state.SerializedCompositeKeyBuilder;
+import org.apache.flink.runtime.state.v2.InternalReducingState;
+import org.apache.flink.runtime.state.v2.ReducingStateDescriptor;
+import org.apache.flink.util.Preconditions;
+
+import org.rocksdb.ColumnFamilyHandle;
+
+import java.io.IOException;
+import java.util.function.Supplier;
+
+/**
+ * The {@link InternalReducingState} implement for ForStDB.
+ *
+ * @param <K> The type of the key.
+ * @param <N> The type of the namespace.
+ * @param <V> The type of the value.
+ */
+public class ForStReducingState<K, N, V> extends InternalReducingState<K, N, V>
+        implements ForStInnerTable<K, N, V> {
+
+    /** The column family which this internal value state belongs to. */
+    private final ColumnFamilyHandle columnFamilyHandle;
+
+    /** The serialized key builder which should be thread-safe. */
+    private final ThreadLocal<SerializedCompositeKeyBuilder<K>> 
serializedKeyBuilder;
+
+    /** The default namespace if not set. * */
+    private final N defaultNamespace;
+
+    /** The serializer for namespace. * */
+    private final ThreadLocal<TypeSerializer<N>> namespaceSerializer;
+
+    /** The data outputStream used for value serializer, which should be 
thread-safe. */
+    private final ThreadLocal<DataOutputSerializer> valueSerializerView;
+
+    /** The data inputStream used for value deserializer, which should be 
thread-safe. */
+    private final ThreadLocal<DataInputDeserializer> valueDeserializerView;
+
+    public ForStReducingState(
+            StateRequestHandler stateRequestHandler,
+            ColumnFamilyHandle columnFamily,
+            ReducingStateDescriptor<V> reducingStateDescriptor,
+            Supplier<SerializedCompositeKeyBuilder<K>> 
serializedKeyBuilderInitializer,
+            N defaultNamespace,
+            Supplier<TypeSerializer<N>> namespaceSerializerInitializer,
+            Supplier<DataOutputSerializer> valueSerializerViewInitializer,
+            Supplier<DataInputDeserializer> valueDeserializerViewInitializer) {
+        super(stateRequestHandler, reducingStateDescriptor);
+        this.columnFamilyHandle = columnFamily;
+        this.serializedKeyBuilder = 
ThreadLocal.withInitial(serializedKeyBuilderInitializer);
+        this.defaultNamespace = defaultNamespace;
+        this.namespaceSerializer = 
ThreadLocal.withInitial(namespaceSerializerInitializer);
+        this.valueSerializerView = 
ThreadLocal.withInitial(valueSerializerViewInitializer);
+        this.valueDeserializerView = 
ThreadLocal.withInitial(valueDeserializerViewInitializer);
+    }
+
+    @Override
+    public ColumnFamilyHandle getColumnFamilyHandle() {
+        return columnFamilyHandle;
+    }
+
+    @Override
+    public byte[] serializeKey(ContextKey<K, N> contextKey) throws IOException 
{
+        return contextKey.getOrCreateSerializedKey(
+                ctxKey -> {
+                    SerializedCompositeKeyBuilder<K> builder = 
serializedKeyBuilder.get();
+                    builder.setKeyAndKeyGroup(ctxKey.getRawKey(), 
ctxKey.getKeyGroup());
+                    N namespace = ctxKey.getNamespace();
+                    return builder.buildCompositeKeyNamespace(
+                            namespace == null ? defaultNamespace : namespace,
+                            namespaceSerializer.get());
+                });
+    }
+
+    @Override
+    public byte[] serializeValue(V value) throws IOException {
+        DataOutputSerializer outputView = valueSerializerView.get();
+        outputView.clear();
+        getValueSerializer().serialize(value, outputView);
+        return outputView.getCopyOfBuffer();
+    }
+
+    @Override
+    public V deserializeValue(byte[] valueBytes) throws IOException {
+        DataInputDeserializer inputView = valueDeserializerView.get();
+        inputView.setBuffer(valueBytes);
+        return getValueSerializer().deserialize(inputView);
+    }
+
+    @SuppressWarnings("unchecked")
+    @Override
+    public ForStDBGetRequest<K, N, V, V> buildDBGetRequest(StateRequest<?, ?, 
?, ?> stateRequest) {
+        Preconditions.checkArgument(stateRequest.getRequestType() == 
StateRequestType.REDUCING_GET);
+        ContextKey<K, N> contextKey =
+                new ContextKey<>(
+                        (RecordContext<K>) stateRequest.getRecordContext(),
+                        (N) stateRequest.getNamespace());
+        return new ForStDBSingleGetRequest<>(
+                contextKey, this, (InternalStateFuture<V>) 
stateRequest.getFuture());
+    }
+
+    @SuppressWarnings("unchecked")
+    @Override
+    public ForStDBPutRequest<K, N, V> buildDBPutRequest(StateRequest<?, ?, ?, 
?> stateRequest) {
+        Preconditions.checkArgument(
+                stateRequest.getRequestType() == StateRequestType.REDUCING_ADD
+                        || stateRequest.getRequestType() == 
StateRequestType.REDUCING_REMOVE
+                        || stateRequest.getRequestType() == 
StateRequestType.CLEAR);
+        ContextKey<K, N> contextKey =
+                new ContextKey<>(
+                        (RecordContext<K>) stateRequest.getRecordContext(),
+                        (N) stateRequest.getNamespace());
+        V value =
+                (stateRequest.getRequestType() == 
StateRequestType.REDUCING_REMOVE
+                                || stateRequest.getRequestType() == 
StateRequestType.CLEAR)
+                        ? null // "Delete(key)" is equivalent to "Put(key, 
null)"
+                        : (V) stateRequest.getPayload();
+        return ForStDBPutRequest.of(
+                contextKey, value, this, (InternalStateFuture<Void>) 
stateRequest.getFuture());
+    }
+}
diff --git 
a/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStStateRequestClassifier.java
 
b/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStStateRequestClassifier.java
index ea6ede88acf..c7205f0791c 100644
--- 
a/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStStateRequestClassifier.java
+++ 
b/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStStateRequestClassifier.java
@@ -62,6 +62,7 @@ public class ForStStateRequestClassifier implements 
StateRequestContainer {
             case MAP_GET:
             case MAP_IS_EMPTY:
             case MAP_CONTAINS:
+            case REDUCING_GET:
                 {
                     ForStInnerTable<?, ?, ?> innerTable =
                             (ForStInnerTable<?, ?, ?>) stateRequest.getState();
@@ -74,6 +75,7 @@ public class ForStStateRequestClassifier implements 
StateRequestContainer {
             case LIST_ADD_ALL:
             case MAP_PUT:
             case MAP_REMOVE:
+            case REDUCING_ADD:
                 {
                     ForStInnerTable<?, ?, ?> innerTable =
                             (ForStInnerTable<?, ?, ?>) stateRequest.getState();


Reply via email to