This is an automated email from the ASF dual-hosted git repository.

leiyanfei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 6a85f8047b2 [FLINK-34974][state] Support getOrCreateKeyedState for 
AsyncKeyedStateBackend (#25745)
6a85f8047b2 is described below

commit 6a85f8047b2bf3121a3961f6bd5905b0d33403e8
Author: Yanfei Lei <fredia...@gmail.com>
AuthorDate: Wed Dec 11 10:14:36 2024 +0800

    [FLINK-34974][state] Support getOrCreateKeyedState for 
AsyncKeyedStateBackend (#25745)
---
 .../runtime/state/AsyncKeyedStateBackend.java      | 29 ++++++++++++++++------
 .../runtime/state/v2/DefaultKeyedStateStore.java   | 10 ++++----
 .../v2/adaptor/AsyncKeyedStateBackendAdaptor.java  | 15 ++++++-----
 .../api/operators/StreamOperatorStateHandler.java  |  2 +-
 .../AsyncExecutionControllerTest.java              |  2 +-
 .../flink/runtime/state/StateBackendTestUtils.java | 12 ++++-----
 .../state/v2/AbstractKeyedStateTestBase.java       |  9 +++----
 .../v2/AsyncKeyedStateBackendAdaptorTest.java      | 10 ++++----
 .../runtime/state/v2/StateBackendTestV2Base.java   |  6 ++---
 .../api/operators/StreamingRuntimeContextTest.java |  2 +-
 .../flink/state/forst/ForStKeyedStateBackend.java  | 29 ++++++++++++++++++++--
 11 files changed, 84 insertions(+), 42 deletions(-)

diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AsyncKeyedStateBackend.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AsyncKeyedStateBackend.java
index d6e606529aa..133e107e768 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AsyncKeyedStateBackend.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AsyncKeyedStateBackend.java
@@ -56,7 +56,7 @@ public interface AsyncKeyedStateBackend<K>
     void setup(@Nonnull StateRequestHandler stateRequestHandler);
 
     /**
-     * Creates and returns a new state.
+     * Creates or retrieves a keyed state backed by this state backend.
      *
      * @param <N> the type of namespace for partitioning.
      * @param <S> The type of the public API state.
@@ -64,13 +64,14 @@ public interface AsyncKeyedStateBackend<K>
      * @param defaultNamespace the default namespace for this state.
      * @param namespaceSerializer the serializer for namespace.
      * @param stateDesc The {@code StateDescriptor} that contains the name of 
the state.
-     * @throws Exception Exceptions may occur during initialization of the 
state.
+     * @return A new key/value state backed by this backend.
+     * @throws Exception Exceptions may occur during initialization of the 
state and should be
+     *     forwarded.
      */
-    @Nonnull
-    <N, S extends State, SV> S createState(
-            @Nonnull N defaultNamespace,
-            @Nonnull TypeSerializer<N> namespaceSerializer,
-            @Nonnull StateDescriptor<SV> stateDesc)
+    <N, S extends State, SV> S getOrCreateKeyedState(
+            N defaultNamespace,
+            TypeSerializer<N> namespaceSerializer,
+            StateDescriptor<SV> stateDesc)
             throws Exception;
 
     /**
@@ -122,6 +123,20 @@ public interface AsyncKeyedStateBackend<K>
         return true;
     }
 
+    /**
+     * Whether it's safe to reuse key-values from the state-backend, e.g for 
the purpose of
+     * optimization.
+     *
+     * <p>NOTE: this method should not be used to check for {@link 
InternalPriorityQueue}, as the
+     * priority queue could be stored on different locations, e.g ForSt 
state-backend could store
+     * that on JVM heap if configuring HEAP as the time-service factory.
+     *
+     * @return returns ture if safe to reuse the key-values from the 
state-backend.
+     */
+    default boolean isSafeToReuseKVState() {
+        return false;
+    }
+
     @Override
     void dispose();
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/DefaultKeyedStateStore.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/DefaultKeyedStateStore.java
index 3a96b950ecb..9eb655bc671 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/DefaultKeyedStateStore.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/DefaultKeyedStateStore.java
@@ -43,7 +43,7 @@ public class DefaultKeyedStateStore implements 
KeyedStateStore {
     public <T> ValueState<T> getValueState(@Nonnull ValueStateDescriptor<T> 
stateProperties) {
         Preconditions.checkNotNull(stateProperties, "The state properties must 
not be null");
         try {
-            return asyncKeyedStateBackend.createState(
+            return asyncKeyedStateBackend.getOrCreateKeyedState(
                     VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, 
stateProperties);
         } catch (Exception e) {
             throw new RuntimeException("Error while getting state", e);
@@ -54,7 +54,7 @@ public class DefaultKeyedStateStore implements 
KeyedStateStore {
     public <T> ListState<T> getListState(@Nonnull ListStateDescriptor<T> 
stateProperties) {
         Preconditions.checkNotNull(stateProperties, "The state properties must 
not be null");
         try {
-            return asyncKeyedStateBackend.createState(
+            return asyncKeyedStateBackend.getOrCreateKeyedState(
                     VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, 
stateProperties);
         } catch (Exception e) {
             throw new RuntimeException("Error while getting state", e);
@@ -66,7 +66,7 @@ public class DefaultKeyedStateStore implements 
KeyedStateStore {
             @Nonnull MapStateDescriptor<UK, UV> stateProperties) {
         Preconditions.checkNotNull(stateProperties, "The state properties must 
not be null");
         try {
-            return asyncKeyedStateBackend.createState(
+            return asyncKeyedStateBackend.getOrCreateKeyedState(
                     VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, 
stateProperties);
         } catch (Exception e) {
             throw new RuntimeException("Error while getting state", e);
@@ -78,7 +78,7 @@ public class DefaultKeyedStateStore implements 
KeyedStateStore {
             @Nonnull ReducingStateDescriptor<T> stateProperties) {
         Preconditions.checkNotNull(stateProperties, "The state properties must 
not be null");
         try {
-            return asyncKeyedStateBackend.createState(
+            return asyncKeyedStateBackend.getOrCreateKeyedState(
                     VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, 
stateProperties);
         } catch (Exception e) {
             throw new RuntimeException("Error while getting state", e);
@@ -90,7 +90,7 @@ public class DefaultKeyedStateStore implements 
KeyedStateStore {
             @Nonnull AggregatingStateDescriptor<IN, ACC, OUT> stateProperties) 
{
         Preconditions.checkNotNull(stateProperties, "The state properties must 
not be null");
         try {
-            return asyncKeyedStateBackend.createState(
+            return asyncKeyedStateBackend.getOrCreateKeyedState(
                     VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, 
stateProperties);
         } catch (Exception e) {
             throw new RuntimeException("Error while getting state", e);
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/adaptor/AsyncKeyedStateBackendAdaptor.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/adaptor/AsyncKeyedStateBackendAdaptor.java
index 28c6993c9c4..79b98491bac 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/adaptor/AsyncKeyedStateBackendAdaptor.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/adaptor/AsyncKeyedStateBackendAdaptor.java
@@ -68,13 +68,11 @@ public class AsyncKeyedStateBackendAdaptor<K> implements 
AsyncKeyedStateBackend<
     @Override
     public void setup(@Nonnull StateRequestHandler stateRequestHandler) {}
 
-    @Nonnull
     @Override
-    @SuppressWarnings({"rawtypes", "unchecked"})
-    public <N, S extends State, SV> S createState(
-            @Nonnull N defaultNamespace,
-            @Nonnull TypeSerializer<N> namespaceSerializer,
-            @Nonnull StateDescriptor<SV> stateDesc)
+    public <N, S extends State, SV> S getOrCreateKeyedState(
+            N defaultNamespace,
+            TypeSerializer<N> namespaceSerializer,
+            StateDescriptor<SV> stateDesc)
             throws Exception {
         return createStateInternal(defaultNamespace, namespaceSerializer, 
stateDesc);
     }
@@ -191,4 +189,9 @@ public class AsyncKeyedStateBackendAdaptor<K> implements 
AsyncKeyedStateBackend<
         }
         return false;
     }
+
+    @Override
+    public boolean isSafeToReuseKVState() {
+        return keyedStateBackend.isSafeToReuseKVState();
+    }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/streaming/api/operators/StreamOperatorStateHandler.java
 
b/flink-runtime/src/main/java/org/apache/flink/streaming/api/operators/StreamOperatorStateHandler.java
index 7ab8636f2fe..aebfdada84a 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/streaming/api/operators/StreamOperatorStateHandler.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/streaming/api/operators/StreamOperatorStateHandler.java
@@ -415,7 +415,7 @@ public class StreamOperatorStateHandler {
             throws Exception {
 
         if (asyncKeyedStateBackend != null) {
-            return asyncKeyedStateBackend.createState(
+            return asyncKeyedStateBackend.getOrCreateKeyedState(
                     defaultNamespace, namespaceSerializer, stateDescriptor);
         } else {
             throw new IllegalStateException(
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/asyncprocessing/AsyncExecutionControllerTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/asyncprocessing/AsyncExecutionControllerTest.java
index 8d7853eb7ef..ae0c96cb3d9 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/asyncprocessing/AsyncExecutionControllerTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/asyncprocessing/AsyncExecutionControllerTest.java
@@ -110,7 +110,7 @@ class AsyncExecutionControllerTest {
 
         try {
             valueState =
-                    asyncKeyedStateBackend.createState(
+                    asyncKeyedStateBackend.getOrCreateKeyedState(
                             VoidNamespace.INSTANCE,
                             VoidNamespaceSerializer.INSTANCE,
                             stateDescriptor);
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestUtils.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestUtils.java
index eca13bc180f..e30ba257bc8 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestUtils.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestUtils.java
@@ -131,13 +131,13 @@ public class StateBackendTestUtils {
             // do nothing
         }
 
-        @Nonnull
         @Override
-        @SuppressWarnings("unchecked")
-        public <N, S extends org.apache.flink.api.common.state.v2.State, SV> S 
createState(
-                @Nonnull N defaultNamespace,
-                @Nonnull TypeSerializer<N> namespaceSerializer,
-                @Nonnull org.apache.flink.runtime.state.v2.StateDescriptor<SV> 
stateDesc) {
+        public <N, S extends org.apache.flink.api.common.state.v2.State, SV>
+                S getOrCreateKeyedState(
+                        N defaultNamespace,
+                        TypeSerializer<N> namespaceSerializer,
+                        org.apache.flink.runtime.state.v2.StateDescriptor<SV> 
stateDesc)
+                        throws Exception {
             return (S) innerStateSupplier.get();
         }
 
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AbstractKeyedStateTestBase.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AbstractKeyedStateTestBase.java
index 56ae6c13470..d413f07d2ea 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AbstractKeyedStateTestBase.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AbstractKeyedStateTestBase.java
@@ -173,12 +173,11 @@ public class AbstractKeyedStateTestBase {
                 @Override
                 public void setup(@Nonnull StateRequestHandler 
stateRequestHandler) {}
 
-                @Nonnull
                 @Override
-                public <N, S extends State, SV> S createState(
-                        @Nonnull N defaultNamespace,
-                        @Nonnull TypeSerializer<N> namespaceSerializer,
-                        @Nonnull StateDescriptor<SV> stateDesc)
+                public <N, S extends State, SV> S getOrCreateKeyedState(
+                        N defaultNamespace,
+                        TypeSerializer<N> namespaceSerializer,
+                        StateDescriptor<SV> stateDesc)
                         throws Exception {
                     return null;
                 }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AsyncKeyedStateBackendAdaptorTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AsyncKeyedStateBackendAdaptorTest.java
index b1dbf0e6174..371e4ef3ab3 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AsyncKeyedStateBackendAdaptorTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AsyncKeyedStateBackendAdaptorTest.java
@@ -63,7 +63,7 @@ public class AsyncKeyedStateBackendAdaptorTest {
                 new ValueStateDescriptor<>("testState", 
BasicTypeInfo.INT_TYPE_INFO);
 
         org.apache.flink.api.common.state.v2.ValueState<Integer> valueState =
-                adaptor.createState(
+                adaptor.getOrCreateKeyedState(
                         VoidNamespace.INSTANCE, 
VoidNamespaceSerializer.INSTANCE, descriptor);
 
         // test synchronous interfaces.
@@ -102,7 +102,7 @@ public class AsyncKeyedStateBackendAdaptorTest {
                 new ListStateDescriptor<>("testState", 
BasicTypeInfo.INT_TYPE_INFO);
 
         org.apache.flink.api.common.state.v2.ListState<Integer> listState =
-                adaptor.createState(
+                adaptor.getOrCreateKeyedState(
                         VoidNamespace.INSTANCE, 
VoidNamespaceSerializer.INSTANCE, descriptor);
 
         // test synchronous interfaces.
@@ -154,7 +154,7 @@ public class AsyncKeyedStateBackendAdaptorTest {
                         "testState", BasicTypeInfo.INT_TYPE_INFO, 
BasicTypeInfo.INT_TYPE_INFO);
 
         org.apache.flink.api.common.state.v2.MapState<Integer, Integer> 
mapState =
-                adaptor.createState(
+                adaptor.getOrCreateKeyedState(
                         VoidNamespace.INSTANCE, 
VoidNamespaceSerializer.INSTANCE, descriptor);
 
         final HashMap<Integer, Integer> groundTruth =
@@ -247,7 +247,7 @@ public class AsyncKeyedStateBackendAdaptorTest {
                         "testState", Integer::sum, 
BasicTypeInfo.INT_TYPE_INFO);
 
         InternalReducingState<String, Long, Integer> reducingState =
-                adaptor.createState(0L, LongSerializer.INSTANCE, descriptor);
+                adaptor.getOrCreateKeyedState(0L, LongSerializer.INSTANCE, 
descriptor);
 
         // test synchronous interfaces.
         reducingState.clear();
@@ -353,7 +353,7 @@ public class AsyncKeyedStateBackendAdaptorTest {
                         BasicTypeInfo.INT_TYPE_INFO);
 
         InternalAggregatingState<String, Long, Integer, Integer, String> 
aggState =
-                adaptor.createState(0L, LongSerializer.INSTANCE, descriptor);
+                adaptor.getOrCreateKeyedState(0L, LongSerializer.INSTANCE, 
descriptor);
 
         // test synchronous interfaces.
         aggState.clear();
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/StateBackendTestV2Base.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/StateBackendTestV2Base.java
index 3e4a26fe83f..c6ec6366c4a 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/StateBackendTestV2Base.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/StateBackendTestV2Base.java
@@ -229,7 +229,7 @@ public abstract class StateBackendTestV2Base<B extends 
AbstractStateBackend> {
                     new ValueStateDescriptor<>("test", 
BasicTypeInfo.INT_TYPE_INFO);
 
             ValueState<Integer> valueState =
-                    backend.createState(
+                    backend.getOrCreateKeyedState(
                             VoidNamespace.INSTANCE,
                             VoidNamespaceSerializer.INSTANCE,
                             stateDescriptor);
@@ -318,7 +318,7 @@ public abstract class StateBackendTestV2Base<B extends 
AbstractStateBackend> {
                     new ValueStateDescriptor<>("test", 
BasicTypeInfo.INT_TYPE_INFO);
 
             ValueState<Integer> valueState =
-                    backend.createState(
+                    backend.getOrCreateKeyedState(
                             VoidNamespace.INSTANCE,
                             VoidNamespaceSerializer.INSTANCE,
                             stateDescriptor);
@@ -432,7 +432,7 @@ public abstract class StateBackendTestV2Base<B extends 
AbstractStateBackend> {
             
kvId.enableTimeToLive(StateTtlConfig.newBuilder(Duration.ofSeconds(1)).build());
 
             ValueState<Long> state =
-                    backend.createState(
+                    backend.getOrCreateKeyedState(
                             VoidNamespace.INSTANCE, 
VoidNamespaceSerializer.INSTANCE, kvId);
             RecordContext recordContext = aec.buildContext("record-1", 1L);
             recordContext.retain();
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java
index 1098fc67ad7..b97f33431d7 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java
@@ -493,7 +493,7 @@ class StreamingRuntimeContextTest {
                                     return null;
                                 })
                 .when(asyncKeyedStateBackend)
-                .createState(
+                .getOrCreateKeyedState(
                         any(),
                         any(TypeSerializer.class),
                         
any(org.apache.flink.runtime.state.v2.StateDescriptor.class));
diff --git 
a/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStKeyedStateBackend.java
 
b/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStKeyedStateBackend.java
index 6e0c7613ab1..6f09c2dc315 100644
--- 
a/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStKeyedStateBackend.java
+++ 
b/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStKeyedStateBackend.java
@@ -71,6 +71,7 @@ import javax.annotation.Nonnull;
 import javax.annotation.concurrent.GuardedBy;
 
 import java.io.IOException;
+import java.util.HashMap;
 import java.util.HashSet;
 import java.util.LinkedHashMap;
 import java.util.Map;
@@ -81,6 +82,7 @@ import java.util.function.Function;
 import java.util.function.Supplier;
 
 import static 
org.apache.flink.runtime.state.SnapshotExecutionType.ASYNCHRONOUS;
+import static org.apache.flink.util.Preconditions.checkNotNull;
 
 /**
  * A KeyedStateBackend that stores its state in {@code ForSt}. This state 
backend can store very
@@ -158,6 +160,9 @@ public class ForStKeyedStateBackend<K> implements 
AsyncKeyedStateBackend<K> {
      */
     private final LinkedHashMap<String, ForStOperationUtils.ForStKvStateInfo> 
kvStateInformation;
 
+    /** So that we can give out state when the user uses the same key. */
+    private final HashMap<String, InternalKeyedState<K, ?, ?>> 
keyValueStatesByName;
+
     /** Lock guarding the {@code managedStateExecutors} and {@code disposed}. 
*/
     private final Object lock = new Object();
 
@@ -201,6 +206,7 @@ public class ForStKeyedStateBackend<K> implements 
AsyncKeyedStateBackend<K> {
         this.valueDeserializerView = valueDeserializerView;
         this.db = db;
         this.kvStateInformation = kvStateInformation;
+        this.keyValueStatesByName = new HashMap<>();
         this.columnFamilyOptionsFactory = columnFamilyOptionsFactory;
         this.defaultColumnFamily = defaultColumnFamilyHandle;
         this.snapshotStrategy = snapshotStrategy;
@@ -227,10 +233,24 @@ public class ForStKeyedStateBackend<K> implements 
AsyncKeyedStateBackend<K> {
         this.stateRequestHandler = stateRequestHandler;
     }
 
-    @Nonnull
     @Override
+    public <N, S extends State, SV> S getOrCreateKeyedState(
+            N defaultNamespace,
+            TypeSerializer<N> namespaceSerializer,
+            StateDescriptor<SV> stateDesc)
+            throws Exception {
+        checkNotNull(namespaceSerializer, "Namespace serializer");
+        InternalKeyedState<K, ?, ?> kvState = 
keyValueStatesByName.get(stateDesc.getStateId());
+        if (kvState == null) {
+            kvState = createState(defaultNamespace, namespaceSerializer, 
stateDesc);
+            keyValueStatesByName.put(stateDesc.getStateId(), kvState);
+        }
+        return (S) kvState;
+    }
+
+    @Nonnull
     @SuppressWarnings("unchecked")
-    public <N, S extends State, SV> S createState(
+    protected <N, S extends State, SV> S createState(
             @Nonnull N defaultNamespace,
             @Nonnull TypeSerializer<N> namespaceSerializer,
             @Nonnull StateDescriptor<SV> stateDesc)
@@ -523,6 +543,11 @@ public class ForStKeyedStateBackend<K> implements 
AsyncKeyedStateBackend<K> {
         }
     }
 
+    @Override
+    public boolean isSafeToReuseKVState() {
+        return true;
+    }
+
     @VisibleForTesting
     Path getLocalBasePath() {
         return optionsContainer.getLocalBasePath();

Reply via email to