This is an automated email from the ASF dual-hosted git repository.

tangyun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 73c103b  [FLINK-23018][state] Enable state factories to handle 
extended state descriptors
73c103b is described below

commit 73c103b6b117fe3996eedfb9d04e926f00c70996
Author: Yun Tang <[email protected]>
AuthorDate: Thu Jun 17 16:28:02 2021 +0800

    [FLINK-23018][state] Enable state factories to handle extended state 
descriptors
---
 .../client/QueryableStateClient.java               | 19 +++++++-----------
 .../runtime/state/heap/HeapKeyedStateBackend.java  | 17 +++++++---------
 .../flink/runtime/state/ttl/TtlStateFactory.java   | 23 ++++++++++------------
 .../state/ttl/mock/MockKeyedStateBackend.java      | 17 +++++++---------
 .../changelog/ChangelogKeyedStateBackend.java      | 19 +++++++-----------
 .../streaming/state/RocksDBKeyedStateBackend.java  | 22 ++++++++-------------
 .../state/BatchExecutionKeyedStateBackend.java     | 20 +++++++------------
 7 files changed, 53 insertions(+), 84 deletions(-)

diff --git 
a/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/client/QueryableStateClient.java
 
b/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/client/QueryableStateClient.java
index ecb3f39..aee16cf 100644
--- 
a/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/client/QueryableStateClient.java
+++ 
b/flink-queryable-state/flink-queryable-state-client-java/src/main/java/org/apache/flink/queryablestate/client/QueryableStateClient.java
@@ -21,13 +21,8 @@ package org.apache.flink.queryablestate.client;
 import org.apache.flink.annotation.PublicEvolving;
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.JobID;
-import org.apache.flink.api.common.state.AggregatingStateDescriptor;
-import org.apache.flink.api.common.state.ListStateDescriptor;
-import org.apache.flink.api.common.state.MapStateDescriptor;
-import org.apache.flink.api.common.state.ReducingStateDescriptor;
 import org.apache.flink.api.common.state.State;
 import org.apache.flink.api.common.state.StateDescriptor;
-import org.apache.flink.api.common.state.ValueStateDescriptor;
 import org.apache.flink.api.common.typeinfo.TypeHint;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
