This is an automated email from the ASF dual-hosted git repository.

leiyanfei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new b0ad87a76f5 [FLINK-37524][state/forst] Make ForStState Restore from 
previous StateMetaInfo (#26391)
b0ad87a76f5 is described below

commit b0ad87a76f5ece023a97ce8c2b86c7c9e1a47f7b
Author: mayuehappy <mayue.fi...@bytedance.com>
AuthorDate: Mon Apr 7 14:22:00 2025 +0800

    [FLINK-37524][state/forst] Make ForStState Restore from previous 
StateMetaInfo (#26391)
---
 .../runtime/state/v2/AbstractAggregatingState.java |  11 +-
 .../flink/runtime/state/v2/AbstractKeyedState.java |  14 +-
 .../flink/runtime/state/v2/AbstractListState.java  |   6 +-
 .../flink/runtime/state/v2/AbstractMapState.java   |   7 +-
 .../runtime/state/v2/AbstractReducingState.java    |  10 +-
 .../flink/runtime/state/v2/AbstractValueState.java |   6 +-
 .../AsyncExecutionControllerTest.java              |   2 +-
 .../state/v2/AbstractAggregatingStateTest.java     |  15 +-
 .../runtime/state/v2/AbstractListStateTest.java    |   4 +-
 .../runtime/state/v2/AbstractMapStateTest.java     |   4 +-
 .../state/v2/AbstractReducingStateTest.java        |   9 +-
 .../runtime/state/v2/AbstractValueStateTest.java   |   4 +-
 .../flink/state/forst/ForStAggregatingState.java   |   7 +-
 .../flink/state/forst/ForStKeyedStateBackend.java  |  39 ++---
 .../apache/flink/state/forst/ForStListState.java   |   5 +-
 .../apache/flink/state/forst/ForStMapState.java    |  17 ++-
 .../flink/state/forst/ForStReducingState.java      |   7 +-
 .../apache/flink/state/forst/ForStValueState.java  |   5 +-
 .../state/forst/ForStDBOperationTestBase.java      |  10 +-
 .../flink/state/forst/ForStStateMigrationTest.java | 166 +++++++++++++++++++++
 .../flink/state/forst/ForStStateTestBase.java      |   2 +-
 21 files changed, 268 insertions(+), 82 deletions(-)

diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/AbstractAggregatingState.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/AbstractAggregatingState.java
index 556d63dad5e..f59677aeac1 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/AbstractAggregatingState.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/AbstractAggregatingState.java
@@ -20,8 +20,8 @@ package org.apache.flink.runtime.state.v2;
 
 import org.apache.flink.api.common.functions.AggregateFunction;
 import org.apache.flink.api.common.state.v2.AggregatingState;
-import org.apache.flink.api.common.state.v2.AggregatingStateDescriptor;
 import org.apache.flink.api.common.state.v2.StateFuture;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.core.state.StateFutureUtils;
 import org.apache.flink.runtime.asyncprocessing.StateRequestHandler;
 import org.apache.flink.runtime.asyncprocessing.StateRequestType;
@@ -50,13 +50,14 @@ public class AbstractAggregatingState<K, N, IN, ACC, OUT> 
extends AbstractKeyedS
      * Creates a new AbstractKeyedState with the given 
asyncExecutionController and stateDescriptor.
      *
      * @param stateRequestHandler The async request handler for handling all 
requests.
-     * @param stateDescriptor The properties of the state.
+     * @param valueSerializer The type serializer for the values in the state.
      */
     public AbstractAggregatingState(
             StateRequestHandler stateRequestHandler,
-            AggregatingStateDescriptor<IN, ACC, OUT> stateDescriptor) {
-        super(stateRequestHandler, stateDescriptor);
-        this.aggregateFunction = stateDescriptor.getAggregateFunction();
+            AggregateFunction<IN, ACC, OUT> aggregateFunction,
+            TypeSerializer<ACC> valueSerializer) {
+        super(stateRequestHandler, valueSerializer);
+        this.aggregateFunction = aggregateFunction;
     }
 
     @Override
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/AbstractKeyedState.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/AbstractKeyedState.java
index 6c27130c9c2..2dbb8815608 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/AbstractKeyedState.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/AbstractKeyedState.java
@@ -19,7 +19,6 @@ package org.apache.flink.runtime.state.v2;
 
 import org.apache.flink.annotation.Internal;
 import org.apache.flink.api.common.state.v2.State;
-import org.apache.flink.api.common.state.v2.StateDescriptor;
 import org.apache.flink.api.common.state.v2.StateFuture;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.runtime.asyncprocessing.AsyncExecutionController;
@@ -42,19 +41,15 @@ import 
org.apache.flink.runtime.state.v2.internal.InternalKeyedState;
 public abstract class AbstractKeyedState<K, N, V> implements 
InternalKeyedState<K, N, V> {
 
     protected final StateRequestHandler stateRequestHandler;
-
-    private final StateDescriptor<V> stateDescriptor;
-
     private final ThreadLocal<TypeSerializer<V>> valueSerializer;
 
     /**
      * Creates a new AbstractKeyedState with the given 
asyncExecutionController and stateDescriptor.
      */
     public AbstractKeyedState(
-            StateRequestHandler stateRequestHandler, StateDescriptor<V> 
stateDescriptor) {
+            StateRequestHandler stateRequestHandler, TypeSerializer<V> 
valueSerializer) {
         this.stateRequestHandler = stateRequestHandler;
-        this.stateDescriptor = stateDescriptor;
-        this.valueSerializer = 
ThreadLocal.withInitial(stateDescriptor::getSerializer);
+        this.valueSerializer = 
ThreadLocal.withInitial(valueSerializer::duplicate);
     }
 
     /**
@@ -87,11 +82,6 @@ public abstract class AbstractKeyedState<K, N, V> implements 
InternalKeyedState<
         handleRequestSync(StateRequestType.CLEAR, null);
     }
 
-    /** Return specific {@code StateDescriptor}. */
-    public StateDescriptor<V> getStateDescriptor() {
-        return stateDescriptor;
-    }
-
     /** Return related value serializer. */
     public TypeSerializer<V> getValueSerializer() {
         return valueSerializer.get();
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/AbstractListState.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/AbstractListState.java
index 263de23f31a..dcc238fdeda 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/AbstractListState.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/AbstractListState.java
@@ -18,9 +18,9 @@
 package org.apache.flink.runtime.state.v2;
 
 import org.apache.flink.api.common.state.v2.ListState;
-import org.apache.flink.api.common.state.v2.ListStateDescriptor;
 import org.apache.flink.api.common.state.v2.StateFuture;
 import org.apache.flink.api.common.state.v2.StateIterator;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.runtime.asyncprocessing.StateRequestHandler;
 import org.apache.flink.runtime.asyncprocessing.StateRequestType;
 import org.apache.flink.runtime.state.v2.internal.InternalListState;
@@ -39,8 +39,8 @@ public class AbstractListState<K, N, V> extends 
AbstractKeyedState<K, N, V>
         implements InternalListState<K, N, V> {
 
     public AbstractListState(
-            StateRequestHandler stateRequestHandler, ListStateDescriptor<V> 
stateDescriptor) {
-        super(stateRequestHandler, stateDescriptor);
+            StateRequestHandler stateRequestHandler, TypeSerializer<V> 
serializer) {
+        super(stateRequestHandler, serializer);
     }
 
     @Override
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/AbstractMapState.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/AbstractMapState.java
index fb3ad546aed..f8c1b458294 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/AbstractMapState.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/AbstractMapState.java
@@ -18,9 +18,9 @@
 package org.apache.flink.runtime.state.v2;
 
 import org.apache.flink.api.common.state.v2.MapState;
-import org.apache.flink.api.common.state.v2.MapStateDescriptor;
 import org.apache.flink.api.common.state.v2.StateFuture;
 import org.apache.flink.api.common.state.v2.StateIterator;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.runtime.asyncprocessing.StateRequestHandler;
 import org.apache.flink.runtime.asyncprocessing.StateRequestType;
@@ -40,9 +40,8 @@ import java.util.Map;
 public class AbstractMapState<K, N, UK, V> extends AbstractKeyedState<K, N, V>
         implements InternalMapState<K, N, UK, V> {
 
-    public AbstractMapState(
-            StateRequestHandler stateRequestHandler, MapStateDescriptor<UK, V> 
stateDescriptor) {
-        super(stateRequestHandler, stateDescriptor);
+    public AbstractMapState(StateRequestHandler stateRequestHandler, 
TypeSerializer<V> serializer) {
+        super(stateRequestHandler, serializer);
     }
 
     @Override
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/AbstractReducingState.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/AbstractReducingState.java
index 45eb7f841b2..4e269d94bc1 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/AbstractReducingState.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/AbstractReducingState.java
@@ -19,8 +19,8 @@ package org.apache.flink.runtime.state.v2;
 
 import org.apache.flink.api.common.functions.ReduceFunction;
 import org.apache.flink.api.common.state.v2.ReducingState;
-import org.apache.flink.api.common.state.v2.ReducingStateDescriptor;
 import org.apache.flink.api.common.state.v2.StateFuture;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.core.state.StateFutureUtils;
 import org.apache.flink.runtime.asyncprocessing.StateRequestHandler;
 import org.apache.flink.runtime.asyncprocessing.StateRequestType;
@@ -44,9 +44,11 @@ public class AbstractReducingState<K, N, V> extends 
AbstractKeyedState<K, N, V>
     protected final ReduceFunction<V> reduceFunction;
 
     public AbstractReducingState(
-            StateRequestHandler stateRequestHandler, 
ReducingStateDescriptor<V> stateDescriptor) {
-        super(stateRequestHandler, stateDescriptor);
-        this.reduceFunction = stateDescriptor.getReduceFunction();
+            StateRequestHandler stateRequestHandler,
+            ReduceFunction<V> reduceFunction,
+            TypeSerializer<V> valueSerializer) {
+        super(stateRequestHandler, valueSerializer);
+        this.reduceFunction = reduceFunction;
     }
 
     @Override
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/AbstractValueState.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/AbstractValueState.java
index eda66c9cea2..558644cd871 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/AbstractValueState.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/AbstractValueState.java
@@ -19,7 +19,7 @@ package org.apache.flink.runtime.state.v2;
 
 import org.apache.flink.api.common.state.v2.StateFuture;
 import org.apache.flink.api.common.state.v2.ValueState;
-import org.apache.flink.api.common.state.v2.ValueStateDescriptor;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.runtime.asyncprocessing.AsyncExecutionController;
 import org.apache.flink.runtime.asyncprocessing.StateRequestHandler;
 import org.apache.flink.runtime.asyncprocessing.StateRequestType;
@@ -36,8 +36,8 @@ public class AbstractValueState<K, N, V> extends 
AbstractKeyedState<K, N, V>
         implements InternalValueState<K, N, V> {
 
     public AbstractValueState(
-            StateRequestHandler stateRequestHandler, ValueStateDescriptor<V> 
valueStateDescriptor) {
-        super(stateRequestHandler, valueStateDescriptor);
+            StateRequestHandler stateRequestHandler, TypeSerializer<V> 
valueSerializer) {
+        super(stateRequestHandler, valueSerializer);
     }
 
     @Override
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/asyncprocessing/AsyncExecutionControllerTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/asyncprocessing/AsyncExecutionControllerTest.java
index e77e3c522de..a9b4200fbb5 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/asyncprocessing/AsyncExecutionControllerTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/asyncprocessing/AsyncExecutionControllerTest.java
@@ -866,7 +866,7 @@ class AsyncExecutionControllerTest {
                 StateRequestHandler stateRequestHandler,
                 TestUnderlyingState underlyingState,
                 ValueStateDescriptor<Integer> stateDescriptor) {
-            super(stateRequestHandler, stateDescriptor);
+            super(stateRequestHandler, stateDescriptor.getSerializer());
             this.underlyingState = underlyingState;
             
assertThat(this.getValueSerializer()).isEqualTo(IntSerializer.INSTANCE);
         }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AbstractAggregatingStateTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AbstractAggregatingStateTest.java
index 3be5ed9f418..52b1e228b82 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AbstractAggregatingStateTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AbstractAggregatingStateTest.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.runtime.state.v2;
 
+import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.functions.AggregateFunction;
 import org.apache.flink.api.common.state.v2.AggregatingStateDescriptor;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
@@ -79,8 +80,12 @@ class AbstractAggregatingStateTest extends 
AbstractKeyedStateTestBase {
         AggregatingStateDescriptor<Integer, Integer, Integer> descriptor =
                 new AggregatingStateDescriptor<>(
                         "testAggState", aggregator, 
BasicTypeInfo.INT_TYPE_INFO);
+        descriptor.initializeSerializerUnlessSet(new ExecutionConfig());
         AbstractAggregatingState<String, Void, Integer, Integer, Integer> 
state =
-                new AbstractAggregatingState<>(aec, descriptor);
+                new AbstractAggregatingState<>(
+                        aec,
+                        descriptor.getAggregateFunction(),
+                        descriptor.getSerializer().duplicate());
 
         aec.setCurrentContext(aec.buildContext("test", "test"));
 
@@ -107,6 +112,7 @@ class AbstractAggregatingStateTest extends 
AbstractKeyedStateTestBase {
         AggregatingStateDescriptor<Integer, Integer, Integer> descriptor =
                 new AggregatingStateDescriptor<>(
                         "testState", aggregator, BasicTypeInfo.INT_TYPE_INFO);
+        descriptor.initializeSerializerUnlessSet(new ExecutionConfig());
         AggregatingStateExecutor aggregatingStateExecutor = new 
AggregatingStateExecutor();
         AsyncExecutionController<String> aec =
                 new AsyncExecutionController<>(
@@ -121,7 +127,8 @@ class AbstractAggregatingStateTest extends 
AbstractKeyedStateTestBase {
                         null,
                         null);
         AbstractAggregatingState<String, String, Integer, Integer, Integer> 
aggregatingState =
-                new AbstractAggregatingState<>(aec, descriptor);
+                new AbstractAggregatingState<>(
+                        aec, descriptor.getAggregateFunction(), 
descriptor.getSerializer());
         aec.setCurrentContext(aec.buildContext("test", "test"));
         aec.setCurrentNamespaceForState(aggregatingState, "1");
         aggregatingState.add(1);
@@ -143,6 +150,7 @@ class AbstractAggregatingStateTest extends 
AbstractKeyedStateTestBase {
         AggregatingStateDescriptor<Integer, Integer, Integer> descriptor =
                 new AggregatingStateDescriptor<>(
                         "testState", aggregator, BasicTypeInfo.INT_TYPE_INFO);
+        descriptor.initializeSerializerUnlessSet(new ExecutionConfig());
         AggregatingStateExecutor aggregatingStateExecutor = new 
AggregatingStateExecutor();
         AsyncExecutionController<String> aec =
                 new AsyncExecutionController<>(
@@ -157,7 +165,8 @@ class AbstractAggregatingStateTest extends 
AbstractKeyedStateTestBase {
                         null,
                         null);
         AbstractAggregatingState<String, String, Integer, Integer, Integer> 
aggregatingState =
-                new AbstractAggregatingState<>(aec, descriptor);
+                new AbstractAggregatingState<>(
+                        aec, descriptor.getAggregateFunction(), 
descriptor.getSerializer());
         aec.setCurrentContext(aec.buildContext("test", "test"));
         aec.setCurrentNamespaceForState(aggregatingState, "1");
         aggregatingState.asyncAdd(1);
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AbstractListStateTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AbstractListStateTest.java
index 456b23dfd3b..060d317305d 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AbstractListStateTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AbstractListStateTest.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.runtime.state.v2;
 
+import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.state.v2.ListStateDescriptor;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
 import org.apache.flink.runtime.asyncprocessing.StateRequestType;
@@ -35,8 +36,9 @@ public class AbstractListStateTest extends 
AbstractKeyedStateTestBase {
     public void testEachOperation() {
         ListStateDescriptor<Integer> descriptor =
                 new ListStateDescriptor<>("testState", 
BasicTypeInfo.INT_TYPE_INFO);
+        descriptor.initializeSerializerUnlessSet(new ExecutionConfig());
         AbstractListState<String, Void, Integer> listState =
-                new AbstractListState<>(aec, descriptor);
+                new AbstractListState<>(aec, descriptor.getSerializer());
         aec.setCurrentContext(aec.buildContext("test", "test"));
 
         listState.asyncClear();
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AbstractMapStateTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AbstractMapStateTest.java
index 5d08c561b0c..fa9fda312ed 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AbstractMapStateTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AbstractMapStateTest.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.runtime.state.v2;
 
+import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.state.v2.MapStateDescriptor;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
 import org.apache.flink.api.java.tuple.Tuple2;
@@ -37,8 +38,9 @@ public class AbstractMapStateTest extends 
AbstractKeyedStateTestBase {
         MapStateDescriptor<String, Integer> descriptor =
                 new MapStateDescriptor<>(
                         "testState", BasicTypeInfo.STRING_TYPE_INFO, 
BasicTypeInfo.INT_TYPE_INFO);
+        descriptor.initializeSerializerUnlessSet(new ExecutionConfig());
         AbstractMapState<String, Void, String, Integer> mapState =
-                new AbstractMapState<>(aec, descriptor);
+                new AbstractMapState<>(aec, descriptor.getSerializer());
         aec.setCurrentContext(aec.buildContext("test", "test"));
 
         mapState.asyncClear();
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AbstractReducingStateTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AbstractReducingStateTest.java
index cac5f8631b2..b60e52eb143 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AbstractReducingStateTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AbstractReducingStateTest.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.runtime.state.v2;
 
+import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.functions.ReduceFunction;
 import org.apache.flink.api.common.state.v2.ReducingStateDescriptor;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
@@ -51,8 +52,10 @@ public class AbstractReducingStateTest extends 
AbstractKeyedStateTestBase {
         ReduceFunction<Integer> reducer = Integer::sum;
         ReducingStateDescriptor<Integer> descriptor =
                 new ReducingStateDescriptor<>("testState", reducer, 
BasicTypeInfo.INT_TYPE_INFO);
+        descriptor.initializeSerializerUnlessSet(new ExecutionConfig());
         AbstractReducingState<String, Void, Integer> reducingState =
-                new AbstractReducingState<>(aec, descriptor);
+                new AbstractReducingState<>(
+                        aec, descriptor.getReduceFunction(), 
descriptor.getSerializer());
         aec.setCurrentContext(aec.buildContext("test", "test"));
 
         reducingState.asyncClear();
@@ -81,6 +84,7 @@ public class AbstractReducingStateTest extends 
AbstractKeyedStateTestBase {
         ReduceFunction<Integer> reducer = Integer::sum;
         ReducingStateDescriptor<Integer> descriptor =
                 new ReducingStateDescriptor<>("testState", reducer, 
BasicTypeInfo.INT_TYPE_INFO);
+        descriptor.initializeSerializerUnlessSet(new ExecutionConfig());
         AsyncExecutionController<String> aec =
                 new AsyncExecutionController<>(
                         new SyncMailboxExecutor(),
@@ -94,7 +98,8 @@ public class AbstractReducingStateTest extends 
AbstractKeyedStateTestBase {
                         null,
                         null);
         AbstractReducingState<String, String, Integer> reducingState =
-                new AbstractReducingState<>(aec, descriptor);
+                new AbstractReducingState<>(
+                        aec, descriptor.getReduceFunction(), 
descriptor.getSerializer());
         aec.setCurrentContext(aec.buildContext("test", "test"));
         aec.setCurrentNamespaceForState(reducingState, "1");
         reducingState.asyncAdd(1);
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AbstractValueStateTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AbstractValueStateTest.java
index 307e7063e32..f71280d4e28 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AbstractValueStateTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/AbstractValueStateTest.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.runtime.state.v2;
 
+import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.state.v2.ValueStateDescriptor;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
 import org.apache.flink.runtime.asyncprocessing.StateRequestType;
@@ -32,8 +33,9 @@ public class AbstractValueStateTest extends 
AbstractKeyedStateTestBase {
     public void testEachOperation() {
         ValueStateDescriptor<Integer> descriptor =
                 new ValueStateDescriptor<>("testState", 
BasicTypeInfo.INT_TYPE_INFO);
+        descriptor.initializeSerializerUnlessSet(new ExecutionConfig());
         AbstractValueState<String, Void, Integer> valueState =
-                new AbstractValueState<>(aec, descriptor);
+                new AbstractValueState<>(aec, descriptor.getSerializer());
         aec.setCurrentContext(aec.buildContext("test", "test"));
 
         valueState.asyncClear();
diff --git 
a/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStAggregatingState.java
 
b/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStAggregatingState.java
index ba05170bf20..63ade85b437 100644
--- 
a/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStAggregatingState.java
+++ 
b/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStAggregatingState.java
@@ -18,8 +18,8 @@
 
 package org.apache.flink.state.forst;
 
+import org.apache.flink.api.common.functions.AggregateFunction;
 import org.apache.flink.api.common.state.v2.AggregatingState;
-import org.apache.flink.api.common.state.v2.AggregatingStateDescriptor;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.core.memory.DataInputDeserializer;
 import org.apache.flink.core.memory.DataOutputSerializer;
@@ -77,7 +77,8 @@ public class ForStAggregatingState<K, N, IN, ACC, OUT>
      * @param stateDescriptor     The properties of the state.
      */
     public ForStAggregatingState(
-            AggregatingStateDescriptor<IN, ACC, OUT> stateDescriptor,
+            AggregateFunction<IN, ACC, OUT> aggregateFunction,
+            TypeSerializer<ACC> valueSerializer,
             StateRequestHandler stateRequestHandler,
             ColumnFamilyHandle columnFamily,
             Supplier<SerializedCompositeKeyBuilder<K>> 
serializedKeyBuilderInitializer,
@@ -85,7 +86,7 @@ public class ForStAggregatingState<K, N, IN, ACC, OUT>
             Supplier<TypeSerializer<N>> namespaceSerializerInitializer,
             Supplier<DataOutputSerializer> valueSerializerViewInitializer,
             Supplier<DataInputDeserializer> valueDeserializerViewInitializer) {
-        super(stateRequestHandler, stateDescriptor);
+        super(stateRequestHandler, aggregateFunction, valueSerializer);
         this.columnFamilyHandle = columnFamily;
         this.serializedKeyBuilder = 
ThreadLocal.withInitial(serializedKeyBuilderInitializer);
         this.namespaceSerializer = 
ThreadLocal.withInitial(namespaceSerializerInitializer);
diff --git 
a/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStKeyedStateBackend.java
 
b/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStKeyedStateBackend.java
index 1d78c5bd1a5..65eb52907d4 100644
--- 
a/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStKeyedStateBackend.java
+++ 
b/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStKeyedStateBackend.java
@@ -20,12 +20,10 @@ package org.apache.flink.state.forst;
 import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.state.v2.AggregatingStateDescriptor;
-import org.apache.flink.api.common.state.v2.ListStateDescriptor;
 import org.apache.flink.api.common.state.v2.MapStateDescriptor;
 import org.apache.flink.api.common.state.v2.ReducingStateDescriptor;
 import org.apache.flink.api.common.state.v2.State;
 import org.apache.flink.api.common.state.v2.StateDescriptor;
-import org.apache.flink.api.common.state.v2.ValueStateDescriptor;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.common.typeutils.TypeSerializerSchemaCompatibility;
 import org.apache.flink.api.java.tuple.Tuple2;
@@ -300,7 +298,7 @@ public class ForStKeyedStateBackend<K> implements 
AsyncKeyedStateBackend<K> {
                         new ForStValueState<>(
                                 stateRequestHandler,
                                 columnFamilyHandle,
-                                (ValueStateDescriptor<SV>) stateDesc,
+                                registerResult.f1.getStateSerializer(),
                                 serializedKeyBuilder,
                                 defaultNamespace,
                                 namespaceSerializer::duplicate,
@@ -312,7 +310,7 @@ public class ForStKeyedStateBackend<K> implements 
AsyncKeyedStateBackend<K> {
                         new ForStListState<>(
                                 stateRequestHandler,
                                 columnFamilyHandle,
-                                (ListStateDescriptor<SV>) stateDesc,
+                                registerResult.f1.getStateSerializer(),
                                 serializedKeyBuilder,
                                 defaultNamespace,
                                 namespaceSerializer::duplicate,
@@ -320,23 +318,28 @@ public class ForStKeyedStateBackend<K> implements 
AsyncKeyedStateBackend<K> {
                                 valueDeserializerView);
             case MAP:
                 Supplier<DataInputDeserializer> keyDeserializerView = 
DataInputDeserializer::new;
-                return ForStMapState.create(
-                        stateDesc,
-                        stateRequestHandler,
-                        columnFamilyHandle,
-                        serializedKeyBuilder,
-                        defaultNamespace,
-                        namespaceSerializer::duplicate,
-                        valueSerializerView,
-                        keyDeserializerView,
-                        valueDeserializerView,
-                        keyGroupPrefixBytes);
+                RegisteredKeyAndUserKeyValueStateBackendMetaInfo 
mapStateMetaInfo =
+                        (RegisteredKeyAndUserKeyValueStateBackendMetaInfo) 
registerResult.f1;
+                return (S)
+                        ForStMapState.create(
+                                mapStateMetaInfo.getUserKeySerializer(),
+                                mapStateMetaInfo.getStateSerializer(),
+                                stateRequestHandler,
+                                columnFamilyHandle,
+                                serializedKeyBuilder,
+                                defaultNamespace,
+                                namespaceSerializer::duplicate,
+                                valueSerializerView,
+                                keyDeserializerView,
+                                valueDeserializerView,
+                                keyGroupPrefixBytes);
             case REDUCING:
                 return (S)
                         new ForStReducingState<>(
                                 stateRequestHandler,
                                 columnFamilyHandle,
-                                (ReducingStateDescriptor<SV>) stateDesc,
+                                ((ReducingStateDescriptor<SV>) 
stateDesc).getReduceFunction(),
+                                registerResult.f1.getStateSerializer(),
                                 serializedKeyBuilder,
                                 defaultNamespace,
                                 namespaceSerializer::duplicate,
@@ -345,7 +348,9 @@ public class ForStKeyedStateBackend<K> implements 
AsyncKeyedStateBackend<K> {
             case AGGREGATING:
                 return (S)
                         new ForStAggregatingState<>(
-                                (AggregatingStateDescriptor<?, SV, ?>) 
stateDesc,
+                                ((AggregatingStateDescriptor<?, SV, ?>) 
stateDesc)
+                                        .getAggregateFunction(),
+                                registerResult.f1.getStateSerializer(),
                                 stateRequestHandler,
                                 columnFamilyHandle,
                                 serializedKeyBuilder,
diff --git 
a/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStListState.java
 
b/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStListState.java
index e906ad0ffb4..dad01fefb5a 100644
--- 
a/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStListState.java
+++ 
b/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStListState.java
@@ -19,7 +19,6 @@
 package org.apache.flink.state.forst;
 
 import org.apache.flink.api.common.state.v2.ListState;
-import org.apache.flink.api.common.state.v2.ListStateDescriptor;
 import org.apache.flink.api.common.state.v2.StateFuture;
 import org.apache.flink.api.common.state.v2.StateIterator;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
@@ -83,13 +82,13 @@ public class ForStListState<K, N, V> extends 
AbstractListState<K, N, V>
     public ForStListState(
             StateRequestHandler stateRequestHandler,
             ColumnFamilyHandle columnFamily,
-            ListStateDescriptor<V> listStateDescriptor,
+            TypeSerializer<V> valueSerializer,
             Supplier<SerializedCompositeKeyBuilder<K>> 
serializedKeyBuilderInitializer,
             N defaultNamespace,
             Supplier<TypeSerializer<N>> namespaceSerializerInitializer,
             Supplier<DataOutputSerializer> valueSerializerViewInitializer,
             Supplier<DataInputDeserializer> valueDeserializerViewInitializer) {
-        super(stateRequestHandler, listStateDescriptor);
+        super(stateRequestHandler, valueSerializer);
         this.columnFamilyHandle = columnFamily;
         this.serializedKeyBuilder = 
ThreadLocal.withInitial(serializedKeyBuilderInitializer);
         this.defaultNamespace = defaultNamespace;
diff --git 
a/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStMapState.java
 
b/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStMapState.java
index 43f22dc734e..bb9fb9a931a 100644
--- 
a/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStMapState.java
+++ 
b/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStMapState.java
@@ -18,9 +18,7 @@
 package org.apache.flink.state.forst;
 
 import org.apache.flink.api.common.state.v2.MapState;
-import org.apache.flink.api.common.state.v2.MapStateDescriptor;
 import org.apache.flink.api.common.state.v2.State;
-import org.apache.flink.api.common.state.v2.StateDescriptor;
 import org.apache.flink.api.common.state.v2.StateIterator;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.tuple.Tuple2;
@@ -84,7 +82,8 @@ public class ForStMapState<K, N, UK, UV> extends 
AbstractMapState<K, N, UK, UV>
     public ForStMapState(
             StateRequestHandler stateRequestHandler,
             ColumnFamilyHandle columnFamily,
-            MapStateDescriptor<UK, UV> stateDescriptor,
+            TypeSerializer<UK> userKeySerializer,
+            TypeSerializer<UV> valueSerializer,
             Supplier<SerializedCompositeKeyBuilder<K>> 
serializedKeyBuilderInitializer,
             N defaultNamespace,
             Supplier<TypeSerializer<N>> namespaceSerializerInitializer,
@@ -92,7 +91,7 @@ public class ForStMapState<K, N, UK, UV> extends 
AbstractMapState<K, N, UK, UV>
             Supplier<DataInputDeserializer> keyDeserializerViewInitializer,
             Supplier<DataInputDeserializer> valueDeserializerViewInitializer,
             int keyGroupPrefixBytes) {
-        super(stateRequestHandler, stateDescriptor);
+        super(stateRequestHandler, valueSerializer);
         this.columnFamilyHandle = columnFamily;
         this.serializedKeyBuilder = 
ThreadLocal.withInitial(serializedKeyBuilderInitializer);
         this.defaultNamespace = defaultNamespace;
@@ -100,8 +99,8 @@ public class ForStMapState<K, N, UK, UV> extends 
AbstractMapState<K, N, UK, UV>
         this.valueSerializerView = 
ThreadLocal.withInitial(valueSerializerViewInitializer);
         this.keyDeserializerView = 
ThreadLocal.withInitial(keyDeserializerViewInitializer);
         this.valueDeserializerView = 
ThreadLocal.withInitial(valueDeserializerViewInitializer);
-        this.userKeySerializer = 
ThreadLocal.withInitial(stateDescriptor::getUserKeySerializer);
-        this.userValueSerializer = 
ThreadLocal.withInitial(stateDescriptor::getSerializer);
+        this.userKeySerializer = 
ThreadLocal.withInitial(userKeySerializer::duplicate);
+        this.userValueSerializer = 
ThreadLocal.withInitial(valueSerializer::duplicate);
         this.keyGroupPrefixBytes = keyGroupPrefixBytes;
     }
 
@@ -299,7 +298,8 @@ public class ForStMapState<K, N, UK, UV> extends 
AbstractMapState<K, N, UK, UV>
 
     @SuppressWarnings("unchecked")
     static <N, UK, UV, K, SV, S extends State> S create(
-            StateDescriptor<SV> stateDescriptor,
+            TypeSerializer<UK> userKeySerializer,
+            TypeSerializer<UV> valueSerializer,
             StateRequestHandler stateRequestHandler,
             ColumnFamilyHandle columnFamily,
             Supplier<SerializedCompositeKeyBuilder<K>> 
serializedKeyBuilderInitializer,
@@ -313,7 +313,8 @@ public class ForStMapState<K, N, UK, UV> extends 
AbstractMapState<K, N, UK, UV>
                 new ForStMapState<>(
                         stateRequestHandler,
                         columnFamily,
-                        (MapStateDescriptor<UK, UV>) stateDescriptor,
+                        userKeySerializer,
+                        valueSerializer,
                         serializedKeyBuilderInitializer,
                         defaultNamespace,
                         namespaceSerializerInitializer,
diff --git 
a/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStReducingState.java
 
b/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStReducingState.java
index 2f32fc3577d..a091be5ffe7 100644
--- 
a/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStReducingState.java
+++ 
b/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStReducingState.java
@@ -18,7 +18,7 @@
 
 package org.apache.flink.state.forst;
 
-import org.apache.flink.api.common.state.v2.ReducingStateDescriptor;
+import org.apache.flink.api.common.functions.ReduceFunction;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.core.memory.DataInputDeserializer;
 import org.apache.flink.core.memory.DataOutputSerializer;
@@ -72,13 +72,14 @@ public class ForStReducingState<K, N, V> extends 
AbstractReducingState<K, N, V>
     public ForStReducingState(
             StateRequestHandler stateRequestHandler,
             ColumnFamilyHandle columnFamily,
-            ReducingStateDescriptor<V> reducingStateDescriptor,
+            ReduceFunction<V> reduceFunction,
+            TypeSerializer<V> valueSerializer,
             Supplier<SerializedCompositeKeyBuilder<K>> 
serializedKeyBuilderInitializer,
             N defaultNamespace,
             Supplier<TypeSerializer<N>> namespaceSerializerInitializer,
             Supplier<DataOutputSerializer> valueSerializerViewInitializer,
             Supplier<DataInputDeserializer> valueDeserializerViewInitializer) {
-        super(stateRequestHandler, reducingStateDescriptor);
+        super(stateRequestHandler, reduceFunction, valueSerializer);
         this.columnFamilyHandle = columnFamily;
         this.serializedKeyBuilder = 
ThreadLocal.withInitial(serializedKeyBuilderInitializer);
         this.defaultNamespace = defaultNamespace;
diff --git 
a/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStValueState.java
 
b/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStValueState.java
index 8a6a4dc0c08..807ff9fe350 100644
--- 
a/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStValueState.java
+++ 
b/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStValueState.java
@@ -19,7 +19,6 @@
 package org.apache.flink.state.forst;
 
 import org.apache.flink.api.common.state.v2.ValueState;
-import org.apache.flink.api.common.state.v2.ValueStateDescriptor;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.core.memory.DataInputDeserializer;
 import org.apache.flink.core.memory.DataOutputSerializer;
@@ -72,13 +71,13 @@ public class ForStValueState<K, N, V> extends 
AbstractValueState<K, N, V>
     public ForStValueState(
             StateRequestHandler stateRequestHandler,
             ColumnFamilyHandle columnFamily,
-            ValueStateDescriptor<V> valueStateDescriptor,
+            TypeSerializer<V> valueSerializer,
             Supplier<SerializedCompositeKeyBuilder<K>> 
serializedKeyBuilderInitializer,
             N defaultNamespace,
             Supplier<TypeSerializer<N>> namespaceSerializerInitializer,
             Supplier<DataOutputSerializer> valueSerializerViewInitializer,
             Supplier<DataInputDeserializer> valueDeserializerViewInitializer) {
-        super(stateRequestHandler, valueStateDescriptor);
+        super(stateRequestHandler, valueSerializer);
         this.columnFamilyHandle = columnFamily;
         this.serializedKeyBuilder = 
ThreadLocal.withInitial(serializedKeyBuilderInitializer);
         this.defaultNamespace = defaultNamespace;
diff --git 
a/flink-state-backends/flink-statebackend-forst/src/test/java/org/apache/flink/state/forst/ForStDBOperationTestBase.java
 
b/flink-state-backends/flink-statebackend-forst/src/test/java/org/apache/flink/state/forst/ForStDBOperationTestBase.java
index 5db0c25aa09..074d428f86a 100644
--- 
a/flink-state-backends/flink-statebackend-forst/src/test/java/org/apache/flink/state/forst/ForStDBOperationTestBase.java
+++ 
b/flink-state-backends/flink-statebackend-forst/src/test/java/org/apache/flink/state/forst/ForStDBOperationTestBase.java
@@ -136,7 +136,7 @@ public class ForStDBOperationTestBase {
         return new ForStValueState<>(
                 stateRequestHandler,
                 cf,
-                valueStateDescriptor,
+                valueStateDescriptor.getSerializer(),
                 serializedKeyBuilder,
                 VoidNamespace.INSTANCE,
                 () -> VoidNamespaceSerializer.INSTANCE,
@@ -157,7 +157,7 @@ public class ForStDBOperationTestBase {
         return new ForStListState<>(
                 buildMockStateRequestHandler(),
                 cf,
-                valueStateDescriptor,
+                valueStateDescriptor.getSerializer(),
                 serializedKeyBuilder,
                 VoidNamespace.INSTANCE,
                 () -> VoidNamespaceSerializer.INSTANCE,
@@ -200,7 +200,8 @@ public class ForStDBOperationTestBase {
                 () -> new DataInputDeserializer(new byte[128]);
 
         return new ForStAggregatingState<>(
-                valueStateDescriptor,
+                valueStateDescriptor.getAggregateFunction(),
+                valueStateDescriptor.getSerializer(),
                 buildMockStateRequestHandler(),
                 cf,
                 serializedKeyBuilder,
@@ -226,7 +227,8 @@ public class ForStDBOperationTestBase {
         return new ForStMapState<>(
                 stateRequestHandler,
                 cf,
-                mapStateDescriptor,
+                mapStateDescriptor.getUserKeySerializer(),
+                mapStateDescriptor.getSerializer(),
                 serializedKeyBuilder,
                 VoidNamespace.INSTANCE,
                 () -> VoidNamespaceSerializer.INSTANCE,
diff --git 
a/flink-state-backends/flink-statebackend-forst/src/test/java/org/apache/flink/state/forst/ForStStateMigrationTest.java
 
b/flink-state-backends/flink-statebackend-forst/src/test/java/org/apache/flink/state/forst/ForStStateMigrationTest.java
index a9120a4edc6..48f431b2209 100644
--- 
a/flink-state-backends/flink-statebackend-forst/src/test/java/org/apache/flink/state/forst/ForStStateMigrationTest.java
+++ 
b/flink-state-backends/flink-statebackend-forst/src/test/java/org/apache/flink/state/forst/ForStStateMigrationTest.java
@@ -18,10 +18,16 @@
 
 package org.apache.flink.state.forst;
 
+import org.apache.flink.api.common.serialization.SerializerConfigImpl;
 import org.apache.flink.api.common.state.v2.MapState;
 import org.apache.flink.api.common.state.v2.MapStateDescriptor;
+import org.apache.flink.api.common.state.v2.ValueState;
+import org.apache.flink.api.common.state.v2.ValueStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.common.typeutils.base.IntSerializer;
 import org.apache.flink.api.common.typeutils.base.StringSerializer;
+import org.apache.flink.api.java.typeutils.GenericTypeInfo;
+import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.core.fs.FileSystem;
 import org.apache.flink.core.testutils.CommonTestUtils;
@@ -29,15 +35,20 @@ import 
org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.state.CheckpointStorageLocationReference;
 import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.SnapshotResult;
+import org.apache.flink.runtime.state.StateBackendTestBase;
+import org.apache.flink.runtime.state.v2.AbstractKeyedState;
 import org.apache.flink.util.IOUtils;
 import org.apache.flink.util.StateMigrationException;
 
 import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.io.TempDir;
 
+import java.io.File;
 import java.util.Collections;
 import java.util.concurrent.RunnableFuture;
 
 import static 
org.apache.flink.state.forst.ForStTestUtils.createKeyedStateBackend;
+import static org.assertj.core.api.Assertions.assertThat;
 import static org.junit.Assert.fail;
 
 /** Tests for {@link ForStListState}. */
@@ -99,4 +110,159 @@ public class ForStStateMigrationTest extends 
ForStStateTestBase {
             }
         }
     }
+
+    @Test
+    void testKryoRestoreResilienceWithDifferentRegistrationOrder(@TempDir File 
newTmpDir)
+            throws Exception {
+
+        // register A first then B
+        ((SerializerConfigImpl) env.getExecutionConfig().getSerializerConfig())
+                
.registerKryoType(StateBackendTestBase.TestNestedPojoClassA.class);
+        ((SerializerConfigImpl) env.getExecutionConfig().getSerializerConfig())
+                
.registerKryoType(StateBackendTestBase.TestNestedPojoClassB.class);
+
+        TypeInformation<StateBackendTestBase.TestPojo> pojoType =
+                new GenericTypeInfo<>(StateBackendTestBase.TestPojo.class);
+
+        // make sure that we are in fact using the KryoSerializer
+        
assertThat(pojoType.createSerializer(env.getExecutionConfig().getSerializerConfig()))
+                .isInstanceOf(KryoSerializer.class);
+
+        ValueStateDescriptor<StateBackendTestBase.TestPojo> stateDescriptor =
+                new ValueStateDescriptor<>("id", pojoType);
+
+        ValueState<StateBackendTestBase.TestPojo> state =
+                keyedBackend.getOrCreateKeyedState(1, IntSerializer.INSTANCE, 
stateDescriptor);
+
+        // access the internal state representation to retrieve the original 
Kryo registration
+        // ids;
+        // these will be later used to check that on restore, the new Kryo 
serializer has
+        // reconfigured itself to
+        // have identical mappings
+        AbstractKeyedState abstractKeyedState = (AbstractKeyedState) state;
+        KryoSerializer<StateBackendTestBase.TestPojo> kryoSerializer =
+                (KryoSerializer<StateBackendTestBase.TestPojo>)
+                        abstractKeyedState.getValueSerializer();
+        int mainPojoClassRegistrationId =
+                kryoSerializer
+                        .getKryo()
+                        .getRegistration(StateBackendTestBase.TestPojo.class)
+                        .getId();
+        int nestedPojoClassARegistrationId =
+                kryoSerializer
+                        .getKryo()
+                        
.getRegistration(StateBackendTestBase.TestNestedPojoClassA.class)
+                        .getId();
+        int nestedPojoClassBRegistrationId =
+                kryoSerializer
+                        .getKryo()
+                        
.getRegistration(StateBackendTestBase.TestNestedPojoClassB.class)
+                        .getId();
+
+        // ============== create snapshot of current configuration 
==============
+
+        // make some more modifications
+        setCurrentContext("test", "test");
+        state.asyncUpdate(
+                new StateBackendTestBase.TestPojo(
+                        "u1",
+                        1,
+                        new StateBackendTestBase.TestNestedPojoClassA(1.0, 2),
+                        new StateBackendTestBase.TestNestedPojoClassB(2.3, 
"foo")));
+
+        drain();
+        RunnableFuture<SnapshotResult<KeyedStateHandle>> snapshot =
+                keyedBackend.snapshot(
+                        1L,
+                        System.currentTimeMillis(),
+                        env.getCheckpointStorageAccess()
+                                .resolveCheckpointStorageLocation(
+                                        1L, 
CheckpointStorageLocationReference.getDefault()),
+                        CheckpointOptions.forCheckpointWithDefaultLocation());
+
+        if (!snapshot.isDone()) {
+            snapshot.run();
+        }
+        SnapshotResult<KeyedStateHandle> snapshotResult = snapshot.get();
+        KeyedStateHandle stateHandle = 
snapshotResult.getJobManagerOwnedSnapshot();
+        IOUtils.closeQuietly(keyedBackend);
+        keyedBackend.dispose();
+
+        env = getMockEnvironment(newTmpDir);
+        ((SerializerConfigImpl) env.getExecutionConfig().getSerializerConfig())
+                .registerKryoType(
+                        StateBackendTestBase.TestNestedPojoClassB
+                                .class); // this time register B first
+        ((SerializerConfigImpl) env.getExecutionConfig().getSerializerConfig())
+                
.registerKryoType(StateBackendTestBase.TestNestedPojoClassA.class);
+
+        FileSystem.initialize(new Configuration(), null);
+        Configuration configuration = new Configuration();
+        ForStStateBackend forStStateBackend =
+                new ForStStateBackend().configure(configuration, null);
+        ForStKeyedStateBackend<String> restoredKeyedStateBackend =
+                createKeyedStateBackend(
+                        forStStateBackend,
+                        env,
+                        StringSerializer.INSTANCE,
+                        Collections.singletonList(stateHandle));
+        restoredKeyedStateBackend.setup(aec);
+
+        // re-initialize to ensure that we create the KryoSerializer from 
scratch, otherwise
+        // initializeSerializerUnlessSet would not pick up our new config
+        stateDescriptor = new ValueStateDescriptor<>("id", pojoType);
+        state =
+                restoredKeyedStateBackend.getOrCreateKeyedState(
+                        1, IntSerializer.INSTANCE, stateDescriptor);
+
+        // verify that on restore, the serializer that the state handle uses 
has reconfigured
+        // itself to have
+        // identical Kryo registration ids compared to the previous execution
+        abstractKeyedState = (AbstractKeyedState) state;
+        kryoSerializer =
+                (KryoSerializer<StateBackendTestBase.TestPojo>)
+                        abstractKeyedState.getValueSerializer();
+        assertThat(
+                        kryoSerializer
+                                .getKryo()
+                                
.getRegistration(StateBackendTestBase.TestPojo.class)
+                                .getId())
+                .isEqualTo(mainPojoClassRegistrationId);
+        assertThat(
+                        kryoSerializer
+                                .getKryo()
+                                
.getRegistration(StateBackendTestBase.TestNestedPojoClassA.class)
+                                .getId())
+                .isEqualTo(nestedPojoClassARegistrationId);
+        assertThat(
+                        kryoSerializer
+                                .getKryo()
+                                
.getRegistration(StateBackendTestBase.TestNestedPojoClassB.class)
+                                .getId())
+                .isEqualTo(nestedPojoClassBRegistrationId);
+
+        setCurrentContext("test", "test");
+
+        // update to test state backends that eagerly serialize, such as 
RocksDB
+        state.asyncUpdate(
+                new StateBackendTestBase.TestPojo(
+                        "u1",
+                        11,
+                        new StateBackendTestBase.TestNestedPojoClassA(22.1, 
12),
+                        new StateBackendTestBase.TestNestedPojoClassB(1.23, 
"foobar")));
+
+        RunnableFuture<SnapshotResult<KeyedStateHandle>> snapshot1 =
+                restoredKeyedStateBackend.snapshot(
+                        1L,
+                        System.currentTimeMillis(),
+                        env.getCheckpointStorageAccess()
+                                .resolveCheckpointStorageLocation(
+                                        1L, 
CheckpointStorageLocationReference.getDefault()),
+                        CheckpointOptions.forCheckpointWithDefaultLocation());
+
+        if (!snapshot1.isDone()) {
+            snapshot1.run();
+        }
+        snapshot1.get().discardState();
+    }
 }
diff --git 
a/flink-state-backends/flink-statebackend-forst/src/test/java/org/apache/flink/state/forst/ForStStateTestBase.java
 
b/flink-state-backends/flink-statebackend-forst/src/test/java/org/apache/flink/state/forst/ForStStateTestBase.java
index 8a385b027f4..05ae5837306 100644
--- 
a/flink-state-backends/flink-statebackend-forst/src/test/java/org/apache/flink/state/forst/ForStStateTestBase.java
+++ 
b/flink-state-backends/flink-statebackend-forst/src/test/java/org/apache/flink/state/forst/ForStStateTestBase.java
@@ -106,7 +106,7 @@ public class ForStStateTestBase {
         aec.drainInflightRecords(0);
     }
 
-    private static MockEnvironment getMockEnvironment(File tempDir) throws 
IOException {
+    protected static MockEnvironment getMockEnvironment(File tempDir) throws 
IOException {
         MockEnvironment env =
                 MockEnvironment.builder()
                         
.setUserCodeClassLoader(ForStStateBackendConfigTest.class.getClassLoader())

Reply via email to