This is an automated email from the ASF dual-hosted git repository.

rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 298eadb8f [CELEBORN-1645] Introduce ShuffleFallbackPolicy to support 
custom implementation of shuffle fallback policy for 
CelebornShuffleFallbackPolicyRunner
298eadb8f is described below

commit 298eadb8fdd03c4b05282f507fdd7e34b7386f3f
Author: SteNicholas <[email protected]>
AuthorDate: Tue Oct 15 21:57:04 2024 +0800

    [CELEBORN-1645] Introduce ShuffleFallbackPolicy to support custom 
implementation of shuffle fallback policy for 
CelebornShuffleFallbackPolicyRunner
    
    ### What changes were proposed in this pull request?
    
    Introduce `ShuffleFallbackPolicy` to support custom implementation of 
shuffle fallback policy for `CelebornShuffleFallbackPolicyRunner`.
    
    ### Why are the changes needed?
    
    Shuffle fallback policy is fixed in `CelebornShuffleFallbackPolicyRunner`, 
which could not support user-defined fallback policy implementation. It's 
proposed to introduce `ShuffleFallbackPolicy` to support custom implementation 
of shuffle fallback policy for `CelebornShuffleFallbackPolicyRunner`.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Introduce `ShuffleFallbackPolicy` interface to determine whether fallback 
to vanilla Spark built-in shuffle implementation.
    
    ```
    /**
     * The shuffle fallback policy determines whether fallback to vanilla Spark 
built-in shuffle
     * implementation.
     */
    public interface ShuffleFallbackPolicy {
    
      /**
       * Returns whether fallback to vanilla spark built-in shuffle 
implementation.
       *
       * param shuffleDependency The shuffle dependency of Spark.
       * param celebornConf The configuration of Celeborn.
       * param lifecycleManager The {link LifecycleManager} of Celeborn.
       * return Whether fallback to vanilla spark built-in shuffle 
implementation.
       */
      boolean needFallback(
          ShuffleDependency<?, ?, ?> shuffleDependency,
          CelebornConf celebornConf,
          LifecycleManager lifecycleManager);
    }
    ```
    
    ### How was this patch tested?
    
    `SparkShuffleManagerSuite`
    
    Closes #2807 from SteNicholas/CELEBORN-1645.
    
    Authored-by: SteNicholas <[email protected]>
    Signed-off-by: Shuang <[email protected]>
---
 .../shuffle/celeborn/ForceFallbackPolicy.java      | 58 ++++++++++++++++
 .../shuffle/celeborn/QuotaFallbackPolicy.java      | 61 ++++++++++++++++
 .../shuffle/celeborn/ShuffleFallbackPolicy.java    | 43 ++++++++++++
 .../celeborn/ShuffleFallbackPolicyFactory.java     | 46 ++++++++++++
 .../celeborn/ShufflePartitionsFallbackPolicy.java  | 59 ++++++++++++++++
 .../celeborn/WorkersAvailableFallbackPolicy.java   | 59 ++++++++++++++++
 .../shuffle/celeborn/SparkShuffleManager.java      |  3 +-
 .../CelebornShuffleFallbackPolicyRunner.scala      | 81 +++-------------------
 .../shuffle/celeborn/SparkShuffleManager.java      |  3 +-
 .../CelebornShuffleFallbackPolicyRunner.scala      | 81 +++-------------------
 .../celeborn/CelebornShuffleManagerSuite.scala     |  1 +
 11 files changed, 349 insertions(+), 146 deletions(-)

diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/ForceFallbackPolicy.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/ForceFallbackPolicy.java
new file mode 100644
index 000000000..63d29fd70
--- /dev/null
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/ForceFallbackPolicy.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.celeborn;
+
+import org.apache.spark.ShuffleDependency;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.celeborn.client.LifecycleManager;
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.protocol.FallbackPolicy;
+
+public class ForceFallbackPolicy implements ShuffleFallbackPolicy {
+
+  private static final Logger LOG = 
LoggerFactory.getLogger(ForceFallbackPolicy.class);
+
+  public static final ForceFallbackPolicy INSTANCE = new ForceFallbackPolicy();
+
+  /**
+   * If celeborn.client.spark.shuffle.fallback.policy is ALWAYS, fallback to 
spark built-in shuffle
+   * implementation.
+   *
+   * @param shuffleDependency The shuffle dependency of Spark.
+   * @param celebornConf The configuration of Celeborn.
+   * @param lifecycleManager The {@link LifecycleManager} of Celeborn.
+   * @return Return true if celeborn.client.spark.shuffle.fallback.policy is 
ALWAYS, otherwise
+   *     false.
+   */
+  @Override
+  public boolean needFallback(
+      ShuffleDependency<?, ?, ?> shuffleDependency,
+      CelebornConf celebornConf,
+      LifecycleManager lifecycleManager) {
+    FallbackPolicy shuffleFallbackPolicy = 
celebornConf.shuffleFallbackPolicy();
+    if (FallbackPolicy.ALWAYS.equals(shuffleFallbackPolicy)) {
+      LOG.warn(
+          "{} is {}, forcibly fallback to spark built-in shuffle 
implementation.",
+          CelebornConf.SPARK_SHUFFLE_FALLBACK_POLICY().key(),
+          FallbackPolicy.ALWAYS.name());
+    }
+    return FallbackPolicy.ALWAYS.equals(shuffleFallbackPolicy);
+  }
+}
diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/QuotaFallbackPolicy.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/QuotaFallbackPolicy.java
new file mode 100644
index 000000000..1778746d1
--- /dev/null
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/QuotaFallbackPolicy.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.celeborn;
+
+import org.apache.spark.ShuffleDependency;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.celeborn.client.LifecycleManager;
+import org.apache.celeborn.common.CelebornConf;
+import 
org.apache.celeborn.common.protocol.message.ControlMessages.CheckQuotaResponse;
+
+public class QuotaFallbackPolicy implements ShuffleFallbackPolicy {
+
+  private static final Logger LOG = 
LoggerFactory.getLogger(QuotaFallbackPolicy.class);
+
+  public static final QuotaFallbackPolicy INSTANCE = new QuotaFallbackPolicy();
+
+  /**
+   * If celeborn cluster exceeds current user's quota, fallback to spark 
built-in shuffle
+   * implementation.
+   *
+   * @param shuffleDependency The shuffle dependency of Spark.
+   * @param celebornConf The configuration of Celeborn.
+   * @param lifecycleManager The {@link LifecycleManager} of Celeborn.
+   * @return Whether celeborn cluster has no available space for current user.
+   */
+  @Override
+  public boolean needFallback(
+      ShuffleDependency<?, ?, ?> shuffleDependency,
+      CelebornConf celebornConf,
+      LifecycleManager lifecycleManager) {
+    if (!celebornConf.quotaEnabled()) {
+      return false;
+    }
+    CheckQuotaResponse response = lifecycleManager.checkQuota();
+    boolean needFallback = !response.isAvailable();
+    if (needFallback) {
+      LOG.warn(
+          "Quota exceeds for current user {}. Because {}",
+          lifecycleManager.getUserIdentifier(),
+          response.reason());
+    }
+    return needFallback;
+  }
+}
diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/ShuffleFallbackPolicy.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/ShuffleFallbackPolicy.java
new file mode 100644
index 000000000..0b1d68929
--- /dev/null
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/ShuffleFallbackPolicy.java
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.celeborn;
+
+import org.apache.spark.ShuffleDependency;
+
+import org.apache.celeborn.client.LifecycleManager;
+import org.apache.celeborn.common.CelebornConf;
+
+/**
+ * The shuffle fallback policy determines whether fallback to vanilla Spark 
built-in shuffle
+ * implementation.
+ */
+public interface ShuffleFallbackPolicy {
+
+  /**
+   * Returns whether fallback to vanilla spark built-in shuffle implementation.
+   *
+   * @param shuffleDependency The shuffle dependency of Spark.
+   * @param celebornConf The configuration of Celeborn.
+   * @param lifecycleManager The {@link LifecycleManager} of Celeborn.
+   * @return Whether fallback to vanilla spark built-in shuffle implementation.
+   */
+  boolean needFallback(
+      ShuffleDependency<?, ?, ?> shuffleDependency,
+      CelebornConf celebornConf,
+      LifecycleManager lifecycleManager);
+}
diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/ShuffleFallbackPolicyFactory.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/ShuffleFallbackPolicyFactory.java
new file mode 100644
index 000000000..2641df7f9
--- /dev/null
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/ShuffleFallbackPolicyFactory.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.celeborn;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.ServiceLoader;
+
+public class ShuffleFallbackPolicyFactory {
+
+  public static List<ShuffleFallbackPolicy> getShuffleFallbackPolicies() {
+    List<ShuffleFallbackPolicy> shuffleFallbackPolicies = new ArrayList<>();
+    // Loading order of ShuffleFallbackPolicy should be ForceFallbackPolicy,
+    // ShufflePartitionsFallbackPolicy, QuotaFallbackPolicy, 
WorkersAvailableFallbackPolicy, Custom
+    // to reduce unnecessary RPC for check whether to fallback.
+    shuffleFallbackPolicies.add(ForceFallbackPolicy.INSTANCE);
+    shuffleFallbackPolicies.add(ShufflePartitionsFallbackPolicy.INSTANCE);
+    shuffleFallbackPolicies.add(QuotaFallbackPolicy.INSTANCE);
+    shuffleFallbackPolicies.add(WorkersAvailableFallbackPolicy.INSTANCE);
+    for (ShuffleFallbackPolicy shuffleFallbackPolicy :
+        ServiceLoader.load(ShuffleFallbackPolicy.class)) {
+      if (!(shuffleFallbackPolicy instanceof ForceFallbackPolicy
+          || shuffleFallbackPolicy instanceof ShufflePartitionsFallbackPolicy
+          || shuffleFallbackPolicy instanceof QuotaFallbackPolicy
+          || shuffleFallbackPolicy instanceof WorkersAvailableFallbackPolicy)) 
{
+        shuffleFallbackPolicies.add(shuffleFallbackPolicy);
+      }
+    }
+    return shuffleFallbackPolicies;
+  }
+}
diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/ShufflePartitionsFallbackPolicy.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/ShufflePartitionsFallbackPolicy.java
new file mode 100644
index 000000000..a405a64c8
--- /dev/null
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/ShufflePartitionsFallbackPolicy.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.celeborn;
+
+import org.apache.spark.ShuffleDependency;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.celeborn.client.LifecycleManager;
+import org.apache.celeborn.common.CelebornConf;
+
+public class ShufflePartitionsFallbackPolicy implements ShuffleFallbackPolicy {
+
+  private static final Logger LOG = 
LoggerFactory.getLogger(ShufflePartitionsFallbackPolicy.class);
+
+  public static final ShufflePartitionsFallbackPolicy INSTANCE =
+      new ShufflePartitionsFallbackPolicy();
+
+  /**
+   * If shuffle partitions > celeborn.shuffle.fallback.numPartitionsThreshold, 
fallback to spark
+   * built-in shuffle implementation.
+   *
+   * @param shuffleDependency The shuffle dependency of Spark.
+   * @param celebornConf The configuration of Celeborn.
+   * @param lifecycleManager The {@link LifecycleManager} of Celeborn.
+   * @return Return true if shuffle partitions are greater than limit, 
otherwise false.
+   */
+  @Override
+  public boolean needFallback(
+      ShuffleDependency<?, ?, ?> shuffleDependency,
+      CelebornConf celebornConf,
+      LifecycleManager lifecycleManager) {
+    int numPartitions = shuffleDependency.partitioner().numPartitions();
+    long numPartitionsThreshold = 
celebornConf.shuffleFallbackPartitionThreshold();
+    boolean needFallback = numPartitions >= numPartitionsThreshold;
+    if (needFallback) {
+      LOG.warn(
+          "Shuffle partition number {} exceeds threshold {}, fallback to spark 
built-in shuffle implementation.",
+          numPartitions,
+          numPartitionsThreshold);
+    }
+    return needFallback;
+  }
+}
diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/WorkersAvailableFallbackPolicy.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/WorkersAvailableFallbackPolicy.java
new file mode 100644
index 000000000..b0bdd9f5d
--- /dev/null
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/WorkersAvailableFallbackPolicy.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.celeborn;
+
+import org.apache.spark.ShuffleDependency;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.celeborn.client.LifecycleManager;
+import org.apache.celeborn.common.CelebornConf;
+
+public class WorkersAvailableFallbackPolicy implements ShuffleFallbackPolicy {
+
+  private static final Logger LOG = 
LoggerFactory.getLogger(WorkersAvailableFallbackPolicy.class);
+
+  public static final WorkersAvailableFallbackPolicy INSTANCE =
+      new WorkersAvailableFallbackPolicy();
+
+  /**
+   * If celeborn cluster has no available workers, fallback to spark built-in 
shuffle
+   * implementation.
+   *
+   * @param shuffleDependency The shuffle dependency of Spark.
+   * @param celebornConf The configuration of Celeborn.
+   * @param lifecycleManager The {@link LifecycleManager} of Celeborn.
+   * @return Whether celeborn cluster has no available workers.
+   */
+  @Override
+  public boolean needFallback(
+      ShuffleDependency<?, ?, ?> shuffleDependency,
+      CelebornConf celebornConf,
+      LifecycleManager lifecycleManager) {
+    if (!celebornConf.checkWorkerEnabled()) {
+      return false;
+    }
+    boolean needFallback = 
!lifecycleManager.checkWorkersAvailable().getAvailable();
+    if (needFallback) {
+      LOG.warn(
+          "No celeborn workers available for current user {}.",
+          lifecycleManager.getUserIdentifier());
+    }
+    return needFallback;
+  }
+}
diff --git 
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
 
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
index 4f6e835e7..861070830 100644
--- 
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
+++ 
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
@@ -120,8 +120,7 @@ public class SparkShuffleManager implements ShuffleManager {
         shuffleId,
         
!DeterministicLevel.INDETERMINATE().equals(dependency.rdd().getOutputDeterministicLevel()));
 
