This is an automated email from the ASF dual-hosted git repository.

angerszhuuuu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new c8ad39d9b [CELEBORN-809] Directly use isDriver passed from SparkEnv
c8ad39d9b is described below

commit c8ad39d9bddde7f07214b4aa91074ca0840d738c
Author: Angerszhuuuu <[email protected]>
AuthorDate: Wed Jul 19 15:20:01 2023 +0800

    [CELEBORN-809] Directly use isDriver passed from SparkEnv
    
    ### What changes were proposed in this pull request?
    As title
    <img width="1051" alt="截屏2023-07-19 下午1 01 25" 
src="https://github.com/apache/incubator-celeborn/assets/46485123/26d506b2-bab9-43f5-9bbe-58d22a761bab";>
    
    ### Why are the changes needed?
    
    ### Does this PR introduce _any_ user-facing change?
    
    ### How was this patch tested?
    
    Closes #1732 from AngersZhuuuu/CELEBORN-809.
    
    Authored-by: Angerszhuuuu <[email protected]>
    Signed-off-by: Angerszhuuuu <[email protected]>
---
 .../apache/spark/shuffle/celeborn/RssShuffleManager.java  |  4 ++--
 .../spark/shuffle/celeborn/SparkShuffleManager.java       | 15 ++++++---------
 .../apache/spark/shuffle/celeborn/RssShuffleManager.java  |  4 ++--
 .../spark/shuffle/celeborn/SparkShuffleManager.java       | 14 ++++++--------
 4 files changed, 16 insertions(+), 21 deletions(-)

