This is an automated email from the ASF dual-hosted git repository.

zhuzh pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 558ca75da2fcec875d1e04a8d75a24fd0ad42ccc
Author: JunRuiLee <[email protected]>
AuthorDate: Fri Mar 1 11:04:04 2024 +0800

    [FLINK-33984][runtime] Support batch snapshot for OperatorCoordinator.
    
    This closes #24415.
---
 .../source/hybrid/HybridSourceSplitEnumerator.java |  4 +-
 .../src/impl/ContinuousFileSplitEnumerator.java    |  4 +-
 .../file/src/impl/StaticFileSplitEnumerator.java   |  4 +-
 .../hive/ContinuousHiveSplitEnumerator.java        |  4 +-
 .../connector/source/SupportsBatchSnapshot.java    | 29 ++++++++++
 .../source/lib/util/IteratorSourceEnumerator.java  |  3 +-
 .../coordination/OperatorCoordinator.java          | 14 +++++
 .../coordination/OperatorCoordinatorHolder.java    |  3 +-
 .../RecreateOnResetOperatorCoordinator.java        | 11 ++++
 .../operators/coordination/SubtaskGatewayImpl.java |  4 +-
 .../source/coordinator/SourceCoordinator.java      | 67 ++++++++++++++++++++--
 .../coordinator/SourceCoordinatorContext.java      |  4 ++
 .../coordinator/SourceCoordinatorSerdeUtils.java   | 64 +++++++++++++++++++++
 .../source/coordinator/SplitAssignmentTracker.java | 22 +++++++
 .../source/coordinator/SourceCoordinatorTest.java  | 36 ++++++++++++
 .../coordinator/SplitAssignmentTrackerTest.java    | 22 +++++++
 .../source/coordinator/TestingSplitEnumerator.java |  3 +-
 17 files changed, 284 insertions(+), 14 deletions(-)

diff --git 
a/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/source/hybrid/HybridSourceSplitEnumerator.java
 
b/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/source/hybrid/HybridSourceSplitEnumerator.java
index 3cba9635d83..48bec9753c5 100644
--- 
a/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/source/hybrid/HybridSourceSplitEnumerator.java
+++ 
b/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/source/hybrid/HybridSourceSplitEnumerator.java
@@ -25,6 +25,7 @@ import org.apache.flink.api.connector.source.SourceSplit;
 import org.apache.flink.api.connector.source.SplitEnumerator;
 import org.apache.flink.api.connector.source.SplitEnumeratorContext;
 import org.apache.flink.api.connector.source.SplitsAssignment;
+import org.apache.flink.api.connector.source.SupportsBatchSnapshot;
 import org.apache.flink.api.connector.source.SupportsIntermediateNoMoreSplits;
 import org.apache.flink.core.io.SimpleVersionedSerializer;
 import org.apache.flink.metrics.groups.SplitEnumeratorMetricGroup;
@@ -66,7 +67,8 @@ import java.util.function.BiConsumer;
  * returned splits were processed, delegation to the current underlying 
enumerator resumes.
  */
 public class HybridSourceSplitEnumerator
