This is an automated email from the ASF dual-hosted git repository.

zakelly pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 41379fbdbde17e9fd85979bfbaace05a94db27c9
Author: Xiangyu Feng <[email protected]>
AuthorDate: Wed Feb 5 21:28:46 2025 +0800

    [hotfix] Improve 
TtlAwareSerializer/TtlAwareSerializerSnapshot/TtlAwareSerializerSnapshotWrapper
    
    - Introduce TtlAwareListSerializer/TtlAwareMapSerializer in 
TtlAwareSerializer
    - Use original serializer type as a new generic param in TtlAwareSerializer
    - add ttl migration check and serializer snapshot UT in 
TtlAwareSerializerTest
    - add serializer restore UT in TtlAwareSerializerSnapshotWrapperTest
---
 .../runtime/state/ttl/TtlAwareSerializer.java      | 148 ++++++++++++++-------
 .../state/ttl/TtlAwareSerializerSnapshot.java      |  29 ++++
 .../ttl/TtlAwareSerializerSnapshotWrapper.java     |  76 ++---------
 .../ttl/TtlAwareSerializerSnapshotWrapperTest.java |  81 +++++++----
 .../runtime/state/ttl/TtlAwareSerializerTest.java  | 110 +++++++++------
 .../state/ttl/TtlAwareSerializerUpgradeTest.java   |   4 +-
 .../TypeSerializerTestCoverageTest.java            |   4 +
 7 files changed, 268 insertions(+), 184 deletions(-)

diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlAwareSerializer.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlAwareSerializer.java
index 0aff73eb97e..2f50ba6c668 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlAwareSerializer.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlAwareSerializer.java
@@ -27,22 +27,42 @@ import org.apache.flink.core.memory.DataOutputView;
 import org.apache.flink.util.function.SupplierWithException;
 
 import java.io.IOException;
+import java.util.List;
+import java.util.Map;
 import java.util.Objects;
 
+import static org.apache.flink.util.Preconditions.checkArgument;
+
 /**
  * This class wraps a {@link TypeSerializer} with ttl awareness. It will 
return true when the
  * wrapped {@link TypeSerializer} is instance of {@link 
TtlStateFactory.TtlSerializer}. Also, it
  * wraps the value migration process between TtlSerializer and non-ttl 
typeSerializer.
+ *
+ * @param <T> The data type that the serializer serializes.
+ * @param <S> The original serializer the TtlAwareSerializer wraps.
  */
