This is an automated email from the ASF dual-hosted git repository.

arvid pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-connector-kafka.git


The following commit(s) were added to refs/heads/main by this push:
     new cb5c5c07 [FLINK-38453] Add full splits to KafkaSourceEnumState
cb5c5c07 is described below

commit cb5c5c07318ba602c6c63cb116774a12c52fc478
Author: Arvid Heise <[email protected]>
AuthorDate: Tue Sep 30 10:11:14 2025 +0200

    [FLINK-38453] Add full splits to KafkaSourceEnumState
    
    KafkaEnumerator's state contains the TopicPartitions only but not the 
offsets, so it doesn't contain the full split state contrary to the design 
intent.
    
    There are a couple of issues with that approach. It implicitly assumes that 
splits are fully assigned to readers before the first checkpoint. Else the 
enumerator will invoke the offset initializer again on recovery from such a 
checkpoint leading to inconsistencies (LATEST may be initialized during the 
first attempt for some partitions and initialized during second attempt for 
others).
    
    Through addSplitBack callback, you may also get these scenarios later for 
BATCH which actually leads to duplicate rows (in case of EARLIEST or 
SPECIFIC-OFFSETS) or data loss (in case of LATEST). Finally, it's not possible 
to safely use KafkaSource as part of a HybridSource because the offset 
initializer cannot even be recreated on recovery.
    
    All cases are solved by also retaining the offset in the enumerator state. 
To that end, this commit merges the async discovery phases to immediately 
initialize the splits from the partitions. Any subsequent checkpoint will 
contain the proper start offset.
---
 .../c0d94764-76a0-4c50-b617-70b1754c4612           |  15 +-
 .../enumerator/DynamicKafkaSourceEnumerator.java   |  13 +-
 .../kafka/source/enumerator/AssignmentStatus.java  |   7 +-
 .../source/enumerator/KafkaSourceEnumState.java    |  57 ++--
 .../enumerator/KafkaSourceEnumStateSerializer.java | 164 +++++++---
 .../source/enumerator/KafkaSourceEnumerator.java   | 168 +++++++----
 ...ntStatus.java => SplitAndAssignmentStatus.java} |  16 +-
 .../kafka/source/split/KafkaPartitionSplit.java    |   6 +-
 .../DynamicKafkaSourceEnumStateSerializerTest.java |  35 +--
 .../DynamicKafkaSourceEnumeratorTest.java          |  23 +-
 .../KafkaSourceEnumStateSerializerTest.java        |  93 ++++--
 ...torTest.java => KafkaSourceEnumeratorTest.java} | 329 +++++++++++++++------
 12 files changed, 618 insertions(+), 308 deletions(-)

diff --git 
a/flink-connector-kafka/archunit-violations/c0d94764-76a0-4c50-b617-70b1754c4612
 
b/flink-connector-kafka/archunit-violations/c0d94764-76a0-4c50-b617-70b1754c4612
index 236fade3..e496d80c 100644
--- 
a/flink-connector-kafka/archunit-violations/c0d94764-76a0-4c50-b617-70b1754c4612
+++ 
b/flink-connector-kafka/archunit-violations/c0d94764-76a0-4c50-b617-70b1754c4612
@@ -23,11 +23,11 @@ Method 
<org.apache.flink.connector.kafka.dynamic.source.reader.DynamicKafkaSourc
 Method 
<org.apache.flink.connector.kafka.dynamic.source.reader.DynamicKafkaSourceReader.syncAvailabilityHelperWithReaders()>
 calls method 
<org.apache.flink.streaming.runtime.io.MultipleFuturesAvailabilityHelper.anyOf(int,
 java.util.concurrent.CompletableFuture)> in (DynamicKafkaSourceReader.java:500)
 Method 
<org.apache.flink.connector.kafka.sink.ExactlyOnceKafkaWriter.getProducerPool()>
 is annotated with <org.apache.flink.annotation.VisibleForTesting> in 
(ExactlyOnceKafkaWriter.java:0)
 Method 
<org.apache.flink.connector.kafka.sink.ExactlyOnceKafkaWriter.getTransactionalIdPrefix()>
 is annotated with <org.apache.flink.annotation.VisibleForTesting> in 
(ExactlyOnceKafkaWriter.java:0)
-Method 
<org.apache.flink.connector.kafka.sink.KafkaSink.addPostCommitTopology(org.apache.flink.streaming.api.datastream.DataStream)>
 calls method <org.apache.flink.api.dag.Transformation.getCoLocationGroupKey()> 
in (KafkaSink.java:178)
-Method 
<org.apache.flink.connector.kafka.sink.KafkaSink.addPostCommitTopology(org.apache.flink.streaming.api.datastream.DataStream)>
 calls method <org.apache.flink.api.dag.Transformation.getInputs()> in 
(KafkaSink.java:181)
-Method 
<org.apache.flink.connector.kafka.sink.KafkaSink.addPostCommitTopology(org.apache.flink.streaming.api.datastream.DataStream)>
 calls method <org.apache.flink.api.dag.Transformation.getOutputType()> in 
(KafkaSink.java:177)
-Method 
<org.apache.flink.connector.kafka.sink.KafkaSink.addPostCommitTopology(org.apache.flink.streaming.api.datastream.DataStream)>
 calls method 
<org.apache.flink.api.dag.Transformation.setCoLocationGroupKey(java.lang.String)>
 in (KafkaSink.java:180)
-Method 
<org.apache.flink.connector.kafka.sink.KafkaSink.addPostCommitTopology(org.apache.flink.streaming.api.datastream.DataStream)>
 checks instanceof 
<org.apache.flink.streaming.api.connector.sink2.CommittableMessageTypeInfo> in 
(KafkaSink.java:177)
+Method 
<org.apache.flink.connector.kafka.sink.KafkaSink.addPostCommitTopology(org.apache.flink.streaming.api.datastream.DataStream)>
 calls method <org.apache.flink.api.dag.Transformation.getCoLocationGroupKey()> 
in (KafkaSink.java:183)
+Method 
<org.apache.flink.connector.kafka.sink.KafkaSink.addPostCommitTopology(org.apache.flink.streaming.api.datastream.DataStream)>
 calls method <org.apache.flink.api.dag.Transformation.getInputs()> in 
(KafkaSink.java:186)
+Method 
<org.apache.flink.connector.kafka.sink.KafkaSink.addPostCommitTopology(org.apache.flink.streaming.api.datastream.DataStream)>
 calls method <org.apache.flink.api.dag.Transformation.getOutputType()> in 
(KafkaSink.java:182)
+Method 
<org.apache.flink.connector.kafka.sink.KafkaSink.addPostCommitTopology(org.apache.flink.streaming.api.datastream.DataStream)>
 calls method 
<org.apache.flink.api.dag.Transformation.setCoLocationGroupKey(java.lang.String)>
 in (KafkaSink.java:185)
+Method 
<org.apache.flink.connector.kafka.sink.KafkaSink.addPostCommitTopology(org.apache.flink.streaming.api.datastream.DataStream)>
 checks instanceof 
<org.apache.flink.streaming.api.connector.sink2.CommittableMessageTypeInfo> in 
(KafkaSink.java:182)
 Method 
<org.apache.flink.connector.kafka.sink.KafkaSink.addPostCommitTopology(org.apache.flink.streaming.api.datastream.DataStream)>
 has generic parameter type 
<org.apache.flink.streaming.api.datastream.DataStream<org.apache.flink.streaming.api.connector.sink2.CommittableMessage<org.apache.flink.connector.kafka.sink.KafkaCommittable>>>
 with type argument depending on 
<org.apache.flink.streaming.api.connector.sink2.CommittableMessage> in 
(KafkaSink.java:0)
 Method 
<org.apache.flink.connector.kafka.sink.KafkaSink.getKafkaProducerConfig()> is 
annotated with <org.apache.flink.annotation.VisibleForTesting> in 
(KafkaSink.java:0)
 Method 