@@ -82,22 +77,22 @@ public class QueryableStateClient {
 
     private static final Logger LOG = 
LoggerFactory.getLogger(QueryableStateClient.class);
 
-    private static final Map<Class<? extends StateDescriptor>, StateFactory> 
STATE_FACTORIES =
+    private static final Map<StateDescriptor.Type, StateFactory> 
STATE_FACTORIES =
             Stream.of(
                             Tuple2.of(
-                                    ValueStateDescriptor.class,
+                                    StateDescriptor.Type.VALUE,
                                     (StateFactory) 
ImmutableValueState::createState),
                             Tuple2.of(
-                                    ListStateDescriptor.class,
+                                    StateDescriptor.Type.LIST,
                                     (StateFactory) 
ImmutableListState::createState),
                             Tuple2.of(
-                                    MapStateDescriptor.class,
+                                    StateDescriptor.Type.MAP,
                                     (StateFactory) 
ImmutableMapState::createState),
                             Tuple2.of(
-                                    AggregatingStateDescriptor.class,
+                                    StateDescriptor.Type.AGGREGATING,
                                     (StateFactory) 
ImmutableAggregatingState::createState),
                             Tuple2.of(
-                                    ReducingStateDescriptor.class,
+                                    StateDescriptor.Type.REDUCING,
                                     (StateFactory) 
ImmutableReducingState::createState))
                     .collect(Collectors.toMap(t -> t.f0, t -> t.f1));
 
@@ -322,7 +317,7 @@ public class QueryableStateClient {
 
     private <T, S extends State> S createState(
             KvStateResponse stateResponse, StateDescriptor<S, T> 
stateDescriptor) {
-        StateFactory stateFactory = 
STATE_FACTORIES.get(stateDescriptor.getClass());
+        StateFactory stateFactory = 
STATE_FACTORIES.get(stateDescriptor.getType());
         if (stateFactory == null) {
             String message =
                     String.format(
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
index df86582..ad4eb2b 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
@@ -20,13 +20,10 @@ package org.apache.flink.runtime.state.heap;
 
 import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.ExecutionConfig;
-import org.apache.flink.api.common.state.AggregatingStateDescriptor;
 import org.apache.flink.api.common.state.ListStateDescriptor;
 import org.apache.flink.api.common.state.MapStateDescriptor;
-import org.apache.flink.api.common.state.ReducingStateDescriptor;
 import org.apache.flink.api.common.state.State;
 import org.apache.flink.api.common.state.StateDescriptor;
-import org.apache.flink.api.common.state.ValueStateDescriptor;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.common.typeutils.TypeSerializerSchemaCompatibility;
 import org.apache.flink.api.java.tuple.Tuple2;
@@ -78,21 +75,21 @@ public class HeapKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
 
     private static final Logger LOG = 
LoggerFactory.getLogger(HeapKeyedStateBackend.class);
 
-    private static final Map<Class<? extends StateDescriptor>, StateFactory> 
STATE_FACTORIES =
+    private static final Map<StateDescriptor.Type, StateFactory> 
STATE_FACTORIES =
             Stream.of(
                             Tuple2.of(
-                                    ValueStateDescriptor.class,
+                                    StateDescriptor.Type.VALUE,
                                     (StateFactory) HeapValueState::create),
                             Tuple2.of(
-                                    ListStateDescriptor.class,
+                                    StateDescriptor.Type.LIST,
                                     (StateFactory) HeapListState::create),
                             Tuple2.of(
-                                    MapStateDescriptor.class, (StateFactory) 
HeapMapState::create),
+                                    StateDescriptor.Type.MAP, (StateFactory) 
HeapMapState::create),
                             Tuple2.of(
-                                    AggregatingStateDescriptor.class,
+                                    StateDescriptor.Type.AGGREGATING,
                                     (StateFactory) 
HeapAggregatingState::create),
                             Tuple2.of(
-                                    ReducingStateDescriptor.class,
+                                    StateDescriptor.Type.REDUCING,
                                     (StateFactory) HeapReducingState::create))
                     .collect(Collectors.toMap(t -> t.f0, t -> t.f1));
 
@@ -267,7 +264,7 @@ public class HeapKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
             @Nonnull StateDescriptor<S, SV> stateDesc,
             @Nonnull StateSnapshotTransformFactory<SEV> 
snapshotTransformFactory)
             throws Exception {
-        StateFactory stateFactory = STATE_FACTORIES.get(stateDesc.getClass());
+        StateFactory stateFactory = STATE_FACTORIES.get(stateDesc.getType());
         if (stateFactory == null) {
             String message =
                     String.format(
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateFactory.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateFactory.java
index 5969db4..72501aa 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateFactory.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateFactory.java
@@ -71,8 +71,7 @@ public class TtlStateFactory<K, N, SV, TTLSV, S extends 
State, IS extends S> {
                 : stateBackend.createInternalState(namespaceSerializer, 
stateDesc);
     }
 
-    private final Map<Class<? extends StateDescriptor>, 
SupplierWithException<IS, Exception>>
-            stateFactories;
+    private final Map<StateDescriptor.Type, SupplierWithException<IS, 
Exception>> stateFactories;
 
     @Nonnull private final TypeSerializer<N> namespaceSerializer;
     @Nonnull private final StateDescriptor<S, SV> stateDesc;
@@ -97,23 +96,22 @@ public class TtlStateFactory<K, N, SV, TTLSV, S extends 
State, IS extends S> {
         this.incrementalCleanup = getTtlIncrementalCleanup();
     }
 
-    private Map<Class<? extends StateDescriptor>, SupplierWithException<IS, 
Exception>>
-            createStateFactories() {
+    private Map<StateDescriptor.Type, SupplierWithException<IS, Exception>> 
createStateFactories() {
         return Stream.of(
                         Tuple2.of(
-                                ValueStateDescriptor.class,
+                                StateDescriptor.Type.VALUE,
                                 (SupplierWithException<IS, Exception>) 
this::createValueState),
                         Tuple2.of(
-                                ListStateDescriptor.class,
+                                StateDescriptor.Type.LIST,
                                 (SupplierWithException<IS, Exception>) 
this::createListState),
                         Tuple2.of(
-                                MapStateDescriptor.class,
+                                StateDescriptor.Type.MAP,
                                 (SupplierWithException<IS, Exception>) 
this::createMapState),
                         Tuple2.of(
-                                ReducingStateDescriptor.class,
+                                StateDescriptor.Type.REDUCING,
                                 (SupplierWithException<IS, Exception>) 
this::createReducingState),
                         Tuple2.of(
-                                AggregatingStateDescriptor.class,
+                                StateDescriptor.Type.AGGREGATING,
                                 (SupplierWithException<IS, Exception>)
                                         this::createAggregatingState))
                 .collect(Collectors.toMap(t -> t.f0, t -> t.f1));
@@ -121,13 +119,12 @@ public class TtlStateFactory<K, N, SV, TTLSV, S extends 
State, IS extends S> {
 
     @SuppressWarnings("unchecked")
     private IS createState() throws Exception {
-        SupplierWithException<IS, Exception> stateFactory =
-                stateFactories.get(stateDesc.getClass());
+        SupplierWithException<IS, Exception> stateFactory = 
stateFactories.get(stateDesc.getType());
         if (stateFactory == null) {
             String message =
                     String.format(
-                            "State %s is not supported by %s",
-                            stateDesc.getClass(), TtlStateFactory.class);
+                            "State type: %s is not supported by %s",
+                            stateDesc.getType(), TtlStateFactory.class);
             throw new FlinkRuntimeException(message);
         }
         IS state = stateFactory.get();
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java
index 5ead66a..2f1aab4 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java
@@ -19,13 +19,10 @@
 package org.apache.flink.runtime.state.ttl.mock;
 
 import org.apache.flink.api.common.ExecutionConfig;
-import org.apache.flink.api.common.state.AggregatingStateDescriptor;
 import org.apache.flink.api.common.state.ListStateDescriptor;
 import org.apache.flink.api.common.state.MapStateDescriptor;
-import org.apache.flink.api.common.state.ReducingStateDescriptor;
 import org.apache.flink.api.common.state.State;
 import org.apache.flink.api.common.state.StateDescriptor;
-import org.apache.flink.api.common.state.ValueStateDescriptor;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.core.fs.CloseableRegistry;
@@ -75,22 +72,22 @@ public class MockKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
                 throws Exception;
     }
 
-    private static final Map<Class<? extends StateDescriptor>, StateFactory> 
STATE_FACTORIES =
+    private static final Map<StateDescriptor.Type, StateFactory> 
STATE_FACTORIES =
             Stream.of(
                             Tuple2.of(
-                                    ValueStateDescriptor.class,
+                                    StateDescriptor.Type.VALUE,
                                     (StateFactory) 
MockInternalValueState::createState),
                             Tuple2.of(
-                                    ListStateDescriptor.class,
+                                    StateDescriptor.Type.LIST,
                                     (StateFactory) 
MockInternalListState::createState),
                             Tuple2.of(
-                                    MapStateDescriptor.class,
+                                    StateDescriptor.Type.MAP,
                                     (StateFactory) 
MockInternalMapState::createState),
                             Tuple2.of(
-                                    ReducingStateDescriptor.class,
+                                    StateDescriptor.Type.REDUCING,
                                     (StateFactory) 
MockInternalReducingState::createState),
                             Tuple2.of(
-                                    AggregatingStateDescriptor.class,
+                                    StateDescriptor.Type.AGGREGATING,
                                     (StateFactory) 
MockInternalAggregatingState::createState))
                     .collect(Collectors.toMap(t -> t.f0, t -> t.f1));
 
@@ -130,7 +127,7 @@ public class MockKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
             @Nonnull StateDescriptor<S, SV> stateDesc,
             @Nonnull StateSnapshotTransformFactory<SEV> 
snapshotTransformFactory)
             throws Exception {
-        StateFactory stateFactory = STATE_FACTORIES.get(stateDesc.getClass());
+        StateFactory stateFactory = STATE_FACTORIES.get(stateDesc.getType());
         if (stateFactory == null) {
             String message =
                     String.format(
diff --git 
a/flink-state-backends/flink-statebackend-changelog/src/main/java/org/apache/flink/state/changelog/ChangelogKeyedStateBackend.java
 
b/flink-state-backends/flink-statebackend-changelog/src/main/java/org/apache/flink/state/changelog/ChangelogKeyedStateBackend.java
index e234e97..abe7568 100644
--- 
a/flink-state-backends/flink-statebackend-changelog/src/main/java/org/apache/flink/state/changelog/ChangelogKeyedStateBackend.java
+++ 
b/flink-state-backends/flink-statebackend-changelog/src/main/java/org/apache/flink/state/changelog/ChangelogKeyedStateBackend.java
@@ -21,14 +21,9 @@ package org.apache.flink.state.changelog;
 import org.apache.flink.annotation.Internal;
 import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.ExecutionConfig;
-import org.apache.flink.api.common.state.AggregatingStateDescriptor;
 import org.apache.flink.api.common.state.CheckpointListener;
-import org.apache.flink.api.common.state.ListStateDescriptor;
-import org.apache.flink.api.common.state.MapStateDescriptor;
-import org.apache.flink.api.common.state.ReducingStateDescriptor;
 import org.apache.flink.api.common.state.State;
 import org.apache.flink.api.common.state.StateDescriptor;
-import org.apache.flink.api.common.state.ValueStateDescriptor;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
@@ -80,22 +75,22 @@ class ChangelogKeyedStateBackend<K>
                 CheckpointListener,
                 TestableKeyedStateBackend<K> {
 
-    private static final Map<Class<? extends StateDescriptor>, StateFactory> 
STATE_FACTORIES =
+    private static final Map<StateDescriptor.Type, StateFactory> 
STATE_FACTORIES =
             Stream.of(
                             Tuple2.of(
-                                    ValueStateDescriptor.class,
+                                    StateDescriptor.Type.VALUE,
                                     (StateFactory) 
ChangelogValueState::create),
                             Tuple2.of(
-                                    ListStateDescriptor.class,
+                                    StateDescriptor.Type.LIST,
                                     (StateFactory) ChangelogListState::create),
                             Tuple2.of(
-                                    ReducingStateDescriptor.class,
+                                    StateDescriptor.Type.REDUCING,
                                     (StateFactory) 
ChangelogReducingState::create),
                             Tuple2.of(
-                                    AggregatingStateDescriptor.class,
+                                    StateDescriptor.Type.AGGREGATING,
                                     (StateFactory) 
ChangelogAggregatingState::create),
                             Tuple2.of(
-                                    MapStateDescriptor.class,
+                                    StateDescriptor.Type.MAP,
                                     (StateFactory) ChangelogMapState::create))
                     .collect(Collectors.toMap(t -> t.f0, t -> t.f1));
 
@@ -335,7 +330,7 @@ class ChangelogKeyedStateBackend<K>
                     StateSnapshotTransformer.StateSnapshotTransformFactory<SEV>
                             snapshotTransformFactory)
             throws Exception {
-        StateFactory stateFactory = STATE_FACTORIES.get(stateDesc.getClass());
+        StateFactory stateFactory = STATE_FACTORIES.get(stateDesc.getType());
         if (stateFactory == null) {
             String message =
                     String.format(
diff --git 
a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
 
b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
index 5ec7f41..ea83250 100644
--- 
a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
+++ 
b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
@@ -19,13 +19,8 @@ package org.apache.flink.contrib.streaming.state;
 
 import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.ExecutionConfig;
-import org.apache.flink.api.common.state.AggregatingStateDescriptor;
-import org.apache.flink.api.common.state.ListStateDescriptor;
-import org.apache.flink.api.common.state.MapStateDescriptor;
-import org.apache.flink.api.common.state.ReducingStateDescriptor;
 import org.apache.flink.api.common.state.State;
 import org.apache.flink.api.common.state.StateDescriptor;
-import org.apache.flink.api.common.state.ValueStateDescriptor;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.common.typeutils.TypeSerializerSchemaCompatibility;
 import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
@@ -125,23 +120,22 @@ public class RocksDBKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
      */
     public static final String MERGE_OPERATOR_NAME = "stringappendtest";
 
-    @SuppressWarnings("deprecation")
-    private static final Map<Class<? extends StateDescriptor>, StateFactory> 
STATE_FACTORIES =
+    private static final Map<StateDescriptor.Type, StateFactory> 
STATE_FACTORIES =
             Stream.of(
                             Tuple2.of(
-                                    ValueStateDescriptor.class,
+                                    StateDescriptor.Type.VALUE,
                                     (StateFactory) RocksDBValueState::create),
                             Tuple2.of(
-                                    ListStateDescriptor.class,
+                                    StateDescriptor.Type.LIST,
                                     (StateFactory) RocksDBListState::create),
                             Tuple2.of(
-                                    MapStateDescriptor.class,
+                                    StateDescriptor.Type.MAP,
                                     (StateFactory) RocksDBMapState::create),
                             Tuple2.of(
-                                    AggregatingStateDescriptor.class,
+                                    StateDescriptor.Type.AGGREGATING,
                                     (StateFactory) 
RocksDBAggregatingState::create),
                             Tuple2.of(
-                                    ReducingStateDescriptor.class,
+                                    StateDescriptor.Type.REDUCING,
                                     (StateFactory) 
RocksDBReducingState::create))
                     .collect(Collectors.toMap(t -> t.f0, t -> t.f1));
 
@@ -760,7 +754,7 @@ public class RocksDBKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
         // we need to get an actual state instance because migration is 
different
         // for different state types. For example, ListState needs to deal with
         // individual elements
-        StateFactory stateFactory = STATE_FACTORIES.get(stateDesc.getClass());
+        StateFactory stateFactory = STATE_FACTORIES.get(stateDesc.getType());
         if (stateFactory == null) {
             String message =
                     String.format(
@@ -831,7 +825,7 @@ public class RocksDBKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
             @Nonnull StateDescriptor<S, SV> stateDesc,
             @Nonnull StateSnapshotTransformFactory<SEV> 
snapshotTransformFactory)
             throws Exception {
-        StateFactory stateFactory = STATE_FACTORIES.get(stateDesc.getClass());
+        StateFactory stateFactory = STATE_FACTORIES.get(stateDesc.getType());
         if (stateFactory == null) {
             String message =
                     String.format(
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/sorted/state/BatchExecutionKeyedStateBackend.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/sorted/state/BatchExecutionKeyedStateBackend.java
index f5a8fab..a458d42 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/sorted/state/BatchExecutionKeyedStateBackend.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/sorted/state/BatchExecutionKeyedStateBackend.java
@@ -19,13 +19,8 @@
 package org.apache.flink.streaming.api.operators.sorted.state;
 
 import org.apache.flink.api.common.ExecutionConfig;
-import org.apache.flink.api.common.state.AggregatingStateDescriptor;
-import org.apache.flink.api.common.state.ListStateDescriptor;
-import org.apache.flink.api.common.state.MapStateDescriptor;
-import org.apache.flink.api.common.state.ReducingStateDescriptor;
 import org.apache.flink.api.common.state.State;
 import org.apache.flink.api.common.state.StateDescriptor;
-import org.apache.flink.api.common.state.ValueStateDescriptor;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
@@ -72,23 +67,22 @@ public class BatchExecutionKeyedStateBackend<K> implements 
CheckpointableKeyedSt
     private static final Logger LOG =
             LoggerFactory.getLogger(BatchExecutionKeyedStateBackend.class);
 
-    @SuppressWarnings("rawtypes")
-    private static final Map<Class<? extends StateDescriptor>, StateFactory> 
STATE_FACTORIES =
+    private static final Map<StateDescriptor.Type, StateFactory> 
STATE_FACTORIES =
             Stream.of(
                             Tuple2.of(
-                                    ValueStateDescriptor.class,
+                                    StateDescriptor.Type.VALUE,
                                     (StateFactory) 
BatchExecutionKeyValueState::create),
                             Tuple2.of(
-                                    ListStateDescriptor.class,
+                                    StateDescriptor.Type.LIST,
                                     (StateFactory) 
BatchExecutionKeyListState::create),
                             Tuple2.of(
-                                    MapStateDescriptor.class,
+                                    StateDescriptor.Type.MAP,
                                     (StateFactory) 
BatchExecutionKeyMapState::create),
                             Tuple2.of(
-                                    AggregatingStateDescriptor.class,
+                                    StateDescriptor.Type.AGGREGATING,
                                     (StateFactory) 
BatchExecutionKeyAggregatingState::create),
                             Tuple2.of(
-                                    ReducingStateDescriptor.class,
+                                    StateDescriptor.Type.REDUCING,
                                     (StateFactory) 
BatchExecutionKeyReducingState::create))
                     .collect(Collectors.toMap(t -> t.f0, t -> t.f1));
 
@@ -232,7 +226,7 @@ public class BatchExecutionKeyedStateBackend<K> implements 
CheckpointableKeyedSt
             @Nonnull TypeSerializer<N> namespaceSerializer,
             @Nonnull StateDescriptor<S, SV> stateDesc)
             throws Exception {
-        StateFactory stateFactory = STATE_FACTORIES.get(stateDesc.getClass());
+        StateFactory stateFactory = STATE_FACTORIES.get(stateDesc.getType());
         if (stateFactory == null) {
             String message =
                     String.format(

Reply via email to