This is an automated email from the ASF dual-hosted git repository.

guoweijie pushed a commit to branch hs_jm_fo
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 7735f359ffabfca74797cb7a8919e0716f0784c1
Author: Weijie Guo <[email protected]>
AuthorDate: Wed Dec 18 19:08:28 2024 +0800

    [FLINK-36880][network] Hybrid shuffle supports job master failover if only 
external tier used.
---
 .../executiongraph/ResultPartitionBytes.java       |  21 +++
 .../shuffle/AllTieredShuffleMasterSnapshots.java   |  50 +++++++
 .../shuffle/EmptyTieredShuffleMasterSnapshot.java} |  25 ++--
 .../shuffle/ShuffleDescriptorRetriever.java}       |  29 ++--
 .../shuffle/TieredInternalShuffleMaster.java       |  73 +++++++++-
 .../TieredInternalShuffleMasterSnapshot.java       |  55 +++++++
 .../shuffle/TieredShuffleMasterSnapshot.java}      |  24 ++--
 .../tiered/storage/TieredStorageMasterClient.java  | 159 ++++++++++++++++++++-
 .../partition/hybrid/tiered/tier/TierFactory.java  |   3 +
 .../hybrid/tiered/tier/TierMasterAgent.java        |  58 ++++++++
 .../hybrid/tiered/tier/disk/DiskTierFactory.java   |   5 +
 .../tiered/tier/memory/MemoryTierFactory.java      |   5 +
 .../tiered/tier/remote/RemoteTierFactory.java      |   5 +
 .../tiered/tier/remote/RemoteTierMasterAgent.java  |  20 +++
 .../flink/runtime/shuffle/NettyShuffleMaster.java  | 101 ++++++++++++-
 .../hybrid/tiered/storage/TestingTierFactory.java  |   5 +
 .../adaptivebatch/BatchJobRecoveryTest.java        |  96 +++++++++----
 .../scheduler/adaptivebatch/DummyTierFactory.java} |  82 +++++------
 18 files changed, 692 insertions(+), 124 deletions(-)

diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ResultPartitionBytes.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ResultPartitionBytes.java
index 630a828c648..e0166d34c83 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ResultPartitionBytes.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ResultPartitionBytes.java
@@ -19,7 +19,9 @@
 package org.apache.flink.runtime.executiongraph;
 
 import java.io.Serializable;
+import java.util.List;
 
+import static org.apache.flink.util.Preconditions.checkArgument;
 import static org.apache.flink.util.Preconditions.checkNotNull;
 
 /** This class represents a snapshot of the result partition bytes metrics. */
