This is an automated email from the ASF dual-hosted git repository.

arvid pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-connector-kafka.git

commit 6afe553535df960e6c84069749b69b9096ce23eb
Author: Arvid Heise <[email protected]>
AuthorDate: Tue Nov 11 22:39:00 2025 +0100

    [FLINK-38681] Revise threading model in enumerator
    
    FLINK-38453 broke the threading model of KafkaEnumerator by accessing some 
of the data structures from the worker thread without synchronization. The 
migration path also triggers NPE.
    
    This commit partially reverts some changes regarding multi-threading while 
making migration more robust.
---
 .../86dfd459-67a9-4b26-9b5c-0b0bbf22681a           |  48 ++++
 .../c0d94764-76a0-4c50-b617-70b1754c4612           |   2 +-
 .../enumerator/KafkaSourceEnumStateSerializer.java |   3 +-
 .../source/enumerator/KafkaSourceEnumerator.java   | 283 ++++++++++++---------
 .../kafka/source/split/KafkaPartitionSplit.java    |   8 +-
 .../kafka/source/KafkaSourceMigrationITCase.java   | 236 +++++++++++++++++
 .../KafkaSourceEnumStateSerializerTest.java        |   1 -
 .../enumerator/KafkaSourceEnumeratorTest.java      | 169 +++++++-----
 .../initializer/OffsetsInitializerTest.java        |   1 -
 .../savepoint-3c7c0a-07c49f841952/_metadata        | Bin 0 -> 2951 bytes
 .../savepoint-de9fd4-35f289091a1b/_metadata        | Bin 0 -> 3053 bytes
 .../savepoint-246fa1-85f387ecce0c/_metadata        | Bin 0 -> 3257 bytes
 12 files changed, 555 insertions(+), 196 deletions(-)

diff --git 
a/flink-connector-kafka/archunit-violations/86dfd459-67a9-4b26-9b5c-0b0bbf22681a
 
b/flink-connector-kafka/archunit-violations/86dfd459-67a9-4b26-9b5c-0b0bbf22681a
index 018d66cc..0e4b3bfc 100644
--- 
a/flink-connector-kafka/archunit-violations/86dfd459-67a9-4b26-9b5c-0b0bbf22681a
+++ 
b/flink-connector-kafka/archunit-violations/86dfd459-67a9-4b26-9b5c-0b0bbf22681a
@@ -4,3 +4,51 @@ org.apache.flink.connector.kafka.source.KafkaSourceITCase does 
not satisfy: only
 * reside in a package 'org.apache.flink.runtime.*' and is annotated with 
@ExtendWith with class InternalMiniClusterExtension\
 * reside outside of package 'org.apache.flink.runtime.*' and is annotated with 
@ExtendWith with class MiniClusterExtension\
  or contain any fields that are public, static, and of type 
MiniClusterWithClientResource and final and annotated with @ClassRule or 
contain any fields that is of type MiniClusterWithClientResource and public and 
final and not static and annotated with @Rule
+org.apache.flink.connector.kafka.source.KafkaSourceLegacyITCase does not 
satisfy: only one of the following predicates match:\
+* reside in a package 'org.apache.flink.runtime.*' and contain any fields that 
are static, final, and of type InternalMiniClusterExtension and annotated with 
@RegisterExtension\
+* reside outside of package 'org.apache.flink.runtime.*' and contain any 
fields that are static, final, and of type MiniClusterExtension and annotated 
with @RegisterExtension or are , and of type MiniClusterTestEnvironment and 
annotated with @TestEnv\
+* reside in a package 'org.apache.flink.runtime.*' and is annotated with 
@ExtendWith with class InternalMiniClusterExtension\
+* reside outside of package 'org.apache.flink.runtime.*' and is annotated with 
@ExtendWith with class MiniClusterExtension\
+ or contain any fields that are public, static, and of type 
MiniClusterWithClientResource and final and annotated with @ClassRule or 
contain any fields that is of type MiniClusterWithClientResource and public and 
final and not static and annotated with @Rule
+org.apache.flink.streaming.connectors.kafka.FlinkKafkaInternalProducerITCase 
does not satisfy: only one of the following predicates match:\
+* reside in a package 'org.apache.flink.runtime.*' and contain any fields that 
are static, final, and of type InternalMiniClusterExtension and annotated with 
@RegisterExtension\
+* reside outside of package 'org.apache.flink.runtime.*' and contain any 
fields that are static, final, and of type MiniClusterExtension and annotated 
with @RegisterExtension or are , and of type MiniClusterTestEnvironment and 
annotated with @TestEnv\
+* reside in a package 'org.apache.flink.runtime.*' and is annotated with 
@ExtendWith with class InternalMiniClusterExtension\
+* reside outside of package 'org.apache.flink.runtime.*' and is annotated with 
@ExtendWith with class MiniClusterExtension\
+ or contain any fields that are public, static, and of type 
MiniClusterWithClientResource and final and annotated with @ClassRule or 
contain any fields that is of type MiniClusterWithClientResource and public and 
final and not static and annotated with @Rule
+org.apache.flink.streaming.connectors.kafka.FlinkKafkaProducerITCase does not 
satisfy: only one of the following predicates match:\
+* reside in a package 'org.apache.flink.runtime.*' and contain any fields that 
are static, final, and of type InternalMiniClusterExtension and annotated with 
@RegisterExtension\
+* reside outside of package 'org.apache.flink.runtime.*' and contain any 
fields that are static, final, and of type MiniClusterExtension and annotated 
with @RegisterExtension or are , and of type MiniClusterTestEnvironment and 
annotated with @TestEnv\
+* reside in a package 'org.apache.flink.runtime.*' and is annotated with 
@ExtendWith with class InternalMiniClusterExtension\
+* reside outside of package 'org.apache.flink.runtime.*' and is annotated with 
@ExtendWith with class MiniClusterExtension\
+ or contain any fields that are public, static, and of type 
MiniClusterWithClientResource and final and annotated with @ClassRule or 
contain any fields that is of type MiniClusterWithClientResource and public and 
final and not static and annotated with @Rule
+org.apache.flink.streaming.connectors.kafka.KafkaITCase does not satisfy: only 
one of the following predicates match:\
+* reside in a package 'org.apache.flink.runtime.*' and contain any fields that 
are static, final, and of type InternalMiniClusterExtension and annotated with 
@RegisterExtension\
+* reside outside of package 'org.apache.flink.runtime.*' and contain any 
fields that are static, final, and of type MiniClusterExtension and annotated 
with @RegisterExtension or are , and of type MiniClusterTestEnvironment and 
annotated with @TestEnv\
+* reside in a package 'org.apache.flink.runtime.*' and is annotated with 
@ExtendWith with class InternalMiniClusterExtension\
+* reside outside of package 'org.apache.flink.runtime.*' and is annotated with 
@ExtendWith with class MiniClusterExtension\
+ or contain any fields that are public, static, and of type 
MiniClusterWithClientResource and final and annotated with @ClassRule or 
contain any fields that is of type MiniClusterWithClientResource and public and 
final and not static and annotated with @Rule
+org.apache.flink.streaming.connectors.kafka.KafkaProducerAtLeastOnceITCase 
does not satisfy: only one of the following predicates match:\
+* reside in a package 'org.apache.flink.runtime.*' and contain any fields that 
are static, final, and of type InternalMiniClusterExtension and annotated with 
@RegisterExtension\
+* reside outside of package 'org.apache.flink.runtime.*' and contain any 
fields that are static, final, and of type MiniClusterExtension and annotated 
with @RegisterExtension or are , and of type MiniClusterTestEnvironment and 
annotated with @TestEnv\
+* reside in a package 'org.apache.flink.runtime.*' and is annotated with 
@ExtendWith with class InternalMiniClusterExtension\
+* reside outside of package 'org.apache.flink.runtime.*' and is annotated with 
@ExtendWith with class MiniClusterExtension\
+ or contain any fields that are public, static, and of type 
MiniClusterWithClientResource and final and annotated with @ClassRule or 
contain any fields that is of type MiniClusterWithClientResource and public and 
final and not static and annotated with @Rule
+org.apache.flink.streaming.connectors.kafka.KafkaProducerExactlyOnceITCase 
does not satisfy: only one of the following predicates match:\
+* reside in a package 'org.apache.flink.runtime.*' and contain any fields that 
are static, final, and of type InternalMiniClusterExtension and annotated with 
@RegisterExtension\
+* reside outside of package 'org.apache.flink.runtime.*' and contain any 
fields that are static, final, and of type MiniClusterExtension and annotated 
with @RegisterExtension or are , and of type MiniClusterTestEnvironment and 
annotated with @TestEnv\
+* reside in a package 'org.apache.flink.runtime.*' and is annotated with 
@ExtendWith with class InternalMiniClusterExtension\
+* reside outside of package 'org.apache.flink.runtime.*' and is annotated with 
@ExtendWith with class MiniClusterExtension\
+ or contain any fields that are public, static, and of type 
MiniClusterWithClientResource and final and annotated with @ClassRule or 
contain any fields that is of type MiniClusterWithClientResource and public and 
final and not static and annotated with @Rule
+org.apache.flink.streaming.connectors.kafka.shuffle.KafkaShuffleExactlyOnceITCase
 does not satisfy: only one of the following predicates match:\
