This is an automated email from the ASF dual-hosted git repository.

zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 735650e63 [CELEBORN-1211] Add extension for celeborn shuffle handler
735650e63 is described below

commit 735650e634508edda1e335751fbadb8f6054170a
Author: mingji <[email protected]>
AuthorDate: Fri Jan 5 15:56:29 2024 +0800

    [CELEBORN-1211] Add extension for celeborn shuffle handler
    
    ### What changes were proposed in this pull request?
    1. Add extension API to CelebornShuffleHandler.
    
    ### Why are the changes needed?
    
    ### Does this PR introduce _any_ user-facing change?
    NO.
    
    ### How was this patch tested?
    GA.
    
    Closes #2206 from FMX/b1211.
    
    Authored-by: mingji <[email protected]>
    Signed-off-by: zky.zhoukeyong <[email protected]>
---
 .../shuffle/celeborn/SparkShuffleManager.java      |  3 ++-
 .../shuffle/celeborn/CelebornShuffleHandle.scala   | 24 ++++++++++++++++++++--
 .../shuffle/celeborn/CelebornShuffleReader.scala   |  3 ++-
 .../shuffle/celeborn/SparkShuffleManager.java      |  3 ++-
 .../shuffle/celeborn/CelebornShuffleHandle.scala   | 24 ++++++++++++++++++++--
 .../shuffle/celeborn/CelebornShuffleReader.scala   |  3 ++-
 .../org/apache/celeborn/client/ShuffleClient.java  | 18 ++++++++++++++++
 .../apache/celeborn/client/ShuffleClientImpl.java  |  7 +++++++
 .../apache/celeborn/client/DummyShuffleClient.java |  3 +++
 .../tests/spark/CelebornFetchFailureSuite.scala    |  3 ++-
 10 files changed, 82 insertions(+), 9 deletions(-)

diff --git 
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
 
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
index df2dee534..cda992b29 100644
--- 
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
+++ 
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
@@ -185,7 +185,8 @@ public class SparkShuffleManager implements ShuffleManager {
                 h.lifecycleManagerHost(),
                 h.lifecycleManagerPort(),
                 celebornConf,
-                h.userIdentifier());
+                h.userIdentifier(),
+                h.extension());
         int shuffleId = SparkUtils.celebornShuffleId(client, h, context, true);
         shuffleIdTracker.track(h.shuffleId(), shuffleId);
 
diff --git 
a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala
 
b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala
index 4f67edaf3..4ae52720c 100644
--- 
a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala
+++ 
b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala
@@ -30,5 +30,25 @@ class CelebornShuffleHandle[K, V, C](
     shuffleId: Int,
     val throwsFetchFailure: Boolean,
     numMappers: Int,
-    dependency: ShuffleDependency[K, V, C])
-  extends BaseShuffleHandle(shuffleId, numMappers, dependency)
+    dependency: ShuffleDependency[K, V, C],
+    val extension: Array[Byte])
+  extends BaseShuffleHandle(shuffleId, numMappers, dependency) {
+  def this(
+      appUniqueId: String,
+      lifecycleManagerHost: String,
+      lifecycleManagerPort: Int,
+      userIdentifier: UserIdentifier,
+      shuffleId: Int,
+      throwsFetchFailure: Boolean,
+      numMappers: Int,
+      dependency: ShuffleDependency[K, V, C]) = this(
+    appUniqueId,
+    lifecycleManagerHost,
+    lifecycleManagerPort,
+    userIdentifier,
+    shuffleId,
+    throwsFetchFailure,
+    numMappers,
+    dependency,
+    null)
+}
diff --git 
a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
 
b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
index dec305225..ad1a620bd 100644
--- 
a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
+++ 
b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
@@ -52,7 +52,8 @@ class CelebornShuffleReader[K, C](
     handle.lifecycleManagerHost,
     handle.lifecycleManagerPort,
     conf,
-    handle.userIdentifier)
+    handle.userIdentifier,
+    handle.extension)
 
   private val exceptionRef = new AtomicReference[IOException]
 
diff --git 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
index 7cf7b0979..264bf7766 100644
--- 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
+++ 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
@@ -233,7 +233,8 @@ public class SparkShuffleManager implements ShuffleManager {
                 h.lifecycleManagerHost(),
                 h.lifecycleManagerPort(),
                 celebornConf,
-                h.userIdentifier());
+                h.userIdentifier(),
+                h.extension());
         int shuffleId = SparkUtils.celebornShuffleId(shuffleClient, h, 
context, true);
         shuffleIdTracker.track(h.shuffleId(), shuffleId);
 