@@ -34,4 +36,23 @@ public class ResultPartitionBytes implements Serializable {
     public long[] getSubpartitionBytes() {
         return subpartitionBytes;
     }
+
+    /** Merge all {@link ResultPartitionBytes} by sum up them 
per-subpartition. */
+    public static ResultPartitionBytes mergeAll(List<ResultPartitionBytes> 
partitions) {
+        checkArgument(!partitions.isEmpty());
+        int expectedLength = partitions.get(0).getSubpartitionBytes().length;
+        for (ResultPartitionBytes resultPartitionByte : partitions) {
+            if (resultPartitionByte.getSubpartitionBytes().length != 
expectedLength) {
+                throw new IllegalArgumentException(
+                        "only all ResultPartitionBytes with the same length 
can be merged");
+            }
+        }
+        long[] mergedSubpartitionBytes = new long[expectedLength];
+        for (int i = 0; i < expectedLength; i++) {
+            for (ResultPartitionBytes resultPartitionByte : partitions) {
+                mergedSubpartitionBytes[i] += 
resultPartitionByte.getSubpartitionBytes()[i];
+            }
+        }
+        return new ResultPartitionBytes(mergedSubpartitionBytes);
+    }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/shuffle/AllTieredShuffleMasterSnapshots.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/shuffle/AllTieredShuffleMasterSnapshots.java
new file mode 100644
index 00000000000..ca3626363a2
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/shuffle/AllTieredShuffleMasterSnapshots.java
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.partition.hybrid.tiered.shuffle;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+
+import java.io.Serializable;
+import java.util.Collection;
+import java.util.Collections;
+
+/**
+ * This is a collection of all {@link TieredShuffleMasterSnapshot}s from every 
tier in one snapshot
+ * round.
+ */
+public class AllTieredShuffleMasterSnapshots implements Serializable {
+    /**
+     * Snapshots of all tires. For each tier, it is a tuple of
+     * (identifier,TieredShuffleMasterSnapshot)
+     */
+    private final Collection<Tuple2<String, TieredShuffleMasterSnapshot>> 
snapshots;
+
+    public AllTieredShuffleMasterSnapshots(
+            Collection<Tuple2<String, TieredShuffleMasterSnapshot>> snapshots) 
{
+        this.snapshots = snapshots;
+    }
+
+    public Collection<Tuple2<String, TieredShuffleMasterSnapshot>> 
getSnapshots() {
+        return snapshots;
+    }
+
+    public static AllTieredShuffleMasterSnapshots empty() {
+        return new AllTieredShuffleMasterSnapshots(Collections.emptyList());
+    }
+}
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ResultPartitionBytes.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/shuffle/EmptyTieredShuffleMasterSnapshot.java
similarity index 60%
copy from 
flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ResultPartitionBytes.java
copy to 
flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/shuffle/EmptyTieredShuffleMasterSnapshot.java
index 630a828c648..6fd35cac5f6 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ResultPartitionBytes.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/shuffle/EmptyTieredShuffleMasterSnapshot.java
@@ -16,22 +16,17 @@
  * limitations under the License.
  */
 
-package org.apache.flink.runtime.executiongraph;
+package org.apache.flink.runtime.io.network.partition.hybrid.tiered.shuffle;
 
-import java.io.Serializable;
-
-import static org.apache.flink.util.Preconditions.checkNotNull;
-
-/** This class represents a snapshot of the result partition bytes metrics. */
-public class ResultPartitionBytes implements Serializable {
-
-    private final long[] subpartitionBytes;
-
-    public ResultPartitionBytes(long[] subpartitionBytes) {
-        this.subpartitionBytes = checkNotNull(subpartitionBytes);
-    }
+/**
+ * A singleton implementation of {@link TieredShuffleMasterSnapshot} that 
represents an empty
+ * snapshot of tiered shuffle master.
+ */
+public class EmptyTieredShuffleMasterSnapshot implements 
TieredShuffleMasterSnapshot {
+    private static final EmptyTieredShuffleMasterSnapshot INSTANCE =
+            new EmptyTieredShuffleMasterSnapshot();
 
-    public long[] getSubpartitionBytes() {
-        return subpartitionBytes;
+    public static EmptyTieredShuffleMasterSnapshot getInstance() {
+        return INSTANCE;
     }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ResultPartitionBytes.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/shuffle/ShuffleDescriptorRetriever.java
similarity index 56%
copy from 
flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ResultPartitionBytes.java
copy to 
flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/shuffle/ShuffleDescriptorRetriever.java
index 630a828c648..d5cd9ddd8a3 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ResultPartitionBytes.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/shuffle/ShuffleDescriptorRetriever.java
@@ -16,22 +16,21 @@
  * limitations under the License.
  */
 
-package org.apache.flink.runtime.executiongraph;
+package org.apache.flink.runtime.io.network.partition.hybrid.tiered.shuffle;
 
-import java.io.Serializable;
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
+import org.apache.flink.runtime.shuffle.ShuffleDescriptor;
 
-import static org.apache.flink.util.Preconditions.checkNotNull;
+import java.util.Optional;
 
-/** This class represents a snapshot of the result partition bytes metrics. */
-public class ResultPartitionBytes implements Serializable {
-
-    private final long[] subpartitionBytes;
-
-    public ResultPartitionBytes(long[] subpartitionBytes) {
-        this.subpartitionBytes = checkNotNull(subpartitionBytes);
-    }
-
-    public long[] getSubpartitionBytes() {
-        return subpartitionBytes;
-    }
+/** The retriever for shuffle descriptor. */
+public interface ShuffleDescriptorRetriever {
+    /**
+     * Get shuffle descriptor by JobID and ResultPartitionId.
+     *
+     * @return shuffle descriptor or empty if not exist.
+     */
+    Optional<ShuffleDescriptor> getShuffleDescriptor(
+            JobID jobID, ResultPartitionID resultPartitionID);
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/shuffle/TieredInternalShuffleMaster.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/shuffle/TieredInternalShuffleMaster.java
index 2d8ce328ca1..3768b52be72 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/shuffle/TieredInternalShuffleMaster.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/shuffle/TieredInternalShuffleMaster.java
@@ -19,7 +19,9 @@
 package org.apache.flink.runtime.io.network.partition.hybrid.tiered.shuffle;
 
 import org.apache.flink.api.common.JobID;
+import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.NettyShuffleEnvironmentOptions;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
 import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStorageConfiguration;
 import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStorageIdMappingUtils;
@@ -30,11 +32,16 @@ import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.tier.TierMast
 import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.tier.TierShuffleDescriptor;
 import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.tier.TierShuffleHandler;
 import org.apache.flink.runtime.shuffle.JobShuffleContext;
+import org.apache.flink.runtime.shuffle.PartitionWithMetrics;
 import org.apache.flink.runtime.shuffle.ShuffleDescriptor;
 import org.apache.flink.runtime.shuffle.ShuffleMasterContext;
+import org.apache.flink.runtime.shuffle.ShuffleMasterSnapshotContext;
 
+import java.time.Duration;
 import java.util.Collection;
+import java.util.Collections;
 import java.util.List;
+import java.util.Set;
 import java.util.concurrent.CompletableFuture;
 import java.util.stream.Collectors;
 
@@ -48,17 +55,75 @@ public class TieredInternalShuffleMaster {
 
     private final ShuffleMasterContext shuffleMasterContext;
 
-    public TieredInternalShuffleMaster(ShuffleMasterContext 
shuffleMasterContext) {
+    private final boolean useOnlyExternalTier;
+
+    public TieredInternalShuffleMaster(
+            ShuffleMasterContext shuffleMasterContext,
+            ShuffleDescriptorRetriever shuffleDescriptorRetriever) {
         this.shuffleMasterContext = shuffleMasterContext;
         Configuration conf = shuffleMasterContext.getConfiguration();
+        String externalTierFactoryClass =
+                conf.get(
+                        NettyShuffleEnvironmentOptions
+                                
.NETWORK_HYBRID_SHUFFLE_EXTERNAL_REMOTE_TIER_FACTORY_CLASS_NAME);
+        this.useOnlyExternalTier = externalTierFactoryClass != null;
         TieredStorageConfiguration tieredStorageConfiguration =
                 TieredStorageConfiguration.fromConfiguration(conf);
         TieredStorageResourceRegistry resourceRegistry = new 
TieredStorageResourceRegistry();
-        List<TierMasterAgent> tierFactories =
+        List<Tuple2<String, TierMasterAgent>> tierFactories =
                 tieredStorageConfiguration.getTierFactories().stream()
-                        .map(tierFactory -> 
tierFactory.createMasterAgent(resourceRegistry))
+                        .map(
+                                tierFactory ->
+                                        Tuple2.of(
+                                                tierFactory.identifier(),
+                                                
tierFactory.createMasterAgent(resourceRegistry)))
                         .collect(Collectors.toList());
-        this.tieredStorageMasterClient = new 
TieredStorageMasterClient(tierFactories);
+        this.tieredStorageMasterClient =
+                new TieredStorageMasterClient(tierFactories, 
shuffleDescriptorRetriever);
+    }
+
+    public boolean supportsBatchSnapshot() {
+        return useOnlyExternalTier;
+    }
+
+    public void snapshotState(
+            CompletableFuture<AllTieredShuffleMasterSnapshots> snapshotFuture,
+            ShuffleMasterSnapshotContext context,
+            JobID jobId) {
+        // only external tier supports snapshot for now.
+        if (useOnlyExternalTier) {
+            tieredStorageMasterClient.snapshotState(snapshotFuture, context, 
jobId);
+        }
+    }
+
+    public void 
snapshotState(CompletableFuture<AllTieredShuffleMasterSnapshots> 
snapshotFuture) {
+        if (useOnlyExternalTier) {
+            tieredStorageMasterClient.snapshotState(snapshotFuture);
+        }
+    }
+
+    public void restoreState(List<TieredInternalShuffleMasterSnapshot> 
snapshots, JobID jobId) {
+        if (useOnlyExternalTier) {
+            tieredStorageMasterClient.restoreState(snapshots, jobId);
+        }
+    }
+
+    public void restoreState(TieredInternalShuffleMasterSnapshot 
clusterSnapshot) {
+        if (useOnlyExternalTier) {
+            tieredStorageMasterClient.restoreState(clusterSnapshot);
+        }
+    }
+
+    public CompletableFuture<Collection<PartitionWithMetrics>> 
getPartitionWithMetrics(
+            JobShuffleContext jobShuffleContext,
+            Duration timeout,
+            Set<ResultPartitionID> expectedPartitions) {
+        if (useOnlyExternalTier) {
+            return tieredStorageMasterClient.getPartitionWithMetrics(
+                    jobShuffleContext, timeout, expectedPartitions);
+        } else {
+            return CompletableFuture.completedFuture(Collections.emptyList());
+        }
     }
 
     /**
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/shuffle/TieredInternalShuffleMasterSnapshot.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/shuffle/TieredInternalShuffleMasterSnapshot.java
new file mode 100644
index 00000000000..c349425010e
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/shuffle/TieredInternalShuffleMasterSnapshot.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.partition.hybrid.tiered.shuffle;
+
+import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
+import org.apache.flink.runtime.shuffle.ShuffleDescriptor;
+import org.apache.flink.runtime.shuffle.ShuffleMasterSnapshot;
+
+import java.util.Map;
+
+/**
+ * The internal {@link ShuffleMasterSnapshot} for hybrid shuffle. This bump 
shuffle descriptors and
+ * all tiers snapshot.
+ */
+public class TieredInternalShuffleMasterSnapshot implements 
ShuffleMasterSnapshot {
+    private final Map<ResultPartitionID, ShuffleDescriptor> shuffleDescriptors;
+
+    private final AllTieredShuffleMasterSnapshots allTierSnapshots;
+
+    public TieredInternalShuffleMasterSnapshot(
+            Map<ResultPartitionID, ShuffleDescriptor> shuffleDescriptors,
+            AllTieredShuffleMasterSnapshots allTierSnapshots) {
+        this.shuffleDescriptors = shuffleDescriptors;
+        this.allTierSnapshots = allTierSnapshots;
+    }
+
+    public Map<ResultPartitionID, ShuffleDescriptor> getShuffleDescriptors() {
+        return shuffleDescriptors;
+    }
+
+    public AllTieredShuffleMasterSnapshots getAllTierSnapshots() {
+        return allTierSnapshots;
+    }
+
+    @Override
+    public boolean isIncremental() {
+        return true;
+    }
+}
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ResultPartitionBytes.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/shuffle/TieredShuffleMasterSnapshot.java
similarity index 61%
copy from 
flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ResultPartitionBytes.java
copy to 
flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/shuffle/TieredShuffleMasterSnapshot.java
index 630a828c648..4e9a7d37f72 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ResultPartitionBytes.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/shuffle/TieredShuffleMasterSnapshot.java
@@ -16,22 +16,14 @@
  * limitations under the License.
  */
 
-package org.apache.flink.runtime.executiongraph;
+package org.apache.flink.runtime.io.network.partition.hybrid.tiered.shuffle;
 
 import java.io.Serializable;
 
-import static org.apache.flink.util.Preconditions.checkNotNull;
-
-/** This class represents a snapshot of the result partition bytes metrics. */
-public class ResultPartitionBytes implements Serializable {
-
-    private final long[] subpartitionBytes;
-
-    public ResultPartitionBytes(long[] subpartitionBytes) {
-        this.subpartitionBytes = checkNotNull(subpartitionBytes);
-    }
-
-    public long[] getSubpartitionBytes() {
-        return subpartitionBytes;
-    }
-}
+/**
+ * This class represents a snapshot of tiered shuffle master, which can be 
used to restore the
+ * internal state of the shuffle master.
+ *
+ * <p>IMPORTANT: It is incremental.
+ */
+public interface TieredShuffleMasterSnapshot extends Serializable {}
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/TieredStorageMasterClient.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/TieredStorageMasterClient.java
index 5bb857358a7..0dd561147f9 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/TieredStorageMasterClient.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/TieredStorageMasterClient.java
@@ -19,33 +19,68 @@
 package org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage;
 
 import org.apache.flink.api.common.JobID;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.runtime.executiongraph.ResultPartitionBytes;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.shuffle.AllTieredShuffleMasterSnapshots;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.shuffle.ShuffleDescriptorRetriever;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.shuffle.TieredInternalShuffleMasterSnapshot;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.shuffle.TieredShuffleMasterSnapshot;
 import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.tier.TierMasterAgent;
 import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.tier.TierShuffleDescriptor;
 import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.tier.TierShuffleHandler;
+import org.apache.flink.runtime.shuffle.DefaultPartitionWithMetrics;
+import org.apache.flink.runtime.shuffle.DefaultShuffleMetrics;
+import org.apache.flink.runtime.shuffle.JobShuffleContext;
 import org.apache.flink.runtime.shuffle.NettyShuffleDescriptor;
+import org.apache.flink.runtime.shuffle.PartitionWithMetrics;
 import org.apache.flink.runtime.shuffle.ShuffleDescriptor;
+import org.apache.flink.runtime.shuffle.ShuffleMasterSnapshotContext;
+import org.apache.flink.runtime.shuffle.ShuffleMetrics;
+import org.apache.flink.util.concurrent.FutureUtils;
 
+import java.time.Duration;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+import java.util.function.BiConsumer;
 import java.util.stream.Collectors;
 
 import static org.apache.flink.util.Preconditions.checkState;
 
 /** Client of the Tiered Storage used by the master. */
 public class TieredStorageMasterClient {
+    private final List<Tuple2<String, TierMasterAgent>> tiers;
 
-    private final List<TierMasterAgent> tiers;
+    private final Map<String, TierMasterAgent> tierMasterAgentMap;
 
-    public TieredStorageMasterClient(List<TierMasterAgent> tiers) {
+    private final boolean allPartitionInRemote;
+
+    private final ShuffleDescriptorRetriever shuffleDescriptorRetriever;
+
+    public TieredStorageMasterClient(
+            List<Tuple2<String, TierMasterAgent>> tiers,
+            ShuffleDescriptorRetriever shuffleDescriptorRetriever) {
         this.tiers = tiers;
+        this.allPartitionInRemote = tiers.stream().allMatch(tier -> 
tier.f1.partitionInRemote());
+        this.tierMasterAgentMap = new HashMap<>();
+        for (Tuple2<String, TierMasterAgent> tier : tiers) {
+            tierMasterAgentMap.put(tier.f0, tier.f1);
+        }
+        this.shuffleDescriptorRetriever = shuffleDescriptorRetriever;
     }
 
     public void registerJob(JobID jobID, TierShuffleHandler shuffleHandler) {
-        tiers.forEach(tierMasterAgent -> tierMasterAgent.registerJob(jobID, 
shuffleHandler));
+        tiers.forEach(tierMasterAgent -> tierMasterAgent.f1.registerJob(jobID, 
shuffleHandler));
     }
 
     public void unregisterJob(JobID jobID) {
-        tiers.forEach(tierMasterAgent -> tierMasterAgent.unregisterJob(jobID));
+        tiers.forEach(tierMasterAgent -> 
tierMasterAgent.f1.unregisterJob(jobID));
     }
 
     public List<TierShuffleDescriptor> addPartitionAndGetShuffleDescriptor(
@@ -53,7 +88,7 @@ public class TieredStorageMasterClient {
         return tiers.stream()
                 .map(
                         tierMasterAgent ->
-                                
tierMasterAgent.addPartitionAndGetShuffleDescriptor(
+                                
tierMasterAgent.f1.addPartitionAndGetShuffleDescriptor(
                                         jobID, numSubpartitions, 
resultPartitionID))
                 .collect(Collectors.toList());
     }
@@ -65,12 +100,122 @@ public class TieredStorageMasterClient {
         if (tierShuffleDescriptors != null && 
!tierShuffleDescriptors.isEmpty()) {
             checkState(tierShuffleDescriptors.size() == tiers.size());
             for (int i = 0; i < tierShuffleDescriptors.size(); i++) {
-                tiers.get(i).releasePartition(tierShuffleDescriptors.get(i));
+                
tiers.get(i).f1.releasePartition(tierShuffleDescriptors.get(i));
             }
         }
     }
 
+    public void snapshotState(
+            CompletableFuture<AllTieredShuffleMasterSnapshots> snapshotFuture,
+            ShuffleMasterSnapshotContext context,
+            JobID jobId) {
+        snapshotStateInternal(
+                snapshotFuture, (agent, future) -> agent.snapshotState(future, 
context, jobId));
+    }
+
+    public void 
snapshotState(CompletableFuture<AllTieredShuffleMasterSnapshots> 
snapshotFuture) {
+        snapshotStateInternal(snapshotFuture, TierMasterAgent::snapshotState);
+    }
+
+    private void snapshotStateInternal(
+            CompletableFuture<AllTieredShuffleMasterSnapshots> snapshotFuture,
+            BiConsumer<TierMasterAgent, 
CompletableFuture<TieredShuffleMasterSnapshot>>
+                    masterAgentConsumer) {
+        List<CompletableFuture<Tuple2<String, TieredShuffleMasterSnapshot>>> 
futures =
+                new ArrayList<>(tiers.size());
+        for (Tuple2<String, TierMasterAgent> tier : tiers) {
+            CompletableFuture<TieredShuffleMasterSnapshot> future = new 
CompletableFuture<>();
+            futures.add(future.thenApply(snap -> Tuple2.of(tier.f0, snap)));
+            masterAgentConsumer.accept(tier.f1, future);
+        }
+
+        FutureUtils.combineAll(futures)
+                .thenAccept(
+                        snapshotWithIdentifiers ->
+                                snapshotFuture.complete(
+                                        new AllTieredShuffleMasterSnapshots(
+                                                snapshotWithIdentifiers)));
+    }
+
+    public void restoreState(TieredInternalShuffleMasterSnapshot 
clusterSnapshot) {
+        checkState(clusterSnapshot != null);
+        AllTieredShuffleMasterSnapshots allTierSnapshots = 
clusterSnapshot.getAllTierSnapshots();
+        Collection<Tuple2<String, TieredShuffleMasterSnapshot>> snapshots =
+                allTierSnapshots.getSnapshots();
+        for (Tuple2<String, TieredShuffleMasterSnapshot> identifierWithSnap : 
snapshots) {
+            String identifier = identifierWithSnap.f0;
+            
tierMasterAgentMap.get(identifier).restoreState(identifierWithSnap.f1);
+        }
+    }
+
+    public void restoreState(List<TieredInternalShuffleMasterSnapshot> 
snapshots, JobID jobId) {
+        for (TieredInternalShuffleMasterSnapshot internalSnapshot : snapshots) 
{
+            checkState(internalSnapshot != null);
+            AllTieredShuffleMasterSnapshots allTierSnapshots =
+                    internalSnapshot.getAllTierSnapshots();
+            Collection<Tuple2<String, TieredShuffleMasterSnapshot>> 
tierSnapshots =
+                    allTierSnapshots.getSnapshots();
+            for (Tuple2<String, TieredShuffleMasterSnapshot> 
identifierWithSnap : tierSnapshots) {
+                String identifier = identifierWithSnap.f0;
+                
tierMasterAgentMap.get(identifier).restoreState(identifierWithSnap.f1, jobId);
+            }
+        }
+    }
+
+    public CompletableFuture<Collection<PartitionWithMetrics>> 
getPartitionWithMetrics(
+            JobShuffleContext jobShuffleContext,
+            Duration timeout,
+            Set<ResultPartitionID> expectedPartitions) {
+        JobID jobId = jobShuffleContext.getJobId();
+        if (!allPartitionInRemote) {
+            return jobShuffleContext.getPartitionWithMetrics(timeout, 
expectedPartitions);
+        }
+
+        List<CompletableFuture<Map<ResultPartitionID, ShuffleMetrics>>> 
futures =
+                new ArrayList<>(tiers.size());
+        for (Tuple2<String, TierMasterAgent> tier : tiers) {
+            CompletableFuture<Map<ResultPartitionID, ShuffleMetrics>> 
tierPartitionMapFuture =
+                    tier.f1.getPartitionWithMetrics(jobId, timeout, 
expectedPartitions);
+            futures.add(tierPartitionMapFuture);
+        }
+        return FutureUtils.combineAll(futures)
+                .thenApply(
+                        allPartitions -> {
+                            int TierNums = allPartitions.size();
+                            List<PartitionWithMetrics> result = new 
ArrayList<>();
+                            expectedPartitions.forEach(
+                                    partitionId -> {
+                                        List<ResultPartitionBytes> 
partitionBytes =
+                                                new ArrayList<>();
+                                        for (Map<ResultPartitionID, 
ShuffleMetrics> partitionMap :
+                                                allPartitions) {
+                                            ShuffleMetrics shuffleMetrics =
+                                                    
partitionMap.get(partitionId);
+                                            if (shuffleMetrics == null) {
+                                                break;
+                                            }
+                                            
partitionBytes.add(shuffleMetrics.getPartitionBytes());
+                                        }
+                                        if (partitionBytes.size() == TierNums) 
{
+                                            Optional<ShuffleDescriptor> 
shuffleDescriptor =
+                                                    
shuffleDescriptorRetriever.getShuffleDescriptor(
+                                                            jobId, 
partitionId);
+                                            shuffleDescriptor.ifPresent(
+                                                    descriptor ->
+                                                            result.add(
+                                                                    new 
DefaultPartitionWithMetrics(
+                                                                            
descriptor,
+                                                                            
new DefaultShuffleMetrics(
+                                                                               
     ResultPartitionBytes
+                                                                               
             .mergeAll(
+                                                                               
                     partitionBytes)))));
+                                        }
+                                    });
+                            return result;
+                        });
+    }
+
     public void close() {
-        tiers.forEach(TierMasterAgent::close);
+        tiers.forEach(tier -> tier.f1.close());
     }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/tier/TierFactory.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/tier/TierFactory.java
index 11fb089a4b6..8bddb494ff2 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/tier/TierFactory.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/tier/TierFactory.java
@@ -72,4 +72,7 @@ public interface TierFactory {
             List<TieredStorageConsumerSpec> tieredStorageConsumerSpecs,
             List<TierShuffleDescriptor> shuffleDescriptors,
             TieredStorageNettyService nettyService);
+
+    /** The unique identifier of this tier. */
+    String identifier();
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/tier/TierMasterAgent.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/tier/TierMasterAgent.java
index ee02e14655f..596b84d286b 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/tier/TierMasterAgent.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/tier/TierMasterAgent.java
@@ -20,6 +20,16 @@ package 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.tier;
 
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.shuffle.EmptyTieredShuffleMasterSnapshot;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.shuffle.TieredShuffleMasterSnapshot;
+import org.apache.flink.runtime.shuffle.ShuffleMasterSnapshotContext;
+import org.apache.flink.runtime.shuffle.ShuffleMetrics;
+
+import java.time.Duration;
+import java.util.Collections;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.CompletableFuture;
 
 /** The master-side agent of a Tier. */
 public interface TierMasterAgent {
@@ -34,6 +44,49 @@ public interface TierMasterAgent {
     TierShuffleDescriptor addPartitionAndGetShuffleDescriptor(
             JobID jobID, int numSubpartitions, ResultPartitionID 
resultPartitionID);
 
+    /** Triggers a snapshot of the tier master agent's state which related the 
specified job. */
+    default void snapshotState(
+            CompletableFuture<TieredShuffleMasterSnapshot> snapshotFuture,
+            ShuffleMasterSnapshotContext context,
+            JobID jobId) {
+        
snapshotFuture.complete(EmptyTieredShuffleMasterSnapshot.getInstance());
+    }
+
+    /** Triggers a snapshot of the tier master agent's state. */
+    default void snapshotState(CompletableFuture<TieredShuffleMasterSnapshot> 
snapshotFuture) {
+        
snapshotFuture.complete(EmptyTieredShuffleMasterSnapshot.getInstance());
+    }
+
+    /** Restores the state of the tier master agent from the provided 
snapshots. */
+    default void restoreState(TieredShuffleMasterSnapshot snapshot, JobID 
jobId) {}
+
+    /**
+     * Restores the state of the tier master agent from the provided snapshots 
for the specified
+     * job.
+     */
+    default void restoreState(TieredShuffleMasterSnapshot snapshot) {}
+
+    /**
+     * Retrieves specified partitions and their metrics (identified by {@code 
expectedPartitions}),
+     * the metrics include sizes of sub-partitions in a result partition.
+     *
+     * @param jobId ID of the target job
+     * @param timeout The timeout used for retrieve the specified partitions.
+     * @param expectedPartitions The set of identifiers for the result 
partitions whose metrics are
+     *     to be fetched.
+     * @return A future will contain a map of the partitions with their 
metrics that could be
+     *     retrieved from the expected partitions within the specified timeout 
period.
+     */
+    default CompletableFuture<Map<ResultPartitionID, ShuffleMetrics>> 
getPartitionWithMetrics(
+            JobID jobId, Duration timeout, Set<ResultPartitionID> 
expectedPartitions) {
+        if (!partitionInRemote()) {
+            return CompletableFuture.completedFuture(Collections.emptyMap());
+        } else {
+            throw new UnsupportedOperationException(
+                    "remote partition should be reported by tier itself.");
+        }
+    }
+
     /**
      * Release a tiered storage partition.
      *
@@ -43,4 +96,9 @@ public interface TierMasterAgent {
 
     /** Close this tier master agent. */
     void close();
+
+    /** Is this tier manage the partition in remote cluster instead of flink 
taskmanager. */
+    default boolean partitionInRemote() {
+        return false;
+    }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/tier/disk/DiskTierFactory.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/tier/disk/DiskTierFactory.java
index c3a1ee2c4f9..f010a27b2fa 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/tier/disk/DiskTierFactory.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/tier/disk/DiskTierFactory.java
@@ -154,4 +154,9 @@ public class DiskTierFactory implements TierFactory {
             TieredStorageNettyService nettyService) {
         return new DiskTierConsumerAgent(tieredStorageConsumerSpecs, 
nettyService);
     }
+
+    @Override
+    public String identifier() {
+        return "disk";
+    }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/tier/memory/MemoryTierFactory.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/tier/memory/MemoryTierFactory.java
index 8ef55c6b59e..23300b451d2 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/tier/memory/MemoryTierFactory.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/tier/memory/MemoryTierFactory.java
@@ -118,4 +118,9 @@ public class MemoryTierFactory implements TierFactory {
             TieredStorageNettyService nettyService) {
         return new MemoryTierConsumerAgent(tieredStorageConsumerSpecs, 
nettyService);
     }
+
+    @Override
+    public String identifier() {
+        return "memory";
+    }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/tier/remote/RemoteTierFactory.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/tier/remote/RemoteTierFactory.java
index 3d374715a9b..8d8aa67d859 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/tier/remote/RemoteTierFactory.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/tier/remote/RemoteTierFactory.java
@@ -137,4 +137,9 @@ public class RemoteTierFactory implements TierFactory {
                 partitionFileReader,
                 bufferSizeBytes);
     }
+
+    @Override
+    public String identifier() {
+        return "remote";
+    }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/tier/remote/RemoteTierMasterAgent.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/tier/remote/RemoteTierMasterAgent.java
index 7ff29fe01bc..e10e7d728ec 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/tier/remote/RemoteTierMasterAgent.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/tier/remote/RemoteTierMasterAgent.java
@@ -25,6 +25,13 @@ import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.Tiere
 import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.tier.TierMasterAgent;
 import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.tier.TierShuffleDescriptor;
 import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.tier.TierShuffleHandler;
+import org.apache.flink.runtime.shuffle.ShuffleMetrics;
+
+import java.time.Duration;
+import java.util.Collections;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.CompletableFuture;
 
 import static 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStorageIdMappingUtils.convertId;
 import static 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.file.SegmentPartitionFile.deletePathQuietly;
@@ -64,6 +71,14 @@ public class RemoteTierMasterAgent implements 
TierMasterAgent {
         return new RemoteTierShuffleDescriptor(partitionId);
     }
 
+    @Override
+    public CompletableFuture<Map<ResultPartitionID, ShuffleMetrics>> 
getPartitionWithMetrics(
+            JobID jobId, Duration timeout, Set<ResultPartitionID> 
expectedPartitions) {
+        // TODO we could list the remote path to get all result partitions. 
Currently, this method
+        // only used for external tier, so it's safe to return empty map.
+        return CompletableFuture.completedFuture(Collections.emptyMap());
+    }
+
     @Override
     public void releasePartition(TierShuffleDescriptor shuffleDescriptor) {
         checkState(shuffleDescriptor instanceof RemoteTierShuffleDescriptor);
@@ -75,4 +90,9 @@ public class RemoteTierMasterAgent implements TierMasterAgent 
{
     public void close() {
         // noop
     }
+
+    @Override
+    public boolean partitionInRemote() {
+        return true;
+    }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/shuffle/NettyShuffleMaster.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/shuffle/NettyShuffleMaster.java
index c186505301d..35de4d733d4 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/shuffle/NettyShuffleMaster.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/shuffle/NettyShuffleMaster.java
@@ -19,11 +19,14 @@
 package org.apache.flink.runtime.shuffle;
 
 import org.apache.flink.api.common.JobID;
+import org.apache.flink.configuration.BatchExecutionOptions;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.configuration.MemorySize;
 import org.apache.flink.configuration.NettyShuffleEnvironmentOptions;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.shuffle.AllTieredShuffleMasterSnapshots;
 import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.shuffle.TieredInternalShuffleMaster;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.shuffle.TieredInternalShuffleMasterSnapshot;
 import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.tier.TierShuffleDescriptor;
 import 
org.apache.flink.runtime.shuffle.NettyShuffleDescriptor.LocalExecutionPartitionConnectionInfo;
 import 
org.apache.flink.runtime.shuffle.NettyShuffleDescriptor.NetworkPartitionConnectionInfo;
@@ -40,12 +43,14 @@ import java.util.Map;
 import java.util.Optional;
 import java.util.Set;
 import java.util.concurrent.CompletableFuture;
+import java.util.stream.Collectors;
 
 import static 
org.apache.flink.api.common.BatchShuffleMode.ALL_EXCHANGES_HYBRID_FULL;
 import static 
org.apache.flink.api.common.BatchShuffleMode.ALL_EXCHANGES_HYBRID_SELECTIVE;
 import static 
org.apache.flink.configuration.ExecutionOptions.BATCH_SHUFFLE_MODE;
 import static org.apache.flink.util.Preconditions.checkArgument;
 import static org.apache.flink.util.Preconditions.checkNotNull;
+import static org.apache.flink.util.Preconditions.checkState;
 
 /** Default {@link ShuffleMaster} for netty and local file based shuffle 
implementation. */
 public class NettyShuffleMaster implements 
ShuffleMaster<NettyShuffleDescriptor> {
@@ -62,10 +67,15 @@ public class NettyShuffleMaster implements 
ShuffleMaster<NettyShuffleDescriptor>
 
     private final int networkBufferSize;
 
+    private final boolean enableJobMasterFailover;
+
     @Nullable private final TieredInternalShuffleMaster 
tieredInternalShuffleMaster;
 
     private final Map<JobID, JobShuffleContext> jobShuffleContexts = new 
HashMap<>();
 
+    private final Map<JobID, Map<ResultPartitionID, ShuffleDescriptor>> 
jobShuffleDescriptors =
+            new HashMap<>();
+
     public NettyShuffleMaster(ShuffleMasterContext shuffleMasterContext) {
         Configuration conf = shuffleMasterContext.getConfiguration();
         checkNotNull(conf);
@@ -80,11 +90,16 @@ public class NettyShuffleMaster implements 
ShuffleMaster<NettyShuffleDescriptor>
         networkBufferSize = ConfigurationParserUtils.getPageSize(conf);
 
         if (isHybridShuffleEnabled(conf)) {
-            tieredInternalShuffleMaster = new 
TieredInternalShuffleMaster(shuffleMasterContext);
+            tieredInternalShuffleMaster =
+                    new TieredInternalShuffleMaster(
+                            shuffleMasterContext, this::getShuffleDescriptor);
         } else {
             tieredInternalShuffleMaster = null;
         }
 
+        enableJobMasterFailover =
+                conf.get(BatchExecutionOptions.JOB_RECOVERY_ENABLED) && 
supportsBatchSnapshot();
+
         checkArgument(
                 !maxRequiredBuffersPerGate.isPresent() || 
maxRequiredBuffersPerGate.get() >= 1,
                 String.format(
@@ -120,6 +135,11 @@ public class NettyShuffleMaster implements 
ShuffleMaster<NettyShuffleDescriptor>
                                 producerDescriptor, 
partitionDescriptor.getConnectionIndex()),
                         resultPartitionID,
                         tierShuffleDescriptors);
+        if (enableJobMasterFailover) {
+            Map<ResultPartitionID, ShuffleDescriptor> shuffleDescriptorMap =
+                    jobShuffleDescriptors.computeIfAbsent(jobID, k -> new 
HashMap<>());
+            shuffleDescriptorMap.put(resultPartitionID, 
shuffleDeploymentDescriptor);
+        }
         return CompletableFuture.completedFuture(shuffleDeploymentDescriptor);
     }
 
@@ -130,6 +150,12 @@ public class NettyShuffleMaster implements 
ShuffleMaster<NettyShuffleDescriptor>
         }
     }
 
+    public Optional<ShuffleDescriptor> getShuffleDescriptor(
+            JobID jobID, ResultPartitionID resultPartitionID) {
+        return Optional.ofNullable(jobShuffleDescriptors.get(jobID))
+                .map(descriptorMap -> descriptorMap.get(resultPartitionID));
+    }
+
     private static PartitionConnectionInfo createConnectionInfo(
             ProducerDescriptor producerDescriptor, int connectionIndex) {
         return producerDescriptor.getDataPort() >= 0
@@ -173,6 +199,11 @@ public class NettyShuffleMaster implements 
ShuffleMaster<NettyShuffleDescriptor>
     @Override
     public CompletableFuture<Collection<PartitionWithMetrics>> 
getPartitionWithMetrics(
             JobID jobId, Duration timeout, Set<ResultPartitionID> 
expectedPartitions) {
+        if (tieredInternalShuffleMaster != null) {
+            return tieredInternalShuffleMaster.getPartitionWithMetrics(
+                    jobShuffleContexts.get(jobId), timeout, 
expectedPartitions);
+        }
+
         return checkNotNull(jobShuffleContexts.get(jobId))
                 .getPartitionWithMetrics(timeout, expectedPartitions);
     }
@@ -189,12 +220,19 @@ public class NettyShuffleMaster implements 
ShuffleMaster<NettyShuffleDescriptor>
     public void unregisterJob(JobID jobId) {
         jobShuffleContexts.remove(jobId);
         if (tieredInternalShuffleMaster != null) {
+            if (enableJobMasterFailover) {
+                jobShuffleDescriptors.remove(jobId);
+            }
             tieredInternalShuffleMaster.unregisterJob(jobId);
         }
     }
 
     @Override
     public boolean supportsBatchSnapshot() {
+        if (tieredInternalShuffleMaster != null) {
+            return tieredInternalShuffleMaster.supportsBatchSnapshot();
+        }
+
         return true;
     }
 
@@ -203,14 +241,75 @@ public class NettyShuffleMaster implements 
ShuffleMaster<NettyShuffleDescriptor>
             CompletableFuture<ShuffleMasterSnapshot> snapshotFuture,
             ShuffleMasterSnapshotContext context,
             JobID jobId) {
+        if (tieredInternalShuffleMaster != null) {
+            Map<ResultPartitionID, ShuffleDescriptor> shuffleDescriptorMap =
+                    jobShuffleDescriptors.remove(jobId);
+            CompletableFuture<AllTieredShuffleMasterSnapshots> 
allSnapshotFuture =
+                    new CompletableFuture<>();
+            tieredInternalShuffleMaster.snapshotState(allSnapshotFuture, 
context, jobId);
+            allSnapshotFuture.thenAccept(
+                    allSnap ->
+                            snapshotFuture.complete(
+                                    new TieredInternalShuffleMasterSnapshot(
+                                            shuffleDescriptorMap, allSnap)));
+            return;
+        }
+
         snapshotFuture.complete(EmptyShuffleMasterSnapshot.getInstance());
     }
 
     @Override
     public void snapshotState(CompletableFuture<ShuffleMasterSnapshot> 
snapshotFuture) {
+        if (tieredInternalShuffleMaster != null) {
+            CompletableFuture<AllTieredShuffleMasterSnapshots> 
allSnapshotFuture =
+                    new CompletableFuture<>();
+            tieredInternalShuffleMaster.snapshotState(allSnapshotFuture);
+            allSnapshotFuture.thenAccept(
+                    allSnap ->
+                            snapshotFuture.complete(
+                                    new 
TieredInternalShuffleMasterSnapshot(null, allSnap)));
+            return;
+        }
+
         snapshotFuture.complete(EmptyShuffleMasterSnapshot.getInstance());
     }
 
+    @Override
+    public void restoreState(ShuffleMasterSnapshot snapshot) {
+        if (tieredInternalShuffleMaster != null) {
+            checkState(snapshot instanceof 
TieredInternalShuffleMasterSnapshot);
+            tieredInternalShuffleMaster.restoreState(
+                    (TieredInternalShuffleMasterSnapshot) snapshot);
+        }
+    }
+
+    @Override
+    public void restoreState(List<ShuffleMasterSnapshot> snapshots, JobID 
jobId) {
+        if (tieredInternalShuffleMaster != null) {
+            List<TieredInternalShuffleMasterSnapshot> snapshotList =
+                    snapshots.stream()
+                            .map(
+                                    snap -> {
+                                        checkState(
+                                                snap
+                                                        instanceof
+                                                        
TieredInternalShuffleMasterSnapshot);
+                                        Map<ResultPartitionID, 
ShuffleDescriptor>
+                                                shuffleDescriptors =
+                                                        
((TieredInternalShuffleMasterSnapshot) snap)
+                                                                
.getShuffleDescriptors();
+                                        if (shuffleDescriptors != null) {
+                                            jobShuffleDescriptors
+                                                    .computeIfAbsent(jobId, k 
-> new HashMap<>())
+                                                    
.putAll(shuffleDescriptors);
+                                        }
+                                        return 
(TieredInternalShuffleMasterSnapshot) snap;
+                                    })
+                            .collect(Collectors.toList());
+            tieredInternalShuffleMaster.restoreState(snapshotList, jobId);
+        }
+    }
+
     @Override
     public void notifyPartitionRecoveryStarted(JobID jobId) {
         
checkNotNull(jobShuffleContexts.get(jobId)).notifyPartitionRecoveryStarted();
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/TestingTierFactory.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/TestingTierFactory.java
index 9d87d8e9bc9..7d9e891839f 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/TestingTierFactory.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/TestingTierFactory.java
@@ -131,6 +131,11 @@ public class TestingTierFactory implements TierFactory {
         return tierConsumerAgentSupplier.apply(tieredStorageConsumerSpecs, 
nettyService);
     }
 
+    @Override
+    public String identifier() {
+        return "test";
+    }
+
     /** Builder for {@link TestingTierFactory}. */
     public static class Builder {
 
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/BatchJobRecoveryTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/BatchJobRecoveryTest.java
index 9e5d852743c..57a235765f7 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/BatchJobRecoveryTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/BatchJobRecoveryTest.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.runtime.scheduler.adaptivebatch;
 
+import org.apache.flink.api.common.BatchShuffleMode;
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.api.common.eventtime.WatermarkAlignmentParams;
 import org.apache.flink.api.connector.source.Boundedness;
@@ -26,6 +27,8 @@ import 
org.apache.flink.api.connector.source.mocks.MockSourceSplit;
 import org.apache.flink.api.connector.source.mocks.MockSplitEnumerator;
 import org.apache.flink.configuration.BatchExecutionOptions;
 import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.ExecutionOptions;
+import org.apache.flink.configuration.NettyShuffleEnvironmentOptions;
 import org.apache.flink.core.fs.Path;
 import org.apache.flink.runtime.clusterframework.types.ResourceID;
 import org.apache.flink.runtime.execution.ExecutionState;
@@ -177,9 +180,18 @@ public class BatchJobRecoveryTest {
 
     @Parameter public boolean enableSpeculativeExecution;
 
-    @Parameters(name = "enableSpeculativeExecution={0}")
-    public static Collection<Boolean> parameters() {
-        return Arrays.asList(false, true);
+    @Parameter(value = 1)
+    public boolean isBlockingShuffle;
+
+    @Parameters(name = "enableSpeculativeExecution={0}, isBlockingShuffle={1}")
+    public static Collection<Object[]> parameters() {
+        Object[][] params = {
+            {false, false},
+            {false, true},
+            {true, true},
+            {true, false}
+        };
+        return Arrays.asList(params);
     }
 
     @BeforeEach
@@ -401,7 +413,9 @@ public class BatchJobRecoveryTest {
 
             // check middle task0 is CREATED because it's waiting source task0 
finished.
             if (vertex.getParallelSubtaskIndex() == subtaskIndex) {
-                
assertThat(vertex.getExecutionState()).isEqualTo(ExecutionState.CREATED);
+                ExecutionState expectedState =
+                        isBlockingShuffle ? ExecutionState.CREATED : 
ExecutionState.DEPLOYING;
+                
assertThat(vertex.getExecutionState()).isEqualTo(expectedState);
                 continue;
             }
 
@@ -689,7 +703,9 @@ public class BatchJobRecoveryTest {
                 getExecutionVertex(MIDDLE_ID, 0, 
newScheduler.getExecutionGraph());
         triggerFailedByDataConsumptionException(newScheduler, firstMiddle0);
         // wait until reset done.
-        waitUntilExecutionVertexState(firstMiddle0, ExecutionState.CREATED, 
15000L);
+        ExecutionState expectedState =
+                isBlockingShuffle ? ExecutionState.CREATED : 
ExecutionState.DEPLOYING;
+        waitUntilExecutionVertexState(firstMiddle0, expectedState, 15000L);
         // Check whether the splits have been returned.
         runInMainThread(() -> checkUnassignedSplits(sourceCoordinator, 2));
 
@@ -717,7 +733,7 @@ public class BatchJobRecoveryTest {
                 getExecutionVertex(MIDDLE_ID, 1, 
newScheduler.getExecutionGraph());
         triggerFailedByDataConsumptionException(newScheduler, firstMiddle1);
         // wait until reset done.
-        waitUntilExecutionVertexState(firstMiddle1, ExecutionState.CREATED, 
15000L);
+        waitUntilExecutionVertexState(firstMiddle1, expectedState, 15000L);
 
         // Check whether the splits have been returned.
         runInMainThread(() -> checkUnassignedSplits(sourceCoordinator, 2));
@@ -969,8 +985,8 @@ public class BatchJobRecoveryTest {
      *
      * <p>Parallelism of source and middle is 5.
      *
-     * <p>Edge (source --> middle) is BLOCKING and POINTWISE. Edge (middle --> 
sink) is BLOCKING and
-     * ALL_TO_ALL.
+     * <p>Edge (source --> middle) is BLOCKING/HYBRID and POINTWISE. Edge 
(middle --> sink) is
+     * BLOCKING/HYBRID and ALL_TO_ALL.
      *
      * <p>Source has an operator coordinator.
      */
@@ -992,10 +1008,37 @@ public class BatchJobRecoveryTest {
         sink.setInvokableClass(NoOpInvokable.class);
         jobVertices.add(sink);
 
+        ResultPartitionType resultPartitionType =
+                isBlockingShuffle ? ResultPartitionType.BLOCKING : 
ResultPartitionType.HYBRID_FULL;
         connectNewDataSetAsInput(
-                middle, source, DistributionPattern.POINTWISE, 
ResultPartitionType.BLOCKING);
+                middle, source, DistributionPattern.POINTWISE, 
resultPartitionType);
+        connectNewDataSetAsInput(sink, middle, DistributionPattern.ALL_TO_ALL, 
resultPartitionType);
+
+        return new JobGraph(JOB_ID, "TestJob", jobVertices.toArray(new 
JobVertex[0]));
+    }
+
+    private JobGraph createDefaultHybridJobGraph() throws IOException {
+        List<JobVertex> jobVertices = new ArrayList<>();
+
+        final JobVertex source = new JobVertex("source", SOURCE_ID);
+        source.setInvokableClass(NoOpInvokable.class);
+        source.addOperatorCoordinator(new SerializedValue<>(provider));
+        source.setParallelism(SOURCE_PARALLELISM);
+        jobVertices.add(source);
+
+        final JobVertex middle = new JobVertex("middle", MIDDLE_ID);
+        middle.setInvokableClass(NoOpInvokable.class);
+        middle.setParallelism(MIDDLE_PARALLELISM);
+        jobVertices.add(middle);
+
+        final JobVertex sink = new JobVertex("sink", SINK_ID);
+        sink.setInvokableClass(NoOpInvokable.class);
+        jobVertices.add(sink);
+
         connectNewDataSetAsInput(
-                sink, middle, DistributionPattern.ALL_TO_ALL, 
ResultPartitionType.BLOCKING);
+                middle, source, DistributionPattern.POINTWISE, 
ResultPartitionType.HYBRID_FULL);
+        connectNewDataSetAsInput(
+                sink, middle, DistributionPattern.ALL_TO_ALL, 
ResultPartitionType.HYBRID_FULL);
 
         return new JobGraph(JOB_ID, "TestJob", jobVertices.toArray(new 
JobVertex[0]));
     }
@@ -1042,10 +1085,26 @@ public class BatchJobRecoveryTest {
             int defaultMaxParallelism,
             Duration jobRecoverySnapshotMinPause)
             throws Exception {
+        Configuration jobMasterConfig = new Configuration();
+        jobMasterConfig.set(
+                BatchExecutionOptions.JOB_RECOVERY_SNAPSHOT_MIN_PAUSE, 
jobRecoverySnapshotMinPause);
+        jobMasterConfig.set(BatchExecutionOptions.JOB_RECOVERY_ENABLED, true);
+        jobMasterConfig.set(
+                
BatchExecutionOptions.JOB_RECOVERY_PREVIOUS_WORKER_RECOVERY_TIMEOUT,
+                previousWorkerRecoveryTimeout);
+        if (!isBlockingShuffle) {
+            jobMasterConfig.set(
+                    ExecutionOptions.BATCH_SHUFFLE_MODE,
+                    BatchShuffleMode.ALL_EXCHANGES_HYBRID_FULL);
+            jobMasterConfig.set(
+                    NettyShuffleEnvironmentOptions
+                            
.NETWORK_HYBRID_SHUFFLE_EXTERNAL_REMOTE_TIER_FACTORY_CLASS_NAME,
+                    DummyTierFactory.class.getName());
+        }
 
         final ShuffleMaster<NettyShuffleDescriptor> shuffleMaster =
                 new NettyShuffleMaster(
-                        new ShuffleMasterContextImpl(new Configuration(), 
throwable -> {}));
+                        new ShuffleMasterContextImpl(jobMasterConfig, 
throwable -> {}));
         TestingJobMasterGateway jobMasterGateway =
                 new TestingJobMasterGatewayBuilder()
                         .setGetPartitionWithMetricsFunction(
@@ -1057,14 +1116,6 @@ public class BatchJobRecoveryTest {
                 new JobMasterPartitionTrackerImpl(
                         jobGraph.getJobID(), shuffleMaster, ignored -> 
Optional.empty());
 
-        Configuration jobMasterConfig = new Configuration();
-        jobMasterConfig.set(
-                BatchExecutionOptions.JOB_RECOVERY_SNAPSHOT_MIN_PAUSE, 
jobRecoverySnapshotMinPause);
-        jobMasterConfig.set(BatchExecutionOptions.JOB_RECOVERY_ENABLED, true);
-        jobMasterConfig.set(
-                
BatchExecutionOptions.JOB_RECOVERY_PREVIOUS_WORKER_RECOVERY_TIMEOUT,
-                previousWorkerRecoveryTimeout);
-
         DefaultSchedulerBuilder schedulerBuilder =
                 new DefaultSchedulerBuilder(
                                 jobGraph,
@@ -1150,12 +1201,7 @@ public class BatchJobRecoveryTest {
 
         @Override
         public ShuffleDescriptor getPartition() {
-            return new ShuffleDescriptor() {
-                @Override
-                public ResultPartitionID getResultPartitionID() {
-                    return resultPartitionID;
-                }
-
+            return new NettyShuffleDescriptor(ResourceID.generate(), null, 
resultPartitionID) {
                 @Override
                 public Optional<ResourceID> storesLocalResourcesOn() {
                     return Optional.empty();
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/tier/memory/MemoryTierFactory.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DummyTierFactory.java
similarity index 65%
copy from 
flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/tier/memory/MemoryTierFactory.java
copy to 
flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DummyTierFactory.java
index 8ef55c6b59e..6091ce0c3d4 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/tier/memory/MemoryTierFactory.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DummyTierFactory.java
@@ -16,69 +16,55 @@
  * limitations under the License.
  */
 
-package 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.tier.memory;
+package org.apache.flink.runtime.scheduler.adaptivebatch;
 
+import org.apache.flink.api.common.JobID;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.io.disk.BatchShuffleReadBufferPool;
 import org.apache.flink.runtime.io.network.buffer.BufferCompressor;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
 import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStoragePartitionId;
 import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.netty.TieredStorageNettyService;
 import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.TieredStorageConsumerSpec;
 import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.TieredStorageMemoryManager;
 import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.TieredStorageMemorySpec;
 import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.TieredStorageResourceRegistry;
-import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.tier.NoOpMasterAgent;
 import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.tier.TierConsumerAgent;
 import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.tier.TierFactory;
 import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.tier.TierMasterAgent;
 import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.tier.TierProducerAgent;
 import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.tier.TierShuffleDescriptor;
-import org.apache.flink.runtime.util.ConfigurationParserUtils;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.tier.TierShuffleHandler;
 
 import javax.annotation.Nullable;
 
 import java.util.List;
 import java.util.concurrent.ScheduledExecutorService;
 
-import static 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStorageUtils.getMemoryTierName;
-import static org.apache.flink.util.Preconditions.checkState;
-
-/** The implementation of {@link TierFactory} for memory tier. */
-public class MemoryTierFactory implements TierFactory {
-
-    private static final int DEFAULT_MEMORY_TIER_EXCLUSIVE_BUFFERS = 100;
-
-    private static final int 
DEFAULT_MEMORY_TIER_SUBPARTITION_MAX_QUEUED_BUFFERS = 3;
-
-    private static final int DEFAULT_MEMORY_TIER_NUM_BYTES_PER_SEGMENT = 2 * 
32 * 1024;
-
-    private int bufferSizeBytes = -1;
-
+/** Dummy {@link TierFactory} for testing purpose only. */
+public class DummyTierFactory implements TierFactory {
     @Override
-    public void setup(Configuration configuration) {
-        this.bufferSizeBytes = 
ConfigurationParserUtils.getPageSize(configuration);
-    }
+    public void setup(Configuration configuration) {}
 
     @Override
     public TieredStorageMemorySpec getMasterAgentMemorySpec() {
-        return new TieredStorageMemorySpec(getMemoryTierName(), 0);
+        return null;
     }
 
     @Override
     public TieredStorageMemorySpec getProducerAgentMemorySpec() {
-        return new TieredStorageMemorySpec(
-                getMemoryTierName(), DEFAULT_MEMORY_TIER_EXCLUSIVE_BUFFERS);
+        return null;
     }
 
     @Override
     public TieredStorageMemorySpec getConsumerAgentMemorySpec() {
-        return new TieredStorageMemorySpec(getMemoryTierName(), 0);
+        return null;
     }
 
     @Override
     public TierMasterAgent createMasterAgent(
             TieredStorageResourceRegistry tieredStorageResourceRegistry) {
-        return NoOpMasterAgent.INSTANCE;
+        return new DummyTierMasterAgent();
     }
 
     @Override
@@ -88,27 +74,15 @@ public class MemoryTierFactory implements TierFactory {
             TieredStoragePartitionId partitionID,
             String dataFileBasePath,
             boolean isBroadcastOnly,
-            TieredStorageMemoryManager memoryManager,
+            TieredStorageMemoryManager storageMemoryManager,
             TieredStorageNettyService nettyService,
             TieredStorageResourceRegistry resourceRegistry,
             BatchShuffleReadBufferPool bufferPool,
             ScheduledExecutorService ioExecutor,
             List<TierShuffleDescriptor> shuffleDescriptors,
-            int maxRequestedBuffers,
+            int maxRequestedBuffer,
             @Nullable BufferCompressor bufferCompressor) {
-        checkState(bufferSizeBytes > 0);
-
-        return new MemoryTierProducerAgent(
-                partitionID,
-                numSubpartitions,
-                bufferSizeBytes,
-                DEFAULT_MEMORY_TIER_NUM_BYTES_PER_SEGMENT,
-                DEFAULT_MEMORY_TIER_SUBPARTITION_MAX_QUEUED_BUFFERS,
-                isBroadcastOnly,
-                memoryManager,
-                nettyService,
-                resourceRegistry,
-                bufferCompressor);
+        return null;
     }
 
     @Override
@@ -116,6 +90,32 @@ public class MemoryTierFactory implements TierFactory {
             List<TieredStorageConsumerSpec> tieredStorageConsumerSpecs,
             List<TierShuffleDescriptor> shuffleDescriptors,
             TieredStorageNettyService nettyService) {
-        return new MemoryTierConsumerAgent(tieredStorageConsumerSpecs, 
nettyService);
+        return null;
+    }
+
+    @Override
+    public String identifier() {
+        return "dummy";
+    }
+
+    public static class DummyTierMasterAgent implements TierMasterAgent {
+
+        @Override
+        public void registerJob(JobID jobID, TierShuffleHandler 
tierShuffleHandler) {}
+
+        @Override
+        public void unregisterJob(JobID jobID) {}
+
+        @Override
+        public TierShuffleDescriptor addPartitionAndGetShuffleDescriptor(
+                JobID jobID, int numSubpartitions, ResultPartitionID 
resultPartitionID) {
+            return null;
+        }
+
+        @Override
+        public void releasePartition(TierShuffleDescriptor shuffleDescriptor) 
{}
+
+        @Override
+        public void close() {}
     }
 }

Reply via email to