+* reside in a package 'org.apache.flink.runtime.*' and contain any fields that 
are static, final, and of type InternalMiniClusterExtension and annotated with 
@RegisterExtension\
+* reside outside of package 'org.apache.flink.runtime.*' and contain any 
fields that are static, final, and of type MiniClusterExtension and annotated 
with @RegisterExtension or are , and of type MiniClusterTestEnvironment and 
annotated with @TestEnv\
+* reside in a package 'org.apache.flink.runtime.*' and is annotated with 
@ExtendWith with class InternalMiniClusterExtension\
+* reside outside of package 'org.apache.flink.runtime.*' and is annotated with 
@ExtendWith with class MiniClusterExtension\
+ or contain any fields that are public, static, and of type 
MiniClusterWithClientResource and final and annotated with @ClassRule or 
contain any fields that is of type MiniClusterWithClientResource and public and 
final and not static and annotated with @Rule
+org.apache.flink.streaming.connectors.kafka.shuffle.KafkaShuffleITCase does 
not satisfy: only one of the following predicates match:\
+* reside in a package 'org.apache.flink.runtime.*' and contain any fields that 
are static, final, and of type InternalMiniClusterExtension and annotated with 
@RegisterExtension\
+* reside outside of package 'org.apache.flink.runtime.*' and contain any 
fields that are static, final, and of type MiniClusterExtension and annotated 
with @RegisterExtension or are , and of type MiniClusterTestEnvironment and 
annotated with @TestEnv\
+* reside in a package 'org.apache.flink.runtime.*' and is annotated with 
@ExtendWith with class InternalMiniClusterExtension\
+* reside outside of package 'org.apache.flink.runtime.*' and is annotated with 
@ExtendWith with class MiniClusterExtension\
+ or contain any fields that are public, static, and of type 
MiniClusterWithClientResource and final and annotated with @ClassRule or 
contain any fields that is of type MiniClusterWithClientResource and public and 
final and not static and annotated with @Rule
diff --git 
a/flink-connector-kafka/archunit-violations/c0d94764-76a0-4c50-b617-70b1754c4612
 
b/flink-connector-kafka/archunit-violations/c0d94764-76a0-4c50-b617-70b1754c4612
index e496d80c..afb16263 100644
--- 
a/flink-connector-kafka/archunit-violations/c0d94764-76a0-4c50-b617-70b1754c4612
+++ 
b/flink-connector-kafka/archunit-violations/c0d94764-76a0-4c50-b617-70b1754c4612
@@ -43,7 +43,7 @@ Method 
<org.apache.flink.connector.kafka.source.enumerator.KafkaSourceEnumStateS
 Method 
<org.apache.flink.connector.kafka.source.enumerator.KafkaSourceEnumStateSerializer.serializeV2(java.util.Collection,
 boolean)> is annotated with <org.apache.flink.annotation.VisibleForTesting> in 
(KafkaSourceEnumStateSerializer.java:0)
 Method 
<org.apache.flink.connector.kafka.source.enumerator.KafkaSourceEnumStateSerializer.serializeV3(org.apache.flink.connector.kafka.source.enumerator.KafkaSourceEnumState)>
 is annotated with <org.apache.flink.annotation.VisibleForTesting> in 
(KafkaSourceEnumStateSerializer.java:0)
 Method 
<org.apache.flink.connector.kafka.source.enumerator.KafkaSourceEnumerator.deepCopyProperties(java.util.Properties,
 java.util.Properties)> is annotated with 
<org.apache.flink.annotation.VisibleForTesting> in 
(KafkaSourceEnumerator.java:0)
-Method 
<org.apache.flink.connector.kafka.source.enumerator.KafkaSourceEnumerator.getPartitionChange(java.util.Set)>
 is annotated with <org.apache.flink.annotation.VisibleForTesting> in 
(KafkaSourceEnumerator.java:0)
+Method 
<org.apache.flink.connector.kafka.source.enumerator.KafkaSourceEnumerator.getPartitionChange(java.util.Set,
 boolean)> is annotated with <org.apache.flink.annotation.VisibleForTesting> in 
(KafkaSourceEnumerator.java:0)
 Method 
<org.apache.flink.connector.kafka.source.enumerator.KafkaSourceEnumerator.getPendingPartitionSplitAssignment()>
 is annotated with <org.apache.flink.annotation.VisibleForTesting> in 
(KafkaSourceEnumerator.java:0)
 Method 
<org.apache.flink.connector.kafka.source.enumerator.KafkaSourceEnumerator.getSplitOwner(org.apache.kafka.common.TopicPartition,
 int)> is annotated with <org.apache.flink.annotation.VisibleForTesting> in 
(KafkaSourceEnumerator.java:0)
 Method 
<org.apache.flink.connector.kafka.source.reader.KafkaPartitionSplitReader.consumer()>
 is annotated with <org.apache.flink.annotation.VisibleForTesting> in 
(KafkaPartitionSplitReader.java:0)
diff --git 
a/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumStateSerializer.java
 
b/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumStateSerializer.java
index 99176cfc..a79db37c 100644
--- 
a/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumStateSerializer.java
+++ 
b/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumStateSerializer.java
@@ -229,11 +229,12 @@ public class KafkaSourceEnumStateSerializer
                 final String topic = in.readUTF();
                 final int partition = in.readInt();
                 final int statusCode = in.readInt();