-public class TtlAwareSerializer<T> extends TypeSerializer<T> {
+public class TtlAwareSerializer<T, S extends TypeSerializer<T>> extends 
TypeSerializer<T> {
 
     private final boolean isTtlEnabled;
 
-    private final TypeSerializer<T> typeSerializer;
+    private final S typeSerializer;
 
-    public TtlAwareSerializer(TypeSerializer<T> typeSerializer) {
+    public TtlAwareSerializer(S typeSerializer) {
+        checkArgument(
+                !(typeSerializer instanceof TtlAwareSerializer),
+                typeSerializer
+                        + " is already instance of TtlAwareSerializer, should 
not be wrapped repeatedly.");
         this.typeSerializer = typeSerializer;
-        this.isTtlEnabled = typeSerializer instanceof 
TtlStateFactory.TtlSerializer;
+        this.isTtlEnabled = 
TtlStateFactory.TtlSerializer.isTtlStateSerializer(typeSerializer);
+    }
+
+    public TtlAwareSerializer(S typeSerializer, boolean isTtlEnabled) {
+        checkArgument(
+                !(typeSerializer instanceof TtlAwareSerializer),
+                typeSerializer
+                        + " is already instance of TtlAwareSerializer, should 
not be wrapped repeatedly.");
+        this.typeSerializer = typeSerializer;
+        this.isTtlEnabled = isTtlEnabled;
     }
 
     @Override
@@ -52,7 +72,7 @@ public class TtlAwareSerializer<T> extends TypeSerializer<T> {
 
     @Override
     public TypeSerializer<T> duplicate() {
-        return new TtlAwareSerializer<>(typeSerializer.duplicate());
+        return new TtlAwareSerializer<>(typeSerializer.duplicate(), 
isTtlEnabled);
     }
 
     @Override
@@ -98,7 +118,7 @@ public class TtlAwareSerializer<T> extends TypeSerializer<T> 
{
         if (o == null || getClass() != o.getClass()) {
             return false;
         }
-        TtlAwareSerializer<?> that = (TtlAwareSerializer<?>) o;
+        TtlAwareSerializer<?, ?> that = (TtlAwareSerializer<?, ?>) o;
         return isTtlEnabled == that.isTtlEnabled
                 && Objects.equals(typeSerializer, that.typeSerializer);
     }
@@ -110,7 +130,7 @@ public class TtlAwareSerializer<T> extends 
TypeSerializer<T> {
 
     @SuppressWarnings("unchecked")
     public void migrateValueFromPriorSerializer(
-            TtlAwareSerializer<T> priorTtlAwareSerializer,
+            TtlAwareSerializer<T, ?> priorTtlAwareSerializer,
             SupplierWithException<T, IOException> inputSupplier,
             DataOutputView target,
             TtlTimeProvider ttlTimeProvider)
@@ -142,64 +162,96 @@ public class TtlAwareSerializer<T> extends 
TypeSerializer<T> {
         return isTtlEnabled;
     }
 
-    public TypeSerializer<T> getOriginalTypeSerializer() {
+    public S getOriginalTypeSerializer() {
         return typeSerializer;
     }
 
     @Override
     public TypeSerializerSnapshot<T> snapshotConfiguration() {
-        return new TtlAwareSerializerSnapshot<>(
-                typeSerializer.snapshotConfiguration(), isTtlEnabled);
+        return new 
TtlAwareSerializerSnapshotWrapper<>(typeSerializer.snapshotConfiguration())
+                .getTtlAwareSerializerSnapshot();
     }
 
     public static boolean isSerializerTtlEnabled(TypeSerializer<?> 
typeSerializer) {
-        TypeSerializer<?> wrappedTypeSerializer = 
wrapTtlAwareSerializer(typeSerializer);
-        boolean ttlSerializer =
-                wrappedTypeSerializer instanceof TtlAwareSerializer
-                        && ((TtlAwareSerializer<?>) 
wrappedTypeSerializer).isTtlEnabled();
-        boolean ttlListSerializer =
-                wrappedTypeSerializer instanceof ListSerializer
-                        && ((ListSerializer<?>) 
wrappedTypeSerializer).getElementSerializer()
-                                instanceof TtlAwareSerializer
-                        && ((TtlAwareSerializer<?>)
-                                        ((ListSerializer<?>) 
wrappedTypeSerializer)
-                                                .getElementSerializer())
-                                .isTtlEnabled();
-        boolean ttlMapSerializer =
-                wrappedTypeSerializer instanceof MapSerializer
-                        && ((MapSerializer<?, ?>) 
wrappedTypeSerializer).getValueSerializer()
-                                instanceof TtlAwareSerializer
-                        && ((TtlAwareSerializer<?>)
-                                        ((MapSerializer<?, ?>) 
wrappedTypeSerializer)
-                                                .getValueSerializer())
-                                .isTtlEnabled();
-        return ttlSerializer || ttlListSerializer || ttlMapSerializer;
-    }
-
-    public static TypeSerializer<?> wrapTtlAwareSerializer(TypeSerializer<?> 
typeSerializer) {
+        return wrapTtlAwareSerializer(typeSerializer).isTtlEnabled();
+    }
+
+    public static boolean needTtlStateMigration(
+            TypeSerializer<?> previousSerializer, TypeSerializer<?> 
newSerializer) {
+        return TtlAwareSerializer.isSerializerTtlEnabled(previousSerializer)
+                != TtlAwareSerializer.isSerializerTtlEnabled(newSerializer);
+    }
+
+    public static TtlAwareSerializer<?, ?> wrapTtlAwareSerializer(
+            TypeSerializer<?> typeSerializer) {
         if (typeSerializer instanceof TtlAwareSerializer) {
-            return typeSerializer;
+            return (TtlAwareSerializer<?, ?>) typeSerializer;
         }
 
         if (typeSerializer instanceof ListSerializer) {
-            return ((ListSerializer<?>) typeSerializer).getElementSerializer()
-                            instanceof TtlAwareSerializer
-                    ? typeSerializer
-                    : new ListSerializer<>(
-                            new TtlAwareSerializer<>(
-                                    ((ListSerializer<?>) 
typeSerializer).getElementSerializer()));
+            return new TtlAwareListSerializer<>((ListSerializer<?>) 
typeSerializer);
         }
 
         if (typeSerializer instanceof MapSerializer) {
-            return ((MapSerializer<?, ?>) typeSerializer).getValueSerializer()
-                            instanceof TtlAwareSerializer
-                    ? typeSerializer
-                    : new MapSerializer<>(
-                            ((MapSerializer<?, ?>) 
typeSerializer).getKeySerializer(),
-                            new TtlAwareSerializer<>(
-                                    ((MapSerializer<?, ?>) 
typeSerializer).getValueSerializer()));
+            return new TtlAwareMapSerializer<>((MapSerializer<?, ?>) 
typeSerializer);
         }
 
         return new TtlAwareSerializer<>(typeSerializer);
     }
+
+    /**
+     * The list version of {@link TtlAwareSerializer}.
+     *
+     * @param <T>
+     */
+    public static class TtlAwareListSerializer<T>
+            extends TtlAwareSerializer<List<T>, ListSerializer<T>> {
+
+        public TtlAwareListSerializer(ListSerializer<T> typeSerializer) {
+            super(typeSerializer);
+        }
+
+        // 
------------------------------------------------------------------------
+        //  ListSerializer specific properties
+        // 
------------------------------------------------------------------------
+
+        /**
+         * Gets the serializer for the elements of the list.
+         *
+         * @return The serializer for the elements of the list
+         */
+        @SuppressWarnings("unchecked")
+        public TtlAwareSerializer<T, TypeSerializer<T>> getElementSerializer() 
{
+            return (TtlAwareSerializer<T, TypeSerializer<T>>)
+                    TtlAwareSerializer.wrapTtlAwareSerializer(
+                            
getOriginalTypeSerializer().getElementSerializer());
+        }
+    }
+
+    /** The map version of {@link TtlAwareSerializer}. */
+    public static class TtlAwareMapSerializer<K, V>
+            extends TtlAwareSerializer<Map<K, V>, MapSerializer<K, V>> {
+
+        public TtlAwareMapSerializer(MapSerializer<K, V> typeSerializer) {
+            super(typeSerializer);
+        }
+
+        // 
------------------------------------------------------------------------
+        //  MapSerializer specific properties
+        // 
------------------------------------------------------------------------
+
+        @SuppressWarnings("unchecked")
+        public TtlAwareSerializer<K, TypeSerializer<K>> getKeySerializer() {
+            return (TtlAwareSerializer<K, TypeSerializer<K>>)
+                    TtlAwareSerializer.wrapTtlAwareSerializer(
+                            getOriginalTypeSerializer().getKeySerializer());
+        }
+
+        @SuppressWarnings("unchecked")
+        public TtlAwareSerializer<V, TypeSerializer<V>> getValueSerializer() {
+            return (TtlAwareSerializer<V, TypeSerializer<V>>)
+                    TtlAwareSerializer.wrapTtlAwareSerializer(
+                            getOriginalTypeSerializer().getValueSerializer());
+        }
+    }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlAwareSerializerSnapshot.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlAwareSerializerSnapshot.java
index bed9402f310..f3092988e39 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlAwareSerializerSnapshot.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlAwareSerializerSnapshot.java
@@ -25,6 +25,9 @@ import org.apache.flink.core.memory.DataInputView;
 import org.apache.flink.core.memory.DataOutputView;
 
 import java.io.IOException;
+import java.util.Objects;
+
+import static org.apache.flink.util.Preconditions.checkArgument;
 
 /**
  * A {@link TypeSerializerSnapshot} for TtlAwareSerializer. This class wraps a 
{@link
@@ -52,11 +55,19 @@ public class TtlAwareSerializerSnapshot<T> implements 
TypeSerializerSnapshot<T>
 
     public TtlAwareSerializerSnapshot(
             TypeSerializerSnapshot<T> typeSerializerSnapshot, boolean 
isTtlEnabled) {
+        checkArgument(
+                !(typeSerializerSnapshot instanceof 
TtlAwareSerializerSnapshot),
+                typeSerializerSnapshot
+                        + " is already instance of TtlAwareSerializerSnapshot, 
should not be wrapped repeatedly.");
         this.typeSerializerSnapshot = typeSerializerSnapshot;
         this.isTtlEnabled = isTtlEnabled;
     }
 
     public TtlAwareSerializerSnapshot(TypeSerializerSnapshot<T> 
typeSerializerSnapshot) {
+        checkArgument(
+                !(typeSerializerSnapshot instanceof 
TtlAwareSerializerSnapshot),
+                typeSerializerSnapshot
+                        + " is already instance of TtlAwareSerializerSnapshot, 
should not be wrapped repeatedly.");
         this.typeSerializerSnapshot = typeSerializerSnapshot;
         this.isTtlEnabled = typeSerializerSnapshot instanceof 
TtlStateFactory.TtlSerializerSnapshot;
     }
@@ -145,4 +156,22 @@ public class TtlAwareSerializerSnapshot<T> implements 
TypeSerializerSnapshot<T>
             return TypeSerializerSchemaCompatibility.incompatible();
         }
     }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) {
+            return true;
+        }
+        if (o == null || getClass() != o.getClass()) {
+            return false;
+        }
+        TtlAwareSerializerSnapshot<?> that = (TtlAwareSerializerSnapshot<?>) o;
+        return isTtlEnabled == that.isTtlEnabled
+                && Objects.equals(typeSerializerSnapshot, 
that.typeSerializerSnapshot);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(isTtlEnabled, typeSerializerSnapshot);
+    }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlAwareSerializerSnapshotWrapper.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlAwareSerializerSnapshotWrapper.java
index b903118139a..c24f1f75b28 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlAwareSerializerSnapshotWrapper.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlAwareSerializerSnapshotWrapper.java
@@ -18,86 +18,28 @@
 
 package org.apache.flink.runtime.state.ttl;
 
-import org.apache.flink.annotation.VisibleForTesting;
-import org.apache.flink.api.common.state.StateDescriptor;
 import org.apache.flink.api.common.typeutils.CompositeTypeSerializerUtil;
 import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
 import org.apache.flink.api.common.typeutils.base.ListSerializerSnapshot;
 import org.apache.flink.api.common.typeutils.base.MapSerializerSnapshot;
-import org.apache.flink.api.java.tuple.Tuple2;
-import org.apache.flink.runtime.state.metainfo.StateMetaInfoSnapshot;
-import org.apache.flink.util.Preconditions;
 
-import javax.annotation.Nonnull;
-
-import java.util.Map;
-import java.util.function.Supplier;
-import java.util.stream.Collectors;
-import java.util.stream.Stream;
-
-/**
- * Wrap the TypeSerializerSnapshot restored from {@link StateMetaInfoSnapshot} 
to
- * TtlAwareSerializerSnapshot
- */
+/** Wrap the TypeSerializerSnapshot restored from {@link 
TypeSerializerSnapshot} */
 public class TtlAwareSerializerSnapshotWrapper<T> {
 
-    private final StateDescriptor.Type stateType;
     private final TypeSerializerSnapshot<T> typeSerializerSnapshot;
 
-    private final Map<StateDescriptor.Type, 
Supplier<TypeSerializerSnapshot<T>>>
-            ttlAwareSerializerSnapshotFactories;
-
-    @SuppressWarnings("unchecked")
-    public TtlAwareSerializerSnapshotWrapper(@Nonnull StateMetaInfoSnapshot 
snapshot) {
-        this.stateType =
-                StateDescriptor.Type.valueOf(
-                        snapshot.getOption(
-                                
StateMetaInfoSnapshot.CommonOptionsKeys.KEYED_STATE_TYPE));
-        this.typeSerializerSnapshot =
-                (TypeSerializerSnapshot<T>)
-                        Preconditions.checkNotNull(
-                                snapshot.getTypeSerializerSnapshot(
-                                        
StateMetaInfoSnapshot.CommonSerializerKeys
-                                                .VALUE_SERIALIZER));
-        this.ttlAwareSerializerSnapshotFactories = 
createTtlAwareSerializerSnapshotFactories();
-    }
-
-    @VisibleForTesting
-    public TtlAwareSerializerSnapshotWrapper(
-            StateDescriptor.Type stateType, TypeSerializerSnapshot<T> 
typeSerializerSnapshot) {
-        this.stateType = stateType;
+    public TtlAwareSerializerSnapshotWrapper(TypeSerializerSnapshot<T> 
typeSerializerSnapshot) {
         this.typeSerializerSnapshot = typeSerializerSnapshot;
-        this.ttlAwareSerializerSnapshotFactories = 
createTtlAwareSerializerSnapshotFactories();
-    }
-
-    private Map<StateDescriptor.Type, Supplier<TypeSerializerSnapshot<T>>>
-            createTtlAwareSerializerSnapshotFactories() {
-        return Stream.of(
-                        Tuple2.of(
-                                StateDescriptor.Type.VALUE,
-                                (Supplier<TypeSerializerSnapshot<T>>)
-                                        this::wrapValueSerializerSnapshot),
-                        Tuple2.of(
-                                StateDescriptor.Type.LIST,
-                                (Supplier<TypeSerializerSnapshot<T>>)
-                                        this::wrapListSerializerSnapshot),
-                        Tuple2.of(
-                                StateDescriptor.Type.MAP,
-                                (Supplier<TypeSerializerSnapshot<T>>)
-                                        this::wrapMapSerializerSnapshot),
-                        Tuple2.of(
-                                StateDescriptor.Type.REDUCING,
-                                (Supplier<TypeSerializerSnapshot<T>>)
-                                        this::wrapValueSerializerSnapshot),
-                        Tuple2.of(
-                                StateDescriptor.Type.AGGREGATING,
-                                (Supplier<TypeSerializerSnapshot<T>>)
-                                        this::wrapValueSerializerSnapshot))
-                .collect(Collectors.toMap(t -> t.f0, t -> t.f1));
     }
 
     public TypeSerializerSnapshot<T> getTtlAwareSerializerSnapshot() {
-        return ttlAwareSerializerSnapshotFactories.get(stateType).get();
+        if (typeSerializerSnapshot instanceof ListSerializerSnapshot) {
+            return wrapListSerializerSnapshot();
+        } else if (typeSerializerSnapshot instanceof MapSerializerSnapshot) {
+            return wrapMapSerializerSnapshot();
+        } else {
+            return wrapValueSerializerSnapshot();
+        }
     }
 
     private TypeSerializerSnapshot<T> wrapValueSerializerSnapshot() {
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlAwareSerializerSnapshotWrapperTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlAwareSerializerSnapshotWrapperTest.java
index 769909223f5..ee14bf7484e 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlAwareSerializerSnapshotWrapperTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlAwareSerializerSnapshotWrapperTest.java
@@ -18,7 +18,6 @@
 
 package org.apache.flink.runtime.state.ttl;
 
-import org.apache.flink.api.common.state.StateDescriptor;
 import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
 import org.apache.flink.api.common.typeutils.base.IntSerializer;
 import org.apache.flink.api.common.typeutils.base.ListSerializer;
@@ -41,42 +40,27 @@ public class TtlAwareSerializerSnapshotWrapperTest {
         TypeSerializerSnapshot<Integer> intSerializerSnapshot =
                 IntSerializer.INSTANCE.snapshotConfiguration();
         TypeSerializerSnapshot<Integer> serializerSnapshot =
-                (new TtlAwareSerializerSnapshotWrapper<>(
-                                StateDescriptor.Type.VALUE, 
intSerializerSnapshot))
+                new TtlAwareSerializerSnapshotWrapper<>(intSerializerSnapshot)
                         .getTtlAwareSerializerSnapshot();
         
assertThat(serializerSnapshot).isInstanceOf(TtlAwareSerializerSnapshot.class);
         assertThat(
-                        ((TtlAwareSerializer<Integer>) 
serializerSnapshot.restoreSerializer())
+                        ((TtlAwareSerializer<Integer, IntSerializer>)
+                                        serializerSnapshot.restoreSerializer())
                                 .getOriginalTypeSerializer())
                 .isInstanceOf(IntSerializer.class);
     }
 
     @Test
-    public void testReducingStateTtlAwareSerializerSnapshot() {
+    public void testRestoreValueSerializer() {
         TypeSerializerSnapshot<Integer> intSerializerSnapshot =
                 IntSerializer.INSTANCE.snapshotConfiguration();
         TypeSerializerSnapshot<Integer> serializerSnapshot =
-                (new TtlAwareSerializerSnapshotWrapper<>(
-                                StateDescriptor.Type.REDUCING, 
intSerializerSnapshot))
+                (new 
TtlAwareSerializerSnapshotWrapper<>(intSerializerSnapshot))
                         .getTtlAwareSerializerSnapshot();
-        
assertThat(serializerSnapshot).isInstanceOf(TtlAwareSerializerSnapshot.class);
-        assertThat(
-                        ((TtlAwareSerializer<Integer>) 
serializerSnapshot.restoreSerializer())
-                                .getOriginalTypeSerializer())
-                .isInstanceOf(IntSerializer.class);
-    }
-
-    @Test
-    public void testAggregatingStateTtlAwareSerializerSnapshot() {
-        TypeSerializerSnapshot<Integer> intSerializerSnapshot =
-                IntSerializer.INSTANCE.snapshotConfiguration();
-        TypeSerializerSnapshot<Integer> serializerSnapshot =
-                (new TtlAwareSerializerSnapshotWrapper<>(
-                                StateDescriptor.Type.AGGREGATING, 
intSerializerSnapshot))
-                        .getTtlAwareSerializerSnapshot();
-        
assertThat(serializerSnapshot).isInstanceOf(TtlAwareSerializerSnapshot.class);
+        
assertThat(serializerSnapshot.restoreSerializer()).isInstanceOf(TtlAwareSerializer.class);
         assertThat(
-                        ((TtlAwareSerializer<Integer>) 
serializerSnapshot.restoreSerializer())
+                        ((TtlAwareSerializer<Integer, IntSerializer>)
+                                        serializerSnapshot.restoreSerializer())
                                 .getOriginalTypeSerializer())
                 .isInstanceOf(IntSerializer.class);
     }
@@ -87,8 +71,7 @@ public class TtlAwareSerializerSnapshotWrapperTest {
         TypeSerializerSnapshot<List<Integer>> listTypeSerializerSnapshot =
                 listSerializer.snapshotConfiguration();
         TypeSerializerSnapshot<List<Integer>> serializerSnapshot =
-                (new TtlAwareSerializerSnapshotWrapper<>(
-                                StateDescriptor.Type.LIST, 
listTypeSerializerSnapshot))
+                (new 
TtlAwareSerializerSnapshotWrapper<>(listTypeSerializerSnapshot))
                         .getTtlAwareSerializerSnapshot();
 
         
assertThat(serializerSnapshot).isInstanceOf(ListSerializerSnapshot.class);
@@ -98,6 +81,27 @@ public class TtlAwareSerializerSnapshotWrapperTest {
                 .isInstanceOf(TtlAwareSerializerSnapshot.class);
     }
 
+    @Test
+    @SuppressWarnings("rawtypes")
+    public void testRestoreListSerializer() {
+        ListSerializer<Integer> listSerializer = new 
ListSerializer<>(IntSerializer.INSTANCE);
+        TypeSerializerSnapshot<List<Integer>> listTypeSerializerSnapshot =
+                listSerializer.snapshotConfiguration();
+        TypeSerializerSnapshot<List<Integer>> serializerSnapshot =
+                (new 
TtlAwareSerializerSnapshotWrapper<>(listTypeSerializerSnapshot))
+                        .getTtlAwareSerializerSnapshot();
+
+        
assertThat(serializerSnapshot.restoreSerializer()).isInstanceOf(ListSerializer.class);
+        assertThat(((ListSerializer) 
serializerSnapshot.restoreSerializer()).getElementSerializer())
+                .isInstanceOf(TtlAwareSerializer.class);
+        assertThat(
+                        ((TtlAwareSerializer)
+                                        ((ListSerializer) 
serializerSnapshot.restoreSerializer())
+                                                .getElementSerializer())
+                                .getOriginalTypeSerializer())
+                .isInstanceOf(IntSerializer.class);
+    }
+
     @Test
     public void testMapStateTtlAwareSerializerSnapshot() {
         MapSerializer<String, String> mapSerializer =
@@ -105,8 +109,7 @@ public class TtlAwareSerializerSnapshotWrapperTest {
         TypeSerializerSnapshot<Map<String, String>> mapSerializerSnapshot =
                 mapSerializer.snapshotConfiguration();
         TypeSerializerSnapshot<Map<String, String>> serializerSnapshot =
-                (new TtlAwareSerializerSnapshotWrapper<>(
-                                StateDescriptor.Type.MAP, 
mapSerializerSnapshot))
+                (new 
TtlAwareSerializerSnapshotWrapper<>(mapSerializerSnapshot))
                         .getTtlAwareSerializerSnapshot();
 
         
assertThat(serializerSnapshot).isInstanceOf(MapSerializerSnapshot.class);
@@ -115,4 +118,26 @@ public class TtlAwareSerializerSnapshotWrapperTest {
                                 .getValueSerializerSnapshot())
                 .isInstanceOf(TtlAwareSerializerSnapshot.class);
     }
+
+    @Test
+    @SuppressWarnings("rawtypes")
+    public void testRestoreMapSerializer() {
+        MapSerializer<String, String> mapSerializer =
+                new MapSerializer<>(StringSerializer.INSTANCE, 
StringSerializer.INSTANCE);
+        TypeSerializerSnapshot<Map<String, String>> mapSerializerSnapshot =
+                mapSerializer.snapshotConfiguration();
+        TypeSerializerSnapshot<Map<String, String>> serializerSnapshot =
+                (new 
TtlAwareSerializerSnapshotWrapper<>(mapSerializerSnapshot))
+                        .getTtlAwareSerializerSnapshot();
+
+        
assertThat(serializerSnapshot.restoreSerializer()).isInstanceOf(MapSerializer.class);
+        assertThat(((MapSerializer) 
serializerSnapshot.restoreSerializer()).getValueSerializer())
+                .isInstanceOf(TtlAwareSerializer.class);
+        assertThat(
+                        ((TtlAwareSerializer)
+                                        ((MapSerializer) 
serializerSnapshot.restoreSerializer())
+                                                .getValueSerializer())
+                                .getOriginalTypeSerializer())
+                .isInstanceOf(StringSerializer.class);
+    }
 }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlAwareSerializerTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlAwareSerializerTest.java
index c3dba42401a..90c7f70432c 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlAwareSerializerTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlAwareSerializerTest.java
@@ -18,11 +18,12 @@
 
 package org.apache.flink.runtime.state.ttl;
 
-import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.common.typeutils.base.IntSerializer;
 import org.apache.flink.api.common.typeutils.base.ListSerializer;
+import org.apache.flink.api.common.typeutils.base.ListSerializerSnapshot;
 import org.apache.flink.api.common.typeutils.base.LongSerializer;
 import org.apache.flink.api.common.typeutils.base.MapSerializer;
+import org.apache.flink.api.common.typeutils.base.MapSerializerSnapshot;
 
 import org.junit.jupiter.api.Test;
 
@@ -51,6 +52,13 @@ class TtlAwareSerializerTest {
         
assertThat(TtlAwareSerializer.isSerializerTtlEnabled(intTtlSerializer)).isTrue();
         
assertThat(TtlAwareSerializer.isSerializerTtlEnabled(listTtlSerializer)).isTrue();
         
assertThat(TtlAwareSerializer.isSerializerTtlEnabled(mapTtlSerializer)).isTrue();
+
+        assertThat(TtlAwareSerializer.needTtlStateMigration(intSerializer, 
intTtlSerializer))
+                .isTrue();
+        assertThat(TtlAwareSerializer.needTtlStateMigration(listSerializer, 
listTtlSerializer))
+                .isTrue();
+        assertThat(TtlAwareSerializer.needTtlStateMigration(mapSerializer, 
mapTtlSerializer))
+                .isTrue();
     }
 
     @Test
@@ -60,27 +68,20 @@ class TtlAwareSerializerTest {
         MapSerializer<Integer, Integer> mapSerializer =
                 new MapSerializer<>(intSerializer, intSerializer);
 
-        TypeSerializer<?> intTtlAwareSerializer =
+        TtlAwareSerializer<?, ?> intTtlAwareSerializer =
                 TtlAwareSerializer.wrapTtlAwareSerializer(intSerializer);
-        ListSerializer<?> listTtlAwareSerializer =
-                (ListSerializer<?>) 
TtlAwareSerializer.wrapTtlAwareSerializer(listSerializer);
-        MapSerializer<?, ?> mapTtlAwareSerializer =
-                (MapSerializer<?, ?>) 
TtlAwareSerializer.wrapTtlAwareSerializer(mapSerializer);
-
-        
assertThat(intTtlAwareSerializer).isInstanceOf(TtlAwareSerializer.class);
-        assertThat(((TtlAwareSerializer<?>) 
intTtlAwareSerializer).isTtlEnabled()).isFalse();
-        assertThat(listTtlAwareSerializer.getElementSerializer())
-                .isInstanceOf(TtlAwareSerializer.class);
-        assertThat(
-                        ((TtlAwareSerializer<?>) 
listTtlAwareSerializer.getElementSerializer())
-                                .isTtlEnabled())
-                .isFalse();
-        assertThat(mapTtlAwareSerializer.getValueSerializer())
-                .isInstanceOf(TtlAwareSerializer.class);
-        assertThat(
-                        ((TtlAwareSerializer<?>) 
mapTtlAwareSerializer.getValueSerializer())
-                                .isTtlEnabled())
-                .isFalse();
+        TtlAwareSerializer<?, ?> listTtlAwareSerializer =
+                TtlAwareSerializer.wrapTtlAwareSerializer(listSerializer);
+        TtlAwareSerializer<?, ?> mapTtlAwareSerializer =
+                TtlAwareSerializer.wrapTtlAwareSerializer(mapSerializer);
+
+        assertThat(intTtlAwareSerializer.isTtlEnabled()).isFalse();
+        assertThat(listTtlAwareSerializer)
+                .isInstanceOf(TtlAwareSerializer.TtlAwareListSerializer.class);
+        assertThat((listTtlAwareSerializer).isTtlEnabled()).isFalse();
+        assertThat(mapTtlAwareSerializer)
+                .isInstanceOf(TtlAwareSerializer.TtlAwareMapSerializer.class);
+        assertThat(mapTtlAwareSerializer.isTtlEnabled()).isFalse();
     }
 
     @Test
@@ -93,26 +94,57 @@ class TtlAwareSerializerTest {
         MapSerializer<Integer, TtlValue<Integer>> mapTtlSerializer =
                 new MapSerializer<>(IntSerializer.INSTANCE, intTtlSerializer);
 
-        TypeSerializer<?> intTtlAwareSerializer =
+        TtlAwareSerializer<?, ?> intTtlAwareSerializer =
                 TtlAwareSerializer.wrapTtlAwareSerializer(intTtlSerializer);
-        ListSerializer<?> listTtlAwareSerializer =
-                (ListSerializer<?>) 
TtlAwareSerializer.wrapTtlAwareSerializer(listTtlSerializer);
-        MapSerializer<?, ?> mapTtlAwareSerializer =
-                (MapSerializer<?, ?>) 
TtlAwareSerializer.wrapTtlAwareSerializer(mapTtlSerializer);
-
-        
assertThat(intTtlAwareSerializer).isInstanceOf(TtlAwareSerializer.class);
-        assertThat(((TtlAwareSerializer<?>) 
intTtlAwareSerializer).isTtlEnabled()).isTrue();
-        assertThat(listTtlAwareSerializer.getElementSerializer())
-                .isInstanceOf(TtlAwareSerializer.class);
+        TtlAwareSerializer<?, ?> listTtlAwareSerializer =
+                TtlAwareSerializer.wrapTtlAwareSerializer(listTtlSerializer);
+        TtlAwareSerializer<?, ?> mapTtlAwareSerializer =
+                TtlAwareSerializer.wrapTtlAwareSerializer(mapTtlSerializer);
+
+        assertThat((intTtlAwareSerializer).isTtlEnabled()).isTrue();
+        assertThat(listTtlAwareSerializer)
+                .isInstanceOf(TtlAwareSerializer.TtlAwareListSerializer.class);
+        assertThat((listTtlAwareSerializer).isTtlEnabled()).isTrue();
+        assertThat(mapTtlAwareSerializer)
+                .isInstanceOf(TtlAwareSerializer.TtlAwareMapSerializer.class);
+        assertThat(mapTtlAwareSerializer.isTtlEnabled()).isTrue();
+    }
+
+    @Test
+    @SuppressWarnings("rawtypes")
+    void testSnapshotConfiguration() {
+        TtlAwareSerializer<?, ?> intTtlAwareSerializer =
+                
TtlAwareSerializer.wrapTtlAwareSerializer(IntSerializer.INSTANCE);
+        TtlAwareSerializer.TtlAwareListSerializer<?> listTtlAwareSerializer =
+                (TtlAwareSerializer.TtlAwareListSerializer<?>)
+                        TtlAwareSerializer.wrapTtlAwareSerializer(
+                                new ListSerializer<>(IntSerializer.INSTANCE));
+        TtlAwareSerializer.TtlAwareMapSerializer<?, ?> mapTtlAwareSerializer =
+                (TtlAwareSerializer.TtlAwareMapSerializer<?, ?>)
+                        TtlAwareSerializer.wrapTtlAwareSerializer(
+                                new MapSerializer<>(
+                                        IntSerializer.INSTANCE, 
IntSerializer.INSTANCE));
+
+        assertThat(intTtlAwareSerializer.snapshotConfiguration())
+                .isInstanceOf(TtlAwareSerializerSnapshot.class);
         assertThat(
-                        ((TtlAwareSerializer<?>) 
listTtlAwareSerializer.getElementSerializer())
-                                .isTtlEnabled())
-                .isTrue();
-        assertThat(mapTtlAwareSerializer.getValueSerializer())
-                .isInstanceOf(TtlAwareSerializer.class);
+                        ((TtlAwareSerializerSnapshot<?>)
+                                        
intTtlAwareSerializer.snapshotConfiguration())
+                                .getOrinalTypeSerializerSnapshot())
+                .isInstanceOf(IntSerializer.IntSerializerSnapshot.class);
+
+        assertThat(listTtlAwareSerializer.snapshotConfiguration())
+                .isInstanceOf(ListSerializerSnapshot.class);
         assertThat(
-                        ((TtlAwareSerializer<?>) 
mapTtlAwareSerializer.getValueSerializer())
-                                .isTtlEnabled())
-                .isTrue();
+                        (((ListSerializerSnapshot) 
listTtlAwareSerializer.snapshotConfiguration())
+                                .getElementSerializerSnapshot()))
+                .isInstanceOf(TtlAwareSerializerSnapshot.class);
+
+        assertThat(mapTtlAwareSerializer.snapshotConfiguration())
+                .isInstanceOf(MapSerializerSnapshot.class);
+        assertThat(
+                        (((MapSerializerSnapshot) 
mapTtlAwareSerializer.snapshotConfiguration())
+                                .getValueSerializerSnapshot()))
+                .isInstanceOf(TtlAwareSerializerSnapshot.class);
     }
 }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlAwareSerializerUpgradeTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlAwareSerializerUpgradeTest.java
index fff29aa4417..d8b4234252c 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlAwareSerializerUpgradeTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlAwareSerializerUpgradeTest.java
@@ -155,8 +155,8 @@ public class TtlAwareSerializerUpgradeTest
             TypeSerializer<T> writeSerializer,
             Condition<T> testDataCondition)
             throws IOException {
-        TtlAwareSerializer<T> reader = (TtlAwareSerializer<T>) readSerializer;
-        TtlAwareSerializer<T> writer = (TtlAwareSerializer<T>) writeSerializer;
+        TtlAwareSerializer<T, ?> reader = (TtlAwareSerializer<T, ?>) 
readSerializer;
+        TtlAwareSerializer<T, ?> writer = (TtlAwareSerializer<T, ?>) 
writeSerializer;
 
         DataOutputSerializer migratedOut = new 
DataOutputSerializer(INITIAL_OUTPUT_BUFFER_SIZE);
         writer.migrateValueFromPriorSerializer(
diff --git 
a/flink-tests/src/test/java/org/apache/flink/test/completeness/TypeSerializerTestCoverageTest.java
 
b/flink-tests/src/test/java/org/apache/flink/test/completeness/TypeSerializerTestCoverageTest.java
index 59f0d256f13..86b35018fa4 100644
--- 
a/flink-tests/src/test/java/org/apache/flink/test/completeness/TypeSerializerTestCoverageTest.java
+++ 
b/flink-tests/src/test/java/org/apache/flink/test/completeness/TypeSerializerTestCoverageTest.java
@@ -147,6 +147,8 @@ public class TypeSerializerTestCoverageTest extends 
TestLogger {
                         CoGroupedStreams.UnionSerializer.class.getName(),
                         TtlStateFactory.TtlSerializer.class.getName(),
                         TtlAwareSerializer.class.getName(),
+                        
TtlAwareSerializer.TtlAwareListSerializer.class.getName(),
+                        
TtlAwareSerializer.TtlAwareMapSerializer.class.getName(),
                         
org.apache.flink.runtime.state.v2.ttl.TtlStateFactory.TtlSerializer.class
                                 .getName(),
                         TimeWindow.Serializer.class.getName(),
@@ -211,6 +213,8 @@ public class TypeSerializerTestCoverageTest extends 
TestLogger {
                         UnloadableDummyTypeSerializer.class.getName(),
                         TimeWindow.Serializer.class.getName(),
                         CoGroupedStreams.UnionSerializer.class.getName(),
+                        
TtlAwareSerializer.TtlAwareListSerializer.class.getName(),
+                        
TtlAwareSerializer.TtlAwareMapSerializer.class.getName(),
                         
InternalTimersSnapshotReaderWriters.LegacyTimerSerializer.class.getName(),
                         
TwoPhaseCommitSinkFunction.StateSerializer.class.getName(),
                         GlobalWindow.Serializer.class.getName(),


Reply via email to