diff --git 
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/RssShuffleManager.java
 
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/RssShuffleManager.java
index 2df1faaf7..c12a908cd 100644
--- 
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/RssShuffleManager.java
+++ 
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/RssShuffleManager.java
@@ -21,7 +21,7 @@ import org.apache.spark.SparkConf;
 
 public class RssShuffleManager extends SparkShuffleManager {
 
-  public RssShuffleManager(SparkConf conf) {
-    super(conf);
+  public RssShuffleManager(SparkConf conf, boolean isDriver) {
+    super(conf, isDriver);
   }
 }
diff --git 
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
 
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
index 95641a6fc..4383fea88 100644
--- 
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
+++ 
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
@@ -47,6 +47,7 @@ public class SparkShuffleManager implements ShuffleManager {
       "org.apache.spark.shuffle.sort.SortShuffleManager";
 
   private final SparkConf conf;
+  private final Boolean isDriver;
   private final CelebornConf celebornConf;
   private final int cores;
   // either be "{appId}_{appAttemptId}" or "{appId}"
@@ -62,8 +63,9 @@ public class SparkShuffleManager implements ShuffleManager {
   private final ExecutorService[] asyncPushers;
   private AtomicInteger pusherIdx = new AtomicInteger(0);
 
-  public SparkShuffleManager(SparkConf conf) {
+  public SparkShuffleManager(SparkConf conf, boolean isDriver) {
     this.conf = conf;
+    this.isDriver = isDriver;
     this.celebornConf = SparkUtils.fromSparkConf(conf);
     this.cores = conf.getInt(SparkLauncher.EXECUTOR_CORES, 1);
     this.fallbackPolicyRunner = new 
CelebornShuffleFallbackPolicyRunner(celebornConf);
@@ -78,16 +80,11 @@ public class SparkShuffleManager implements ShuffleManager {
     }
   }
 
-  private boolean isDriver() {
-    return "driver".equals(SparkEnv.get().executorId());
-  }
-
   private SortShuffleManager sortShuffleManager() {
     if (_sortShuffleManager == null) {
       synchronized (this) {
         if (_sortShuffleManager == null) {
-          _sortShuffleManager =
-              SparkUtils.instantiateClass(sortShuffleManagerName, conf, 
isDriver());
+          _sortShuffleManager = 
SparkUtils.instantiateClass(sortShuffleManagerName, conf, isDriver);
         }
       }
     }
@@ -99,7 +96,7 @@ public class SparkShuffleManager implements ShuffleManager {
     // need to ensure that LifecycleManager will only be created once. 
Parallelism needs to be
     // considered in this place, because if there is one RDD that depends on 
multiple RDDs
     // at the same time, it may bring parallel `register shuffle`, such as 
Join in Sql.
-    if (isDriver() && lifecycleManager == null) {
+    if (isDriver && lifecycleManager == null) {
       synchronized (this) {
         if (lifecycleManager == null) {
           lifecycleManager = new LifecycleManager(appId, celebornConf);
@@ -152,7 +149,7 @@ public class SparkShuffleManager implements ShuffleManager {
     if (shuffleClient == null) {
       return false;
     }
-    return shuffleClient.unregisterShuffle(shuffleId, isDriver());
+    return shuffleClient.unregisterShuffle(shuffleId, isDriver);
   }
 
   @Override
diff --git 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/RssShuffleManager.java
 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/RssShuffleManager.java
index 2df1faaf7..c12a908cd 100644
--- 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/RssShuffleManager.java
+++ 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/RssShuffleManager.java
@@ -21,7 +21,7 @@ import org.apache.spark.SparkConf;
 
 public class RssShuffleManager extends SparkShuffleManager {
 
-  public RssShuffleManager(SparkConf conf) {
-    super(conf);
+  public RssShuffleManager(SparkConf conf, boolean isDriver) {
+    super(conf, isDriver);
   }
 }
diff --git 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
index 9695fd8ac..79ac75df1 100644
--- 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
+++ 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
@@ -48,6 +48,7 @@ public class SparkShuffleManager implements ShuffleManager {
       "spark.sql.adaptive.localShuffleReader.enabled";
 
   private final SparkConf conf;
+  private final Boolean isDriver;
   private final CelebornConf celebornConf;
   private final int cores;
   // either be "{appId}_{appAttemptId}" or "{appId}"
@@ -63,7 +64,7 @@ public class SparkShuffleManager implements ShuffleManager {
   private final ExecutorService[] asyncPushers;
   private AtomicInteger pusherIdx = new AtomicInteger(0);
 
-  public SparkShuffleManager(SparkConf conf) {
+  public SparkShuffleManager(SparkConf conf, boolean isDriver) {
     if (conf.getBoolean(LOCAL_SHUFFLE_READER_KEY, true)) {
       logger.warn(
           "Detected {} (default is true) is enabled, it's highly recommended 
to disable it when "
@@ -71,6 +72,7 @@ public class SparkShuffleManager implements ShuffleManager {
           LOCAL_SHUFFLE_READER_KEY);
     }
     this.conf = conf;
+    this.isDriver = isDriver;
     this.celebornConf = SparkUtils.fromSparkConf(conf);
     this.cores = executorCores(conf);
     this.fallbackPolicyRunner = new 
CelebornShuffleFallbackPolicyRunner(celebornConf);
@@ -85,16 +87,12 @@ public class SparkShuffleManager implements ShuffleManager {
     }
   }
 
-  private boolean isDriver() {
-    return "driver".equals(SparkEnv.get().executorId());
-  }
-
   private SortShuffleManager sortShuffleManager() {
     if (_sortShuffleManager == null) {
       synchronized (this) {
         if (_sortShuffleManager == null) {
           _sortShuffleManager =
-              SparkUtils.instantiateClass(SORT_SHUFFLE_MANAGER_NAME, conf, 
isDriver());
+              SparkUtils.instantiateClass(SORT_SHUFFLE_MANAGER_NAME, conf, 
isDriver);
         }
       }
     }
@@ -106,7 +104,7 @@ public class SparkShuffleManager implements ShuffleManager {
     // need to ensure that LifecycleManager will only be created once. 
Parallelism needs to be
     // considered in this place, because if there is one RDD that depends on 
multiple RDDs
     // at the same time, it may bring parallel `register shuffle`, such as 
Join in Sql.
-    if (isDriver() && lifecycleManager == null) {
+    if (isDriver && lifecycleManager == null) {
       synchronized (this) {
         if (lifecycleManager == null) {
           lifecycleManager = new LifecycleManager(appUniqueId, celebornConf);
@@ -159,7 +157,7 @@ public class SparkShuffleManager implements ShuffleManager {
     if (shuffleClient == null) {
       return false;
     }
-    return shuffleClient.unregisterShuffle(shuffleId, isDriver());
+    return shuffleClient.unregisterShuffle(shuffleId, isDriver);
   }
 
   @Override

Reply via email to