+                final AssignmentStatus assignStatus = 
AssignmentStatus.ofStatusCode(statusCode);
                 partitions.add(
                         new SplitAndAssignmentStatus(
                                 new KafkaPartitionSplit(
                                         new TopicPartition(topic, partition), 
MIGRATED),
-                                AssignmentStatus.ofStatusCode(statusCode)));
+                                assignStatus));
             }
             final boolean initialDiscoveryFinished = in.readBoolean();
             if (in.available() > 0) {
diff --git 
a/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumerator.java
 
b/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumerator.java
index e65e9a57..c811f176 100644
--- 
a/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumerator.java
+++ 
b/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumerator.java
@@ -43,7 +43,6 @@ import org.slf4j.LoggerFactory;
 
 import javax.annotation.Nullable;
 
-import java.time.Duration;
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
@@ -51,16 +50,46 @@ import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.Map.Entry;
 import java.util.Properties;
 import java.util.Set;
 import java.util.concurrent.ExecutionException;
 import java.util.function.Consumer;
 import java.util.stream.Collectors;
-import java.util.stream.Stream;
 
-import static org.apache.flink.util.Preconditions.checkState;
+import static 
org.apache.flink.util.CollectionUtil.newLinkedHashMapWithExpectedSize;
+import static org.apache.flink.util.Preconditions.checkNotNull;
 
-/** The enumerator class for Kafka source. */
+/**
+ * The enumerator class for Kafka source.
+ *
+ * <p>A core part of the enumerator is handling discovered splits. The 
following lifecycle applies
+ * to splits:
+ *
+ * <ol>
+ *   <li>{@link #getSubscribedTopicPartitions()} initially or periodically 
retrieves a list of topic
+ *       partitions in the worker thread.
+ *   <li>Partitions are consolidated in {@link #checkPartitionChanges(Set, 
Throwable)} in the main
+ *       thread.
+ *   <li>New partitions will result in new splits, which are initialized 
through {@link
+ *       #initializePartitionSplits(PartitionChange)} where start/offsets are 
resolved in the worker
+ *       thread. Offset resolution happens through the {@link 
PartitionOffsetsRetrieverImpl} which
+ *       communicates with the broker.
+ *   <li>The new, initialized splits are put into {@link #unassignedSplits} 
and {@link
+ *       #pendingPartitionSplitAssignment} in {@link
+ *       #handlePartitionSplitChanges(PartitionSplitChange, Throwable)} in the 
main thread.
+ *   <li>{@link #assignPendingPartitionSplits(Set)} eventually assigns the 
pending splits to readers
+ *       at which point there are removed from {@link #unassignedSplits} and 
{@link
+ *       #pendingPartitionSplitAssignment} and moved into {@link 
#assignedSplits} in the main
+ *       thread.
+ *   <li>Checkpointing is performed in the main thread on {@link 
#unassignedSplits} and {@link
+ *       #assignedSplits}. Information in {@link 
#pendingPartitionSplitAssignment} is ephemeral
+ *       because of FLINK-21817 (pretty much the actual assignment should be 
transient).
+ *   <li>In case of state migration, the start offset of {@link 
#unassignedSplits} may not be
+ *       initialized, so these partitions are reinjected into the discovery 
process during {@link
+ *       #checkPartitionChanges(Set, Throwable)}.
+ * </ol>
+ */
 @Internal
 public class KafkaSourceEnumerator
         implements SplitEnumerator<KafkaPartitionSplit, KafkaSourceEnumState> {
@@ -86,7 +115,8 @@ public class KafkaSourceEnumerator
      * The discovered and initialized partition splits that are waiting for 
owner reader to be
      * ready.
      */
-    private final Map<Integer, Set<KafkaPartitionSplit>> 
pendingPartitionSplitAssignment;
+    private final Map<Integer, Set<KafkaPartitionSplit>> 
pendingPartitionSplitAssignment =
+            new HashMap<>();
 
     /** The consumer group id used for this KafkaSource. */
     private final String consumerGroupId;
@@ -98,8 +128,7 @@ public class KafkaSourceEnumerator
     // initializing partition discovery has finished.
     private boolean noMoreNewPartitionSplits = false;
     // this flag will be marked as true if initial partitions are discovered 
after enumerator starts
-    // the flag is read and set in main thread but also read in worker thread
-    private volatile boolean initialDiscoveryFinished;
+    private boolean initialDiscoveryFinished;
 
     public KafkaSourceEnumerator(
             KafkaSubscriber subscriber,
@@ -134,11 +163,6 @@ public class KafkaSourceEnumerator
         this.context = context;
         this.boundedness = boundedness;
 
-        Map<AssignmentStatus, List<KafkaPartitionSplit>> splits =
-                initializeMigratedSplits(kafkaSourceEnumState.splits());
-        this.assignedSplits = 
indexByPartition(splits.get(AssignmentStatus.ASSIGNED));
-        this.unassignedSplits = 
indexByPartition(splits.get(AssignmentStatus.UNASSIGNED));
-        this.pendingPartitionSplitAssignment = new HashMap<>();
         this.partitionDiscoveryIntervalMs =
                 KafkaSourceOptions.getOption(
                         properties,
@@ -146,70 +170,14 @@ public class KafkaSourceEnumerator
                         Long::parseLong);
         this.consumerGroupId = 
properties.getProperty(ConsumerConfig.GROUP_ID_CONFIG);
         this.initialDiscoveryFinished = 
kafkaSourceEnumState.initialDiscoveryFinished();
+        this.assignedSplits = 
indexByPartition(kafkaSourceEnumState.assignedSplits());
+        this.unassignedSplits = 
indexByPartition(kafkaSourceEnumState.unassignedSplits());
     }
 
-    /**
-     * Initialize migrated splits to splits with concrete starting offsets. 
This method ensures that
-     * the costly offset resolution is performed only when there are splits 
that have been
-     * checkpointed with previous enumerator versions.
-     *
-     * <p>Note that this method is deliberately performed in the main thread 
to avoid a checkpoint
-     * of the splits without starting offset.
-     */
-    private Map<AssignmentStatus, List<KafkaPartitionSplit>> 
initializeMigratedSplits(
-            Set<SplitAndAssignmentStatus> splits) {
-        final Set<TopicPartition> migratedPartitions =
-                splits.stream()
-                        .filter(
-                                splitStatus ->
-                                        splitStatus.split().getStartingOffset()
-                                                == 
KafkaPartitionSplit.MIGRATED)
-                        .map(splitStatus -> 
splitStatus.split().getTopicPartition())
-                        .collect(Collectors.toSet());
-
-        if (migratedPartitions.isEmpty()) {
-            return splitByAssignmentStatus(splits.stream());
-        }
-
-        final Map<TopicPartition, Long> startOffsets =
-                startingOffsetInitializer.getPartitionOffsets(
-                        migratedPartitions, getOffsetsRetriever());
-        return splitByAssignmentStatus(
-                splits.stream()
-                        .map(splitStatus -> resolveMigratedSplit(splitStatus, 
startOffsets)));
-    }
-
-    private static Map<AssignmentStatus, List<KafkaPartitionSplit>> 
splitByAssignmentStatus(
-            Stream<SplitAndAssignmentStatus> stream) {
-        return stream.collect(
-                Collectors.groupingBy(
-                        SplitAndAssignmentStatus::assignmentStatus,
-                        Collectors.mapping(SplitAndAssignmentStatus::split, 
Collectors.toList())));
-    }
-
-    private static SplitAndAssignmentStatus resolveMigratedSplit(
-            SplitAndAssignmentStatus splitStatus, Map<TopicPartition, Long> 
startOffsets) {
-        final KafkaPartitionSplit split = splitStatus.split();
-        if (split.getStartingOffset() != KafkaPartitionSplit.MIGRATED) {
-            return splitStatus;
-        }
-        final Long startOffset = startOffsets.get(split.getTopicPartition());
-        checkState(
-                startOffset != null,
-                "Cannot find starting offset for migrated partition %s",
-                split.getTopicPartition());
-        return new SplitAndAssignmentStatus(
-                new KafkaPartitionSplit(split.getTopicPartition(), 
startOffset),
-                splitStatus.assignmentStatus());
-    }
-
-    private Map<TopicPartition, KafkaPartitionSplit> indexByPartition(
-            List<KafkaPartitionSplit> splits) {
-        if (splits == null) {
-            return new HashMap<>();
-        }
+    private static Map<TopicPartition, KafkaPartitionSplit> indexByPartition(
+            Collection<KafkaPartitionSplit> splits) {
         return splits.stream()
-                
.collect(Collectors.toMap(KafkaPartitionSplit::getTopicPartition, split -> 
split));
+                
.collect(Collectors.toMap(KafkaPartitionSplit::getTopicPartition, e -> e));
     }
 
     /**
@@ -217,17 +185,21 @@ public class KafkaSourceEnumerator
      *
      * <p>Depending on {@link #partitionDiscoveryIntervalMs}, the enumerator 
will trigger a one-time
      * partition discovery, or schedule a callable for discover partitions 
periodically.
-     *
-     * <p>The invoking chain of partition discovery would be:
-     *
-     * <ol>
-     *   <li>{@link #findNewPartitionSplits} in worker thread
-     *   <li>{@link #handlePartitionSplitChanges} in coordinator thread
-     * </ol>
      */
     @Override
     public void start() {
         adminClient = getKafkaAdminClient();
+
+        // Find splits where the start offset has been initialized but not yet 
assigned to readers.
+        // These splits must not be reinitialized to keep offsets consistent 
with first discovery.
+        final List<KafkaPartitionSplit> preinitializedSplits =
+                unassignedSplits.values().stream()
+                        .filter(split -> !split.isMigrated())
+                        .collect(Collectors.toList());
+        if (!preinitializedSplits.isEmpty()) {
+            addPartitionSplitChangeToPendingAssignments(preinitializedSplits);
+        }
+
         if (partitionDiscoveryIntervalMs > 0) {
             LOG.info(
                     "Starting the KafkaSourceEnumerator for consumer group {} "
@@ -235,8 +207,8 @@ public class KafkaSourceEnumerator
                     consumerGroupId,
                     partitionDiscoveryIntervalMs);
             context.callAsync(
-                    this::findNewPartitionSplits,
-                    this::handlePartitionSplitChanges,
+                    this::getSubscribedTopicPartitions,
+                    this::checkPartitionChanges,
                     0,
                     partitionDiscoveryIntervalMs);
         } else {
@@ -244,7 +216,7 @@ public class KafkaSourceEnumerator
                     "Starting the KafkaSourceEnumerator for consumer group {} "
                             + "without periodic partition discovery.",
                     consumerGroupId);
-            context.callAsync(this::findNewPartitionSplits, 
this::handlePartitionSplitChanges);
+            context.callAsync(this::getSubscribedTopicPartitions, 
this::checkPartitionChanges);
         }
     }
 
@@ -257,6 +229,7 @@ public class KafkaSourceEnumerator
     public void addSplitsBack(List<KafkaPartitionSplit> splits, int subtaskId) 
{
         for (KafkaPartitionSplit split : splits) {
             unassignedSplits.put(split.getTopicPartition(), split);
+            assignedSplits.remove(split.getTopicPartition());
         }
         addPartitionSplitChangeToPendingAssignments(splits);
 
@@ -298,16 +271,34 @@ public class KafkaSourceEnumerator
      *
      * @return Set of subscribed {@link TopicPartition}s
      */
-    private PartitionSplitChange findNewPartitionSplits() {
-        final Set<TopicPartition> fetchedPartitions =
-                subscriber.getSubscribedTopicPartitions(adminClient);
+    private Set<TopicPartition> getSubscribedTopicPartitions() {
+        return subscriber.getSubscribedTopicPartitions(adminClient);
+    }
 
-        final PartitionChange partitionChange = 
getPartitionChange(fetchedPartitions);
-        if (partitionChange.isEmpty()) {
-            return null;
+    /**
+     * Check if there's any partition changes within subscribed topic 
partitions fetched by worker
+     * thread, and invoke {@link 
KafkaSourceEnumerator#initializePartitionSplits(PartitionChange)}
+     * in worker thread to initialize splits for new partitions.
+     *
+     * <p>NOTE: This method should only be invoked in the coordinator executor 
thread.
+     *
+     * @param fetchedPartitions Map from topic name to its description
+     * @param t Exception in worker thread
+     */
+    private void checkPartitionChanges(Set<TopicPartition> fetchedPartitions, 
Throwable t) {
+        if (t != null) {
+            throw new FlinkRuntimeException(
+                    "Failed to list subscribed topic partitions due to ", t);
         }
 
-        return initializePartitionSplits(partitionChange);
+        final PartitionChange partitionChange =
+                getPartitionChange(fetchedPartitions, 
!initialDiscoveryFinished);
+        if (partitionChange.isEmpty()) {
+            return;
+        }
+        context.callAsync(
+                () -> initializePartitionSplits(partitionChange),
+                this::handlePartitionSplitChanges);
     }
 
     /**
@@ -331,27 +322,32 @@ public class KafkaSourceEnumerator
      *     partitions
      */
     private PartitionSplitChange initializePartitionSplits(PartitionChange 
partitionChange) {
-        Set<TopicPartition> newPartitions =
-                
Collections.unmodifiableSet(partitionChange.getNewPartitions());
+        Set<TopicPartition> newPartitions = partitionChange.getNewPartitions();
+        Set<TopicPartition> initialPartitions = 
partitionChange.getInitialPartitions();
 
-        OffsetsInitializer.PartitionOffsetsRetriever offsetsRetriever = 
getOffsetsRetriever();
         // initial partitions use OffsetsInitializer specified by the user 
while new partitions use
         // EARLIEST
-        final OffsetsInitializer initializer;
-        if (!initialDiscoveryFinished) {
-            initializer = startingOffsetInitializer;
-        } else {
-            initializer = newDiscoveryOffsetsInitializer;
-        }
         Map<TopicPartition, Long> startingOffsets =
-                initializer.getPartitionOffsets(newPartitions, 
offsetsRetriever);
-
+                newLinkedHashMapWithExpectedSize(newPartitions.size() + 
initialPartitions.size());
         Map<TopicPartition, Long> stoppingOffsets =
-                stoppingOffsetInitializer.getPartitionOffsets(newPartitions, 
offsetsRetriever);
-
-        Set<KafkaPartitionSplit> partitionSplits = new 
HashSet<>(newPartitions.size());
-        for (TopicPartition tp : newPartitions) {
-            Long startingOffset = startingOffsets.get(tp);
+                newLinkedHashMapWithExpectedSize(newPartitions.size() + 
initialPartitions.size());
+        if (!newPartitions.isEmpty()) {
+            initOffsets(
+                    newPartitions,
+                    newDiscoveryOffsetsInitializer,
+                    startingOffsets,
+                    stoppingOffsets);
+        }
+        if (!initialPartitions.isEmpty()) {
+            initOffsets(
+                    initialPartitions, startingOffsetInitializer, 
startingOffsets, stoppingOffsets);
+        }
+
+        Set<KafkaPartitionSplit> partitionSplits =
+                new HashSet<>(newPartitions.size() + initialPartitions.size());
+        for (Entry<TopicPartition, Long> tpAndStartingOffset : 
startingOffsets.entrySet()) {
+            TopicPartition tp = tpAndStartingOffset.getKey();
+            long startingOffset = tpAndStartingOffset.getValue();
             long stoppingOffset =
                     stoppingOffsets.getOrDefault(tp, 
KafkaPartitionSplit.NO_STOPPING_OFFSET);
             partitionSplits.add(new KafkaPartitionSplit(tp, startingOffset, 
stoppingOffset));
@@ -359,6 +355,18 @@ public class KafkaSourceEnumerator
         return new PartitionSplitChange(partitionSplits, 
partitionChange.getRemovedPartitions());
     }
 
+    private void initOffsets(
+            Set<TopicPartition> partitions,
+            OffsetsInitializer startOffsetInitializer,
+            Map<TopicPartition, Long> startingOffsets,
+            Map<TopicPartition, Long> stoppingOffsets) {
+        OffsetsInitializer.PartitionOffsetsRetriever offsetsRetriever = 
getOffsetsRetriever();
+        startingOffsets.putAll(
+                startOffsetInitializer.getPartitionOffsets(partitions, 
offsetsRetriever));
+        stoppingOffsets.putAll(
+                stoppingOffsetInitializer.getPartitionOffsets(partitions, 
offsetsRetriever));
+    }
+
     /**
      * Mark partition splits initialized by {@link
      * KafkaSourceEnumerator#initializePartitionSplits(PartitionChange)} as 
pending and try to
@@ -370,7 +378,7 @@ public class KafkaSourceEnumerator
      * @param t Exception in worker thread
      */
     private void handlePartitionSplitChanges(
-            @Nullable PartitionSplitChange partitionSplitChange, Throwable t) {
+            PartitionSplitChange partitionSplitChange, Throwable t) {
         if (t != null) {
             throw new FlinkRuntimeException("Failed to initialize partition 
splits due to ", t);
         }
@@ -379,12 +387,10 @@ public class KafkaSourceEnumerator
             LOG.debug("Partition discovery is disabled.");
             noMoreNewPartitionSplits = true;
         }
-        if (partitionSplitChange == null) {
-            return;
-        }
         for (KafkaPartitionSplit split : 
partitionSplitChange.newPartitionSplits) {
             unassignedSplits.put(split.getTopicPartition(), split);
         }
+        LOG.info("Partition split changes: {}", partitionSplitChange);
         // TODO: Handle removed partitions.
         
addPartitionSplitChangeToPendingAssignments(partitionSplitChange.newPartitionSplits);
         assignPendingPartitionSplits(context.registeredReaders().keySet());
@@ -460,11 +466,13 @@ public class KafkaSourceEnumerator
     }
 
     @VisibleForTesting
-    PartitionChange getPartitionChange(Set<TopicPartition> fetchedPartitions) {
+    PartitionChange getPartitionChange(
+            Set<TopicPartition> fetchedPartitions, boolean initialDiscovery) {
         final Set<TopicPartition> removedPartitions = new HashSet<>();
+        Set<TopicPartition> newPartitions = new HashSet<>(fetchedPartitions);
         Consumer<TopicPartition> dedupOrMarkAsRemoved =
                 (tp) -> {
-                    if (!fetchedPartitions.remove(tp)) {
+                    if (!newPartitions.remove(tp)) {
                         removedPartitions.add(tp);
                     }
                 };
@@ -475,14 +483,25 @@ public class KafkaSourceEnumerator
                         splits.forEach(
                                 split -> 
dedupOrMarkAsRemoved.accept(split.getTopicPartition())));
 
-        if (!fetchedPartitions.isEmpty()) {
-            LOG.info("Discovered new partitions: {}", fetchedPartitions);
+        if (!newPartitions.isEmpty()) {
+            LOG.info("Discovered new partitions: {}", newPartitions);
         }
         if (!removedPartitions.isEmpty()) {
             LOG.info("Discovered removed partitions: {}", removedPartitions);
         }
 
-        return new PartitionChange(fetchedPartitions, removedPartitions);
+        Set<TopicPartition> initialPartitions = new HashSet<>();
+        if (initialDiscovery) {
+            initialPartitions.addAll(newPartitions);
+            newPartitions.clear();
+        }
+        // migration path, ensure that partitions without offset are properly 
initialized
+        for (KafkaPartitionSplit split : unassignedSplits.values()) {
+            if (split.isMigrated()) {
+                initialPartitions.add(split.getTopicPartition());
+            }
+        }
+        return new PartitionChange(initialPartitions, newPartitions, 
removedPartitions);
     }
 
     private AdminClient getKafkaAdminClient() {
@@ -546,14 +565,23 @@ public class KafkaSourceEnumerator
     /** A container class to hold the newly added partitions and removed 
partitions. */
     @VisibleForTesting
     static class PartitionChange {
+        private final Set<TopicPartition> initialPartitions;
         private final Set<TopicPartition> newPartitions;
         private final Set<TopicPartition> removedPartitions;
 
-        PartitionChange(Set<TopicPartition> newPartitions, Set<TopicPartition> 
removedPartitions) {
+        PartitionChange(
+                Set<TopicPartition> initialPartitions,
+                Set<TopicPartition> newPartitions,
+                Set<TopicPartition> removedPartitions) {
+            this.initialPartitions = initialPartitions;
             this.newPartitions = newPartitions;
             this.removedPartitions = removedPartitions;
         }
 
+        public Set<TopicPartition> getInitialPartitions() {
+            return initialPartitions;
+        }
+
         public Set<TopicPartition> getNewPartitions() {
             return newPartitions;
         }
@@ -563,7 +591,9 @@ public class KafkaSourceEnumerator
         }
 
         public boolean isEmpty() {
-            return newPartitions.isEmpty() && removedPartitions.isEmpty();
+            return initialPartitions.isEmpty()
+                    && newPartitions.isEmpty()
+                    && removedPartitions.isEmpty();
         }
     }
 
@@ -577,17 +607,27 @@ public class KafkaSourceEnumerator
             this.newPartitionSplits = 
Collections.unmodifiableSet(newPartitionSplits);
             this.removedPartitions = 
Collections.unmodifiableSet(removedPartitions);
         }
+
+        @Override
+        public String toString() {
+            return "PartitionSplitChange{"
+                    + "newPartitionSplits="
+                    + newPartitionSplits
+                    + ", removedPartitions="
+                    + removedPartitions
+                    + '}';
+        }
     }
 
     /** The implementation for offsets retriever with a consumer and an admin 
client. */
     @VisibleForTesting
     public static class PartitionOffsetsRetrieverImpl
-            implements OffsetsInitializer.PartitionOffsetsRetriever, 
AutoCloseable {
+            implements OffsetsInitializer.PartitionOffsetsRetriever {
         private final AdminClient adminClient;
         private final String groupId;
 
         public PartitionOffsetsRetrieverImpl(AdminClient adminClient, String 
groupId) {
-            this.adminClient = adminClient;
+            this.adminClient = checkNotNull(adminClient);
             this.groupId = groupId;
         }
 
@@ -715,10 +755,5 @@ public class KafkaSourceEnumerator
                                                     
entry.getValue().timestamp(),
                                                     
entry.getValue().leaderEpoch())));
         }
-
-        @Override
-        public void close() throws Exception {
-            adminClient.close(Duration.ZERO);
-        }
     }
 }
diff --git 
a/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/split/KafkaPartitionSplit.java
 
b/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/split/KafkaPartitionSplit.java
index 52cb3b98..36acc0ef 100644
--- 
a/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/split/KafkaPartitionSplit.java
+++ 
b/flink-connector-kafka/src/main/java/org/apache/flink/connector/kafka/source/split/KafkaPartitionSplit.java
@@ -41,8 +41,8 @@ public class KafkaPartitionSplit implements SourceSplit {
     public static final long EARLIEST_OFFSET = -2;
     // Indicating the split should consume from the last committed offset.
     public static final long COMMITTED_OFFSET = -3;
-    // Used to indicate the split has been migrated from an earlier enumerator 
state; offset needs
-    // to be initialized on recovery
+    // Used to indicate the offset has not been initialized yet in the 
enumerator state; offset
+    // needs to be initialized on recovery
     public static final long MIGRATED = Long.MIN_VALUE;
 
     // Valid special starting offsets
@@ -123,6 +123,10 @@ public class KafkaPartitionSplit implements SourceSplit {
         return tp.toString();
     }
 
+    public boolean isMigrated() {
+        return startingOffset == MIGRATED;
+    }
+
     // ------------ private methods ---------------
 
     private static void verifyInitialOffset(
diff --git 
a/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/source/KafkaSourceMigrationITCase.java
 
b/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/source/KafkaSourceMigrationITCase.java
new file mode 100644
index 00000000..e368b0ac
--- /dev/null
+++ 
b/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/source/KafkaSourceMigrationITCase.java
@@ -0,0 +1,236 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.connector.kafka.source;
+
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.api.common.eventtime.WatermarkStrategy;
+import org.apache.flink.api.common.typeinfo.TypeHint;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.client.program.ClusterClient;
+import org.apache.flink.configuration.Configuration;
+import 
org.apache.flink.connector.kafka.source.reader.deserializer.KafkaRecordDeserializationSchema;
+import org.apache.flink.connector.kafka.testutils.KafkaSourceTestEnv;
+import org.apache.flink.core.execution.JobClient;
+import org.apache.flink.core.execution.SavepointFormatType;
+import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.runtime.testutils.CommonTestUtils;
+import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.graph.StreamGraph;
+import org.apache.flink.test.junit5.InjectClusterClient;
+import org.apache.flink.test.junit5.InjectMiniCluster;
+import org.apache.flink.test.junit5.MiniClusterExtension;
+import org.apache.flink.testutils.junit.SharedObjectsExtension;
+import org.apache.flink.testutils.junit.SharedReference;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.TestLogger;
+
+import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.apache.kafka.clients.producer.ProducerRecord;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Disabled;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.RegisterExtension;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
+import org.testcontainers.junit.jupiter.Testcontainers;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.concurrent.TimeUnit;
+import java.util.stream.Stream;
+
+import static 
org.apache.flink.configuration.StateRecoveryOptions.SAVEPOINT_PATH;
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** The test for creation savepoint for migration tests for the Kafka Sink. */
+@Testcontainers
+public class KafkaSourceMigrationITCase extends TestLogger {
+    public static final String KAFKA_SOURCE_UID = "kafka-source-operator-uid";
+    // Directory to store the savepoints in src/test/resources
+    private static final Path KAFKA_SOURCE_SAVEPOINT_PATH =
+            
Path.of("src/test/resources/kafka-source-savepoint").toAbsolutePath();
+
+    @RegisterExtension
+    public static final MiniClusterExtension MINI_CLUSTER_RESOURCE =
+            new MiniClusterExtension(
+                    new MiniClusterResourceConfiguration.Builder()
+                            .setNumberTaskManagers(2)
+                            .setNumberSlotsPerTaskManager(3)
+                            .build());
+
+    public static final int NUM_RECORDS =
+            KafkaSourceTestEnv.NUM_PARTITIONS * 
KafkaSourceTestEnv.NUM_RECORDS_PER_PARTITION;
+    private static final String TOPIC = "topic";
+
+    @RegisterExtension
+    private static final SharedObjectsExtension SHARED_OBJECTS = 
SharedObjectsExtension.create();
+
+    @BeforeEach
+    void setupEnv() throws Throwable {
+        // restarting Kafka with each migration test because we use the same 
topic underneath
+        KafkaSourceTestEnv.setup();
+    }
+
+    @AfterEach
+    void removeEnv() throws Exception {
+        KafkaSourceTestEnv.tearDown();
+    }
+
+    static Stream<Arguments> getKafkaSourceSavepoint() throws IOException {
+        return Files.walk(KAFKA_SOURCE_SAVEPOINT_PATH)
+                .filter(
+                        f ->
+                                Files.isDirectory(f)
+                                        && 
f.getFileName().toString().startsWith("savepoint"))
+                // allow
+                .map(KAFKA_SOURCE_SAVEPOINT_PATH::relativize)
+                .map(Arguments::arguments);
+    }
+
+    @Disabled("Enable if you want to create savepoint of KafkaSource")
+    @Test
+    void createAndStoreSavepoint(
+            @InjectClusterClient ClusterClient<?> clusterClient,
+            @InjectMiniCluster MiniCluster miniCluster)
+            throws Throwable {
+
+        // this is the part that has been read already in the savepoint
+        final List<ProducerRecord<String, Integer>> writtenRecords = 
writeInitialData();
+
+        StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+        env.setParallelism(3);
+        env.getConfig().enableObjectReuse();
+
+        final KafkaSource<ConsumerRecord<byte[], byte[]>> source = 
createSource();
+        final int enumVersion = 
source.getEnumeratorCheckpointSerializer().getVersion();
+        final int splitVersion = source.getSplitSerializer().getVersion();
+        String testCase = String.format("enum%s-split%s", enumVersion, 
splitVersion);
+
+        Path savepointPath = KAFKA_SOURCE_SAVEPOINT_PATH.resolve(testCase);
+        Files.createDirectories(savepointPath);
+
+        final SharedReference<ConcurrentLinkedQueue<ConsumerRecord<byte[], 
byte[]>>> readRecords =
+                SHARED_OBJECTS.add(new ConcurrentLinkedQueue<>());
+        env.fromSource(source, WatermarkStrategy.noWatermarks(), 
"TestDataSource")
+                .uid(KAFKA_SOURCE_UID)
+                .map(r -> readRecords.get().add(r));
+
+        final JobClient jobClient = env.executeAsync();
+        final JobID jobID = jobClient.getJobID();
+
+        CommonTestUtils.waitForAllTaskRunning(miniCluster, 
jobClient.getJobID(), false);
+        CommonTestUtils.waitUntilCondition(
+                () -> readRecords.get().size() >= writtenRecords.size(), 100L, 
100);
+        CompletableFuture<String> savepointFuture =
+                clusterClient.stopWithSavepoint(
+                        jobID, false, savepointPath.toString(), 
SavepointFormatType.NATIVE);
+        savepointFuture.get(2, TimeUnit.MINUTES);
+
+        final long maxTS = getMaxTS(writtenRecords);
+        assertThat(readRecords.get()).hasSize(NUM_RECORDS).allMatch(r -> 
r.timestamp() <= maxTS);
+    }
+
+    private static List<ProducerRecord<String, Integer>> writeInitialData() 
throws Throwable {
+        KafkaSourceTestEnv.createTestTopic(TOPIC);
+        final List<ProducerRecord<String, Integer>> writtenRecords =
+                KafkaSourceTestEnv.getRecordsForTopic(TOPIC);
+        KafkaSourceTestEnv.produceToKafka(writtenRecords);
+        return writtenRecords;
+    }
+
+    private static long getMaxTS(List<ProducerRecord<String, Integer>> 
writtenRecords) {
+        return 
writtenRecords.stream().mapToLong(ProducerRecord::timestamp).max().orElseThrow();
+    }
+
+    private static KafkaSource<ConsumerRecord<byte[], byte[]>> createSource() {
+        return KafkaSource.<ConsumerRecord<byte[], byte[]>>builder()
+                
.setBootstrapServers(KafkaSourceTestEnv.brokerConnectionStrings)
+                .setTopics(TOPIC)
+                .setDeserializer(new ForwardingDeserializer())
+                .build();
+    }
+
+    @ParameterizedTest(name = "{0}")
+    @MethodSource("getKafkaSourceSavepoint")
+    void testRestoreFromSavepointWithCurrentVersion(
+            Path savepointPath, @InjectMiniCluster MiniCluster miniCluster) 
throws Throwable {
+        // this is the part that has been read already in the savepoint
+        final List<ProducerRecord<String, Integer>> existingRecords = 
writeInitialData();
+        // the new data supposed to be read after resuming from the savepoint
+        final List<ProducerRecord<String, Integer>> writtenRecords =
+                KafkaSourceTestEnv.getRecordsForTopicWithoutTimestamp(TOPIC);
+        KafkaSourceTestEnv.produceToKafka(writtenRecords);
+
+        final Configuration configuration = new Configuration();
+        configuration.set(
+                SAVEPOINT_PATH, 
KAFKA_SOURCE_SAVEPOINT_PATH.resolve(savepointPath).toString());
+        StreamExecutionEnvironment env =
+                
StreamExecutionEnvironment.getExecutionEnvironment(configuration);
+        env.setParallelism(2);
+        env.getConfig().enableObjectReuse();
+
+        final KafkaSource<ConsumerRecord<byte[], byte[]>> source = 
createSource();
+
+        final SharedReference<ConcurrentLinkedQueue<ConsumerRecord<byte[], 
byte[]>>> readRecords =
+                SHARED_OBJECTS.add(new ConcurrentLinkedQueue<>());
+        env.fromSource(source, WatermarkStrategy.noWatermarks(), 
"TestDataSource")
+                .uid(KAFKA_SOURCE_UID)
+                .map(r -> readRecords.get().add(r));
+
+        StreamGraph streamGraph = env.getStreamGraph();
+
+        final JobClient jobClient = env.executeAsync(streamGraph);
+        CommonTestUtils.waitForAllTaskRunning(miniCluster, 
jobClient.getJobID(), false);
+        CommonTestUtils.waitUntilCondition(
+                () -> readRecords.get().size() >= writtenRecords.size(), 100L, 
100);
+
+        jobClient.cancel().get(10, TimeUnit.SECONDS);
+
+        // records of old run, all have artificial timestamp up to 
maxPreviousTS (=9000)
+        final long maxPreviousTS = getMaxTS(existingRecords);
+        // new records should all have epoch timestamp assigned by broker and 
should be much,
+        // much
+        // larger than maxPreviousTS, so we can verify exactly once by 
checking for timestamp
+        assertThat(readRecords.get())
+                .hasSize(writtenRecords.size()) // smaller size indicates 
deadline passed
+                .allMatch(r -> r.timestamp() > maxPreviousTS);
+    }
+
+    private static class ForwardingDeserializer
+            implements KafkaRecordDeserializationSchema<ConsumerRecord<byte[], 
byte[]>> {
+        @Override
+        public void deserialize(
+                ConsumerRecord<byte[], byte[]> record,
+                Collector<ConsumerRecord<byte[], byte[]>> out) {
+            out.collect(record);
+        }
+
+        @Override
+        public TypeInformation<ConsumerRecord<byte[], byte[]>> 
getProducedType() {
+            return TypeInformation.of(new TypeHint<ConsumerRecord<byte[], 
byte[]>>() {});
+        }
+    }
+}
diff --git 
a/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumStateSerializerTest.java
 
b/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumStateSerializerTest.java
index 5207687f..9e98d99f 100644
--- 
a/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumStateSerializerTest.java
+++ 
b/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumStateSerializerTest.java
@@ -42,7 +42,6 @@ public class KafkaSourceEnumStateSerializerTest {
     private static final int NUM_READERS = 10;
     private static final String TOPIC_PREFIX = "topic-";
     private static final int NUM_PARTITIONS_PER_TOPIC = 10;
-    private static final long STARTING_OFFSET = 
KafkaPartitionSplit.EARLIEST_OFFSET;
 
     @Test
     public void testEnumStateSerde() throws IOException {
diff --git 
a/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumeratorTest.java
 
b/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumeratorTest.java
index 3e64e62c..6d69541f 100644
--- 
a/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumeratorTest.java
+++ 
b/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/source/enumerator/KafkaSourceEnumeratorTest.java
@@ -29,6 +29,8 @@ import 
org.apache.flink.connector.kafka.source.enumerator.subscriber.KafkaSubscr
 import org.apache.flink.connector.kafka.source.split.KafkaPartitionSplit;
 import org.apache.flink.connector.kafka.testutils.KafkaSourceTestEnv;
 import org.apache.flink.mock.Whitebox;
+import org.apache.flink.runtime.checkpoint.CheckpointException;
+import org.apache.flink.runtime.checkpoint.CheckpointFailureReason;
 
 import com.google.common.collect.Iterables;
 import org.apache.kafka.clients.admin.AdminClient;
@@ -41,7 +43,6 @@ import org.assertj.core.api.SoftAssertions;
 import org.junit.jupiter.api.AfterAll;
 import org.junit.jupiter.api.BeforeAll;
 import org.junit.jupiter.api.Test;
-import org.junit.jupiter.api.Timeout;
 import org.junit.jupiter.params.ParameterizedTest;
 import org.junit.jupiter.params.provider.EnumSource;
 
@@ -56,7 +57,6 @@ import java.util.Map;
 import java.util.Properties;
 import java.util.Set;
 import java.util.StringJoiner;
-import java.util.concurrent.TimeUnit;
 import java.util.regex.Pattern;
 import java.util.stream.Collectors;
 
@@ -252,8 +252,8 @@ public class KafkaSourceEnumeratorTest {
     public void 
testRunWithDiscoverPartitionsOnceWithZeroMsToCheckNoMoreSplit() throws 
Throwable {
         try (MockSplitEnumeratorContext<KafkaPartitionSplit> context =
                         new MockSplitEnumeratorContext<>(NUM_SUBTASKS);
-                // Disable periodic partition discovery
-                KafkaSourceEnumerator enumerator = createEnumerator(context, 
false)) {
+                // set partitionDiscoveryIntervalMs = 0
+                KafkaSourceEnumerator enumerator = createEnumerator(context, 
0L)) {
 
             // Start the enumerator, and it should schedule a one time task to 
discover and assign
             // partitions.
@@ -273,7 +273,6 @@ public class KafkaSourceEnumeratorTest {
     }
 
     @Test
-    @Timeout(value = 30, unit = TimeUnit.SECONDS)
     public void testDiscoverPartitionsPeriodically() throws Throwable {
         try (MockSplitEnumeratorContext<KafkaPartitionSplit> context =
                         new MockSplitEnumeratorContext<>(NUM_SUBTASKS);
@@ -342,14 +341,8 @@ public class KafkaSourceEnumeratorTest {
         }
     }
 
-    /**
-     * Ensures that migrated splits are immediately initialized with {@link 
OffsetsInitializer},
-     * such that an early {@link
-     * 
org.apache.flink.api.connector.source.SplitEnumerator#snapshotState(long)} 
doesn't see the
-     * special value.
-     */
     @Test
-    public void shouldEagerlyInitializeSplitOffsetsOnMigration() throws 
Throwable {
+    public void shouldLazilyInitializeSplitOffsetsOnMigration() throws 
Throwable {
         final TopicPartition assigned1 = new TopicPartition(TOPIC1, 0);
         final TopicPartition assigned2 = new TopicPartition(TOPIC1, 1);
         final TopicPartition unassigned1 = new TopicPartition(TOPIC2, 0);
@@ -363,7 +356,7 @@ public class KafkaSourceEnumeratorTest {
                     public Map<TopicPartition, Long> getPartitionOffsets(
                             Collection<TopicPartition> partitions,
                             PartitionOffsetsRetriever 
partitionOffsetsRetriever) {
-                        return Map.of(assigned1, migratedOffset1, unassigned2, 
migratedOffset2);
+                        return Map.of(unassigned2, migratedOffset2);
                     }
 
                     @Override
@@ -376,6 +369,7 @@ public class KafkaSourceEnumeratorTest {
                 KafkaSourceEnumerator enumerator =
                         createEnumerator(
                                 context,
+                                0,
                                 offsetsInitializer,
                                 PRE_EXISTING_TOPICS,
                                 List.of(
@@ -386,18 +380,47 @@ public class KafkaSourceEnumeratorTest {
                                         new KafkaPartitionSplit(unassigned2, 
MIGRATED)),
                                 false,
                                 new Properties())) {
-            final KafkaSourceEnumState state = enumerator.snapshotState(1L);
+            KafkaSourceEnumState state = snapshotWhenReady(enumerator);
             assertThat(state.assignedSplits())
                     .containsExactlyInAnyOrder(
-                            new KafkaPartitionSplit(assigned1, 
migratedOffset1),
+                            new KafkaPartitionSplit(assigned1, MIGRATED),
                             new KafkaPartitionSplit(assigned2, 2));
             assertThat(state.unassignedSplits())
+                    .containsExactlyInAnyOrder(
+                            new KafkaPartitionSplit(unassigned1, 1),
+                            new KafkaPartitionSplit(unassigned2, MIGRATED));
+
+            enumerator.start();
+
+            runOneTimePartitionDiscovery(context);
+            KafkaSourceEnumState state2 = snapshotWhenReady(enumerator);
+            // verify that only unassigned splits are migrated; assigned 
splits are tracked by
+            // the reader, so any initialization on enumerator would be 
discarded
+            assertThat(state2.assignedSplits())
+                    .containsExactlyInAnyOrder(
+                            new KafkaPartitionSplit(assigned1, MIGRATED),
+                            new KafkaPartitionSplit(assigned2, 2));
+            assertThat(state2.unassignedSplits())
                     .containsExactlyInAnyOrder(
                             new KafkaPartitionSplit(unassigned1, 1),
                             new KafkaPartitionSplit(unassigned2, 
migratedOffset2));
         }
     }
 
+    private static KafkaSourceEnumState 
snapshotWhenReady(KafkaSourceEnumerator enumerator)
+            throws Exception {
+        while (true) {
+            try {
+                return enumerator.snapshotState(1L);
+            } catch (CheckpointException e) {
+                if (e.getCheckpointFailureReason()
+                        != 
CheckpointFailureReason.CHECKPOINT_DECLINED_TASK_NOT_READY) {
+                    throw e;
+                }
+            }
+        }
+    }
+
     @ParameterizedTest
     @EnumSource(StandardOffsetsInitializer.class)
     public void testAddSplitsBack(StandardOffsetsInitializer 
offsetsInitializer) throws Throwable {
@@ -436,10 +459,12 @@ public class KafkaSourceEnumeratorTest {
                     .as("The added back splits should have not been assigned")
                     .hasSize(2);
 
-            assertThat(enumerator.snapshotState(2L).unassignedSplits())
+            final KafkaSourceEnumState state = enumerator.snapshotState(2L);
+            assertThat(state.unassignedSplits())
                     .containsExactlyInAnyOrderElementsOf(
                             Iterables.concat(
                                     advancedSplits, unassignedSplits)); // 
READER0 + READER2
+            
assertThat(state.assignedSplits()).doesNotContainAnyElementsOf(advancedSplits);
 
             // Simulate a reader recovery.
             registerReader(context, enumerator, READER0);
@@ -468,6 +493,7 @@ public class KafkaSourceEnumeratorTest {
                 KafkaSourceEnumerator enumerator =
                         createEnumerator(
                                 context2,
+                                ENABLE_PERIODIC_PARTITION_DISCOVERY ? 1 : -1,
                                 OffsetsInitializer.earliest(),
                                 PRE_EXISTING_TOPICS,
                                 preexistingAssignments,
@@ -499,6 +525,7 @@ public class KafkaSourceEnumeratorTest {
                 KafkaSourceEnumerator enumerator =
                         createEnumerator(
                                 context,
+                                ENABLE_PERIODIC_PARTITION_DISCOVERY ? 1 : -1,
                                 OffsetsInitializer.earliest(),
                                 PRE_EXISTING_TOPICS,
                                 Collections.emptySet(),
@@ -538,31 +565,38 @@ public class KafkaSourceEnumeratorTest {
             registerReader(context, enumerator, READER0);
             registerReader(context, enumerator, READER1);
 
-            // Step2: Assign partials partitions to reader0 and reader1
+            // Step2: First partition discovery after start, but no 
assignments to readers
+            context.runNextOneTimeCallable();
+            final KafkaSourceEnumState state2 = enumerator.snapshotState(2L);
+            assertThat(state2.assignedSplits()).isEmpty();
+            assertThat(state2.unassignedSplits()).isEmpty();
+            assertThat(state2.initialDiscoveryFinished()).isFalse();
+
+            // Step3: Assign partials partitions to reader0 and reader1
             context.runNextOneTimeCallable();
 
             // The state should contain splits assigned to READER0 and 
READER1, but no READER2
             // register.
             // Thus, both assignedPartitions and unassignedInitialPartitions 
are not empty.
-            final KafkaSourceEnumState state2 = enumerator.snapshotState(2L);
+            final KafkaSourceEnumState state3 = enumerator.snapshotState(2L);
             verifySplitAssignmentWithPartitions(
                     getExpectedAssignments(
                             new HashSet<>(Arrays.asList(READER0, READER1)),
                             PRE_EXISTING_TOPICS,
                             offsetsInitializer.getOffsetsInitializer()),
-                    state2.assignedSplits());
-            assertThat(state2.assignedSplits()).isNotEmpty();
-            assertThat(state2.unassignedSplits()).isNotEmpty();
-            assertThat(state2.initialDiscoveryFinished()).isTrue();
+                    state3.assignedSplits());
+            assertThat(state3.assignedSplits()).isNotEmpty();
+            assertThat(state3.unassignedSplits()).isNotEmpty();
+            assertThat(state3.initialDiscoveryFinished()).isTrue();
 
             // Step3: register READER2, then all partitions are assigned
             registerReader(context, enumerator, READER2);
-            final KafkaSourceEnumState state3 = enumerator.snapshotState(3L);
-            assertThat(state3.assignedSplits())
+            final KafkaSourceEnumState state4 = enumerator.snapshotState(3L);
+            assertThat(state4.assignedSplits())
                     .containsExactlyInAnyOrderElementsOf(
-                            Iterables.concat(state2.assignedSplits(), 
state2.unassignedSplits()));
-            assertThat(state3.unassignedSplits()).isEmpty();
-            assertThat(state3.initialDiscoveryFinished()).isTrue();
+                            Iterables.concat(state3.assignedSplits(), 
state3.unassignedSplits()));
+            assertThat(state4.unassignedSplits()).isEmpty();
+            assertThat(state4.initialDiscoveryFinished()).isTrue();
         }
     }
 
@@ -585,7 +619,7 @@ public class KafkaSourceEnumeratorTest {
             Set<TopicPartition> fetchedPartitions = new HashSet<>();
             fetchedPartitions.add(newPartition);
             final KafkaSourceEnumerator.PartitionChange partitionChange =
-                    enumerator.getPartitionChange(fetchedPartitions);
+                    enumerator.getPartitionChange(fetchedPartitions, false);
 
             // Since enumerator never met DYNAMIC_TOPIC_NAME-0, it should be 
mark as a new partition
             Set<TopicPartition> expectedNewPartitions = 
Collections.singleton(newPartition);
@@ -598,38 +632,14 @@ public class KafkaSourceEnumeratorTest {
                 expectedRemovedPartitions.add(new TopicPartition(TOPIC2, i));
             }
 
+            // Since enumerator never met DYNAMIC_TOPIC_NAME-1, it should be 
marked as a new
+            // partition
             
assertThat(partitionChange.getNewPartitions()).isEqualTo(expectedNewPartitions);
+            
assertThat(partitionChange.getInitialPartitions()).isEqualTo(Set.of());
             
assertThat(partitionChange.getRemovedPartitions()).isEqualTo(expectedRemovedPartitions);
         }
     }
 
-    @Test
-    public void testEnablePartitionDiscoveryByDefault() throws Throwable {
-        try (MockSplitEnumeratorContext<KafkaPartitionSplit> context =
-                        new MockSplitEnumeratorContext<>(NUM_SUBTASKS);
-                KafkaSourceEnumerator enumerator = createEnumerator(context, 
new Properties())) {
-            enumerator.start();
-            long partitionDiscoveryIntervalMs =
-                    (long) Whitebox.getInternalState(enumerator, 
"partitionDiscoveryIntervalMs");
-            assertThat(partitionDiscoveryIntervalMs)
-                    
.isEqualTo(KafkaSourceOptions.PARTITION_DISCOVERY_INTERVAL_MS.defaultValue());
-            assertThat(context.getPeriodicCallables()).isNotEmpty();
-        }
-    }
-
-    @Test
-    public void testDisablePartitionDiscovery() throws Throwable {
-        Properties props = new Properties();
-        props.setProperty(
-                KafkaSourceOptions.PARTITION_DISCOVERY_INTERVAL_MS.key(), 
String.valueOf(0));
-        try (MockSplitEnumeratorContext<KafkaPartitionSplit> context =
-                        new MockSplitEnumeratorContext<>(NUM_SUBTASKS);
-                KafkaSourceEnumerator enumerator = createEnumerator(context, 
props)) {
-            enumerator.start();
-            assertThat(context.getPeriodicCallables()).isEmpty();
-        }
-    }
-
     // -------------- some common startup sequence ---------------
 
     private void startEnumeratorAndRegisterReaders(
@@ -677,9 +687,13 @@ public class KafkaSourceEnumeratorTest {
     }
 
     private KafkaSourceEnumerator createEnumerator(
-            MockSplitEnumeratorContext<KafkaPartitionSplit> enumContext, 
Properties properties) {
+            MockSplitEnumeratorContext<KafkaPartitionSplit> enumContext,
+            long partitionDiscoveryIntervalMs) {
         return createEnumerator(
-                enumContext, properties, EXCLUDE_DYNAMIC_TOPIC, 
OffsetsInitializer.earliest());
+                enumContext,
+                partitionDiscoveryIntervalMs,
+                EXCLUDE_DYNAMIC_TOPIC,
+                OffsetsInitializer.earliest());
     }
 
     private KafkaSourceEnumerator createEnumerator(
@@ -691,23 +705,20 @@ public class KafkaSourceEnumeratorTest {
         if (includeDynamicTopic) {
             topics.add(DYNAMIC_TOPIC_NAME);
         }
-        Properties props = new Properties();
-        props.setProperty(
-                KafkaSourceOptions.PARTITION_DISCOVERY_INTERVAL_MS.key(),
-                enablePeriodicPartitionDiscovery ? "1" : "-1");
         return createEnumerator(
                 enumContext,
+                enablePeriodicPartitionDiscovery ? 1 : -1,
                 startingOffsetsInitializer,
                 topics,
                 Collections.emptySet(),
                 Collections.emptySet(),
                 false,
-                props);
+                new Properties());
     }
 
     private KafkaSourceEnumerator createEnumerator(
             MockSplitEnumeratorContext<KafkaPartitionSplit> enumContext,
-            Properties props,
+            long partitionDiscoveryIntervalMs,
             boolean includeDynamicTopic,
             OffsetsInitializer startingOffsetsInitializer) {
         List<String> topics = new ArrayList<>(PRE_EXISTING_TOPICS);
@@ -716,12 +727,13 @@ public class KafkaSourceEnumeratorTest {
         }
         return createEnumerator(
                 enumContext,
+                partitionDiscoveryIntervalMs,
                 startingOffsetsInitializer,
                 topics,
                 Collections.emptySet(),
                 Collections.emptySet(),
                 false,
-                props);
+                new Properties());
     }
 
     /**
@@ -730,12 +742,35 @@ public class KafkaSourceEnumeratorTest {
      */
     private KafkaSourceEnumerator createEnumerator(
             MockSplitEnumeratorContext<KafkaPartitionSplit> enumContext,
+            long partitionDiscoveryIntervalMs,
             OffsetsInitializer startingOffsetsInitializer,
             Collection<String> topicsToSubscribe,
             Collection<KafkaPartitionSplit> assignedSplits,
             Collection<KafkaPartitionSplit> unassignedInitialSplits,
             boolean initialDiscoveryFinished,
             Properties overrideProperties) {
+        return createEnumerator(
+                enumContext,
+                partitionDiscoveryIntervalMs,
+                topicsToSubscribe,
+                assignedSplits,
+                unassignedInitialSplits,
+                overrideProperties,
+                startingOffsetsInitializer);
+    }
+
+    /**
+     * Create the enumerator. For the purpose of the tests in this class we 
don't care about the
+     * subscriber and stopping initializer, so just use arbitrary settings.
+     */
+    private KafkaSourceEnumerator createEnumerator(
+            MockSplitEnumeratorContext<KafkaPartitionSplit> enumContext,
+            long partitionDiscoveryIntervalMs,
+            Collection<String> topicsToSubscribe,
+            Collection<KafkaPartitionSplit> assignedSplits,
+            Collection<KafkaPartitionSplit> unassignedInitialSplits,
+            Properties overrideProperties,
+            OffsetsInitializer startingOffsetsInitializer) {
         // Use a TopicPatternSubscriber so that no exception if a subscribed 
topic hasn't been
         // created yet.
         StringJoiner topicNameJoiner = new StringJoiner("|");
@@ -748,6 +783,9 @@ public class KafkaSourceEnumeratorTest {
         Properties props =
                 new 
Properties(KafkaSourceTestEnv.getConsumerProperties(StringDeserializer.class));
         KafkaSourceEnumerator.deepCopyProperties(overrideProperties, props);
+        props.setProperty(
+                KafkaSourceOptions.PARTITION_DISCOVERY_INTERVAL_MS.key(),
+                String.valueOf(partitionDiscoveryIntervalMs));
 
         return new KafkaSourceEnumerator(
                 subscriber,
@@ -756,8 +794,7 @@ public class KafkaSourceEnumeratorTest {
                 props,
                 enumContext,
                 Boundedness.CONTINUOUS_UNBOUNDED,
-                new KafkaSourceEnumState(
-                        assignedSplits, unassignedInitialSplits, 
initialDiscoveryFinished));
+                new KafkaSourceEnumState(assignedSplits, 
unassignedInitialSplits, false));
     }
 
     // ---------------------
diff --git 
a/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/source/enumerator/initializer/OffsetsInitializerTest.java
 
b/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/source/enumerator/initializer/OffsetsInitializerTest.java
index 46dd61a6..19f8c47a 100644
--- 
a/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/source/enumerator/initializer/OffsetsInitializerTest.java
+++ 
b/flink-connector-kafka/src/test/java/org/apache/flink/connector/kafka/source/enumerator/initializer/OffsetsInitializerTest.java
@@ -58,7 +58,6 @@ public class OffsetsInitializerTest {
 
     @AfterClass
     public static void tearDown() throws Exception {
-        retriever.close();
         KafkaSourceTestEnv.tearDown();
     }
 
diff --git 
a/flink-connector-kafka/src/test/resources/kafka-source-savepoint/enum1-split0/savepoint-3c7c0a-07c49f841952/_metadata
 
b/flink-connector-kafka/src/test/resources/kafka-source-savepoint/enum1-split0/savepoint-3c7c0a-07c49f841952/_metadata
new file mode 100644
index 00000000..21276fd2
Binary files /dev/null and 
b/flink-connector-kafka/src/test/resources/kafka-source-savepoint/enum1-split0/savepoint-3c7c0a-07c49f841952/_metadata
 differ
diff --git 
a/flink-connector-kafka/src/test/resources/kafka-source-savepoint/enum2-split0/savepoint-de9fd4-35f289091a1b/_metadata
 
b/flink-connector-kafka/src/test/resources/kafka-source-savepoint/enum2-split0/savepoint-de9fd4-35f289091a1b/_metadata
new file mode 100644
index 00000000..368d3dc5
Binary files /dev/null and 
b/flink-connector-kafka/src/test/resources/kafka-source-savepoint/enum2-split0/savepoint-de9fd4-35f289091a1b/_metadata
 differ
diff --git 
a/flink-connector-kafka/src/test/resources/kafka-source-savepoint/enum3-split0/savepoint-246fa1-85f387ecce0c/_metadata
 
b/flink-connector-kafka/src/test/resources/kafka-source-savepoint/enum3-split0/savepoint-246fa1-85f387ecce0c/_metadata
new file mode 100644
index 00000000..61325695
Binary files /dev/null and 
b/flink-connector-kafka/src/test/resources/kafka-source-savepoint/enum3-split0/savepoint-246fa1-85f387ecce0c/_metadata
 differ

Reply via email to