-        implements SplitEnumerator<HybridSourceSplit, 
HybridSourceEnumeratorState> {
+        implements SplitEnumerator<HybridSourceSplit, 
HybridSourceEnumeratorState>,
+                SupportsBatchSnapshot {
     private static final Logger LOG = 
LoggerFactory.getLogger(HybridSourceSplitEnumerator.class);
 
     private final SplitEnumeratorContext<HybridSourceSplit> context;
diff --git 
a/flink-connectors/flink-connector-files/src/main/java/org/apache/flink/connector/file/src/impl/ContinuousFileSplitEnumerator.java
 
b/flink-connectors/flink-connector-files/src/main/java/org/apache/flink/connector/file/src/impl/ContinuousFileSplitEnumerator.java
index 8102e0decba..33ca4e45131 100644
--- 
a/flink-connectors/flink-connector-files/src/main/java/org/apache/flink/connector/file/src/impl/ContinuousFileSplitEnumerator.java
+++ 
b/flink-connectors/flink-connector-files/src/main/java/org/apache/flink/connector/file/src/impl/ContinuousFileSplitEnumerator.java
@@ -22,6 +22,7 @@ import org.apache.flink.annotation.Internal;
 import org.apache.flink.api.connector.source.SourceEvent;
 import org.apache.flink.api.connector.source.SplitEnumerator;
 import org.apache.flink.api.connector.source.SplitEnumeratorContext;
+import org.apache.flink.api.connector.source.SupportsBatchSnapshot;
 import org.apache.flink.connector.file.src.FileSourceSplit;
 import org.apache.flink.connector.file.src.PendingSplitsCheckpoint;
 import org.apache.flink.connector.file.src.assigners.FileSplitAssigner;
@@ -49,7 +50,8 @@ import static 
org.apache.flink.util.Preconditions.checkNotNull;
 /** A continuously monitoring enumerator. */
 @Internal
 public class ContinuousFileSplitEnumerator
-        implements SplitEnumerator<FileSourceSplit, 
PendingSplitsCheckpoint<FileSourceSplit>> {
+        implements SplitEnumerator<FileSourceSplit, 
PendingSplitsCheckpoint<FileSourceSplit>>,
+                SupportsBatchSnapshot {
 
     private static final Logger LOG = 
LoggerFactory.getLogger(ContinuousFileSplitEnumerator.class);
 
diff --git 
a/flink-connectors/flink-connector-files/src/main/java/org/apache/flink/connector/file/src/impl/StaticFileSplitEnumerator.java
 
b/flink-connectors/flink-connector-files/src/main/java/org/apache/flink/connector/file/src/impl/StaticFileSplitEnumerator.java
index 140f52053b1..fb56767240b 100644
--- 
a/flink-connectors/flink-connector-files/src/main/java/org/apache/flink/connector/file/src/impl/StaticFileSplitEnumerator.java
+++ 
b/flink-connectors/flink-connector-files/src/main/java/org/apache/flink/connector/file/src/impl/StaticFileSplitEnumerator.java
@@ -22,6 +22,7 @@ import org.apache.flink.annotation.Internal;
 import org.apache.flink.api.connector.source.SourceEvent;
 import org.apache.flink.api.connector.source.SplitEnumerator;
 import org.apache.flink.api.connector.source.SplitEnumeratorContext;
+import org.apache.flink.api.connector.source.SupportsBatchSnapshot;
 import org.apache.flink.connector.file.src.FileSource;
 import org.apache.flink.connector.file.src.FileSourceSplit;
 import org.apache.flink.connector.file.src.PendingSplitsCheckpoint;
@@ -51,7 +52,8 @@ import static 
org.apache.flink.util.Preconditions.checkNotNull;
  */
 @Internal
 public class StaticFileSplitEnumerator
-        implements SplitEnumerator<FileSourceSplit, 
PendingSplitsCheckpoint<FileSourceSplit>> {
+        implements SplitEnumerator<FileSourceSplit, 
PendingSplitsCheckpoint<FileSourceSplit>>,
+                SupportsBatchSnapshot {
 
     private static final Logger LOG = 
LoggerFactory.getLogger(StaticFileSplitEnumerator.class);
 
diff --git 
a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/connectors/hive/ContinuousHiveSplitEnumerator.java
 
b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/connectors/hive/ContinuousHiveSplitEnumerator.java
index 526158fd57e..7b115a0f719 100644
--- 
a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/connectors/hive/ContinuousHiveSplitEnumerator.java
+++ 
b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/connectors/hive/ContinuousHiveSplitEnumerator.java
@@ -22,6 +22,7 @@ import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.connector.source.SourceEvent;
 import org.apache.flink.api.connector.source.SplitEnumerator;
 import org.apache.flink.api.connector.source.SplitEnumeratorContext;
+import org.apache.flink.api.connector.source.SupportsBatchSnapshot;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.connector.file.src.FileSourceSplit;
 import org.apache.flink.connector.file.src.PendingSplitsCheckpoint;
@@ -54,7 +55,8 @@ import java.util.concurrent.Callable;
 
 /** A continuously monitoring {@link SplitEnumerator} for hive source. */
 public class ContinuousHiveSplitEnumerator<T extends Comparable<T>>
-        implements SplitEnumerator<HiveSourceSplit, 
PendingSplitsCheckpoint<HiveSourceSplit>> {
+        implements SplitEnumerator<HiveSourceSplit, 
PendingSplitsCheckpoint<HiveSourceSplit>>,
+                SupportsBatchSnapshot {
 
     private static final Logger LOG = 
LoggerFactory.getLogger(ContinuousHiveSplitEnumerator.class);
 
diff --git 
a/flink-core/src/main/java/org/apache/flink/api/connector/source/SupportsBatchSnapshot.java
 
b/flink-core/src/main/java/org/apache/flink/api/connector/source/SupportsBatchSnapshot.java
new file mode 100644
index 00000000000..d690f0efa37
--- /dev/null
+++ 
b/flink-core/src/main/java/org/apache/flink/api/connector/source/SupportsBatchSnapshot.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.connector.source;
+
+import org.apache.flink.annotation.PublicEvolving;
+
+/**
+ * A decorative interface {@link SplitEnumerator}. Implementing it means the 
split enumerator {@link
+ * SplitEnumerator#snapshotState} method supports taking snapshot in batch 
processing scenarios. In
+ * such scenarios, the checkpointId will always be -1.
+ */
+@PublicEvolving
+public interface SupportsBatchSnapshot {}
diff --git 
a/flink-core/src/main/java/org/apache/flink/api/connector/source/lib/util/IteratorSourceEnumerator.java
 
b/flink-core/src/main/java/org/apache/flink/api/connector/source/lib/util/IteratorSourceEnumerator.java
index 1d7028b6aa7..f7963134af3 100644
--- 
a/flink-core/src/main/java/org/apache/flink/api/connector/source/lib/util/IteratorSourceEnumerator.java
+++ 
b/flink-core/src/main/java/org/apache/flink/api/connector/source/lib/util/IteratorSourceEnumerator.java
@@ -21,6 +21,7 @@ package org.apache.flink.api.connector.source.lib.util;
 import org.apache.flink.annotation.Public;
 import org.apache.flink.api.connector.source.SplitEnumerator;
 import org.apache.flink.api.connector.source.SplitEnumeratorContext;
+import org.apache.flink.api.connector.source.SupportsBatchSnapshot;
 
 import javax.annotation.Nullable;
 
@@ -39,7 +40,7 @@ import static 
org.apache.flink.util.Preconditions.checkNotNull;
  */
 @Public
 public class IteratorSourceEnumerator<SplitT extends IteratorSourceSplit<?, ?>>
-        implements SplitEnumerator<SplitT, Collection<SplitT>> {
+        implements SplitEnumerator<SplitT, Collection<SplitT>>, 
SupportsBatchSnapshot {
 
     private final SplitEnumeratorContext<SplitT> context;
     private final Queue<SplitT> remainingSplits;
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/coordination/OperatorCoordinator.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/coordination/OperatorCoordinator.java
index eac555b8fbf..8b8b77b06e6 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/coordination/OperatorCoordinator.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/coordination/OperatorCoordinator.java
@@ -86,6 +86,9 @@ public interface OperatorCoordinator extends 
CheckpointListener, AutoCloseable {
      */
     long NO_CHECKPOINT = -1L;
 
+    /** The checkpoint ID passed to the restore methods when batch scenarios. 
*/
+    long BATCH_CHECKPOINT_ID = -1L;
+
     // ------------------------------------------------------------------------
 
     /**
@@ -234,6 +237,17 @@ public interface OperatorCoordinator extends 
CheckpointListener, AutoCloseable {
      */
     void executionAttemptReady(int subtask, int attemptNumber, SubtaskGateway 
gateway);
 
+    /**
+     * Whether the operator coordinator supports taking snapshot in 
no-checkpoint/batch scenarios.
+     * If it returns true, the {@link 
OperatorCoordinator#checkpointCoordinator} and {@link
+     * OperatorCoordinator#resetToCheckpoint} methods supports taking snapshot 
and restoring from a
+     * snapshot in batch processing scenarios. In such scenarios, the 
checkpointId will always be
+     * -1.
+     */
+    default boolean supportsBatchSnapshot() {
+        return false;
+    }
+
     // ------------------------------------------------------------------------
     // ------------------------------------------------------------------------
 
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/coordination/OperatorCoordinatorHolder.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/coordination/OperatorCoordinatorHolder.java
index 8aff60fb766..ced81933406 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/coordination/OperatorCoordinatorHolder.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/coordination/OperatorCoordinatorHolder.java
@@ -314,7 +314,8 @@ public class OperatorCoordinatorHolder
                         (success, failure) -> {
                             if (failure != null) {
                                 result.completeExceptionally(failure);
-                            } else if (closeGateways(checkpointId)) {
+                            } else if (checkpointId == 
OperatorCoordinator.BATCH_CHECKPOINT_ID
+                                    || closeGateways(checkpointId)) {
                                 
completeCheckpointOnceEventsAreDone(checkpointId, result, success);
                             } else {
                                 // if we cannot close the gateway, this means 
the checkpoint
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/coordination/RecreateOnResetOperatorCoordinator.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/coordination/RecreateOnResetOperatorCoordinator.java
index d2d6d5681e8..2209e911b47 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/coordination/RecreateOnResetOperatorCoordinator.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/coordination/RecreateOnResetOperatorCoordinator.java
@@ -160,6 +160,17 @@ public class RecreateOnResetOperatorCoordinator implements 
OperatorCoordinator {
                 });
     }
 
+    @Override
+    public boolean supportsBatchSnapshot() {
+        try {
+            return getInternalCoordinator().supportsBatchSnapshot();
+        } catch (Exception e) {
+            String msg = "Could not get internal coordinator";
+            LOG.error(msg, e);
+            throw new RuntimeException(msg, e);
+        }
+    }
+
     // ---------------------
 
     public OperatorCoordinator getInternalCoordinator() throws Exception {
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/coordination/SubtaskGatewayImpl.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/coordination/SubtaskGatewayImpl.java
index 7c0e8a673a8..1f6a120c400 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/coordination/SubtaskGatewayImpl.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/coordination/SubtaskGatewayImpl.java
@@ -37,6 +37,8 @@ import java.util.TreeSet;
 import java.util.concurrent.Callable;
 import java.util.concurrent.CompletableFuture;
 
+import static 
org.apache.flink.runtime.operators.coordination.OperatorCoordinator.BATCH_CHECKPOINT_ID;
+
 /**
  * Implementation of the {@link OperatorCoordinator.SubtaskGateway} interface 
that access to
  * subtasks for status and event sending via {@link SubtaskAccess}.
@@ -185,7 +187,7 @@ class SubtaskGatewayImpl implements 
OperatorCoordinator.SubtaskGateway {
         if (checkpointId > latestAttemptedCheckpointId) {
             currentMarkedCheckpointIds.add(checkpointId);
             latestAttemptedCheckpointId = checkpointId;
-        } else {
+        } else if (checkpointId != BATCH_CHECKPOINT_ID) {
             throw new IllegalStateException(
                     String.format(
                             "Regressing checkpoint IDs. Previous checkpointId 
= %d, new checkpointId = %d",
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SourceCoordinator.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SourceCoordinator.java
index c9e6eb8a2a2..21c692e09a2 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SourceCoordinator.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SourceCoordinator.java
@@ -29,6 +29,7 @@ import org.apache.flink.api.connector.source.Source;
 import org.apache.flink.api.connector.source.SourceEvent;
 import org.apache.flink.api.connector.source.SourceSplit;
 import org.apache.flink.api.connector.source.SplitEnumerator;
+import org.apache.flink.api.connector.source.SupportsBatchSnapshot;
 import 
org.apache.flink.api.connector.source.SupportsHandleExecutionAttemptSourceEvent;
 import org.apache.flink.core.io.SimpleVersionedSerializer;
 import org.apache.flink.core.memory.DataInputViewStreamWrapper;
@@ -76,6 +77,7 @@ import static 
org.apache.flink.runtime.source.coordinator.SourceCoordinatorSerde
 import static 
org.apache.flink.runtime.source.coordinator.SourceCoordinatorSerdeUtils.writeCoordinatorSerdeVersion;
 import static org.apache.flink.util.IOUtils.closeQuietly;
 import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkNotNull;
 import static org.apache.flink.util.Preconditions.checkState;
 
 /**
@@ -375,6 +377,16 @@ public class SourceCoordinator<SplitT extends SourceSplit, 
EnumChkT>
                 attemptNumber);
     }
 
+    /**
+     * Whether the enumerator supports batch snapshot. Note the enumerator is 
created either during
+     * resetting the coordinator to a checkpoint, or when the coordinator is 
started.
+     */
+    @Override
+    public boolean supportsBatchSnapshot() {
+        checkNotNull(enumerator);
+        return enumerator instanceof SupportsBatchSnapshot;
+    }
+
     @Override
     public void checkpointCoordinator(long checkpointId, 
CompletableFuture<byte[]> result) {
         runInEventLoop(
@@ -384,8 +396,28 @@ public class SourceCoordinator<SplitT extends SourceSplit, 
EnumChkT>
                             operatorName,
                             checkpointId);
                     try {
-                        context.onCheckpoint(checkpointId);
-                        result.complete(toBytes(checkpointId));
+                        if (checkpointId == BATCH_CHECKPOINT_ID) {
+                            checkState(supportsBatchSnapshot());
+                            try (ByteArrayOutputStream baos = new 
ByteArrayOutputStream();
+                                    DataOutputStream out = new 
DataOutputViewStreamWrapper(baos)) {
+                                // assignments
+                                byte[] assignmentData =
+                                        context.getAssignmentTracker()
+                                                
.snapshotState(source.getSplitSerializer());
+                                out.writeInt(assignmentData.length);
+                                out.write(assignmentData);
+
+                                // enumerator
+                                byte[] enumeratorData = toBytes(checkpointId);
+                                out.writeInt(enumeratorData.length);
+                                out.write(enumeratorData);
+                                out.flush();
+                                result.complete(baos.toByteArray());
+                            }
+                        } else {
+                            context.onCheckpoint(checkpointId);
+                            result.complete(toBytes(checkpointId));
+                        }
                     } catch (Throwable e) {
                         ExceptionUtils.rethrowIfFatalErrorOrOOM(e);
                         result.completeExceptionally(
@@ -446,10 +478,33 @@ public class SourceCoordinator<SplitT extends 
SourceSplit, EnumChkT>
 
         final ClassLoader userCodeClassLoader =
                 context.getCoordinatorContext().getUserCodeClassloader();
-        try (TemporaryClassLoaderContext ignored =
-                TemporaryClassLoaderContext.of(userCodeClassLoader)) {
-            final EnumChkT enumeratorCheckpoint = 
deserializeCheckpoint(checkpointData);
-            enumerator = source.restoreEnumerator(context, 
enumeratorCheckpoint);
+
+        if (checkpointId == BATCH_CHECKPOINT_ID) {
+            try (TemporaryClassLoaderContext ignored =
+                    TemporaryClassLoaderContext.of(userCodeClassLoader)) {
+                try (ByteArrayInputStream bais = new 
ByteArrayInputStream(checkpointData);
+                        DataInputStream in = new 
DataInputViewStreamWrapper(bais)) {
+                    int assignmentDataLength = in.readInt();
+                    byte[] assignmentData =
+                            SourceCoordinatorSerdeUtils.readBytes(in, 
assignmentDataLength);
+
+                    int enumeratorDataLength = in.readInt();
+                    byte[] enumeratorData =
+                            SourceCoordinatorSerdeUtils.readBytes(in, 
enumeratorDataLength);
+
+                    final EnumChkT enumeratorCheckpoint = 
deserializeCheckpoint(enumeratorData);
+                    enumerator = source.restoreEnumerator(context, 
enumeratorCheckpoint);
+
+                    context.getAssignmentTracker()
+                            .restoreState(source.getSplitSerializer(), 
assignmentData);
+                }
+            }
+        } else {
+            try (TemporaryClassLoaderContext ignored =
+                    TemporaryClassLoaderContext.of(userCodeClassLoader)) {
+                final EnumChkT enumeratorCheckpoint = 
deserializeCheckpoint(checkpointData);
+                enumerator = source.restoreEnumerator(context, 
enumeratorCheckpoint);
+            }
         }
     }
 
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorContext.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorContext.java
index fa764389797..934c222af92 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorContext.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorContext.java
@@ -511,6 +511,10 @@ public class SourceCoordinatorContext<SplitT extends 
SourceSplit>
         return operatorCoordinatorContext;
     }
 
+    SplitAssignmentTracker<SplitT> getAssignmentTracker() {
+        return assignmentTracker;
+    }
+
     // ---------------- Executor methods to avoid use coordinatorExecutor 
directly -----------------
 
     Future<?> submitTask(Runnable task) {
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorSerdeUtils.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorSerdeUtils.java
index 8134834c6a1..ce1e9fe0a6f 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorSerdeUtils.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorSerdeUtils.java
@@ -18,9 +18,18 @@ limitations under the License.
 
 package org.apache.flink.runtime.source.coordinator;
 
+import org.apache.flink.core.io.SimpleVersionedSerializer;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
 import java.io.DataInputStream;
 import java.io.DataOutputStream;
 import java.io.IOException;
+import java.util.HashMap;
+import java.util.LinkedHashSet;
+import java.util.Map;
 
 /** A serialization util class for the {@link SourceCoordinator}. */
 public class SourceCoordinatorSerdeUtils {
@@ -53,4 +62,59 @@ public class SourceCoordinatorSerdeUtils {
         in.readFully(bytes);
         return bytes;
     }
+
+    static <SplitT> byte[] serializeAssignments(
+            Map<Integer, LinkedHashSet<SplitT>> assignments,
+            SimpleVersionedSerializer<SplitT> splitSerializer)
+            throws IOException {
+        try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
+                DataOutputStream out = new DataOutputViewStreamWrapper(baos)) {
+            out.writeInt(splitSerializer.getVersion());
+
+            int numSubtasks = assignments.size();
+            out.writeInt(numSubtasks);
+            for (Map.Entry<Integer, LinkedHashSet<SplitT>> assignment : 
assignments.entrySet()) {
+                int subtaskId = assignment.getKey();
+                out.writeInt(subtaskId);
+
+                int numAssignedSplits = assignment.getValue().size();
+                out.writeInt(numAssignedSplits);
+                for (SplitT split : assignment.getValue()) {
+                    byte[] serializedSplit = splitSerializer.serialize(split);
+                    out.writeInt(serializedSplit.length);
+                    out.write(serializedSplit);
+                }
+            }
+            out.flush();
+            return baos.toByteArray();
+        }
+    }
+
+    static <SplitT> Map<Integer, LinkedHashSet<SplitT>> deserializeAssignments(
+            byte[] assignmentData, SimpleVersionedSerializer<SplitT> 
splitSerializer)
+            throws IOException {
+
+        try (ByteArrayInputStream bais = new 
ByteArrayInputStream(assignmentData);
+                DataInputStream in = new DataInputViewStreamWrapper(bais)) {
+            int splitSerializerVersion = in.readInt();
+
+            int numSubtasks = in.readInt();
+            Map<Integer, LinkedHashSet<SplitT>> assignments = new HashMap<>();
+            for (int j = 0; j < numSubtasks; j++) {
+                int subtaskId = in.readInt();
+                int numAssignedSplits = in.readInt();
+                LinkedHashSet<SplitT> splits = new 
LinkedHashSet<>(numAssignedSplits);
+                assignments.put(subtaskId, splits);
+                for (int k = 0; k < numAssignedSplits; k++) {
+                    int serializedSplitSize = in.readInt();
+                    byte[] serializedSplit = readBytes(in, 
serializedSplitSize);
+                    SplitT split =
+                            
splitSerializer.deserialize(splitSerializerVersion, serializedSplit);
+                    splits.add(split);
+                }
+            }
+
+            return assignments;
+        }
+    }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SplitAssignmentTracker.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SplitAssignmentTracker.java
index d5b7c25ca5b..87525647e1b 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SplitAssignmentTracker.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SplitAssignmentTracker.java
@@ -23,6 +23,7 @@ import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.connector.source.SourceSplit;
 import org.apache.flink.api.connector.source.SplitEnumerator;
 import org.apache.flink.api.connector.source.SplitsAssignment;
+import org.apache.flink.core.io.SimpleVersionedSerializer;
 
 import java.util.ArrayList;
 import java.util.Collections;
@@ -65,6 +66,27 @@ public class SplitAssignmentTracker<SplitT extends 
SourceSplit> {
         uncheckpointedAssignments = new HashMap<>();
     }
 
+    /** Take a snapshot of the split assignments. */
+    public byte[] snapshotState(SimpleVersionedSerializer<SplitT> 
splitSerializer)
+            throws Exception {
+        return SourceCoordinatorSerdeUtils.serializeAssignments(
+                uncheckpointedAssignments, splitSerializer);
+    }
+
+    /**
+     * Restore the state of the SplitAssignmentTracker.
+     *
+     * @param splitSerializer The serializer of the splits.
+     * @param assignmentData The state of the SplitAssignmentTracker.
+     * @throws Exception when the state deserialization fails.
+     */
+    public void restoreState(
+            SimpleVersionedSerializer<SplitT> splitSerializer, byte[] 
assignmentData)
+            throws Exception {
+        uncheckpointedAssignments =
+                
SourceCoordinatorSerdeUtils.deserializeAssignments(assignmentData, 
splitSerializer);
+    }
+
     /**
      * when a checkpoint has been successfully made, this method is invoked to 
clean up the
      * assignment history before this successful checkpoint.
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorTest.java
index ea9455ac45f..06343c87413 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorTest.java
@@ -168,6 +168,42 @@ class SourceCoordinatorTest extends 
SourceCoordinatorTestBase {
                 .isEmpty();
     }
 
+    @Test
+    void testBatchSnapshotCoordinatorAndRestore() throws Exception {
+        sourceReady();
+        addTestingSplitSet(6);
+
+        registerReader(0);
+        getEnumerator().executeAssignOneSplit(0);
+        getEnumerator().executeAssignOneSplit(0);
+
+        final CompletableFuture<byte[]> checkpointFuture = new 
CompletableFuture<>();
+        sourceCoordinator.checkpointCoordinator(
+                OperatorCoordinator.BATCH_CHECKPOINT_ID, checkpointFuture);
+        final byte[] bytes = checkpointFuture.get();
+
+        // restore from the batch snapshot.
+        SourceCoordinator<?, ?> restoredCoordinator = 
getNewSourceCoordinator();
+        
restoredCoordinator.resetToCheckpoint(OperatorCoordinator.BATCH_CHECKPOINT_ID, 
bytes);
+        TestingSplitEnumerator<?> restoredEnumerator =
+                (TestingSplitEnumerator<?>) 
restoredCoordinator.getEnumerator();
+        SourceCoordinatorContext<?> restoredContext = 
restoredCoordinator.getContext();
+        assertThat(restoredEnumerator.getUnassignedSplits())
+                .as("2 splits should have been assigned to reader 0")
+                .hasSize(4);
+        
assertThat(restoredEnumerator.getContext().registeredReaders()).isEmpty();
+        assertThat(restoredContext.registeredReaders())
+                .as("Registered readers should not be recovered by restoring")
+                .isEmpty();
+
+        
assertThat(restoredContext.getAssignmentTracker().uncheckpointedAssignments())
+                .isEqualTo(
+                        sourceCoordinator
+                                .getContext()
+                                .getAssignmentTracker()
+                                .uncheckpointedAssignments());
+    }
+
     @Test
     void testSubtaskFailedAndRevertUncompletedAssignments() throws Exception {
         sourceReady();
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/source/coordinator/SplitAssignmentTrackerTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/source/coordinator/SplitAssignmentTrackerTest.java
index 3c8c26047ca..bc0f0182585 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/source/coordinator/SplitAssignmentTrackerTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/source/coordinator/SplitAssignmentTrackerTest.java
@@ -19,6 +19,7 @@ limitations under the License.
 package org.apache.flink.runtime.source.coordinator;
 
 import org.apache.flink.api.connector.source.mocks.MockSourceSplit;
+import org.apache.flink.api.connector.source.mocks.MockSourceSplitSerializer;
 
 import org.junit.jupiter.api.Test;
 
@@ -47,6 +48,27 @@ class SplitAssignmentTrackerTest {
         verifyAssignment(Arrays.asList("3", "4", "5"), 
tracker.uncheckpointedAssignments().get(2));
     }
 
+    @Test
+    void testSnapshotStateAndRestoreState() throws Exception {
+        SplitAssignmentTracker<MockSourceSplit> tracker = new 
SplitAssignmentTracker<>();
+        tracker.recordSplitAssignment(getSplitsAssignment(3, 0));
+        tracker.recordSplitAssignment(getSplitsAssignment(2, 6));
+
+        byte[] snapshotState = tracker.snapshotState(new 
MockSourceSplitSerializer());
+
+        SplitAssignmentTracker<MockSourceSplit> trackerToRestore = new 
SplitAssignmentTracker<>();
+        assertThat(trackerToRestore.uncheckpointedAssignments()).isEmpty();
+        trackerToRestore.restoreState(new MockSourceSplitSerializer(), 
snapshotState);
+
+        verifyAssignment(
+                Arrays.asList("0", "6"), 
trackerToRestore.uncheckpointedAssignments().get(0));
+        verifyAssignment(
+                Arrays.asList("1", "2", "7", "8"),
+                trackerToRestore.uncheckpointedAssignments().get(1));
+        verifyAssignment(
+                Arrays.asList("3", "4", "5"), 
trackerToRestore.uncheckpointedAssignments().get(2));
+    }
+
     @Test
     void testOnCheckpoint() throws Exception {
         final long checkpointId = 123L;
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/source/coordinator/TestingSplitEnumerator.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/source/coordinator/TestingSplitEnumerator.java
index 7eba9e64c16..1690c083fce 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/source/coordinator/TestingSplitEnumerator.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/source/coordinator/TestingSplitEnumerator.java
@@ -27,6 +27,7 @@ import 
org.apache.flink.api.connector.source.SourceReaderContext;
 import org.apache.flink.api.connector.source.SourceSplit;
 import org.apache.flink.api.connector.source.SplitEnumerator;
 import org.apache.flink.api.connector.source.SplitEnumeratorContext;
+import org.apache.flink.api.connector.source.SupportsBatchSnapshot;
 import org.apache.flink.core.io.SimpleVersionedSerializer;
 import org.apache.flink.util.ExceptionUtils;
 import org.apache.flink.util.Preconditions;
@@ -55,7 +56,7 @@ import java.util.concurrent.ExecutionException;
  * assertions.
  */
 public class TestingSplitEnumerator<SplitT extends SourceSplit>
-        implements SplitEnumerator<SplitT, Set<SplitT>> {
+        implements SplitEnumerator<SplitT, Set<SplitT>>, SupportsBatchSnapshot 
{
 
     private final SplitEnumeratorContext<SplitT> context;
 

Reply via email to