<org.apache.flink.connector.kafka.sink.KafkaSinkBuilder.setRecordSerializer(org.apache.flink.connector.kafka.sink.KafkaRecordSerializationSchema)>
 calls method <org.apache.flink.api.java.ClosureCleaner.clean(java.lang.Object, 
org.apache.flink.api.common.ExecutionConfig$ClosureCleanerLevel, boolean)> in 
(KafkaSinkBuilder.java:154)
@@ -39,9 +39,12 @@ Method 
<org.apache.flink.connector.kafka.source.KafkaSource.createReader(org.apa
 Method 
<org.apache.flink.connector.kafka.source.KafkaSource.getConfiguration()> is 
annotated with <org.apache.flink.annotation.VisibleForTesting> in 
(KafkaSource.java:0)
 Method 
<org.apache.flink.connector.kafka.source.KafkaSource.getKafkaSubscriber()> is 
annotated with <org.apache.flink.annotation.VisibleForTesting> in 
(KafkaSource.java:0)
 Method 
<org.apache.flink.connector.kafka.source.KafkaSource.getStoppingOffsetsInitializer()>
 is annotated with <org.apache.flink.annotation.VisibleForTesting> in 
(KafkaSource.java:0)
-Method 
<org.apache.flink.connector.kafka.source.enumerator.KafkaSourceEnumStateSerializer.serializeTopicPartitions(java.util.Collection)>
 is annotated with <org.apache.flink.annotation.VisibleForTesting> in 
(KafkaSourceEnumStateSerializer.java:0)
+Method 
<org.apache.flink.connector.kafka.source.enumerator.KafkaSourceEnumStateSerializer.serializeV1(java.util.Collection)>
 is annotated with <org.apache.flink.annotation.VisibleForTesting> in 
(KafkaSourceEnumStateSerializer.java:0)
+Method 
<org.apache.flink.connector.kafka.source.enumerator.KafkaSourceEnumStateSerializer.serializeV2(java.util.Collection,
 boolean)> is annotated with <org.apache.flink.annotation.VisibleForTesting> in 
(KafkaSourceEnumStateSerializer.java:0)
+Method 
<org.apache.flink.connector.kafka.source.enumerator.KafkaSourceEnumStateSerializer.serializeV3(org.apache.flink.connector.kafka.source.enumerator.KafkaSourceEnumState)>
 is annotated with <org.apache.flink.annotation.VisibleForTesting> in 
(KafkaSourceEnumStateSerializer.java:0)
 Method 
<org.apache.flink.connector.kafka.source.enumerator.KafkaSourceEnumerator.deepCopyProperties(java.util.Properties,
 java.util.Properties)> is annotated with 
<org.apache.flink.annotation.VisibleForTesting> in 
(KafkaSourceEnumerator.java:0)
 Method 
<org.apache.flink.connector.kafka.source.enumerator.KafkaSourceEnumerator.getPartitionChange(java.util.Set)>
 is annotated with <org.apache.flink.annotation.VisibleForTesting> in 
(KafkaSourceEnumerator.java:0)
+Method 
<org.apache.flink.connector.kafka.source.enumerator.KafkaSourceEnumerator.getPendingPartitionSplitAssignment()>
 is annotated with <org.apache.flink.annotation.VisibleForTesting> in 
(KafkaSourceEnumerator.java:0)
 Method 
<org.apache.flink.connector.kafka.source.enumerator.KafkaSourceEnumerator.getSplitOwner(org.apache.kafka.common.TopicPartition,
 int)> is annotated with <org.apache.flink.annotation.VisibleForTesting> in 
(KafkaSourceEnumerator.java:0)
 Method 
<org.apache.flink.connector.kafka.source.reader.KafkaPartitionSplitReader.consumer()>
 is annotated with <org.apache.flink.annotation.VisibleForTesting> in 
(KafkaPartitionSplitReader.java:0)
 Method 
<org.apache.flink.connector.kafka.source.reader.KafkaPartitionSplitReader.setConsumerClientRack(java.util.Properties,
 java.lang.String)> is annotated with 
<org.apache.flink.annotation.VisibleForTesting> in 
(KafkaPartitionSplitReader.java:0)
diff --git 
a/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/dynamic/source/enumerator/DynamicKafkaSourceEnumerator.java
 
b/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/dynamic/source/enumerator/DynamicKafkaSourceEnumerator.java
index ff7cc21d..7643e62b 100644
--- 
a/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/dynamic/source/enumerator/DynamicKafkaSourceEnumerator.java
+++ 
b/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/dynamic/source/enumerator/DynamicKafkaSourceEnumerator.java
@@ -35,14 +35,13 @@ import 
org.apache.flink.connector.kafka.dynamic.source.split.DynamicKafkaSourceS
 import org.apache.flink.connector.kafka.source.KafkaPropertiesUtil;
 import org.apache.flink.connector.kafka.source.enumerator.KafkaSourceEnumState;
 import 
org.apache.flink.connector.kafka.source.enumerator.KafkaSourceEnumerator;
-import 
org.apache.flink.connector.kafka.source.enumerator.TopicPartitionAndAssignmentStatus;
+import 
org.apache.flink.connector.kafka.source.enumerator.SplitAndAssignmentStatus;
 import 
org.apache.flink.connector.kafka.source.enumerator.initializer.OffsetsInitializer;
 import 
org.apache.flink.connector.kafka.source.enumerator.subscriber.KafkaSubscriber;
 import org.apache.flink.connector.kafka.source.split.KafkaPartitionSplit;
 import org.apache.flink.util.Preconditions;
 
 import org.apache.kafka.common.KafkaException;
-import org.apache.kafka.common.TopicPartition;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -174,8 +173,8 @@ public class DynamicKafkaSourceEnumerator
                 
dynamicKafkaSourceEnumState.getClusterEnumeratorStates().entrySet()) {
             this.latestClusterTopicsMap.put(
                     clusterEnumState.getKey(),
-                    clusterEnumState.getValue().assignedPartitions().stream()
-                            .map(TopicPartition::topic)
+                    clusterEnumState.getValue().assignedSplits().stream()
+                            .map(KafkaPartitionSplit::getTopic)
                             .collect(Collectors.toSet()));
 
             createEnumeratorWithAssignedTopicPartitions(
@@ -291,9 +290,9 @@ public class DynamicKafkaSourceEnumerator
                 final Set<String> activeTopics = 
activeClusterTopics.getValue();
 
                 // filter out removed topics
-                Set<TopicPartitionAndAssignmentStatus> partitions =
-                        kafkaSourceEnumState.partitions().stream()
-                                .filter(tp -> 
activeTopics.contains(tp.topicPartition().topic()))
+                Set<SplitAndAssignmentStatus> partitions =
+                        kafkaSourceEnumState.splits().stream()
+                                .filter(tp -> 
activeTopics.contains(tp.split().getTopic()))
                                 .collect(Collectors.toSet());
 
                 newKafkaSourceEnumState =
diff --git 
a/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/AssignmentStatus.java
 
b/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/AssignmentStatus.java
index b7d11538..e8f9600b 100644
--- 
a/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/AssignmentStatus.java
+++ 
b/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/AssignmentStatus.java
@@ -26,11 +26,8 @@ public enum AssignmentStatus {
 
     /** Partitions that have been assigned to readers. */
     ASSIGNED(0),
-    /**
-     * The partitions that have been discovered during initialization but not 
assigned to readers
-     * yet.
-     */
-    UNASSIGNED_INITIAL(1);
+    /** The partitions that have been discovered but not assigned to readers 
yet. */
+    UNASSIGNED(1);
     private final int statusCode;
 
     AssignmentStatus(int statusCode) {
diff --git 
a/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumState.java
 
b/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumState.java
index 66ceeeb8..649bd584 100644
--- 
a/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumState.java
+++ 
b/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumState.java
@@ -19,9 +19,9 @@
 package org.apache.flink.connector.kafka.source.enumerator;
 
 import org.apache.flink.annotation.Internal;
+import org.apache.flink.connector.kafka.source.split.KafkaPartitionSplit;
 
-import org.apache.kafka.common.TopicPartition;
-
+import java.util.Collection;
 import java.util.HashSet;
 import java.util.Set;
 import java.util.stream.Collectors;
@@ -29,8 +29,8 @@ import java.util.stream.Collectors;
 /** The state of Kafka source enumerator. */
 @Internal
 public class KafkaSourceEnumState {
-    /** Partitions with status: ASSIGNED or UNASSIGNED_INITIAL. */
-    private final Set<TopicPartitionAndAssignmentStatus> partitions;
+    /** Splits with status: ASSIGNED or UNASSIGNED_INITIAL. */
+    private final Set<SplitAndAssignmentStatus> splits;
     /**
      * this flag will be marked as true if initial partitions are discovered 
after enumerator
      * starts.
@@ -38,57 +38,54 @@ public class KafkaSourceEnumState {
     private final boolean initialDiscoveryFinished;
 
     public KafkaSourceEnumState(
-            Set<TopicPartitionAndAssignmentStatus> partitions, boolean 
initialDiscoveryFinished) {
-        this.partitions = partitions;
+            Set<SplitAndAssignmentStatus> splits, boolean 
initialDiscoveryFinished) {
+        this.splits = splits;
         this.initialDiscoveryFinished = initialDiscoveryFinished;
     }
 
     public KafkaSourceEnumState(
-            Set<TopicPartition> assignPartitions,
-            Set<TopicPartition> unassignedInitialPartitions,
+            Collection<KafkaPartitionSplit> assignedSplits,
+            Collection<KafkaPartitionSplit> unassignedSplits,
             boolean initialDiscoveryFinished) {
-        this.partitions = new HashSet<>();
-        partitions.addAll(
-                assignPartitions.stream()
+        this.splits = new HashSet<>();
+        splits.addAll(
+                assignedSplits.stream()
                         .map(
                                 topicPartition ->
-                                        new TopicPartitionAndAssignmentStatus(
+                                        new SplitAndAssignmentStatus(
                                                 topicPartition, 
AssignmentStatus.ASSIGNED))
                         .collect(Collectors.toSet()));
-        partitions.addAll(
-                unassignedInitialPartitions.stream()
+        splits.addAll(
+                unassignedSplits.stream()
                         .map(
                                 topicPartition ->
-                                        new TopicPartitionAndAssignmentStatus(
-                                                topicPartition,
-                                                
AssignmentStatus.UNASSIGNED_INITIAL))
+                                        new SplitAndAssignmentStatus(
+                                                topicPartition, 
AssignmentStatus.UNASSIGNED))
                         .collect(Collectors.toSet()));
         this.initialDiscoveryFinished = initialDiscoveryFinished;
     }
 
-    public Set<TopicPartitionAndAssignmentStatus> partitions() {
-        return partitions;
+    public Set<SplitAndAssignmentStatus> splits() {
+        return splits;
     }
 
-    public Set<TopicPartition> assignedPartitions() {
-        return filterPartitionsByAssignmentStatus(AssignmentStatus.ASSIGNED);
+    public Collection<KafkaPartitionSplit> assignedSplits() {
+        return filterByAssignmentStatus(AssignmentStatus.ASSIGNED);
     }
 
-    public Set<TopicPartition> unassignedInitialPartitions() {
-        return 
filterPartitionsByAssignmentStatus(AssignmentStatus.UNASSIGNED_INITIAL);
+    public Collection<KafkaPartitionSplit> unassignedSplits() {
+        return filterByAssignmentStatus(AssignmentStatus.UNASSIGNED);
     }
 
     public boolean initialDiscoveryFinished() {
         return initialDiscoveryFinished;
     }
 
-    private Set<TopicPartition> filterPartitionsByAssignmentStatus(
+    private Collection<KafkaPartitionSplit> filterByAssignmentStatus(
             AssignmentStatus assignmentStatus) {
-        return partitions.stream()
-                .filter(
-                        partitionWithStatus ->
-                                
partitionWithStatus.assignmentStatus().equals(assignmentStatus))
-                .map(TopicPartitionAndAssignmentStatus::topicPartition)
-                .collect(Collectors.toSet());
+        return splits.stream()
+                .filter(split -> 
split.assignmentStatus().equals(assignmentStatus))
+                .map(SplitAndAssignmentStatus::split)
+                .collect(Collectors.toList());
     }
 }
diff --git 
a/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumStateSerializer.java
 
b/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumStateSerializer.java
index f8dc17de..99176cfc 100644
--- 
a/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumStateSerializer.java
+++ 
b/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumStateSerializer.java
@@ -37,6 +37,8 @@ import java.util.HashSet;
 import java.util.Map;
 import java.util.Set;
 
+import static 
org.apache.flink.connector.kafka.source.split.KafkaPartitionSplit.MIGRATED;
+
 /**
  * The {@link org.apache.flink.core.io.SimpleVersionedSerializer Serializer} 
for the enumerator
  * state of Kafka source.
@@ -58,7 +60,12 @@ public class KafkaSourceEnumStateSerializer
      */
     private static final int VERSION_2 = 2;
 
-    private static final int CURRENT_VERSION = VERSION_2;
+    private static final int VERSION_3 = 3;
+
+    private static final int CURRENT_VERSION = VERSION_3;
+
+    private static final KafkaPartitionSplitSerializer SPLIT_SERIALIZER =
+            new KafkaPartitionSplitSerializer();
 
     @Override
     public int getVersion() {
@@ -67,15 +74,22 @@ public class KafkaSourceEnumStateSerializer
 
     @Override
     public byte[] serialize(KafkaSourceEnumState enumState) throws IOException 
{
-        Set<TopicPartitionAndAssignmentStatus> partitions = 
enumState.partitions();
+        return serializeV3(enumState);
+    }
+
+    @VisibleForTesting
+    static byte[] serializeV3(KafkaSourceEnumState enumState) throws 
IOException {
+        Set<SplitAndAssignmentStatus> splits = enumState.splits();
         boolean initialDiscoveryFinished = 
enumState.initialDiscoveryFinished();
         try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
                 DataOutputStream out = new DataOutputStream(baos)) {
-            out.writeInt(partitions.size());
-            for (TopicPartitionAndAssignmentStatus 
topicPartitionAndAssignmentStatus : partitions) {
-                
out.writeUTF(topicPartitionAndAssignmentStatus.topicPartition().topic());
-                
out.writeInt(topicPartitionAndAssignmentStatus.topicPartition().partition());
-                
out.writeInt(topicPartitionAndAssignmentStatus.assignmentStatus().getStatusCode());
+            out.writeInt(splits.size());
+            out.writeInt(SPLIT_SERIALIZER.getVersion());
+            for (SplitAndAssignmentStatus split : splits) {
+                final byte[] splitBytes = 
SPLIT_SERIALIZER.serialize(split.split());
+                out.writeInt(splitBytes.length);
+                out.write(splitBytes);
+                out.writeInt(split.assignmentStatus().getStatusCode());
             }
             out.writeBoolean(initialDiscoveryFinished);
             out.flush();
@@ -86,22 +100,14 @@ public class KafkaSourceEnumStateSerializer
     @Override
     public KafkaSourceEnumState deserialize(int version, byte[] serialized) 
throws IOException {
         switch (version) {
-            case CURRENT_VERSION:
-                return 
deserializeTopicPartitionAndAssignmentStatus(serialized);
+            case VERSION_3:
+                return deserializeVersion3(serialized);
+            case VERSION_2:
+                return deserializeVersion2(serialized);
             case VERSION_1:
-                return deserializeAssignedTopicPartitions(serialized);
+                return deserializeVersion1(serialized);
             case VERSION_0:
-                Map<Integer, Set<KafkaPartitionSplit>> 
currentPartitionAssignment =
-                        SerdeUtils.deserializeSplitAssignments(
-                                serialized, new 
KafkaPartitionSplitSerializer(), HashSet::new);
-                Set<TopicPartition> currentAssignedSplits = new HashSet<>();
-                currentPartitionAssignment.forEach(
-                        (reader, splits) ->
-                                splits.forEach(
-                                        split ->
-                                                currentAssignedSplits.add(
-                                                        
split.getTopicPartition())));
-                return new KafkaSourceEnumState(currentAssignedSplits, new 
HashSet<>(), true);
+                return deserializeVersion0(serialized);
             default:
                 throw new IOException(
                         String.format(
@@ -111,19 +117,78 @@ public class KafkaSourceEnumStateSerializer
         }
     }
 
-    private static KafkaSourceEnumState deserializeAssignedTopicPartitions(
-            byte[] serializedTopicPartitions) throws IOException {
+    private static KafkaSourceEnumState deserializeVersion3(byte[] serialized) 
throws IOException {
+
+        final KafkaPartitionSplitSerializer splitSerializer = new 
KafkaPartitionSplitSerializer();
+
+        try (ByteArrayInputStream bais = new ByteArrayInputStream(serialized);
+                DataInputStream in = new DataInputStream(bais)) {
+
+            final int numPartitions = in.readInt();
+            final int splitVersion = in.readInt();
+            Set<SplitAndAssignmentStatus> partitions = new 
HashSet<>(numPartitions);
+
+            for (int i = 0; i < numPartitions; i++) {
+                final KafkaPartitionSplit split =
+                        splitSerializer.deserialize(splitVersion, 
in.readNBytes(in.readInt()));
+                final int statusCode = in.readInt();
+                partitions.add(
+                        new SplitAndAssignmentStatus(
+                                split, 
AssignmentStatus.ofStatusCode(statusCode)));
+            }
+            final boolean initialDiscoveryFinished = in.readBoolean();
+            if (in.available() > 0) {
+                throw new IOException("Unexpected trailing bytes in serialized 
topic partitions");
+            }
+
+            return new KafkaSourceEnumState(partitions, 
initialDiscoveryFinished);
+        }
+    }
+
+    private static KafkaSourceEnumState deserializeVersion0(byte[] serialized) 
throws IOException {
+        Map<Integer, Set<KafkaPartitionSplit>> currentPartitionAssignment =
+                SerdeUtils.deserializeSplitAssignments(
+                        serialized, new KafkaPartitionSplitSerializer(), 
HashSet::new);
+        Set<KafkaPartitionSplit> currentAssignedSplits = new HashSet<>();
+        for (Map.Entry<Integer, Set<KafkaPartitionSplit>> entry :
+                currentPartitionAssignment.entrySet()) {
+            currentAssignedSplits.addAll(entry.getValue());
+        }
+        return new KafkaSourceEnumState(currentAssignedSplits, new 
HashSet<>(), true);
+    }
+
+    @VisibleForTesting
+    static byte[] serializeV1(Collection<KafkaPartitionSplit> splits) throws 
IOException {
+        try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
+                DataOutputStream out = new DataOutputStream(baos)) {
+
+            out.writeInt(splits.size());
+            for (KafkaPartitionSplit split : splits) {
+                final TopicPartition tp = split.getTopicPartition();
+                out.writeUTF(tp.topic());
+                out.writeInt(tp.partition());
+            }
+            out.flush();
+
+            return baos.toByteArray();
+        }
+    }
+
+    private static KafkaSourceEnumState deserializeVersion1(byte[] 
serializedTopicPartitions)
+            throws IOException {
         try (ByteArrayInputStream bais = new 
ByteArrayInputStream(serializedTopicPartitions);
                 DataInputStream in = new DataInputStream(bais)) {
 
             final int numPartitions = in.readInt();
-            Set<TopicPartitionAndAssignmentStatus> partitions = new 
HashSet<>(numPartitions);
+            Set<SplitAndAssignmentStatus> partitions = new 
HashSet<>(numPartitions);
             for (int i = 0; i < numPartitions; i++) {
                 final String topic = in.readUTF();
                 final int partition = in.readInt();
                 partitions.add(
-                        new TopicPartitionAndAssignmentStatus(
-                                new TopicPartition(topic, partition), 
AssignmentStatus.ASSIGNED));
+                        new SplitAndAssignmentStatus(
+                                new KafkaPartitionSplit(
+                                        new TopicPartition(topic, partition), 
MIGRATED),
+                                AssignmentStatus.ASSIGNED));
             }
             if (in.available() > 0) {
                 throw new IOException("Unexpected trailing bytes in serialized 
topic partitions");
@@ -132,22 +197,42 @@ public class KafkaSourceEnumStateSerializer
         }
     }
 
-    private static KafkaSourceEnumState 
deserializeTopicPartitionAndAssignmentStatus(
-            byte[] serialized) throws IOException {
+    @VisibleForTesting
+    static byte[] serializeV2(
+            Collection<SplitAndAssignmentStatus> splits, boolean 
initialDiscoveryFinished)
+            throws IOException {
+        try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
+                DataOutputStream out = new DataOutputStream(baos)) {
+            out.writeInt(splits.size());
+            for (SplitAndAssignmentStatus splitAndAssignmentStatus : splits) {
+                final TopicPartition topicPartition =
+                        splitAndAssignmentStatus.split().getTopicPartition();
+                out.writeUTF(topicPartition.topic());
+                out.writeInt(topicPartition.partition());
+                
out.writeInt(splitAndAssignmentStatus.assignmentStatus().getStatusCode());
+            }
+            out.writeBoolean(initialDiscoveryFinished);
+            out.flush();
+            return baos.toByteArray();
+        }
+    }
+
+    private static KafkaSourceEnumState deserializeVersion2(byte[] serialized) 
throws IOException {
 
         try (ByteArrayInputStream bais = new ByteArrayInputStream(serialized);
                 DataInputStream in = new DataInputStream(bais)) {
 
             final int numPartitions = in.readInt();
-            Set<TopicPartitionAndAssignmentStatus> partitions = new 
HashSet<>(numPartitions);
+            Set<SplitAndAssignmentStatus> partitions = new 
HashSet<>(numPartitions);
 
             for (int i = 0; i < numPartitions; i++) {
                 final String topic = in.readUTF();
                 final int partition = in.readInt();
                 final int statusCode = in.readInt();
                 partitions.add(
-                        new TopicPartitionAndAssignmentStatus(
-                                new TopicPartition(topic, partition),
+                        new SplitAndAssignmentStatus(
+                                new KafkaPartitionSplit(
+                                        new TopicPartition(topic, partition), 
MIGRATED),
                                 AssignmentStatus.ofStatusCode(statusCode)));
             }
             final boolean initialDiscoveryFinished = in.readBoolean();
@@ -158,21 +243,4 @@ public class KafkaSourceEnumStateSerializer
             return new KafkaSourceEnumState(partitions, 
initialDiscoveryFinished);
         }
     }
-
-    @VisibleForTesting
-    public static byte[] serializeTopicPartitions(Collection<TopicPartition> 
topicPartitions)
-            throws IOException {
-        try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
-                DataOutputStream out = new DataOutputStream(baos)) {
-
-            out.writeInt(topicPartitions.size());
-            for (TopicPartition tp : topicPartitions) {
-                out.writeUTF(tp.topic());
-                out.writeInt(tp.partition());
-            }
-            out.flush();
-
-            return baos.toByteArray();
-        }
-    }
 }
diff --git 
a/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumerator.java
 
b/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumerator.java
index f3058193..e65e9a57 100644
--- 
a/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumerator.java
+++ 
b/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumerator.java
@@ -56,6 +56,9 @@ import java.util.Set;
 import java.util.concurrent.ExecutionException;
 import java.util.function.Consumer;
 import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+import static org.apache.flink.util.Preconditions.checkState;
 
 /** The enumerator class for Kafka source. */
 @Internal
@@ -72,13 +75,12 @@ public class KafkaSourceEnumerator
     private final Boundedness boundedness;
 
     /** Partitions that have been assigned to readers. */
-    private final Set<TopicPartition> assignedPartitions;
+    private final Map<TopicPartition, KafkaPartitionSplit> assignedSplits;
 
     /**
-     * The partitions that have been discovered during initialization but not 
assigned to readers
-     * yet.
+     * The splits that have been discovered during initialization but not 
assigned to readers yet.
      */
-    private final Set<TopicPartition> unassignedInitialPartitions;
+    private final Map<TopicPartition, KafkaPartitionSplit> unassignedSplits;
 
     /**
      * The discovered and initialized partition splits that are waiting for 
owner reader to be
@@ -96,7 +98,8 @@ public class KafkaSourceEnumerator
     // initializing partition discovery has finished.
     private boolean noMoreNewPartitionSplits = false;
     // this flag will be marked as true if initial partitions are discovered 
after enumerator starts
-    private boolean initialDiscoveryFinished;
+    // the flag is read and set in main thread but also read in worker thread
+    private volatile boolean initialDiscoveryFinished;
 
     public KafkaSourceEnumerator(
             KafkaSubscriber subscriber,
@@ -131,7 +134,10 @@ public class KafkaSourceEnumerator
         this.context = context;
         this.boundedness = boundedness;
 
-        this.assignedPartitions = new 
HashSet<>(kafkaSourceEnumState.assignedPartitions());
+        Map<AssignmentStatus, List<KafkaPartitionSplit>> splits =
+                initializeMigratedSplits(kafkaSourceEnumState.splits());
+        this.assignedSplits = 
indexByPartition(splits.get(AssignmentStatus.ASSIGNED));
+        this.unassignedSplits = 
indexByPartition(splits.get(AssignmentStatus.UNASSIGNED));
         this.pendingPartitionSplitAssignment = new HashMap<>();
         this.partitionDiscoveryIntervalMs =
                 KafkaSourceOptions.getOption(
@@ -139,11 +145,73 @@ public class KafkaSourceEnumerator
                         KafkaSourceOptions.PARTITION_DISCOVERY_INTERVAL_MS,
                         Long::parseLong);
         this.consumerGroupId = 
properties.getProperty(ConsumerConfig.GROUP_ID_CONFIG);
-        this.unassignedInitialPartitions =
-                new 
HashSet<>(kafkaSourceEnumState.unassignedInitialPartitions());
         this.initialDiscoveryFinished = 
kafkaSourceEnumState.initialDiscoveryFinished();
     }
 
+    /**
+     * Initialize migrated splits to splits with concrete starting offsets. 
This method ensures that
+     * the costly offset resolution is performed only when there are splits 
that have been
+     * checkpointed with previous enumerator versions.
+     *
+     * <p>Note that this method is deliberately performed in the main thread 
to avoid a checkpoint
+     * of the splits without starting offset.
+     */
+    private Map<AssignmentStatus, List<KafkaPartitionSplit>> 
initializeMigratedSplits(
+            Set<SplitAndAssignmentStatus> splits) {
+        final Set<TopicPartition> migratedPartitions =
+                splits.stream()
+                        .filter(
+                                splitStatus ->
+                                        splitStatus.split().getStartingOffset()
+                                                == 
KafkaPartitionSplit.MIGRATED)
+                        .map(splitStatus -> 
splitStatus.split().getTopicPartition())
+                        .collect(Collectors.toSet());
+
+        if (migratedPartitions.isEmpty()) {
+            return splitByAssignmentStatus(splits.stream());
+        }
+
+        final Map<TopicPartition, Long> startOffsets =
+                startingOffsetInitializer.getPartitionOffsets(
+                        migratedPartitions, getOffsetsRetriever());
+        return splitByAssignmentStatus(
+                splits.stream()
+                        .map(splitStatus -> resolveMigratedSplit(splitStatus, 
startOffsets)));
+    }
+
+    private static Map<AssignmentStatus, List<KafkaPartitionSplit>> 
splitByAssignmentStatus(
+            Stream<SplitAndAssignmentStatus> stream) {
+        return stream.collect(
+                Collectors.groupingBy(
+                        SplitAndAssignmentStatus::assignmentStatus,
+                        Collectors.mapping(SplitAndAssignmentStatus::split, 
Collectors.toList())));
+    }
+
+    private static SplitAndAssignmentStatus resolveMigratedSplit(
+            SplitAndAssignmentStatus splitStatus, Map<TopicPartition, Long> 
startOffsets) {
+        final KafkaPartitionSplit split = splitStatus.split();
+        if (split.getStartingOffset() != KafkaPartitionSplit.MIGRATED) {
+            return splitStatus;
+        }
+        final Long startOffset = startOffsets.get(split.getTopicPartition());
+        checkState(
+                startOffset != null,
+                "Cannot find starting offset for migrated partition %s",
+                split.getTopicPartition());
+        return new SplitAndAssignmentStatus(
+                new KafkaPartitionSplit(split.getTopicPartition(), 
startOffset),
+                splitStatus.assignmentStatus());
+    }
+
+    private Map<TopicPartition, KafkaPartitionSplit> indexByPartition(
+            List<KafkaPartitionSplit> splits) {
+        if (splits == null) {
+            return new HashMap<>();
+        }
+        return splits.stream()
+                
.collect(Collectors.toMap(KafkaPartitionSplit::getTopicPartition, split -> 
split));
+    }
+
     /**
      * Start the enumerator.
      *
@@ -153,9 +221,7 @@ public class KafkaSourceEnumerator
      * <p>The invoking chain of partition discovery would be:
      *
      * <ol>
-     *   <li>{@link #getSubscribedTopicPartitions} in worker thread
-     *   <li>{@link #checkPartitionChanges} in coordinator thread
-     *   <li>{@link #initializePartitionSplits} in worker thread
+     *   <li>{@link #findNewPartitionSplits} in worker thread
      *   <li>{@link #handlePartitionSplitChanges} in coordinator thread
      * </ol>
      */
@@ -169,8 +235,8 @@ public class KafkaSourceEnumerator
                     consumerGroupId,
                     partitionDiscoveryIntervalMs);
             context.callAsync(
-                    this::getSubscribedTopicPartitions,
-                    this::checkPartitionChanges,
+                    this::findNewPartitionSplits,
+                    this::handlePartitionSplitChanges,
                     0,
                     partitionDiscoveryIntervalMs);
         } else {
@@ -178,7 +244,7 @@ public class KafkaSourceEnumerator
                     "Starting the KafkaSourceEnumerator for consumer group {} "
                             + "without periodic partition discovery.",
                     consumerGroupId);
-            context.callAsync(this::getSubscribedTopicPartitions, 
this::checkPartitionChanges);
+            context.callAsync(this::findNewPartitionSplits, 
this::handlePartitionSplitChanges);
         }
     }
 
@@ -189,6 +255,9 @@ public class KafkaSourceEnumerator
 
     @Override
     public void addSplitsBack(List<KafkaPartitionSplit> splits, int subtaskId) 
{
+        for (KafkaPartitionSplit split : splits) {
+            unassignedSplits.put(split.getTopicPartition(), split);
+        }
         addPartitionSplitChangeToPendingAssignments(splits);
 
         // If the failed subtask has already restarted, we need to assign 
pending splits to it
@@ -209,7 +278,7 @@ public class KafkaSourceEnumerator
     @Override
     public KafkaSourceEnumState snapshotState(long checkpointId) throws 
Exception {
         return new KafkaSourceEnumState(
-                assignedPartitions, unassignedInitialPartitions, 
initialDiscoveryFinished);
+                assignedSplits.values(), unassignedSplits.values(), 
initialDiscoveryFinished);
     }
 
     @Override
@@ -229,38 +298,16 @@ public class KafkaSourceEnumerator
      *
      * @return Set of subscribed {@link TopicPartition}s
      */
-    private Set<TopicPartition> getSubscribedTopicPartitions() {
-        return subscriber.getSubscribedTopicPartitions(adminClient);
-    }
-
-    /**
-     * Check if there's any partition changes within subscribed topic 
partitions fetched by worker
-     * thread, and invoke {@link 
KafkaSourceEnumerator#initializePartitionSplits(PartitionChange)}
-     * in worker thread to initialize splits for new partitions.
-     *
-     * <p>NOTE: This method should only be invoked in the coordinator executor 
thread.
-     *
-     * @param fetchedPartitions Map from topic name to its description
-     * @param t Exception in worker thread
-     */
-    private void checkPartitionChanges(Set<TopicPartition> fetchedPartitions, 
Throwable t) {
-        if (t != null) {
-            throw new FlinkRuntimeException(
-                    "Failed to list subscribed topic partitions due to ", t);
-        }
-
-        if (!initialDiscoveryFinished) {
-            unassignedInitialPartitions.addAll(fetchedPartitions);
-            initialDiscoveryFinished = true;
-        }
+    private PartitionSplitChange findNewPartitionSplits() {
+        final Set<TopicPartition> fetchedPartitions =
+                subscriber.getSubscribedTopicPartitions(adminClient);
 
         final PartitionChange partitionChange = 
getPartitionChange(fetchedPartitions);
         if (partitionChange.isEmpty()) {
-            return;
+            return null;
         }
-        context.callAsync(
-                () -> initializePartitionSplits(partitionChange),
-                this::handlePartitionSplitChanges);
+
+        return initializePartitionSplits(partitionChange);
     }
 
     /**
@@ -290,13 +337,14 @@ public class KafkaSourceEnumerator
         OffsetsInitializer.PartitionOffsetsRetriever offsetsRetriever = 
getOffsetsRetriever();
         // initial partitions use OffsetsInitializer specified by the user 
while new partitions use
         // EARLIEST
-        Map<TopicPartition, Long> startingOffsets = new HashMap<>();
-        startingOffsets.putAll(
-                newDiscoveryOffsetsInitializer.getPartitionOffsets(
-                        newPartitions, offsetsRetriever));
-        startingOffsets.putAll(
-                startingOffsetInitializer.getPartitionOffsets(
-                        unassignedInitialPartitions, offsetsRetriever));
+        final OffsetsInitializer initializer;
+        if (!initialDiscoveryFinished) {
+            initializer = startingOffsetInitializer;
+        } else {
+            initializer = newDiscoveryOffsetsInitializer;
+        }
+        Map<TopicPartition, Long> startingOffsets =
+                initializer.getPartitionOffsets(newPartitions, 
offsetsRetriever);
 
         Map<TopicPartition, Long> stoppingOffsets =
                 stoppingOffsetInitializer.getPartitionOffsets(newPartitions, 
offsetsRetriever);
@@ -322,14 +370,21 @@ public class KafkaSourceEnumerator
      * @param t Exception in worker thread
      */
     private void handlePartitionSplitChanges(
-            PartitionSplitChange partitionSplitChange, Throwable t) {
+            @Nullable PartitionSplitChange partitionSplitChange, Throwable t) {
         if (t != null) {
             throw new FlinkRuntimeException("Failed to initialize partition 
splits due to ", t);
         }
+        initialDiscoveryFinished = true;
         if (partitionDiscoveryIntervalMs <= 0) {
             LOG.debug("Partition discovery is disabled.");
             noMoreNewPartitionSplits = true;
         }
+        if (partitionSplitChange == null) {
+            return;
+        }
+        for (KafkaPartitionSplit split : 
partitionSplitChange.newPartitionSplits) {
+            unassignedSplits.put(split.getTopicPartition(), split);
+        }
         // TODO: Handle removed partitions.
         
addPartitionSplitChangeToPendingAssignments(partitionSplitChange.newPartitionSplits);
         assignPendingPartitionSplits(context.registeredReaders().keySet());
@@ -373,8 +428,8 @@ public class KafkaSourceEnumerator
                 // Mark pending partitions as already assigned
                 pendingAssignmentForReader.forEach(
                         split -> {
-                            assignedPartitions.add(split.getTopicPartition());
-                            
unassignedInitialPartitions.remove(split.getTopicPartition());
+                            assignedSplits.put(split.getTopicPartition(), 
split);
+                            unassignedSplits.remove(split.getTopicPartition());
                         });
             }
         }
@@ -414,7 +469,7 @@ public class KafkaSourceEnumerator
                     }
                 };
 
-        assignedPartitions.forEach(dedupOrMarkAsRemoved);
+        assignedSplits.keySet().forEach(dedupOrMarkAsRemoved);
         pendingPartitionSplitAssignment.forEach(
                 (reader, splits) ->
                         splits.forEach(
@@ -446,6 +501,11 @@ public class KafkaSourceEnumerator
         return new PartitionOffsetsRetrieverImpl(adminClient, groupId);
     }
 
+    @VisibleForTesting
+    Map<Integer, Set<KafkaPartitionSplit>> 
getPendingPartitionSplitAssignment() {
+        return pendingPartitionSplitAssignment;
+    }
+
     /**
      * Returns the index of the target subtask that a specific Kafka partition 
should be assigned
      * to.
diff --git 
a/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/TopicPartitionAndAssignmentStatus.java
 
b/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/SplitAndAssignmentStatus.java
similarity index 75%
rename from 
flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/TopicPartitionAndAssignmentStatus.java
rename to 
flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/SplitAndAssignmentStatus.java
index 2caed99b..a7763fb5 100644
--- 
a/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/TopicPartitionAndAssignmentStatus.java
+++ 
b/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/SplitAndAssignmentStatus.java
@@ -19,23 +19,21 @@
 package org.apache.flink.connector.kafka.source.enumerator;
 
 import org.apache.flink.annotation.Internal;
-
-import org.apache.kafka.common.TopicPartition;
+import org.apache.flink.connector.kafka.source.split.KafkaPartitionSplit;
 
 /** Kafka partition with assign status. */
 @Internal
-public class TopicPartitionAndAssignmentStatus {
-    private final TopicPartition topicPartition;
+public class SplitAndAssignmentStatus {
+    private final KafkaPartitionSplit split;
     private final AssignmentStatus assignmentStatus;
 
-    public TopicPartitionAndAssignmentStatus(
-            TopicPartition topicPartition, AssignmentStatus assignStatus) {
-        this.topicPartition = topicPartition;
+    public SplitAndAssignmentStatus(KafkaPartitionSplit split, 
AssignmentStatus assignStatus) {
+        this.split = split;
         this.assignmentStatus = assignStatus;
     }
 
-    public TopicPartition topicPartition() {
-        return topicPartition;
+    public KafkaPartitionSplit split() {
+        return split;
     }
 
     public AssignmentStatus assignmentStatus() {
diff --git 
a/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/split/KafkaPartitionSplit.java
 
b/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/split/KafkaPartitionSplit.java
index 7c04600d..52cb3b98 100644
--- 
a/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/split/KafkaPartitionSplit.java
+++ 
b/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/split/KafkaPartitionSplit.java
@@ -41,10 +41,14 @@ public class KafkaPartitionSplit implements SourceSplit {
     public static final long EARLIEST_OFFSET = -2;
     // Indicating the split should consume from the last committed offset.
     public static final long COMMITTED_OFFSET = -3;
+    // Used to indicate the split has been migrated from an earlier enumerator 
state; offset needs
+    // to be initialized on recovery
+    public static final long MIGRATED = Long.MIN_VALUE;
 
     // Valid special starting offsets
     public static final Set<Long> VALID_STARTING_OFFSET_MARKERS =
-            new HashSet<>(Arrays.asList(EARLIEST_OFFSET, LATEST_OFFSET, 
COMMITTED_OFFSET));
+            new HashSet<>(
+                    Arrays.asList(EARLIEST_OFFSET, LATEST_OFFSET, 
COMMITTED_OFFSET, MIGRATED));
     public static final Set<Long> VALID_STOPPING_OFFSET_MARKERS =
             new HashSet<>(Arrays.asList(LATEST_OFFSET, COMMITTED_OFFSET, 
NO_STOPPING_OFFSET));
 
diff --git 
a/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/dynamic/source/enumerator/DynamicKafkaSourceEnumStateSerializerTest.java
 
b/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/dynamic/source/enumerator/DynamicKafkaSourceEnumStateSerializerTest.java
index 66caec4c..251309bc 100644
--- 
a/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/dynamic/source/enumerator/DynamicKafkaSourceEnumStateSerializerTest.java
+++ 
b/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/dynamic/source/enumerator/DynamicKafkaSourceEnumStateSerializerTest.java
@@ -22,7 +22,8 @@ import 
org.apache.flink.connector.kafka.dynamic.metadata.ClusterMetadata;
 import org.apache.flink.connector.kafka.dynamic.metadata.KafkaStream;
 import org.apache.flink.connector.kafka.source.enumerator.AssignmentStatus;
 import org.apache.flink.connector.kafka.source.enumerator.KafkaSourceEnumState;
-import 
org.apache.flink.connector.kafka.source.enumerator.TopicPartitionAndAssignmentStatus;
+import 
org.apache.flink.connector.kafka.source.enumerator.SplitAndAssignmentStatus;
+import org.apache.flink.connector.kafka.source.split.KafkaPartitionSplit;
 
 import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.ImmutableSet;
@@ -33,6 +34,8 @@ import org.junit.jupiter.api.Test;
 import java.util.Properties;
 import java.util.Set;
 
+import static 
org.apache.flink.connector.kafka.source.enumerator.AssignmentStatus.ASSIGNED;
+import static 
org.apache.flink.connector.kafka.source.enumerator.AssignmentStatus.UNASSIGNED;
 import static org.assertj.core.api.Assertions.assertThat;
 
 /**
@@ -81,28 +84,16 @@ public class DynamicKafkaSourceEnumStateSerializerTest {
                                 "cluster0",
                                 new KafkaSourceEnumState(
                                         ImmutableSet.of(
-                                                new 
TopicPartitionAndAssignmentStatus(
-                                                        new 
TopicPartition("topic0", 0),
-                                                        
AssignmentStatus.ASSIGNED),
-                                                new 
TopicPartitionAndAssignmentStatus(
-                                                        new 
TopicPartition("topic1", 1),
-                                                        
AssignmentStatus.UNASSIGNED_INITIAL)),
+                                                getSplitAssignment("topic0", 
0, ASSIGNED),
+                                                getSplitAssignment("topic1", 
1, UNASSIGNED)),
                                         true),
                                 "cluster1",
                                 new KafkaSourceEnumState(
                                         ImmutableSet.of(
-                                                new 
TopicPartitionAndAssignmentStatus(
-                                                        new 
TopicPartition("topic2", 0),
-                                                        
AssignmentStatus.UNASSIGNED_INITIAL),
-                                                new 
TopicPartitionAndAssignmentStatus(
-                                                        new 
TopicPartition("topic3", 1),
-                                                        
AssignmentStatus.UNASSIGNED_INITIAL),
-                                                new 
TopicPartitionAndAssignmentStatus(
-                                                        new 
TopicPartition("topic4", 2),
-                                                        
AssignmentStatus.UNASSIGNED_INITIAL),
-                                                new 
TopicPartitionAndAssignmentStatus(
-                                                        new 
TopicPartition("topic5", 3),
-                                                        
AssignmentStatus.UNASSIGNED_INITIAL)),
+                                                getSplitAssignment("topic2", 
0, UNASSIGNED),
+                                                getSplitAssignment("topic3", 
1, UNASSIGNED),
+                                                getSplitAssignment("topic4", 
2, UNASSIGNED),
+                                                getSplitAssignment("topic5", 
3, UNASSIGNED)),
                                         false)));
 
         DynamicKafkaSourceEnumState dynamicKafkaSourceEnumStateAfterSerde =
@@ -115,4 +106,10 @@ public class DynamicKafkaSourceEnumStateSerializerTest {
                 .usingRecursiveComparison()
                 .isEqualTo(dynamicKafkaSourceEnumStateAfterSerde);
     }
+
+    private static SplitAndAssignmentStatus getSplitAssignment(
+            String topic, int partition, AssignmentStatus assignStatus) {
+        return new SplitAndAssignmentStatus(
+                new KafkaPartitionSplit(new TopicPartition(topic, partition), 
0), assignStatus);
+    }
 }
diff --git 
a/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/dynamic/source/enumerator/DynamicKafkaSourceEnumeratorTest.java
 
b/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/dynamic/source/enumerator/DynamicKafkaSourceEnumeratorTest.java
index 86133345..f974b6ff 100644
--- 
a/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/dynamic/source/enumerator/DynamicKafkaSourceEnumeratorTest.java
+++ 
b/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/dynamic/source/enumerator/DynamicKafkaSourceEnumeratorTest.java
@@ -33,7 +33,6 @@ import 
org.apache.flink.connector.kafka.dynamic.source.enumerator.subscriber.Kaf
 import 
org.apache.flink.connector.kafka.dynamic.source.split.DynamicKafkaSourceSplit;
 import org.apache.flink.connector.kafka.source.KafkaSourceOptions;
 import org.apache.flink.connector.kafka.source.enumerator.AssignmentStatus;
-import 
org.apache.flink.connector.kafka.source.enumerator.TopicPartitionAndAssignmentStatus;
 import 
org.apache.flink.connector.kafka.source.enumerator.initializer.NoStoppingOffsetsInitializer;
 import 
org.apache.flink.connector.kafka.source.enumerator.initializer.OffsetsInitializer;
 import org.apache.flink.connector.kafka.testutils.MockKafkaMetadataService;
@@ -429,7 +428,7 @@ public class DynamicKafkaSourceEnumeratorTest {
             assertThat(
                             
stateBeforeSplitAssignment.getClusterEnumeratorStates().values()
                                     .stream()
-                                    .map(subState -> 
subState.assignedPartitions().stream())
+                                    .map(subState -> 
subState.assignedSplits().stream())
                                     .count())
                     .as("no readers registered, so state should be empty")
                     .isZero();
@@ -458,7 +457,7 @@ public class DynamicKafkaSourceEnumeratorTest {
 
             assertThat(
                             
stateAfterSplitAssignment.getClusterEnumeratorStates().values().stream()
-                                    .flatMap(enumState -> 
enumState.assignedPartitions().stream())
+                                    .flatMap(enumState -> 
enumState.assignedSplits().stream())
                                     .count())
                     .isEqualTo(
                             NUM_SPLITS_PER_CLUSTER
@@ -514,15 +513,13 @@ public class DynamicKafkaSourceEnumeratorTest {
 
             assertThat(getFilteredTopicPartitions(initialState, TOPIC, 
AssignmentStatus.ASSIGNED))
                     .hasSize(2);
-            assertThat(
-                            getFilteredTopicPartitions(
-                                    initialState, TOPIC, 
AssignmentStatus.UNASSIGNED_INITIAL))
+            assertThat(getFilteredTopicPartitions(initialState, TOPIC, 
AssignmentStatus.UNASSIGNED))
                     .hasSize(1);
             assertThat(getFilteredTopicPartitions(initialState, topic2, 
AssignmentStatus.ASSIGNED))
                     .hasSize(2);
             assertThat(
                             getFilteredTopicPartitions(
-                                    initialState, topic2, 
AssignmentStatus.UNASSIGNED_INITIAL))
+                                    initialState, topic2, 
AssignmentStatus.UNASSIGNED))
                     .hasSize(1);
 
             // mock metadata change
@@ -540,13 +537,13 @@ public class DynamicKafkaSourceEnumeratorTest {
                     .hasSize(3);
             assertThat(
                             getFilteredTopicPartitions(
-                                    migratedState, TOPIC, 
AssignmentStatus.UNASSIGNED_INITIAL))
+                                    migratedState, TOPIC, 
AssignmentStatus.UNASSIGNED))
                     .isEmpty();
             assertThat(getFilteredTopicPartitions(migratedState, topic2, 
AssignmentStatus.ASSIGNED))
                     .isEmpty();
             assertThat(
                             getFilteredTopicPartitions(
-                                    migratedState, topic2, 
AssignmentStatus.UNASSIGNED_INITIAL))
+                                    migratedState, topic2, 
AssignmentStatus.UNASSIGNED))
                     .isEmpty();
         }
     }
@@ -955,12 +952,14 @@ public class DynamicKafkaSourceEnumeratorTest {
     private List<TopicPartition> getFilteredTopicPartitions(
             DynamicKafkaSourceEnumState state, String topic, AssignmentStatus 
assignmentStatus) {
         return state.getClusterEnumeratorStates().values().stream()
-                .flatMap(s -> s.partitions().stream())
+                .flatMap(s -> s.splits().stream())
                 .filter(
                         partition ->
-                                
partition.topicPartition().topic().equals(topic)
+                                partition.split().getTopic().equals(topic)
                                         && partition.assignmentStatus() == 
assignmentStatus)
-                .map(TopicPartitionAndAssignmentStatus::topicPartition)
+                .map(
+                        splitAndAssignmentStatus ->
+                                
splitAndAssignmentStatus.split().getTopicPartition())
                 .collect(Collectors.toList());
     }
 
diff --git 
a/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumStateSerializerTest.java
 
b/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumStateSerializerTest.java
index 6c172e4a..5207687f 100644
--- 
a/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumStateSerializerTest.java
+++ 
b/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumStateSerializerTest.java
@@ -29,8 +29,10 @@ import java.io.IOException;
 import java.util.Collection;
 import java.util.HashMap;
 import java.util.HashSet;
+import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.stream.Collectors;
 
 import static org.assertj.core.api.Assertions.assertThat;
 
@@ -46,8 +48,8 @@ public class KafkaSourceEnumStateSerializerTest {
     public void testEnumStateSerde() throws IOException {
         final KafkaSourceEnumState state =
                 new KafkaSourceEnumState(
-                        constructTopicPartitions(0),
-                        constructTopicPartitions(NUM_PARTITIONS_PER_TOPIC),
+                        constructTopicSplits(0),
+                        constructTopicSplits(NUM_PARTITIONS_PER_TOPIC),
                         true);
         final KafkaSourceEnumStateSerializer serializer = new 
KafkaSourceEnumStateSerializer();
 
@@ -56,26 +58,35 @@ public class KafkaSourceEnumStateSerializerTest {
         final KafkaSourceEnumState restoredState =
                 serializer.deserialize(serializer.getVersion(), bytes);
 
-        
assertThat(restoredState.assignedPartitions()).isEqualTo(state.assignedPartitions());
-        assertThat(restoredState.unassignedInitialPartitions())
-                .isEqualTo(state.unassignedInitialPartitions());
+        assertThat(restoredState.assignedSplits())
+                .containsExactlyInAnyOrderElementsOf(state.assignedSplits());
+        assertThat(restoredState.unassignedSplits())
+                .containsExactlyInAnyOrderElementsOf(state.unassignedSplits());
         assertThat(restoredState.initialDiscoveryFinished()).isTrue();
     }
 
     @Test
     public void testBackwardCompatibility() throws IOException {
 
-        final Set<TopicPartition> topicPartitions = 
constructTopicPartitions(0);
-        final Map<Integer, Set<KafkaPartitionSplit>> splitAssignments =
-                toSplitAssignments(topicPartitions);
+        final Set<KafkaPartitionSplit> splits = constructTopicSplits(0);
+        final Map<Integer, Collection<KafkaPartitionSplit>> splitAssignments =
+                toSplitAssignments(splits);
+        final List<SplitAndAssignmentStatus> splitAndAssignmentStatuses =
+                splits.stream()
+                        .map(
+                                split ->
+                                        new SplitAndAssignmentStatus(
+                                                split, 
getAssignmentStatus(split)))
+                        .collect(Collectors.toList());
 
         // Create bytes in the way of KafkaEnumStateSerializer version 0 doing 
serialization
         final byte[] bytesV0 =
                 SerdeUtils.serializeSplitAssignments(
                         splitAssignments, new KafkaPartitionSplitSerializer());
         // Create bytes in the way of KafkaEnumStateSerializer version 1 doing 
serialization
-        final byte[] bytesV1 =
-                
KafkaSourceEnumStateSerializer.serializeTopicPartitions(topicPartitions);
+        final byte[] bytesV1 = 
KafkaSourceEnumStateSerializer.serializeV1(splits);
+        final byte[] bytesV2 =
+                
KafkaSourceEnumStateSerializer.serializeV2(splitAndAssignmentStatuses, false);
 
         // Deserialize above bytes with KafkaEnumStateSerializer version 2 to 
check backward
         // compatibility
@@ -83,46 +94,72 @@ public class KafkaSourceEnumStateSerializerTest {
                 new KafkaSourceEnumStateSerializer().deserialize(0, bytesV0);
         final KafkaSourceEnumState kafkaSourceEnumStateV1 =
                 new KafkaSourceEnumStateSerializer().deserialize(1, bytesV1);
+        final KafkaSourceEnumState kafkaSourceEnumStateV2 =
+                new KafkaSourceEnumStateSerializer().deserialize(2, bytesV2);
 
-        
assertThat(kafkaSourceEnumStateV0.assignedPartitions()).isEqualTo(topicPartitions);
-        
assertThat(kafkaSourceEnumStateV0.unassignedInitialPartitions()).isEmpty();
+        assertThat(kafkaSourceEnumStateV0.assignedSplits())
+                .containsExactlyInAnyOrderElementsOf(splits);
+        assertThat(kafkaSourceEnumStateV0.unassignedSplits()).isEmpty();
         assertThat(kafkaSourceEnumStateV0.initialDiscoveryFinished()).isTrue();
 
-        
assertThat(kafkaSourceEnumStateV1.assignedPartitions()).isEqualTo(topicPartitions);
-        
assertThat(kafkaSourceEnumStateV1.unassignedInitialPartitions()).isEmpty();
+        assertThat(kafkaSourceEnumStateV1.assignedSplits())
+                .containsExactlyInAnyOrderElementsOf(splits);
+        assertThat(kafkaSourceEnumStateV1.unassignedSplits()).isEmpty();
         assertThat(kafkaSourceEnumStateV1.initialDiscoveryFinished()).isTrue();
+
+        final Map<AssignmentStatus, Set<KafkaPartitionSplit>> splitsByStatus =
+                splitAndAssignmentStatuses.stream()
+                        .collect(
+                                Collectors.groupingBy(
+                                        
SplitAndAssignmentStatus::assignmentStatus,
+                                        Collectors.mapping(
+                                                
SplitAndAssignmentStatus::split,
+                                                Collectors.toSet())));
+        assertThat(kafkaSourceEnumStateV2.assignedSplits())
+                
.containsExactlyInAnyOrderElementsOf(splitsByStatus.get(AssignmentStatus.ASSIGNED));
+        assertThat(kafkaSourceEnumStateV2.unassignedSplits())
+                .containsExactlyInAnyOrderElementsOf(
+                        splitsByStatus.get(AssignmentStatus.UNASSIGNED));
+        
assertThat(kafkaSourceEnumStateV2.initialDiscoveryFinished()).isFalse();
+    }
+
+    private static AssignmentStatus getAssignmentStatus(KafkaPartitionSplit 
split) {
+        return AssignmentStatus.values()[
+                Math.abs(split.hashCode()) % AssignmentStatus.values().length];
     }
 
-    private Set<TopicPartition> constructTopicPartitions(int startPartition) {
+    private Set<KafkaPartitionSplit> constructTopicSplits(int startPartition) {
         // Create topic partitions for readers.
         // Reader i will be assigned with NUM_PARTITIONS_PER_TOPIC splits, 
with topic name
         // "topic-{i}" and
         // NUM_PARTITIONS_PER_TOPIC partitions. The starting partition number 
is startPartition
         // Totally NUM_READERS * NUM_PARTITIONS_PER_TOPIC partitions will be 
created.
-        Set<TopicPartition> topicPartitions = new HashSet<>();
+        Set<KafkaPartitionSplit> topicPartitions = new HashSet<>();
         for (int readerId = 0; readerId < NUM_READERS; readerId++) {
             for (int partition = startPartition;
                     partition < startPartition + NUM_PARTITIONS_PER_TOPIC;
                     partition++) {
-                topicPartitions.add(new TopicPartition(TOPIC_PREFIX + 
readerId, partition));
+                topicPartitions.add(
+                        new KafkaPartitionSplit(
+                                new TopicPartition(TOPIC_PREFIX + readerId, 
partition),
+                                KafkaPartitionSplit.MIGRATED));
             }
         }
         return topicPartitions;
     }
 
-    private Map<Integer, Set<KafkaPartitionSplit>> toSplitAssignments(
-            Collection<TopicPartition> topicPartitions) {
+    private Map<Integer, Collection<KafkaPartitionSplit>> toSplitAssignments(
+            Collection<KafkaPartitionSplit> splits) {
         // Assign splits to readers according to topic name. For example, 
topic "topic-5" will be
         // assigned to reader with ID=5
-        Map<Integer, Set<KafkaPartitionSplit>> splitAssignments = new 
HashMap<>();
-        topicPartitions.forEach(
-                (tp) ->
-                        splitAssignments
-                                .computeIfAbsent(
-                                        Integer.valueOf(
-                                                
tp.topic().substring(TOPIC_PREFIX.length())),
-                                        HashSet::new)
-                                .add(new KafkaPartitionSplit(tp, 
STARTING_OFFSET)));
+        Map<Integer, Collection<KafkaPartitionSplit>> splitAssignments = new 
HashMap<>();
+        for (KafkaPartitionSplit split : splits) {
+            splitAssignments
+                    .computeIfAbsent(
+                            
Integer.valueOf(split.getTopic().substring(TOPIC_PREFIX.length())),
+                            HashSet::new)
+                    .add(split);
+        }
         return splitAssignments;
     }
 }
diff --git 
a/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/source/enumerator/KafkaEnumeratorTest.java
 
b/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumeratorTest.java
similarity index 71%
rename from 
flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/source/enumerator/KafkaEnumeratorTest.java
rename to 
flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumeratorTest.java
index 8b308af1..3e64e62c 100644
--- 
a/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/source/enumerator/KafkaEnumeratorTest.java
+++ 
b/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumeratorTest.java
@@ -30,14 +30,20 @@ import 
org.apache.flink.connector.kafka.source.split.KafkaPartitionSplit;
 import org.apache.flink.connector.kafka.testutils.KafkaSourceTestEnv;
 import org.apache.flink.mock.Whitebox;
 
+import com.google.common.collect.Iterables;
 import org.apache.kafka.clients.admin.AdminClient;
 import org.apache.kafka.clients.admin.NewTopic;
 import org.apache.kafka.clients.consumer.ConsumerConfig;
+import org.apache.kafka.clients.consumer.OffsetResetStrategy;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.serialization.StringDeserializer;
-import org.junit.AfterClass;
-import org.junit.BeforeClass;
-import org.junit.Test;
+import org.assertj.core.api.SoftAssertions;
+import org.junit.jupiter.api.AfterAll;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.Timeout;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.EnumSource;
 
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -50,13 +56,15 @@ import java.util.Map;
 import java.util.Properties;
 import java.util.Set;
 import java.util.StringJoiner;
+import java.util.concurrent.TimeUnit;
 import java.util.regex.Pattern;
 import java.util.stream.Collectors;
 
+import static 
org.apache.flink.connector.kafka.source.split.KafkaPartitionSplit.MIGRATED;
 import static org.assertj.core.api.Assertions.assertThat;
 
 /** Unit tests for {@link KafkaSourceEnumerator}. */
-public class KafkaEnumeratorTest {
+public class KafkaSourceEnumeratorTest {
     private static final int NUM_SUBTASKS = 3;
     private static final String DYNAMIC_TOPIC_NAME = "dynamic_topic";
     private static final int NUM_PARTITIONS_DYNAMIC_TOPIC = 4;
@@ -74,15 +82,30 @@ public class KafkaEnumeratorTest {
     private static final boolean DISABLE_PERIODIC_PARTITION_DISCOVERY = false;
     private static final boolean INCLUDE_DYNAMIC_TOPIC = true;
     private static final boolean EXCLUDE_DYNAMIC_TOPIC = false;
+    private static KafkaSourceEnumerator.PartitionOffsetsRetrieverImpl 
retriever;
+    private static final Map<TopicPartition, Long> specificOffsets = new 
HashMap<>();
 
-    @BeforeClass
+    @BeforeAll
     public static void setup() throws Throwable {
         KafkaSourceTestEnv.setup();
+        retriever =
+                new KafkaSourceEnumerator.PartitionOffsetsRetrieverImpl(
+                        KafkaSourceTestEnv.getAdminClient(), 
KafkaSourceTestEnv.GROUP_ID);
         KafkaSourceTestEnv.setupTopic(TOPIC1, true, true, 
KafkaSourceTestEnv::getRecordsForTopic);
         KafkaSourceTestEnv.setupTopic(TOPIC2, true, true, 
KafkaSourceTestEnv::getRecordsForTopic);
+
+        for (Map.Entry<TopicPartition, Long> partitionEnd :
+                retriever
+                        
.endOffsets(KafkaSourceTestEnv.getPartitionsForTopics(PRE_EXISTING_TOPICS))
+                        .entrySet()) {
+            specificOffsets.put(
+                    partitionEnd.getKey(),
+                    partitionEnd.getValue() / 
(partitionEnd.getKey().partition() + 1));
+        }
+        assertThat(specificOffsets).hasSize(2 * 
KafkaSourceTestEnv.NUM_PARTITIONS);
     }
 
-    @AfterClass
+    @AfterAll
     public static void tearDown() throws Exception {
         KafkaSourceTestEnv.tearDown();
     }
@@ -249,7 +272,8 @@ public class KafkaEnumeratorTest {
         }
     }
 
-    @Test(timeout = 30000L)
+    @Test
+    @Timeout(value = 30, unit = TimeUnit.SECONDS)
     public void testDiscoverPartitionsPeriodically() throws Throwable {
         try (MockSplitEnumeratorContext<KafkaPartitionSplit> context =
                         new MockSplitEnumeratorContext<>(NUM_SUBTASKS);
@@ -261,7 +285,7 @@ public class KafkaEnumeratorTest {
                                 OffsetsInitializer.latest());
                 AdminClient adminClient = KafkaSourceTestEnv.getAdminClient()) 
{
 
-            startEnumeratorAndRegisterReaders(context, enumerator);
+            startEnumeratorAndRegisterReaders(context, enumerator, 
OffsetsInitializer.latest());
 
             // invoke partition discovery callable again and there should be 
no new assignments.
             runPeriodicPartitionDiscovery(context);
@@ -289,11 +313,13 @@ public class KafkaEnumeratorTest {
                     break;
                 }
             }
+            // later elements are initialized with EARLIEST
             verifyLastReadersAssignments(
                     context,
                     Arrays.asList(READER0, READER1),
                     Collections.singleton(DYNAMIC_TOPIC_NAME),
-                    3);
+                    3,
+                    OffsetsInitializer.earliest());
 
             // new partitions use EARLIEST_OFFSET, while initial partitions 
use LATEST_OFFSET
             List<KafkaPartitionSplit> initialPartitionAssign =
@@ -316,39 +342,123 @@ public class KafkaEnumeratorTest {
         }
     }
 
+    /**
+     * Ensures that migrated splits are immediately initialized with {@link 
OffsetsInitializer},
+     * such that an early {@link
+     * 
org.apache.flink.api.connector.source.SplitEnumerator#snapshotState(long)} 
doesn't see the
+     * special value.
+     */
     @Test
-    public void testAddSplitsBack() throws Throwable {
+    public void shouldEagerlyInitializeSplitOffsetsOnMigration() throws 
Throwable {
+        final TopicPartition assigned1 = new TopicPartition(TOPIC1, 0);
+        final TopicPartition assigned2 = new TopicPartition(TOPIC1, 1);
+        final TopicPartition unassigned1 = new TopicPartition(TOPIC2, 0);
+        final TopicPartition unassigned2 = new TopicPartition(TOPIC2, 1);
+
+        final long migratedOffset1 = 11L;
+        final long migratedOffset2 = 22L;
+        final OffsetsInitializer offsetsInitializer =
+                new OffsetsInitializer() {
+                    @Override
+                    public Map<TopicPartition, Long> getPartitionOffsets(
+                            Collection<TopicPartition> partitions,
+                            PartitionOffsetsRetriever 
partitionOffsetsRetriever) {
+                        return Map.of(assigned1, migratedOffset1, unassigned2, 
migratedOffset2);
+                    }
+
+                    @Override
+                    public OffsetResetStrategy getAutoOffsetResetStrategy() {
+                        return null;
+                    }
+                };
         try (MockSplitEnumeratorContext<KafkaPartitionSplit> context =
                         new MockSplitEnumeratorContext<>(NUM_SUBTASKS);
                 KafkaSourceEnumerator enumerator =
-                        createEnumerator(context, 
ENABLE_PERIODIC_PARTITION_DISCOVERY)) {
+                        createEnumerator(
+                                context,
+                                offsetsInitializer,
+                                PRE_EXISTING_TOPICS,
+                                List.of(
+                                        new KafkaPartitionSplit(assigned1, 
MIGRATED),
+                                        new KafkaPartitionSplit(assigned2, 2)),
+                                List.of(
+                                        new KafkaPartitionSplit(unassigned1, 
1),
+                                        new KafkaPartitionSplit(unassigned2, 
MIGRATED)),
+                                false,
+                                new Properties())) {
+            final KafkaSourceEnumState state = enumerator.snapshotState(1L);
+            assertThat(state.assignedSplits())
+                    .containsExactlyInAnyOrder(
+                            new KafkaPartitionSplit(assigned1, 
migratedOffset1),
+                            new KafkaPartitionSplit(assigned2, 2));
+            assertThat(state.unassignedSplits())
+                    .containsExactlyInAnyOrder(
+                            new KafkaPartitionSplit(unassigned1, 1),
+                            new KafkaPartitionSplit(unassigned2, 
migratedOffset2));
+        }
+    }
 
-            startEnumeratorAndRegisterReaders(context, enumerator);
+    @ParameterizedTest
+    @EnumSource(StandardOffsetsInitializer.class)
+    public void testAddSplitsBack(StandardOffsetsInitializer 
offsetsInitializer) throws Throwable {
+        try (MockSplitEnumeratorContext<KafkaPartitionSplit> context =
+                        new MockSplitEnumeratorContext<>(NUM_SUBTASKS);
+                KafkaSourceEnumerator enumerator =
+                        createEnumerator(
+                                context,
+                                ENABLE_PERIODIC_PARTITION_DISCOVERY,
+                                true,
+                                offsetsInitializer.getOffsetsInitializer())) {
+
+            startEnumeratorAndRegisterReaders(
+                    context, enumerator, 
offsetsInitializer.getOffsetsInitializer());
+
+            // READER2 not yet assigned
+            final Set<KafkaPartitionSplit> unassignedSplits =
+                    
enumerator.getPendingPartitionSplitAssignment().get(READER2);
+            assertThat(enumerator.snapshotState(1L).unassignedSplits())
+                    .containsExactlyInAnyOrderElementsOf(unassignedSplits);
 
             // Simulate a reader failure.
             context.unregisterReader(READER0);
-            enumerator.addSplitsBack(
-                    
context.getSplitsAssignmentSequence().get(0).assignment().get(READER0),
-                    READER0);
+            final List<KafkaPartitionSplit> assignedSplits =
+                    
context.getSplitsAssignmentSequence().get(0).assignment().get(READER0);
+            final List<KafkaPartitionSplit> advancedSplits =
+                    assignedSplits.stream()
+                            .map(
+                                    split ->
+                                            new KafkaPartitionSplit(
+                                                    split.getTopicPartition(),
+                                                    split.getStartingOffset() 
+ 1))
+                            .collect(Collectors.toList());
+            enumerator.addSplitsBack(advancedSplits, READER0);
             assertThat(context.getSplitsAssignmentSequence())
                     .as("The added back splits should have not been assigned")
                     .hasSize(2);
 
+            assertThat(enumerator.snapshotState(2L).unassignedSplits())
+                    .containsExactlyInAnyOrderElementsOf(
+                            Iterables.concat(
+                                    advancedSplits, unassignedSplits)); // 
READER0 + READER2
+
             // Simulate a reader recovery.
             registerReader(context, enumerator, READER0);
-            verifyLastReadersAssignments(
-                    context, Collections.singleton(READER0), 
PRE_EXISTING_TOPICS, 3);
+            verifyAssignments(
+                    Map.of(READER0, advancedSplits),
+                    context.getSplitsAssignmentSequence().get(2).assignment());
+            assertThat(enumerator.snapshotState(3L).unassignedSplits())
+                    .containsExactlyInAnyOrderElementsOf(unassignedSplits);
         }
     }
 
     @Test
     public void testWorkWithPreexistingAssignments() throws Throwable {
-        Set<TopicPartition> preexistingAssignments;
+        Collection<KafkaPartitionSplit> preexistingAssignments;
         try (MockSplitEnumeratorContext<KafkaPartitionSplit> context1 =
                         new MockSplitEnumeratorContext<>(NUM_SUBTASKS);
                 KafkaSourceEnumerator enumerator =
                         createEnumerator(context1, 
ENABLE_PERIODIC_PARTITION_DISCOVERY)) {
-            startEnumeratorAndRegisterReaders(context1, enumerator);
+            startEnumeratorAndRegisterReaders(context1, enumerator, 
OffsetsInitializer.earliest());
             preexistingAssignments =
                     
asEnumState(context1.getSplitsAssignmentSequence().get(0).assignment());
         }
@@ -409,56 +519,50 @@ public class KafkaEnumeratorTest {
         }
     }
 
-    @Test
-    public void testSnapshotState() throws Throwable {
+    @ParameterizedTest
+    @EnumSource(StandardOffsetsInitializer.class)
+    public void testSnapshotState(StandardOffsetsInitializer 
offsetsInitializer) throws Throwable {
         try (MockSplitEnumeratorContext<KafkaPartitionSplit> context =
                         new MockSplitEnumeratorContext<>(NUM_SUBTASKS);
-                KafkaSourceEnumerator enumerator = createEnumerator(context, 
false)) {
+                KafkaSourceEnumerator enumerator =
+                        createEnumerator(
+                                context, false, true, 
offsetsInitializer.getOffsetsInitializer())) {
             enumerator.start();
 
             // Step1: Before first discovery, so the state should be empty
             final KafkaSourceEnumState state1 = enumerator.snapshotState(1L);
-            assertThat(state1.assignedPartitions()).isEmpty();
-            assertThat(state1.unassignedInitialPartitions()).isEmpty();
+            assertThat(state1.assignedSplits()).isEmpty();
+            assertThat(state1.unassignedSplits()).isEmpty();
             assertThat(state1.initialDiscoveryFinished()).isFalse();
 
             registerReader(context, enumerator, READER0);
             registerReader(context, enumerator, READER1);
 
-            // Step2: First partition discovery after start, but no 
assignments to readers
-            context.runNextOneTimeCallable();
-            final KafkaSourceEnumState state2 = enumerator.snapshotState(2L);
-            assertThat(state2.assignedPartitions()).isEmpty();
-            assertThat(state2.unassignedInitialPartitions()).isNotEmpty();
-            assertThat(state2.initialDiscoveryFinished()).isTrue();
-
-            // Step3: Assign partials partitions to reader0 and reader1
+            // Step2: Assign partials partitions to reader0 and reader1
             context.runNextOneTimeCallable();
 
             // The state should contain splits assigned to READER0 and 
READER1, but no READER2
             // register.
             // Thus, both assignedPartitions and unassignedInitialPartitions 
are not empty.
-            final KafkaSourceEnumState state3 = enumerator.snapshotState(3L);
+            final KafkaSourceEnumState state2 = enumerator.snapshotState(2L);
             verifySplitAssignmentWithPartitions(
                     getExpectedAssignments(
-                            new HashSet<>(Arrays.asList(READER0, READER1)), 
PRE_EXISTING_TOPICS),
-                    state3.assignedPartitions());
-            assertThat(state3.unassignedInitialPartitions()).isNotEmpty();
-            assertThat(state3.initialDiscoveryFinished()).isTrue();
-            // total partitions of state2 and state3  are equal
-            // state2 only includes unassignedInitialPartitions
-            // state3 includes unassignedInitialPartitions + assignedPartitions
-            Set<TopicPartition> allPartitionOfState3 = new HashSet<>();
-            allPartitionOfState3.addAll(state3.unassignedInitialPartitions());
-            allPartitionOfState3.addAll(state3.assignedPartitions());
-            
assertThat(state2.unassignedInitialPartitions()).isEqualTo(allPartitionOfState3);
-
-            // Step4: register READER2, then all partitions are assigned
+                            new HashSet<>(Arrays.asList(READER0, READER1)),
+                            PRE_EXISTING_TOPICS,
+                            offsetsInitializer.getOffsetsInitializer()),
+                    state2.assignedSplits());
+            assertThat(state2.assignedSplits()).isNotEmpty();
+            assertThat(state2.unassignedSplits()).isNotEmpty();
+            assertThat(state2.initialDiscoveryFinished()).isTrue();
+
+            // Step3: register READER2, then all partitions are assigned
             registerReader(context, enumerator, READER2);
-            final KafkaSourceEnumState state4 = enumerator.snapshotState(4L);
-            
assertThat(state4.assignedPartitions()).isEqualTo(allPartitionOfState3);
-            assertThat(state4.unassignedInitialPartitions()).isEmpty();
-            assertThat(state4.initialDiscoveryFinished()).isTrue();
+            final KafkaSourceEnumState state3 = enumerator.snapshotState(3L);
+            assertThat(state3.assignedSplits())
+                    .containsExactlyInAnyOrderElementsOf(
+                            Iterables.concat(state2.assignedSplits(), 
state2.unassignedSplits()));
+            assertThat(state3.unassignedSplits()).isEmpty();
+            assertThat(state3.initialDiscoveryFinished()).isTrue();
         }
     }
 
@@ -530,7 +634,8 @@ public class KafkaEnumeratorTest {
 
     private void startEnumeratorAndRegisterReaders(
             MockSplitEnumeratorContext<KafkaPartitionSplit> context,
-            KafkaSourceEnumerator enumerator)
+            KafkaSourceEnumerator enumerator,
+            OffsetsInitializer offsetsInitializer)
             throws Throwable {
         // Start the enumerator and it should schedule a one time task to 
discover and assign
         // partitions.
@@ -543,12 +648,20 @@ public class KafkaEnumeratorTest {
         // Run the partition discover callable and check the partition 
assignment.
         runPeriodicPartitionDiscovery(context);
         verifyLastReadersAssignments(
-                context, Collections.singleton(READER0), PRE_EXISTING_TOPICS, 
1);
+                context,
+                Collections.singleton(READER0),
+                PRE_EXISTING_TOPICS,
+                1,
+                offsetsInitializer);
 
         // Register reader 1 after first partition discovery.
         registerReader(context, enumerator, READER1);
         verifyLastReadersAssignments(
-                context, Collections.singleton(READER1), PRE_EXISTING_TOPICS, 
2);
+                context,
+                Collections.singleton(READER1),
+                PRE_EXISTING_TOPICS,
+                2,
+                offsetsInitializer);
     }
 
     // ----------------------------------------
@@ -619,8 +732,8 @@ public class KafkaEnumeratorTest {
             MockSplitEnumeratorContext<KafkaPartitionSplit> enumContext,
             OffsetsInitializer startingOffsetsInitializer,
             Collection<String> topicsToSubscribe,
-            Set<TopicPartition> assignedPartitions,
-            Set<TopicPartition> unassignedInitialPartitions,
+            Collection<KafkaPartitionSplit> assignedSplits,
+            Collection<KafkaPartitionSplit> unassignedInitialSplits,
             boolean initialDiscoveryFinished,
             Properties overrideProperties) {
         // Use a TopicPatternSubscriber so that no exception if a subscribed 
topic hasn't been
@@ -644,11 +757,29 @@ public class KafkaEnumeratorTest {
                 enumContext,
                 Boundedness.CONTINUOUS_UNBOUNDED,
                 new KafkaSourceEnumState(
-                        assignedPartitions, unassignedInitialPartitions, 
initialDiscoveryFinished));
+                        assignedSplits, unassignedInitialSplits, 
initialDiscoveryFinished));
     }
 
     // ---------------------
 
+    /** The standard {@link OffsetsInitializer}s used for parameterized tests. 
*/
+    enum StandardOffsetsInitializer {
+        EARLIEST_OFFSETS(OffsetsInitializer.earliest()),
+        LATEST_OFFSETS(OffsetsInitializer.latest()),
+        SPECIFIC_OFFSETS(OffsetsInitializer.offsets(specificOffsets, 
OffsetResetStrategy.NONE)),
+        COMMITTED_OFFSETS(OffsetsInitializer.committedOffsets());
+
+        private final OffsetsInitializer offsetsInitializer;
+
+        StandardOffsetsInitializer(OffsetsInitializer offsetsInitializer) {
+            this.offsetsInitializer = offsetsInitializer;
+        }
+
+        public OffsetsInitializer getOffsetsInitializer() {
+            return offsetsInitializer;
+        }
+    }
+
     private void registerReader(
             MockSplitEnumeratorContext<KafkaPartitionSplit> context,
             KafkaSourceEnumerator enumerator,
@@ -662,63 +793,84 @@ public class KafkaEnumeratorTest {
             Collection<Integer> readers,
             Set<String> topics,
             int expectedAssignmentSeqSize) {
+        verifyLastReadersAssignments(
+                context, readers, topics, expectedAssignmentSeqSize, 
OffsetsInitializer.earliest());
+    }
+
+    private void verifyLastReadersAssignments(
+            MockSplitEnumeratorContext<KafkaPartitionSplit> context,
+            Collection<Integer> readers,
+            Set<String> topics,
+            int expectedAssignmentSeqSize,
+            OffsetsInitializer offsetsInitializer) {
         verifyAssignments(
-                getExpectedAssignments(new HashSet<>(readers), topics),
+                getExpectedAssignments(new HashSet<>(readers), topics, 
offsetsInitializer),
                 context.getSplitsAssignmentSequence()
                         .get(expectedAssignmentSeqSize - 1)
                         .assignment());
     }
 
     private void verifyAssignments(
-            Map<Integer, Set<TopicPartition>> expectedAssignments,
+            Map<Integer, Collection<KafkaPartitionSplit>> expectedAssignments,
             Map<Integer, List<KafkaPartitionSplit>> actualAssignments) {
-        actualAssignments.forEach(
-                (reader, splits) -> {
-                    Set<TopicPartition> expectedAssignmentsForReader =
-                            expectedAssignments.get(reader);
-                    assertThat(expectedAssignmentsForReader).isNotNull();
-                    
assertThat(splits.size()).isEqualTo(expectedAssignmentsForReader.size());
-                    for (KafkaPartitionSplit split : splits) {
-                        assertThat(expectedAssignmentsForReader)
-                                .contains(split.getTopicPartition());
+        
assertThat(actualAssignments).containsOnlyKeys(expectedAssignments.keySet());
+        SoftAssertions.assertSoftly(
+                softly -> {
+                    for (Map.Entry<Integer, List<KafkaPartitionSplit>> actual :
+                            actualAssignments.entrySet()) {
+                        softly.assertThat(actual.getValue())
+                                .as("Assignment for reader %s", 
actual.getKey())
+                                .containsExactlyInAnyOrderElementsOf(
+                                        
expectedAssignments.get(actual.getKey()));
                     }
                 });
     }
 
-    private Map<Integer, Set<TopicPartition>> getExpectedAssignments(
-            Set<Integer> readers, Set<String> topics) {
-        Map<Integer, Set<TopicPartition>> expectedAssignments = new 
HashMap<>();
-        Set<TopicPartition> allPartitions = new HashSet<>();
+    private Map<Integer, Collection<KafkaPartitionSplit>> 
getExpectedAssignments(
+            Set<Integer> readers,
+            Set<String> topics,
+            OffsetsInitializer startingOffsetsInitializer) {
+        Map<Integer, Collection<KafkaPartitionSplit>> expectedAssignments = 
new HashMap<>();
+        Set<KafkaPartitionSplit> allPartitions = new HashSet<>();
 
         if (topics.contains(DYNAMIC_TOPIC_NAME)) {
             for (int i = 0; i < NUM_PARTITIONS_DYNAMIC_TOPIC; i++) {
-                allPartitions.add(new TopicPartition(DYNAMIC_TOPIC_NAME, i));
+                TopicPartition tp = new TopicPartition(DYNAMIC_TOPIC_NAME, i);
+                allPartitions.add(createSplit(tp, startingOffsetsInitializer));
             }
         }
 
         for (TopicPartition tp : 
KafkaSourceTestEnv.getPartitionsForTopics(PRE_EXISTING_TOPICS)) {
             if (topics.contains(tp.topic())) {
-                allPartitions.add(tp);
+                allPartitions.add(createSplit(tp, startingOffsetsInitializer));
             }
         }
 
-        for (TopicPartition tp : allPartitions) {
-            int ownerReader = KafkaSourceEnumerator.getSplitOwner(tp, 
NUM_SUBTASKS);
+        for (KafkaPartitionSplit split : allPartitions) {
+            int ownerReader =
+                    
KafkaSourceEnumerator.getSplitOwner(split.getTopicPartition(), NUM_SUBTASKS);
             if (readers.contains(ownerReader)) {
-                expectedAssignments.computeIfAbsent(ownerReader, r -> new 
HashSet<>()).add(tp);
+                expectedAssignments.computeIfAbsent(ownerReader, r -> new 
HashSet<>()).add(split);
             }
         }
         return expectedAssignments;
     }
 
+    private static KafkaPartitionSplit createSplit(
+            TopicPartition tp, OffsetsInitializer startingOffsetsInitializer) {
+        return new KafkaPartitionSplit(
+                tp, 
startingOffsetsInitializer.getPartitionOffsets(List.of(tp), retriever).get(tp));
+    }
+
     private void verifySplitAssignmentWithPartitions(
-            Map<Integer, Set<TopicPartition>> expectedAssignment,
-            Set<TopicPartition> actualTopicPartitions) {
-        final Set<TopicPartition> allTopicPartitionsFromAssignment = new 
HashSet<>();
-        expectedAssignment.forEach(
-                (reader, topicPartitions) ->
-                        
allTopicPartitionsFromAssignment.addAll(topicPartitions));
-        
assertThat(actualTopicPartitions).isEqualTo(allTopicPartitionsFromAssignment);
+            Map<Integer, Collection<KafkaPartitionSplit>> expectedAssignment,
+            Collection<KafkaPartitionSplit> actualTopicPartitions) {
+        final Set<KafkaPartitionSplit> allTopicPartitionsFromAssignment =
+                expectedAssignment.values().stream()
+                        .flatMap(Collection::stream)
+                        .collect(Collectors.toSet());
+        assertThat(actualTopicPartitions)
+                
.containsExactlyInAnyOrderElementsOf(allTopicPartitionsFromAssignment);
     }
 
     /** get all assigned partition splits of topics. */
@@ -740,12 +892,11 @@ public class KafkaEnumeratorTest {
         return allSplits;
     }
 
-    private Set<TopicPartition> asEnumState(Map<Integer, 
List<KafkaPartitionSplit>> assignments) {
-        Set<TopicPartition> enumState = new HashSet<>();
-        assignments.forEach(
-                (reader, assignment) ->
-                        assignment.forEach(split -> 
enumState.add(split.getTopicPartition())));
-        return enumState;
+    private Collection<KafkaPartitionSplit> asEnumState(
+            Map<Integer, List<KafkaPartitionSplit>> assignments) {
+        return assignments.values().stream()
+                .flatMap(Collection::stream)
+                .collect(Collectors.toList());
     }
 
     private void runOneTimePartitionDiscovery(


Reply via email to