This is an automated email from the ASF dual-hosted git repository. leiyanfei pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push: new 6a85f8047b2 [FLINK-34974][state] Support getOrCreateKeyedState for AsyncKeyedStateBackend (#25745) 6a85f8047b2 is described below commit 6a85f8047b2bf3121a3961f6bd5905b0d33403e8 Author: Yanfei Lei <fredia...@gmail.com> AuthorDate: Wed Dec 11 10:14:36 2024 +0800 [FLINK-34974][state] Support getOrCreateKeyedState for AsyncKeyedStateBackend (#25745) --- .../runtime/state/AsyncKeyedStateBackend.java | 29 ++++++++++++++++------ .../runtime/state/v2/DefaultKeyedStateStore.java | 10 ++++---- .../v2/adaptor/AsyncKeyedStateBackendAdaptor.java | 15 ++++++----- .../api/operators/StreamOperatorStateHandler.java | 2 +- .../AsyncExecutionControllerTest.java | 2 +- .../flink/runtime/state/StateBackendTestUtils.java | 12 ++++----- .../state/v2/AbstractKeyedStateTestBase.java | 9 +++---- .../v2/AsyncKeyedStateBackendAdaptorTest.java | 10 ++++---- .../runtime/state/v2/StateBackendTestV2Base.java | 6 ++--- .../api/operators/StreamingRuntimeContextTest.java | 2 +- .../flink/state/forst/ForStKeyedStateBackend.java | 29 ++++++++++++++++++++-- 11 files changed, 84 insertions(+), 42 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AsyncKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AsyncKeyedStateBackend.java index d6e606529aa..133e107e768 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AsyncKeyedStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AsyncKeyedStateBackend.java @@ -56,7 +56,7 @@ public interface AsyncKeyedStateBackend<K> void setup(@Nonnull StateRequestHandler stateRequestHandler); /** - * Creates and returns a new state. + * Creates or retrieves a keyed state backed by this state backend. * * @param <N> the type of namespace for partitioning. * @param <S> The type of the public API state. @@ -64,13 +64,14 @@ public interface AsyncKeyedStateBackend<K> * @param defaultNamespace the default namespace for this state. * @param namespaceSerializer the serializer for namespace. * @param stateDesc The {@code StateDescriptor} that contains the name of the state. - * @throws Exception Exceptions may occur during initialization of the state. + * @return A new key/value state backed by this backend. + * @throws Exception Exceptions may occur during initialization of the state and should be + * forwarded. */ - @Nonnull - <N, S extends State, SV> S createState( - @Nonnull N defaultNamespace, - @Nonnull TypeSerializer<N> namespaceSerializer, - @Nonnull StateDescriptor<SV> stateDesc) + <N, S extends State, SV> S getOrCreateKeyedState( + N defaultNamespace, + TypeSerializer<N> namespaceSerializer, + StateDescriptor<SV> stateDesc) throws Exception; /** @@ -122,6 +123,20 @@ public interface AsyncKeyedStateBackend<K> return true; } + /** + * Whether it's safe to reuse key-values from the state-backend, e.g for the purpose of + * optimization. + * + * <p>NOTE: this method should not be used to check for {@link InternalPriorityQueue}, as the + * priority queue could be stored on different locations, e.g ForSt state-backend could store + * that on JVM heap if configuring HEAP as the time-service factory. + * + * @return returns ture if safe to reuse the key-values from the state-backend. + */ + default boolean isSafeToReuseKVState() { + return false; + } + @Override void dispose(); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/DefaultKeyedStateStore.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/DefaultKeyedStateStore.java index 3a96b950ecb..9eb655bc671 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/DefaultKeyedStateStore.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/DefaultKeyedStateStore.java @@ -43,7 +43,7 @@ public class DefaultKeyedStateStore implements KeyedStateStore { public <T> ValueState<T> getValueState(@Nonnull ValueStateDescriptor<T> stateProperties) { Preconditions.checkNotNull(stateProperties, "The state properties must not be null"); try { - return asyncKeyedStateBackend.createState( + return asyncKeyedStateBackend.getOrCreateKeyedState( VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, stateProperties); } catch (Exception e) { throw new RuntimeException("Error while getting state", e); @@ -54,7 +54,7 @@ public class DefaultKeyedStateStore implements KeyedStateStore { public <T> ListState<T> getListState(@Nonnull ListStateDescriptor<T> stateProperties) { Preconditions.checkNotNull(stateProperties, "The state properties must not be null"); try { - return asyncKeyedStateBackend.createState( + return asyncKeyedStateBackend.getOrCreateKeyedState( VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, stateProperties); } catch (Exception e) { throw new RuntimeException("Error while getting state", e); @@ -66,7 +66,7 @@ public class DefaultKeyedStateStore implements KeyedStateStore { @Nonnull MapStateDescriptor<UK, UV> stateProperties) { Preconditions.checkNotNull(stateProperties, "The state properties must not be null"); try { - return asyncKeyedStateBackend.createState( + return asyncKeyedStateBackend.getOrCreateKeyedState( VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, stateProperties); } catch (Exception e) { throw new RuntimeException("Error while getting state", e); @@ -78,7 +78,7 @@ public class DefaultKeyedStateStore implements KeyedStateStore { @Nonnull ReducingStateDescriptor<T> stateProperties) { Preconditions.checkNotNull(stateProperties, "The state properties must not be null"); try { - return asyncKeyedStateBackend.createState( + return asyncKeyedStateBackend.getOrCreateKeyedState( VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, stateProperties); } catch (Exception e) { throw new RuntimeException("Error while getting state", e); @@ -90,7 +90,7 @@ public class DefaultKeyedStateStore implements KeyedStateStore { @Nonnull AggregatingStateDescriptor<IN, ACC, OUT> stateProperties) { Preconditions.checkNotNull(stateProperties, "The state properties must not be null"); try { - return asyncKeyedStateBackend.createState( + return asyncKeyedStateBackend.getOrCreateKeyedState( VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, stateProperties); } catch (Exception e) { throw new RuntimeException("Error while getting state", e); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/adaptor/AsyncKeyedStateBackendAdaptor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/adaptor/AsyncKeyedStateBackendAdaptor.java index 28c6993c9c4..79b98491bac 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/adaptor/AsyncKeyedStateBackendAdaptor.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/adaptor/AsyncKeyedStateBackendAdaptor.java @@ -68,13 +68,11 @@ public class AsyncKeyedStateBackendAdaptor<K> implements AsyncKeyedStateBackend< @Override public void setup(@Nonnull StateRequestHandler stateRequestHandler) {} - @Nonnull @Override - @SuppressWarnings({"rawtypes", "unchecked"}) - public <N, S extends State, SV> S createState( - @Nonnull N defaultNamespace, - @Nonnull TypeSerializer<N> namespaceSerializer, - @Nonnull StateDescriptor<SV> stateDesc) + public <N, S extends State, SV> S getOrCreateKeyedState( + N defaultNamespace, + TypeSerializer<N> namespaceSerializer, + StateDescriptor<SV> stateDesc) throws Exception { return createStateInternal(defaultNamespace, namespaceSerializer, stateDesc); } @@ -191,4 +189,9 @@ public class AsyncKeyedStateBackendAdaptor<K> implements AsyncKeyedStateBackend< } return false; } + + @Override + public boolean isSafeToReuseKVState() { + return keyedStateBackend.isSafeToReuseKVState(); + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/api/operators/StreamOperatorStateHandler.java b/flink-runtime/src/main/java/org/apache/flink/streaming/api/operators/StreamOperatorStateHandler.java index 7ab8636f2fe..aebfdada84a 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/api/operators/StreamOperatorStateHandler.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/api/operators/StreamOperatorStateHandler.java @@ -415,7 +415,7 @@ public class StreamOperatorStateHandler { throws Exception { if (asyncKeyedStateBackend != null) { - return asyncKeyedStateBackend.createState( + return asyncKeyedStateBackend.getOrCreateKeyedState( defaultNamespace, namespaceSerializer, stateDescriptor); } else { throw new IllegalStateException( diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/asyncprocessing/AsyncExecutionControllerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/asyncprocessing/AsyncExecutionControllerTest.java index 8d7853eb7ef..ae0c96cb3d9 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/asyncprocessing/AsyncExecutionControllerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/asyncprocessing/AsyncExecutionControllerTest.java @@ -110,7 +110,7 @@ class AsyncExecutionControllerTest { try { valueState = - asyncKeyedStateBackend.createState( + asyncKeyedStateBackend.getOrCreateKeyedState( VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, stateDescriptor); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestUtils.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestUtils.java index eca13bc180f..e30ba257bc8 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestUtils.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestUtils.java @@ -131,13 +131,13 @@ public class StateBackendTestUtils { // do nothing } - @Nonnull @Override - @SuppressWarnings("unchecked") - public <N, S extends org.apache.flink.api.common.state.v2.State, SV> S createState( - @Nonnull N defaultNamespace, - @Nonnull TypeSerializer<N> namespaceSerializer, - @Nonnull org.apache.flink.runtime.state.v2.StateDescriptor<SV> stateDesc) { + public <N, S extends org.apache.flink.api.common.state.v2.State, SV> + S getOrCreateKeyedState( + N defaultNamespace, + TypeSerializer<N> namespaceSerializer, + org.apache.flink.runtime.state.v2.StateDescriptor<SV> stateDesc) + throws Exception { return (S) innerStateSupplier.get(); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AbstractKeyedStateTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AbstractKeyedStateTestBase.java index 56ae6c13470..d413f07d2ea 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AbstractKeyedStateTestBase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AbstractKeyedStateTestBase.java @@ -173,12 +173,11 @@ public class AbstractKeyedStateTestBase { @Override public void setup(@Nonnull StateRequestHandler stateRequestHandler) {} - @Nonnull @Override - public <N, S extends State, SV> S createState( - @Nonnull N defaultNamespace, - @Nonnull TypeSerializer<N> namespaceSerializer, - @Nonnull StateDescriptor<SV> stateDesc) + public <N, S extends State, SV> S getOrCreateKeyedState( + N defaultNamespace, + TypeSerializer<N> namespaceSerializer, + StateDescriptor<SV> stateDesc) throws Exception { return null; } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AsyncKeyedStateBackendAdaptorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AsyncKeyedStateBackendAdaptorTest.java index b1dbf0e6174..371e4ef3ab3 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AsyncKeyedStateBackendAdaptorTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AsyncKeyedStateBackendAdaptorTest.java @@ -63,7 +63,7 @@ public class AsyncKeyedStateBackendAdaptorTest { new ValueStateDescriptor<>("testState", BasicTypeInfo.INT_TYPE_INFO); org.apache.flink.api.common.state.v2.ValueState<Integer> valueState = - adaptor.createState( + adaptor.getOrCreateKeyedState( VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, descriptor); // test synchronous interfaces. @@ -102,7 +102,7 @@ public class AsyncKeyedStateBackendAdaptorTest { new ListStateDescriptor<>("testState", BasicTypeInfo.INT_TYPE_INFO); org.apache.flink.api.common.state.v2.ListState<Integer> listState = - adaptor.createState( + adaptor.getOrCreateKeyedState( VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, descriptor); // test synchronous interfaces. @@ -154,7 +154,7 @@ public class AsyncKeyedStateBackendAdaptorTest { "testState", BasicTypeInfo.INT_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO); org.apache.flink.api.common.state.v2.MapState<Integer, Integer> mapState = - adaptor.createState( + adaptor.getOrCreateKeyedState( VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, descriptor); final HashMap<Integer, Integer> groundTruth = @@ -247,7 +247,7 @@ public class AsyncKeyedStateBackendAdaptorTest { "testState", Integer::sum, BasicTypeInfo.INT_TYPE_INFO); InternalReducingState<String, Long, Integer> reducingState = - adaptor.createState(0L, LongSerializer.INSTANCE, descriptor); + adaptor.getOrCreateKeyedState(0L, LongSerializer.INSTANCE, descriptor); // test synchronous interfaces. reducingState.clear(); @@ -353,7 +353,7 @@ public class AsyncKeyedStateBackendAdaptorTest { BasicTypeInfo.INT_TYPE_INFO); InternalAggregatingState<String, Long, Integer, Integer, String> aggState = - adaptor.createState(0L, LongSerializer.INSTANCE, descriptor); + adaptor.getOrCreateKeyedState(0L, LongSerializer.INSTANCE, descriptor); // test synchronous interfaces. aggState.clear(); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/StateBackendTestV2Base.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/StateBackendTestV2Base.java index 3e4a26fe83f..c6ec6366c4a 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/StateBackendTestV2Base.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/StateBackendTestV2Base.java @@ -229,7 +229,7 @@ public abstract class StateBackendTestV2Base<B extends AbstractStateBackend> { new ValueStateDescriptor<>("test", BasicTypeInfo.INT_TYPE_INFO); ValueState<Integer> valueState = - backend.createState( + backend.getOrCreateKeyedState( VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, stateDescriptor); @@ -318,7 +318,7 @@ public abstract class StateBackendTestV2Base<B extends AbstractStateBackend> { new ValueStateDescriptor<>("test", BasicTypeInfo.INT_TYPE_INFO); ValueState<Integer> valueState = - backend.createState( + backend.getOrCreateKeyedState( VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, stateDescriptor); @@ -432,7 +432,7 @@ public abstract class StateBackendTestV2Base<B extends AbstractStateBackend> { kvId.enableTimeToLive(StateTtlConfig.newBuilder(Duration.ofSeconds(1)).build()); ValueState<Long> state = - backend.createState( + backend.getOrCreateKeyedState( VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId); RecordContext recordContext = aec.buildContext("record-1", 1L); recordContext.retain(); diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java b/flink-runtime/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java index 1098fc67ad7..b97f33431d7 100644 --- a/flink-runtime/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java @@ -493,7 +493,7 @@ class StreamingRuntimeContextTest { return null; }) .when(asyncKeyedStateBackend) - .createState( + .getOrCreateKeyedState( any(), any(TypeSerializer.class), any(org.apache.flink.runtime.state.v2.StateDescriptor.class)); diff --git a/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStKeyedStateBackend.java b/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStKeyedStateBackend.java index 6e0c7613ab1..6f09c2dc315 100644 --- a/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStKeyedStateBackend.java +++ b/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStKeyedStateBackend.java @@ -71,6 +71,7 @@ import javax.annotation.Nonnull; import javax.annotation.concurrent.GuardedBy; import java.io.IOException; +import java.util.HashMap; import java.util.HashSet; import java.util.LinkedHashMap; import java.util.Map; @@ -81,6 +82,7 @@ import java.util.function.Function; import java.util.function.Supplier; import static org.apache.flink.runtime.state.SnapshotExecutionType.ASYNCHRONOUS; +import static org.apache.flink.util.Preconditions.checkNotNull; /** * A KeyedStateBackend that stores its state in {@code ForSt}. This state backend can store very @@ -158,6 +160,9 @@ public class ForStKeyedStateBackend<K> implements AsyncKeyedStateBackend<K> { */ private final LinkedHashMap<String, ForStOperationUtils.ForStKvStateInfo> kvStateInformation; + /** So that we can give out state when the user uses the same key. */ + private final HashMap<String, InternalKeyedState<K, ?, ?>> keyValueStatesByName; + /** Lock guarding the {@code managedStateExecutors} and {@code disposed}. */ private final Object lock = new Object(); @@ -201,6 +206,7 @@ public class ForStKeyedStateBackend<K> implements AsyncKeyedStateBackend<K> { this.valueDeserializerView = valueDeserializerView; this.db = db; this.kvStateInformation = kvStateInformation; + this.keyValueStatesByName = new HashMap<>(); this.columnFamilyOptionsFactory = columnFamilyOptionsFactory; this.defaultColumnFamily = defaultColumnFamilyHandle; this.snapshotStrategy = snapshotStrategy; @@ -227,10 +233,24 @@ public class ForStKeyedStateBackend<K> implements AsyncKeyedStateBackend<K> { this.stateRequestHandler = stateRequestHandler; } - @Nonnull @Override + public <N, S extends State, SV> S getOrCreateKeyedState( + N defaultNamespace, + TypeSerializer<N> namespaceSerializer, + StateDescriptor<SV> stateDesc) + throws Exception { + checkNotNull(namespaceSerializer, "Namespace serializer"); + InternalKeyedState<K, ?, ?> kvState = keyValueStatesByName.get(stateDesc.getStateId()); + if (kvState == null) { + kvState = createState(defaultNamespace, namespaceSerializer, stateDesc); + keyValueStatesByName.put(stateDesc.getStateId(), kvState); + } + return (S) kvState; + } + + @Nonnull @SuppressWarnings("unchecked") - public <N, S extends State, SV> S createState( + protected <N, S extends State, SV> S createState( @Nonnull N defaultNamespace, @Nonnull TypeSerializer<N> namespaceSerializer, @Nonnull StateDescriptor<SV> stateDesc) @@ -523,6 +543,11 @@ public class ForStKeyedStateBackend<K> implements AsyncKeyedStateBackend<K> { } } + @Override + public boolean isSafeToReuseKVState() { + return true; + } + @VisibleForTesting Path getLocalBasePath() { return optionsContainer.getLocalBasePath();