diff --git 
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala
 
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala
index 18a3053e0..3d0180a26 100644
--- 
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala
+++ 
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala
@@ -30,5 +30,25 @@ class CelebornShuffleHandle[K, V, C](
     shuffleId: Int,
     val throwsFetchFailure: Boolean,
     val numMappers: Int,
-    dependency: ShuffleDependency[K, V, C])
-  extends BaseShuffleHandle(shuffleId, dependency)
+    dependency: ShuffleDependency[K, V, C],
+    val extension: Array[Byte])
+  extends BaseShuffleHandle(shuffleId, dependency) {
+  def this(
+      appUniqueId: String,
+      lifecycleManagerHost: String,
+      lifecycleManagerPort: Int,
+      userIdentifier: UserIdentifier,
+      shuffleId: Int,
+      throwsFetchFailure: Boolean,
+      numMappers: Int,
+      dependency: ShuffleDependency[K, V, C]) = this(
+    appUniqueId,
+    lifecycleManagerHost,
+    lifecycleManagerPort,
+    userIdentifier,
+    shuffleId,
+    throwsFetchFailure,
+    numMappers,
+    dependency,
+    null)
+}
diff --git 
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
 
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
index fe7af8309..663983b58 100644
--- 
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
+++ 
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
@@ -53,7 +53,8 @@ class CelebornShuffleReader[K, C](
     handle.lifecycleManagerHost,
     handle.lifecycleManagerPort,
     conf,
-    handle.userIdentifier)
+    handle.userIdentifier,
+    handle.extension)
 
   private val exceptionRef = new AtomicReference[IOException]
 
diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java 
b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
index 72230a536..af7b4e3d5 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
@@ -61,6 +61,16 @@ public abstract class ShuffleClient {
       int port,
       CelebornConf conf,
       UserIdentifier userIdentifier) {
+    return ShuffleClient.get(appUniqueId, driverHost, port, conf, 
userIdentifier, null);
+  }
+
+  public static ShuffleClient get(
+      String appUniqueId,
+      String driverHost,
+      int port,
+      CelebornConf conf,
+      UserIdentifier userIdentifier,
+      byte[] extension) {
     if (null == _instance || !initialized) {
       synchronized (ShuffleClient.class) {
         if (null == _instance) {
@@ -72,11 +82,13 @@ public abstract class ShuffleClient {
           // when communicating with LifecycleManager, it will cause a 
NullPointerException.
           _instance = new ShuffleClientImpl(appUniqueId, conf, userIdentifier);
           _instance.setupLifecycleManagerRef(driverHost, port);
+          _instance.setExtension(extension);
           initialized = true;
         } else if (!initialized) {
           _instance.shutdown();
           _instance = new ShuffleClientImpl(appUniqueId, conf, userIdentifier);
           _instance.setupLifecycleManagerRef(driverHost, port);
+          _instance.setExtension(extension);
           initialized = true;
         }
       }
@@ -122,6 +134,12 @@ public abstract class ShuffleClient {
 
   public abstract void setupLifecycleManagerRef(RpcEndpointRef endpointRef);
 
+  /**
+   * @param extension Extension for shuffle client, it's a byte array. Used in 
derived shuffle
+   *     client implementation.
+   */
+  public abstract void setExtension(byte[] extension);
+
   /**
    * Write data to a specific reduce partition
    *
diff --git 
a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java 
b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index 27aaf1817..a7b4e7083 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -86,6 +86,8 @@ public class ShuffleClientImpl extends ShuffleClient {
 
   protected final int BATCH_HEADER_SIZE = 4 * 4;
 
+  protected byte[] extension;
+
   // key: appShuffleIdentifier, value: shuffleId
   protected Map<String, Integer> shuffleIdCache = 
JavaUtils.newConcurrentHashMap();
 
@@ -1703,6 +1705,11 @@ public class ShuffleClientImpl extends ShuffleClient {
     lifecycleManagerRef = endpointRef;
   }
 
+  @Override
+  public void setExtension(byte[] extension) {
+    this.extension = extension;
+  }
+
   boolean mapperEnded(int shuffleId, int mapId) {
     return (mapperEndMap.containsKey(shuffleId) && 
mapperEndMap.get(shuffleId).contains(mapId))
         || isStageEnded(shuffleId);
diff --git 
a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java 
b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java
index bb4f4fe41..dda76b1da 100644
--- a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java
+++ b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java
@@ -62,6 +62,9 @@ public class DummyShuffleClient extends ShuffleClient {
   @Override
   public void setupLifecycleManagerRef(RpcEndpointRef endpointRef) {}
 
+  @Override
+  public void setExtension(byte[] extension) {}
+
   @Override
   public int pushData(
       int shuffleId,
diff --git 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala
 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala
index ebb916d1a..8983f6bd6 100644
--- 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala
+++ 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala
@@ -72,7 +72,8 @@ class CelebornFetchFailureSuite extends AnyFunSuite
               h.lifecycleManagerHost,
               h.lifecycleManagerPort,
               conf,
-              h.userIdentifier)
+              h.userIdentifier,
+              h.extension)
             val celebornShuffleId = 
SparkUtils.celebornShuffleId(shuffleClient, h, context, false)
             val datafile =
               workerDirs.map(dir => {

Reply via email to