This is an automated email from the ASF dual-hosted git repository.

xianjingfeng pushed a commit to branch branch-0.9
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/branch-0.9 by this push:
     new 4944d5481 [#1751][0.9] improvement: support gluten (#1753)
4944d5481 is described below

commit 4944d5481e7b64e75ddf8bf6eee03b27490a3667
Author: xianjingfeng <[email protected]>
AuthorDate: Tue Jun 18 09:13:24 2024 +0800

    [#1751][0.9] improvement: support gluten (#1753)
    
    * support gluten
    
    * optimize
    
    * fix bug
    
    * nit
    
    * fix spotless
    
    * nit
    
    * nit
    
    * fix bug
    
    * optimize
    
    * optimize
    
    * nit
    
    * nit
    
    * nit
    
    * nit
    
    * nit
    
    * Update RssShuffleWriter.java
---
 .../apache/spark/shuffle/RssShuffleManager.java    | 24 +++++++-------
 .../spark/shuffle/writer/RssShuffleWriter.java     |  6 ++--
 .../apache/spark/shuffle/RssShuffleManager.java    | 37 ++++++++--------------
 .../spark/shuffle/writer/RssShuffleWriter.java     | 19 ++++++++---
 4 files changed, 44 insertions(+), 42 deletions(-)

diff --git 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 78bcc2c17..45d338e39 100644
--- 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -475,15 +475,6 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
 
       int shuffleId = rssHandle.getShuffleId();
       String taskId = "" + context.taskAttemptId() + "_" + 
context.attemptNumber();
-      ShuffleHandleInfo shuffleHandleInfo;
-      if (shuffleManagerRpcServiceEnabled) {
-        // Get the ShuffleServer list from the Driver based on the shuffleId
-        shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId);
-      } else {
-        shuffleHandleInfo =
-            new ShuffleHandleInfo(
-                shuffleId, rssHandle.getPartitionToServers(), 
rssHandle.getRemoteStorage());
-      }
       ShuffleWriteMetrics writeMetrics = 
context.taskMetrics().shuffleWriteMetrics();
       return new RssShuffleWriter<>(
           rssHandle.getAppId(),
@@ -496,8 +487,7 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
           shuffleWriteClient,
           rssHandle,
           this::markFailedTask,
-          context,
-          shuffleHandleInfo);
+          context);
     } else {
       throw new RssException("Unexpected ShuffleHandle:" + 
handle.getClass().getName());
     }
@@ -806,6 +796,18 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
         .createShuffleManagerClient(ClientType.GRPC, host, port);
   }
 
+  public ShuffleHandleInfo getShuffleHandleInfo(RssShuffleHandle<?, ?, ?> 
rssHandle) {
+    if (shuffleManagerRpcServiceEnabled) {
+      // Get the ShuffleServer list from the Driver based on the shuffleId
+      return getRemoteShuffleHandleInfo(rssHandle.getShuffleId());
+    } else {
+      return new ShuffleHandleInfo(
+          rssHandle.getShuffleId(),
+          rssHandle.getPartitionToServers(),
+          rssHandle.getRemoteStorage());
+    }
+  }
+
   /**
    * Get the ShuffleServer list from the Driver based on the shuffleId
    *
diff --git 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 9e64b2fd5..37576c1c9 100644
--- 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -188,8 +188,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
       ShuffleWriteClient shuffleWriteClient,
       RssShuffleHandle<K, V, C> rssHandle,
       Function<String, Boolean> taskFailureCallback,
-      TaskContext context,
-      ShuffleHandleInfo shuffleHandleInfo) {
+      TaskContext context) {
     this(
         appId,
         shuffleId,
@@ -201,9 +200,10 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         shuffleWriteClient,
         rssHandle,
         taskFailureCallback,
-        shuffleHandleInfo,
+        shuffleManager.getShuffleHandleInfo(rssHandle),
         context);
     BufferManagerOptions bufferOptions = new BufferManagerOptions(sparkConf);
+    ShuffleHandleInfo shuffleHandleInfo = 
shuffleManager.getShuffleHandleInfo(rssHandle);
     final WriteBufferManager bufferManager =
         new WriteBufferManager(
             shuffleId,
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 6d9487ca4..700b7691b 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -141,7 +141,6 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
   private boolean rssResubmitStage;
 
   private boolean taskBlockSendFailureRetryEnabled;
-
   private boolean shuffleManagerRpcServiceEnabled;
   /** A list of shuffleServer for Write failures */
   private Set<String> failuresShuffleServerIds;
@@ -514,15 +513,6 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
     } else {
       writeMetrics = context.taskMetrics().shuffleWriteMetrics();
     }
-    ShuffleHandleInfo shuffleHandleInfo;
-    if (shuffleManagerRpcServiceEnabled) {
-      // Get the ShuffleServer list from the Driver based on the shuffleId
-      shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId);
-    } else {
-      shuffleHandleInfo =
-          new ShuffleHandleInfo(
-              shuffleId, rssHandle.getPartitionToServers(), 
rssHandle.getRemoteStorage());
-    }
     String taskId = "" + context.taskAttemptId() + "_" + 