-    if (fallbackPolicyRunner.applyAllFallbackPolicy(
-        lifecycleManager, dependency.partitioner().numPartitions())) {
+    if (fallbackPolicyRunner.applyFallbackPolicies(dependency, 
lifecycleManager)) {
       logger.warn("Fallback to SortShuffleManager!");
       sortShuffleIds.add(shuffleId);
       return sortShuffleManager().registerShuffle(shuffleId, numMaps, 
dependency);
diff --git 
a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleFallbackPolicyRunner.scala
 
b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleFallbackPolicyRunner.scala
index 48e0825e6..c1e93ec76 100644
--- 
a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleFallbackPolicyRunner.scala
+++ 
b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleFallbackPolicyRunner.scala
@@ -17,6 +17,10 @@
 
 package org.apache.spark.shuffle.celeborn
 
+import scala.collection.JavaConverters._
+
+import org.apache.spark.ShuffleDependency
+
 import org.apache.celeborn.client.LifecycleManager
 import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.exception.CelebornIOException
@@ -25,83 +29,18 @@ import org.apache.celeborn.common.protocol.FallbackPolicy
 
 class CelebornShuffleFallbackPolicyRunner(conf: CelebornConf) extends Logging {
   private val shuffleFallbackPolicy = conf.shuffleFallbackPolicy
-  private val checkWorkerEnabled = conf.checkWorkerEnabled
-  private val quotaEnabled = conf.quotaEnabled
-  private val numPartitionsThreshold = conf.shuffleFallbackPartitionThreshold
+  private val shuffleFallbackPolicies =
+    ShuffleFallbackPolicyFactory.getShuffleFallbackPolicies.asScala
 
-  def applyAllFallbackPolicy(lifecycleManager: LifecycleManager, 
numPartitions: Int): Boolean = {
+  def applyFallbackPolicies[K, V, C](
+      dependency: ShuffleDependency[K, V, C],
+      lifecycleManager: LifecycleManager): Boolean = {
     val needFallback =
-      applyForceFallbackPolicy() || 
applyShufflePartitionsFallbackPolicy(numPartitions) ||
-        !checkQuota(lifecycleManager) || 
!checkWorkersAvailable(lifecycleManager)
+      shuffleFallbackPolicies.exists(_.needFallback(dependency, conf, 
lifecycleManager))
     if (needFallback && FallbackPolicy.NEVER.equals(shuffleFallbackPolicy)) {
       throw new CelebornIOException(
         "Fallback to spark built-in shuffle implementation is prohibited.")
     }
     needFallback
   }
-
-  /**
-   * if celeborn.client.spark.shuffle.fallback.policy is ALWAYS, fallback to 
spark built-in shuffle implementation
-   * @return return true if celeborn.client.spark.shuffle.fallback.policy is 
ALWAYS, otherwise false
-   */
-  def applyForceFallbackPolicy(): Boolean = {
-    if (FallbackPolicy.ALWAYS.equals(shuffleFallbackPolicy)) {
-      logWarning(
-        s"${CelebornConf.SPARK_SHUFFLE_FALLBACK_POLICY.key} is 
${FallbackPolicy.ALWAYS.name}, " +
-          s"forcibly fallback to spark built-in shuffle implementation.")
-    }
-    FallbackPolicy.ALWAYS.equals(shuffleFallbackPolicy)
-  }
-
-  /**
-   * if shuffle partitions > celeborn.shuffle.fallback.numPartitionsThreshold, 
fallback to spark built-in
-   * shuffle implementation
-   * @param numPartitions shuffle partitions
-   * @return return true if shuffle partitions bigger than limit, otherwise 
false
-   */
-  def applyShufflePartitionsFallbackPolicy(numPartitions: Int): Boolean = {
-    val needFallback = numPartitions >= numPartitionsThreshold
-    if (needFallback) {
-      logWarning(
-        s"Shuffle partition number: $numPartitions exceeds threshold: 
$numPartitionsThreshold, " +
-          "need to fallback to spark built-in shuffle implementation.")
-    }
-    needFallback
-  }
-
-  /**
-   * If celeborn cluster is exceed current user's quota, fallback to spark 
built-in shuffle implementation
-   *
-   * @return if celeborn cluster have available space for current user
-   */
-  def checkQuota(lifecycleManager: LifecycleManager): Boolean = {
-    if (!quotaEnabled) {
-      return true
-    }
-
-    val resp = lifecycleManager.checkQuota()
-    if (!resp.isAvailable) {
-      logWarning(
-        s"Quota exceed for current user ${lifecycleManager.getUserIdentifier}. 
Because: ${resp.reason}")
-    }
-    resp.isAvailable
-  }
-
-  /**
-   * If celeborn cluster has no available workers, fallback to spark built-in 
shuffle implementation
-   *
-   * @return if celeborn cluster has available workers.
-   */
-  def checkWorkersAvailable(lifecycleManager: LifecycleManager): Boolean = {
-    if (!checkWorkerEnabled) {
-      return true
-    }
-
-    val resp = lifecycleManager.checkWorkersAvailable()
-    if (!resp.getAvailable) {
-      logWarning(
-        s"No celeborn workers available for current user 
${lifecycleManager.getUserIdentifier}.")
-    }
-    resp.getAvailable
-  }
 }
diff --git 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
index da785886c..a2a6f7a37 100644
--- 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
+++ 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
@@ -163,8 +163,7 @@ public class SparkShuffleManager implements ShuffleManager {
         shuffleId,
         
!DeterministicLevel.INDETERMINATE().equals(dependency.rdd().getOutputDeterministicLevel()));
 
-    if (fallbackPolicyRunner.applyAllFallbackPolicy(
-        lifecycleManager, dependency.partitioner().numPartitions())) {
+    if (fallbackPolicyRunner.applyFallbackPolicies(dependency, 
lifecycleManager)) {
       if (conf.getBoolean("spark.dynamicAllocation.enabled", false)
           && !conf.getBoolean("spark.shuffle.service.enabled", false)) {
         logger.error(
diff --git 
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleFallbackPolicyRunner.scala
 
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleFallbackPolicyRunner.scala
index 48e0825e6..c1e93ec76 100644
--- 
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleFallbackPolicyRunner.scala
+++ 
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleFallbackPolicyRunner.scala
@@ -17,6 +17,10 @@
 
 package org.apache.spark.shuffle.celeborn
 
+import scala.collection.JavaConverters._
+
+import org.apache.spark.ShuffleDependency
+
 import org.apache.celeborn.client.LifecycleManager
 import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.exception.CelebornIOException
@@ -25,83 +29,18 @@ import org.apache.celeborn.common.protocol.FallbackPolicy
 
 class CelebornShuffleFallbackPolicyRunner(conf: CelebornConf) extends Logging {
   private val shuffleFallbackPolicy = conf.shuffleFallbackPolicy
-  private val checkWorkerEnabled = conf.checkWorkerEnabled
-  private val quotaEnabled = conf.quotaEnabled
-  private val numPartitionsThreshold = conf.shuffleFallbackPartitionThreshold
+  private val shuffleFallbackPolicies =
+    ShuffleFallbackPolicyFactory.getShuffleFallbackPolicies.asScala
 
-  def applyAllFallbackPolicy(lifecycleManager: LifecycleManager, 
numPartitions: Int): Boolean = {
+  def applyFallbackPolicies[K, V, C](
+      dependency: ShuffleDependency[K, V, C],
+      lifecycleManager: LifecycleManager): Boolean = {
     val needFallback =
-      applyForceFallbackPolicy() || 
applyShufflePartitionsFallbackPolicy(numPartitions) ||
-        !checkQuota(lifecycleManager) || 
!checkWorkersAvailable(lifecycleManager)
+      shuffleFallbackPolicies.exists(_.needFallback(dependency, conf, 
lifecycleManager))
     if (needFallback && FallbackPolicy.NEVER.equals(shuffleFallbackPolicy)) {
       throw new CelebornIOException(
         "Fallback to spark built-in shuffle implementation is prohibited.")
     }
     needFallback
   }
-
-  /**
-   * if celeborn.client.spark.shuffle.fallback.policy is ALWAYS, fallback to 
spark built-in shuffle implementation
-   * @return return true if celeborn.client.spark.shuffle.fallback.policy is 
ALWAYS, otherwise false
-   */
-  def applyForceFallbackPolicy(): Boolean = {
-    if (FallbackPolicy.ALWAYS.equals(shuffleFallbackPolicy)) {
-      logWarning(
-        s"${CelebornConf.SPARK_SHUFFLE_FALLBACK_POLICY.key} is 
${FallbackPolicy.ALWAYS.name}, " +
-          s"forcibly fallback to spark built-in shuffle implementation.")
-    }
-    FallbackPolicy.ALWAYS.equals(shuffleFallbackPolicy)
-  }
-
-  /**
-   * if shuffle partitions > celeborn.shuffle.fallback.numPartitionsThreshold, 
fallback to spark built-in
-   * shuffle implementation
-   * @param numPartitions shuffle partitions
-   * @return return true if shuffle partitions bigger than limit, otherwise 
false
-   */
-  def applyShufflePartitionsFallbackPolicy(numPartitions: Int): Boolean = {
-    val needFallback = numPartitions >= numPartitionsThreshold
-    if (needFallback) {
-      logWarning(
-        s"Shuffle partition number: $numPartitions exceeds threshold: 
$numPartitionsThreshold, " +
-          "need to fallback to spark built-in shuffle implementation.")
-    }
-    needFallback
-  }
-
-  /**
-   * If celeborn cluster is exceed current user's quota, fallback to spark 
built-in shuffle implementation
-   *
-   * @return if celeborn cluster have available space for current user
-   */
-  def checkQuota(lifecycleManager: LifecycleManager): Boolean = {
-    if (!quotaEnabled) {
-      return true
-    }
-
-    val resp = lifecycleManager.checkQuota()
-    if (!resp.isAvailable) {
-      logWarning(
-        s"Quota exceed for current user ${lifecycleManager.getUserIdentifier}. 
Because: ${resp.reason}")
-    }
-    resp.isAvailable
-  }
-
-  /**
-   * If celeborn cluster has no available workers, fallback to spark built-in 
shuffle implementation
-   *
-   * @return if celeborn cluster has available workers.
-   */
-  def checkWorkersAvailable(lifecycleManager: LifecycleManager): Boolean = {
-    if (!checkWorkerEnabled) {
-      return true
-    }
-
-    val resp = lifecycleManager.checkWorkersAvailable()
-    if (!resp.getAvailable) {
-      logWarning(
-        s"No celeborn workers available for current user 
${lifecycleManager.getUserIdentifier}.")
-    }
-    resp.getAvailable
-  }
 }
diff --git 
a/client-spark/spark-3/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleManagerSuite.scala
 
b/client-spark/spark-3/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleManagerSuite.scala
index 7b790cbbf..bded50941 100644
--- 
a/client-spark/spark-3/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleManagerSuite.scala
+++ 
b/client-spark/spark-3/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleManagerSuite.scala
@@ -69,6 +69,7 @@ class SparkShuffleManagerSuite extends Logging {
     sc.stop()
   }
 
+  @junit.Test
   def testChangeWriteModeByPartitionCount(): Unit = {
     val conf = new SparkConf().setIfMissing("spark.master", "local")
       .setIfMissing(

Reply via email to