context.attemptNumber();
     LOG.info("RssHandle appId {} shuffleId {} ", rssHandle.getAppId(), 
rssHandle.getShuffleId());
     return new RssShuffleWriter<>(
@@ -536,8 +526,7 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
         shuffleWriteClient,
         rssHandle,
         this::markFailedTask,
-        context,
-        shuffleHandleInfo);
+        context);
   }
 
   @Override
@@ -656,17 +645,7 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
     RssShuffleHandle<K, ?, C> rssShuffleHandle = (RssShuffleHandle<K, ?, C>) 
handle;
     final int partitionNum = 
rssShuffleHandle.getDependency().partitioner().numPartitions();
     int shuffleId = rssShuffleHandle.getShuffleId();
-    ShuffleHandleInfo shuffleHandleInfo;
-    if (shuffleManagerRpcServiceEnabled) {
-      // Get the ShuffleServer list from the Driver based on the shuffleId
-      shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId);
-    } else {
-      shuffleHandleInfo =
-          new ShuffleHandleInfo(
-              shuffleId,
-              rssShuffleHandle.getPartitionToServers(),
-              rssShuffleHandle.getRemoteStorage());
-    }
+    ShuffleHandleInfo shuffleHandleInfo = 
getShuffleHandleInfo(rssShuffleHandle);
     Map<Integer, List<ShuffleServerInfo>> allPartitionToServers =
         shuffleHandleInfo.getPartitionToServers();
     Map<Integer, List<ShuffleServerInfo>> requirePartitionToServers =
@@ -1101,6 +1080,18 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
         .createShuffleManagerClient(ClientType.GRPC, host, port);
   }
 
+  public ShuffleHandleInfo getShuffleHandleInfo(RssShuffleHandle<?, ?, ?> 
rssHandle) {
+    if (shuffleManagerRpcServiceEnabled) {
+      // Get the ShuffleServer list from the Driver based on the shuffleId
+      return getRemoteShuffleHandleInfo(rssHandle.getShuffleId());
+    } else {
+      return new ShuffleHandleInfo(
+          rssHandle.getShuffleId(),
+          rssHandle.getPartitionToServers(),
+          rssHandle.getRemoteStorage());
+    }
+  }
+
   /**
    * Get the ShuffleServer list from the Driver based on the shuffleId
    *
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 8a22b73ba..70ae3d8f6 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -95,6 +95,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
 
   private final String appId;
   private final int shuffleId;
+  private final ShuffleHandleInfo shuffleHandleInfo;
   private WriteBufferManager bufferManager;
   private String taskId;
   private final int numMaps;
@@ -110,7 +111,8 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   private final ShuffleWriteClient shuffleWriteClient;
   private final Set<ShuffleServerInfo> shuffleServersForData;
   private final long[] partitionLengths;
-  private final boolean isMemoryShuffleEnabled;
+  // Gluten needs this variable
+  protected final boolean isMemoryShuffleEnabled;
   private final Function<String, Boolean> taskFailureCallback;
   private final Set<Long> blockIds = Sets.newConcurrentHashSet();
   private TaskContext taskContext;
@@ -195,6 +197,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     this.isMemoryShuffleEnabled =
         
isMemoryShuffleEnabled(sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key()));
     this.taskFailureCallback = taskFailureCallback;
+    this.shuffleHandleInfo = shuffleHandleInfo;
     this.taskContext = context;
     this.sparkConf = sparkConf;
     this.blockFailSentRetryEnabled =
@@ -204,6 +207,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
             
RssClientConf.RSS_CLIENT_BLOCK_SEND_FAILURE_RETRY_ENABLED.defaultValue());
   }
 
+  // Gluten needs this constructor
   public RssShuffleWriter(
       String appId,
       int shuffleId,
@@ -215,8 +219,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
       ShuffleWriteClient shuffleWriteClient,
       RssShuffleHandle<K, V, C> rssHandle,
       Function<String, Boolean> taskFailureCallback,
-      TaskContext context,
-      ShuffleHandleInfo shuffleHandleInfo) {
+      TaskContext context) {
     this(
         appId,
         shuffleId,
@@ -228,7 +231,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         shuffleWriteClient,
         rssHandle,
         taskFailureCallback,
-        shuffleHandleInfo,
+        shuffleManager.getShuffleHandleInfo(rssHandle),
         context);
     BufferManagerOptions bufferOptions = new BufferManagerOptions(sparkConf);
     final WriteBufferManager bufferManager =
@@ -264,7 +267,8 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     }
   }
 
-  private void writeImpl(Iterator<Product2<K, V>> records) {
+  // Gluten needs this method.
+  protected void writeImpl(Iterator<Product2<K, V>> records) {
     List<ShuffleBlockInfo> shuffleBlockInfos;
     boolean isCombine = shuffleDependency.mapSideCombine();
     Function1<V, C> createCombiner = null;
@@ -322,6 +326,11 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
             + bufferManager.getManagerCostInfo());
   }
 
+  // Gluten needs this method
+  protected void internalCheckBlockSendResult() {
+    this.checkBlockSendResult(this.blockIds);
+  }
+
   private void checkSentRecordCount(long recordCount) {
     if (recordCount != bufferManager.getRecordCount()) {
       String errorMsg =

Reply via email to