This is an automated email from the ASF dual-hosted git repository.

feiwang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 5e12b7d60 [CELEBORN-1921] Broadcast large GetReducerFileGroupResponse 
to prevent Spark driver network exhausted
5e12b7d60 is described below

commit 5e12b7d6075787d2f70382674b5b552c31aa9080
Author: Wang, Fei <[email protected]>
AuthorDate: Tue Apr 1 08:29:21 2025 -0700

    [CELEBORN-1921] Broadcast large GetReducerFileGroupResponse to prevent 
Spark driver network exhausted
    
    ### What changes were proposed in this pull request?
    
    For spark celeborn application, if the GetReducerFileGroupResponse is 
larger than the threshold, Spark driver would broadcast the 
GetReducerFileGroupResponse to the executors, it prevents the driver from being 
the bottleneck in sending out multiple copies of the 
GetReducerFileGroupResponse (one per executor).
    
    ### Why are the changes needed?
    To prevent the driver from being the bottleneck in sending out multiple 
copies of the GetReducerFileGroupResponse (one per executor).
    
    ### Does this PR introduce _any_ user-facing change?
    No, the feature is not enabled by defaults.
    
    ### How was this patch tested?
    
    UT.
    
    Cluster testing with 
`spark.celeborn.client.spark.shuffle.getReducerFileGroup.broadcast.enabled=true`.
    
    The broadcast response size should be always about 1kb.
    
![image](https://github.com/user-attachments/assets/d5d1b751-762d-43c8-8a84-0674630a5638)
    
![image](https://github.com/user-attachments/assets/4841a29e-5d11-4932-9fa5-f6e78b7bc521)
    Application succeed.
    
![image](https://github.com/user-attachments/assets/9b570f70-1433-4457-90ae-b8292e5476ba)
    
    Closes #3158 from turboFei/broadcast_rgf.
    
    Authored-by: Wang, Fei <[email protected]>
    Signed-off-by: Wang, Fei <[email protected]>
---
 LICENSE                                            |   1 +
 client-spark/spark-2-shaded/pom.xml                |   1 +
 .../shuffle/celeborn/SparkShuffleManager.java      |   9 ++
 .../apache/spark/shuffle/celeborn/SparkUtils.java  | 129 ++++++++++++++++++++
 .../shuffle/celeborn/CelebornShuffleReader.scala   |  11 ++
 client-spark/spark-3-shaded/pom.xml                |   1 +
 .../shuffle/celeborn/SparkShuffleManager.java      |   9 ++
 .../apache/spark/shuffle/celeborn/SparkUtils.java  | 130 +++++++++++++++++++++
 .../shuffle/celeborn/CelebornShuffleReader.scala   |  13 ++-
 .../org/apache/celeborn/client/ShuffleClient.java  |  24 ++++
 .../apache/celeborn/client/ShuffleClientImpl.java  |   8 ++
 .../org/apache/celeborn/client/CommitManager.scala |   3 +-
 .../apache/celeborn/client/LifecycleManager.scala  |  32 +++++
 .../commit/ReducePartitionCommitHandler.scala      |  52 ++++++++-
 .../celeborn/client/ShuffleClientSuiteJ.java       |  24 ++--
 common/src/main/proto/TransportMessages.proto      |   2 +
 .../org/apache/celeborn/common/CelebornConf.scala  |  25 ++++
 .../common/protocol/message/ControlMessages.scala  |  21 +++-
 .../org/apache/celeborn/common/util/KeyLock.scala  |  70 +++++++++++
 docs/configuration/client.md                       |   2 +
 .../celeborn/tests/spark/CelebornHashSuite.scala   |  41 +++++++
 .../celeborn/tests/spark/CelebornSortSuite.scala   |  42 +++++++
 .../spark/shuffle/celeborn/SparkUtilsSuite.scala   |  61 +++++++++-
 23 files changed, 690 insertions(+), 21 deletions(-)

diff --git a/LICENSE b/LICENSE
index 402c728a7..f757490f9 100644
--- a/LICENSE
+++ b/LICENSE
@@ -223,6 +223,7 @@ Apache Spark
 
./common/src/test/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeaderSuiteJ.java
 
./common/src/test/java/org/apache/celeborn/common/network/ssl/ReloadingX509TrustManagerSuiteJ.java
 
./common/src/test/java/org/apache/celeborn/common/network/ssl/SslSampleConfigs.java
+./common/src/main/scala/org/apache/celeborn/common/util/KeyLock.scala
 
./worker/src/main/java/org/apache/celeborn/service/deploy/worker/shuffledb/DB.java
 
./worker/src/main/java/org/apache/celeborn/service/deploy/worker/shuffledb/DBIterator.java
 
./worker/src/main/java/org/apache/celeborn/service/deploy/worker/shuffledb/LevelDB.java
diff --git a/client-spark/spark-2-shaded/pom.xml 
b/client-spark/spark-2-shaded/pom.xml
index 3551b1caa..cfbc02079 100644
--- a/client-spark/spark-2-shaded/pom.xml
+++ b/client-spark/spark-2-shaded/pom.xml
@@ -73,6 +73,7 @@
               <include>io.netty:*</include>
               <include>org.apache.commons:commons-lang3</include>
               <include>org.roaringbitmap:RoaringBitmap</include>
+              <include>commons-io:commons-io</include>
             </includes>
           </artifactSet>
           <filters>
diff --git 
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
 
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
index 67f29eced..41fe777b7 100644
--- 
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
+++ 
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
@@ -109,6 +109,15 @@ public class SparkShuffleManager implements ShuffleManager 
{
             lifecycleManager.registerShuffleTrackerCallback(
                 shuffleId -> 
mapOutputTracker.unregisterAllMapOutput(shuffleId));
           }
+
+          if (celebornConf.getReducerFileGroupBroadcastEnabled()) {
+            
lifecycleManager.registerBroadcastGetReducerFileGroupResponseCallback(
+                (shuffleId, getReducerFileGroupResponse) ->
+                    SparkUtils.serializeGetReducerFileGroupResponse(
+                        shuffleId, getReducerFileGroupResponse));
+            lifecycleManager.registerInvalidatedBroadcastCallback(
+                shuffleId -> 
SparkUtils.invalidateSerializedGetReducerFileGroupResponse(shuffleId));
+          }
         }
       }
     }
diff --git 
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
 
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
index 7ac0f6583..c708c1e84 100644
--- 
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
+++ 
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
@@ -17,7 +17,10 @@
 
 package org.apache.spark.shuffle.celeborn;
 
+import java.io.ByteArrayInputStream;
 import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
 import java.lang.reflect.Field;
 import java.lang.reflect.Method;
 import java.util.HashSet;
@@ -25,6 +28,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.LongAdder;
 import java.util.stream.Collectors;
 
@@ -37,7 +41,12 @@ import org.apache.spark.BarrierTaskContext;
 import org.apache.spark.SparkConf;
 import org.apache.spark.SparkContext;
 import org.apache.spark.SparkContext$;
+import org.apache.spark.SparkEnv;
+import org.apache.spark.SparkEnv$;
 import org.apache.spark.TaskContext;
+import org.apache.spark.broadcast.Broadcast;
+import org.apache.spark.io.CompressionCodec;
+import org.apache.spark.io.CompressionCodec$;
 import org.apache.spark.scheduler.DAGScheduler;
 import org.apache.spark.scheduler.MapStatus;
 import org.apache.spark.scheduler.MapStatus$;
@@ -54,7 +63,10 @@ import org.slf4j.LoggerFactory;
 
 import org.apache.celeborn.client.ShuffleClient;
 import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.network.protocol.TransportMessage;
+import 
org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse;
 import org.apache.celeborn.common.util.JavaUtils;
+import org.apache.celeborn.common.util.KeyLock;
 import org.apache.celeborn.common.util.Utils;
 import org.apache.celeborn.reflect.DynFields;
 
@@ -346,4 +358,121 @@ public class SparkUtils {
       sparkContext.addSparkListener(listener);
     }
   }
+
+  /**
+   * A [[KeyLock]] whose key is a shuffle id to ensure there is only one 
thread accessing the
+   * broadcast belonging to the shuffle id at a time.
+   */
+  private static final KeyLock<Integer> shuffleBroadcastLock = new KeyLock<>();
+
+  @VisibleForTesting
+  public static AtomicInteger getReducerFileGroupResponseBroadcastNum = new 
AtomicInteger();
+
+  @VisibleForTesting
+  public static Map<Integer, Tuple2<Broadcast<TransportMessage>, byte[]>>
+      getReducerFileGroupResponseBroadcasts = JavaUtils.newConcurrentHashMap();
+
+  public static byte[] serializeGetReducerFileGroupResponse(
+      Integer shuffleId, GetReducerFileGroupResponse response) {
+    SparkContext sparkContext = 
SparkContext$.MODULE$.getActive().getOrElse(null);
+    if (sparkContext == null) {
+      logger.error("Can not get active SparkContext.");
+      return null;
+    }
+
+    return shuffleBroadcastLock.withLock(
+        shuffleId,
+        () -> {
+          Tuple2<Broadcast<TransportMessage>, byte[]> 
cachedSerializeGetReducerFileGroupResponse =
+              getReducerFileGroupResponseBroadcasts.get(shuffleId);
+          if (cachedSerializeGetReducerFileGroupResponse != null) {
+            return cachedSerializeGetReducerFileGroupResponse._2;
+          }
+
+          try {
+            logger.info("Broadcasting GetReducerFileGroupResponse for shuffle: 
{}", shuffleId);
+            TransportMessage transportMessage =
+                (TransportMessage) Utils.toTransportMessage(response);
+            Broadcast<TransportMessage> broadcast =
+                sparkContext.broadcast(
+                    transportMessage,
+                    
scala.reflect.ClassManifestFactory.fromClass(TransportMessage.class));
+
+            CompressionCodec codec = 
CompressionCodec$.MODULE$.createCodec(sparkContext.conf());
+            // Using `org.apache.commons.io.output.ByteArrayOutputStream` 
instead of the standard
+            // one
+            // This implementation doesn't reallocate the whole memory block 
but allocates
+            // additional buffers. This way no buffers need to be garbage 
collected and
+            // the contents don't have to be copied to the new buffer.
+            org.apache.commons.io.output.ByteArrayOutputStream out =
+                new org.apache.commons.io.output.ByteArrayOutputStream();
+            try (ObjectOutputStream oos =
+                new ObjectOutputStream(codec.compressedOutputStream(out))) {
+              oos.writeObject(broadcast);
+            }
+            byte[] _serializeResult = out.toByteArray();
+            getReducerFileGroupResponseBroadcasts.put(
+                shuffleId, Tuple2.apply(broadcast, _serializeResult));
+            getReducerFileGroupResponseBroadcastNum.incrementAndGet();
+            return _serializeResult;
+          } catch (Throwable e) {
+            logger.error(
+                "Failed to serialize GetReducerFileGroupResponse for shuffle: 
{}", shuffleId, e);
+            return null;
+          }
+        });
+  }
+
+  public static GetReducerFileGroupResponse 
deserializeGetReducerFileGroupResponse(
+      Integer shuffleId, byte[] bytes) {
+    SparkEnv sparkEnv = SparkEnv$.MODULE$.get();
+    if (sparkEnv == null) {
+      logger.error("Can not get SparkEnv.");
+      return null;
+    }
+
+    return shuffleBroadcastLock.withLock(
+        shuffleId,
+        () -> {
+          GetReducerFileGroupResponse response = null;
+          logger.info(
+              "Deserializing GetReducerFileGroupResponse broadcast for 
shuffle: {}", shuffleId);
+
+          try {
+            CompressionCodec codec = 
CompressionCodec$.MODULE$.createCodec(sparkEnv.conf());
+            try (ObjectInputStream objIn =
+                new ObjectInputStream(
+                    codec.compressedInputStream(new 
ByteArrayInputStream(bytes)))) {
+              Broadcast<TransportMessage> broadcast =
+                  (Broadcast<TransportMessage>) objIn.readObject();
+              response =
+                  (GetReducerFileGroupResponse) 
Utils.fromTransportMessage(broadcast.value());
+            }
+          } catch (Throwable e) {
+            logger.error(
+                "Failed to deserialize GetReducerFileGroupResponse for 
shuffle: " + shuffleId, e);
+          }
+          return response;
+        });
+  }
+
+  public static void invalidateSerializedGetReducerFileGroupResponse(Integer 
shuffleId) {
+    shuffleBroadcastLock.withLock(
+        shuffleId,
+        () -> {
+          try {
+            Tuple2<Broadcast<TransportMessage>, byte[]> 
cachedSerializeGetReducerFileGroupResponse =
+                getReducerFileGroupResponseBroadcasts.remove(shuffleId);
+            if (cachedSerializeGetReducerFileGroupResponse != null) {
+              cachedSerializeGetReducerFileGroupResponse._1().destroy();
+            }
+          } catch (Throwable e) {
+            logger.error(
+                "Failed to invalidate serialized GetReducerFileGroupResponse 
for shuffle: "
+                    + shuffleId,
+                e);
+          }
+          return null;
+        });
+  }
 }
diff --git 
a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
 
b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
index a60c68f3d..2269aaf68 100644
--- 
a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
+++ 
b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
@@ -20,6 +20,7 @@ package org.apache.spark.shuffle.celeborn
 import java.io.IOException
 import java.util.concurrent.{ThreadPoolExecutor, TimeUnit}
 import java.util.concurrent.atomic.AtomicReference
+import java.util.function.BiFunction
 
 import org.apache.spark.{Aggregator, InterruptibleIterator, TaskContext}
 import org.apache.spark.internal.Logging
@@ -33,6 +34,7 @@ import org.apache.celeborn.client.read.CelebornInputStream
 import org.apache.celeborn.client.read.MetricsCallback
 import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.exception.{CelebornIOException, 
PartitionUnRetryAbleException}
+import 
org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse
 import org.apache.celeborn.common.util.{JavaUtils, ThreadUtils}
 
 class CelebornShuffleReader[K, C](
@@ -254,4 +256,13 @@ class CelebornShuffleReader[K, C](
 
 object CelebornShuffleReader {
   var streamCreatorPool: ThreadPoolExecutor = null
+  // Register the deserializer for GetReducerFileGroupResponse broadcast
+  ShuffleClient.registerDeserializeReducerFileGroupResponseFunction(new 
BiFunction[
+    Integer,
+    Array[Byte],
+    GetReducerFileGroupResponse] {
+    override def apply(shuffleId: Integer, broadcast: Array[Byte]): 
GetReducerFileGroupResponse = {
+      SparkUtils.deserializeGetReducerFileGroupResponse(shuffleId, broadcast)
+    }
+  })
 }
diff --git a/client-spark/spark-3-shaded/pom.xml 
b/client-spark/spark-3-shaded/pom.xml
index d3d59cb87..8cce8577a 100644
--- a/client-spark/spark-3-shaded/pom.xml
+++ b/client-spark/spark-3-shaded/pom.xml
@@ -77,6 +77,7 @@
               <include>io.netty:*</include>
               <include>org.apache.commons:commons-lang3</include>
               <include>org.roaringbitmap:RoaringBitmap</include>
+              <include>commons-io:commons-io</include>
             </includes>
           </artifactSet>
           <filters>
diff --git 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
index 234fba1fe..a3e75cd10 100644
--- 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
+++ 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
@@ -156,6 +156,15 @@ public class SparkShuffleManager implements ShuffleManager 
{
                   SparkUtils::isCelebornSkewShuffleOrChildShuffle);
             }
           }
+
+          if (celebornConf.getReducerFileGroupBroadcastEnabled()) {
+            
lifecycleManager.registerBroadcastGetReducerFileGroupResponseCallback(
+                (shuffleId, getReducerFileGroupResponse) ->
+                    SparkUtils.serializeGetReducerFileGroupResponse(
+                        shuffleId, getReducerFileGroupResponse));
+            lifecycleManager.registerInvalidatedBroadcastCallback(
+                shuffleId -> 
SparkUtils.invalidateSerializedGetReducerFileGroupResponse(shuffleId));
+          }
         }
       }
     }
diff --git 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
index 6443c2163..4efc29a90 100644
--- 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
+++ 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
@@ -17,11 +17,15 @@
 
 package org.apache.spark.shuffle.celeborn;
 
+import java.io.ByteArrayInputStream;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.LongAdder;
 import java.util.stream.Collectors;
 
@@ -35,7 +39,12 @@ import org.apache.spark.MapOutputTrackerMaster;
 import org.apache.spark.SparkConf;
 import org.apache.spark.SparkContext;
 import org.apache.spark.SparkContext$;
+import org.apache.spark.SparkEnv;
+import org.apache.spark.SparkEnv$;
 import org.apache.spark.TaskContext;
+import org.apache.spark.broadcast.Broadcast;
+import org.apache.spark.io.CompressionCodec;
+import org.apache.spark.io.CompressionCodec$;
 import org.apache.spark.scheduler.DAGScheduler;
 import org.apache.spark.scheduler.MapStatus;
 import org.apache.spark.scheduler.MapStatus$;
@@ -57,7 +66,11 @@ import org.slf4j.LoggerFactory;
 
 import org.apache.celeborn.client.ShuffleClient;
 import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.network.protocol.TransportMessage;
+import 
org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse;
 import org.apache.celeborn.common.util.JavaUtils;
+import org.apache.celeborn.common.util.KeyLock;
+import org.apache.celeborn.common.util.Utils;
 import org.apache.celeborn.reflect.DynConstructors;
 import org.apache.celeborn.reflect.DynFields;
 import org.apache.celeborn.reflect.DynMethods;
@@ -476,4 +489,121 @@ public class SparkUtils {
       return false;
     }
   }
+
+  /**
+   * A [[KeyLock]] whose key is a shuffle id to ensure there is only one 
thread accessing the
+   * broadcast belonging to the shuffle id at a time.
+   */
+  private static final KeyLock<Integer> shuffleBroadcastLock = new KeyLock<>();
+
+  @VisibleForTesting
+  public static AtomicInteger getReducerFileGroupResponseBroadcastNum = new 
AtomicInteger();
+
+  @VisibleForTesting
+  public static Map<Integer, Tuple2<Broadcast<TransportMessage>, byte[]>>
+      getReducerFileGroupResponseBroadcasts = JavaUtils.newConcurrentHashMap();
+
+  public static byte[] serializeGetReducerFileGroupResponse(
+      Integer shuffleId, GetReducerFileGroupResponse response) {
+    SparkContext sparkContext = 
SparkContext$.MODULE$.getActive().getOrElse(null);
+    if (sparkContext == null) {
+      LOG.error("Can not get active SparkContext.");
+      return null;
+    }
+
+    return shuffleBroadcastLock.withLock(
+        shuffleId,
+        () -> {
+          Tuple2<Broadcast<TransportMessage>, byte[]> 
cachedSerializeGetReducerFileGroupResponse =
+              getReducerFileGroupResponseBroadcasts.get(shuffleId);
+          if (cachedSerializeGetReducerFileGroupResponse != null) {
+            return cachedSerializeGetReducerFileGroupResponse._2;
+          }
+
+          try {
+            LOG.info("Broadcasting GetReducerFileGroupResponse for shuffle: 
{}", shuffleId);
+            TransportMessage transportMessage =
+                (TransportMessage) Utils.toTransportMessage(response);
+            Broadcast<TransportMessage> broadcast =
+                sparkContext.broadcast(
+                    transportMessage,
+                    
scala.reflect.ClassManifestFactory.fromClass(TransportMessage.class));
+
+            CompressionCodec codec = 
CompressionCodec$.MODULE$.createCodec(sparkContext.conf());
+            // Using `org.apache.commons.io.output.ByteArrayOutputStream` 
instead of the standard
+            // one
+            // This implementation doesn't reallocate the whole memory block 
but allocates
+            // additional buffers. This way no buffers need to be garbage 
collected and
+            // the contents don't have to be copied to the new buffer.
+            org.apache.commons.io.output.ByteArrayOutputStream out =
+                new org.apache.commons.io.output.ByteArrayOutputStream();
+            try (ObjectOutputStream oos =
+                new ObjectOutputStream(codec.compressedOutputStream(out))) {
+              oos.writeObject(broadcast);
+            }
+            byte[] _serializeResult = out.toByteArray();
+            getReducerFileGroupResponseBroadcasts.put(
+                shuffleId, Tuple2.apply(broadcast, _serializeResult));
+            getReducerFileGroupResponseBroadcastNum.incrementAndGet();
+            return _serializeResult;
+          } catch (Throwable e) {
+            LOG.error(
+                "Failed to serialize GetReducerFileGroupResponse for shuffle: 
{}", shuffleId, e);
+            return null;
+          }
+        });
+  }
+
+  public static GetReducerFileGroupResponse 
deserializeGetReducerFileGroupResponse(
+      Integer shuffleId, byte[] bytes) {
+    SparkEnv sparkEnv = SparkEnv$.MODULE$.get();
+    if (sparkEnv == null) {
+      LOG.error("Can not get SparkEnv.");
+      return null;
+    }
+
+    return shuffleBroadcastLock.withLock(
+        shuffleId,
+        () -> {
+          GetReducerFileGroupResponse response = null;
+          LOG.info(
+              "Deserializing GetReducerFileGroupResponse broadcast for 
shuffle: {}", shuffleId);
+
+          try {
+            CompressionCodec codec = 
CompressionCodec$.MODULE$.createCodec(sparkEnv.conf());
+            try (ObjectInputStream objIn =
+                new ObjectInputStream(
+                    codec.compressedInputStream(new 
ByteArrayInputStream(bytes)))) {
+              Broadcast<TransportMessage> broadcast =
+                  (Broadcast<TransportMessage>) objIn.readObject();
+              response =
+                  (GetReducerFileGroupResponse) 
Utils.fromTransportMessage(broadcast.value());
+            }
+          } catch (Throwable e) {
+            LOG.error(
+                "Failed to deserialize GetReducerFileGroupResponse for 
shuffle: " + shuffleId, e);
+          }
+          return response;
+        });
+  }
+
+  public static void invalidateSerializedGetReducerFileGroupResponse(Integer 
shuffleId) {
+    shuffleBroadcastLock.withLock(
+        shuffleId,
+        () -> {
+          try {
+            Tuple2<Broadcast<TransportMessage>, byte[]> 
cachedSerializeGetReducerFileGroupResponse =
+                getReducerFileGroupResponseBroadcasts.remove(shuffleId);
+            if (cachedSerializeGetReducerFileGroupResponse != null) {
+              cachedSerializeGetReducerFileGroupResponse._1().destroy();
+            }
+          } catch (Throwable e) {
+            LOG.error(
+                "Failed to invalidate serialized GetReducerFileGroupResponse 
for shuffle: "
+                    + shuffleId,
+                e);
+          }
+          return null;
+        });
+  }
 }
diff --git 
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
 
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
index af29b57dd..3e296b310 100644
--- 
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
+++ 
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
@@ -21,6 +21,7 @@ import java.io.IOException
 import java.util.{ArrayList => JArrayList, HashMap => JHashMap, Map => JMap, 
Set => JSet}
 import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor, 
TimeoutException, TimeUnit}
 import java.util.concurrent.atomic.AtomicReference
+import java.util.function.BiFunction
 
 import scala.collection.JavaConverters._
 
@@ -43,7 +44,8 @@ import 
org.apache.celeborn.common.exception.{CelebornIOException, PartitionUnRet
 import org.apache.celeborn.common.network.client.TransportClient
 import org.apache.celeborn.common.network.protocol.TransportMessage
 import org.apache.celeborn.common.protocol._
-import org.apache.celeborn.common.protocol.message.StatusCode
+import org.apache.celeborn.common.protocol.message.{ControlMessages, 
StatusCode}
+import 
org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse
 import org.apache.celeborn.common.util.{JavaUtils, ThreadUtils, Utils}
 
 class CelebornShuffleReader[K, C](
@@ -465,4 +467,13 @@ class CelebornShuffleReader[K, C](
 
 object CelebornShuffleReader {
   var streamCreatorPool: ThreadPoolExecutor = null
+  // Register the deserializer for GetReducerFileGroupResponse broadcast
+  ShuffleClient.registerDeserializeReducerFileGroupResponseFunction(new 
BiFunction[
+    Integer,
+    Array[Byte],
+    GetReducerFileGroupResponse] {
+    override def apply(shuffleId: Integer, broadcast: Array[Byte]): 
GetReducerFileGroupResponse = {
+      SparkUtils.deserializeGetReducerFileGroupResponse(shuffleId, broadcast)
+    }
+  })
 }
diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java 
b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
index 8690f5ceb..dde2b36c4 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
@@ -20,9 +20,11 @@ package org.apache.celeborn.client;
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Map;
+import java.util.Optional;
 import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.atomic.LongAdder;
+import java.util.function.BiFunction;
 
 import org.apache.commons.lang3.tuple.Pair;
 import org.apache.hadoop.fs.FileSystem;
@@ -38,6 +40,7 @@ import 
org.apache.celeborn.common.network.client.TransportClientFactory;
 import org.apache.celeborn.common.protocol.PartitionLocation;
 import org.apache.celeborn.common.protocol.PbStreamHandler;
 import org.apache.celeborn.common.protocol.StorageInfo;
+import org.apache.celeborn.common.protocol.message.ControlMessages;
 import org.apache.celeborn.common.rpc.RpcEndpointRef;
 import org.apache.celeborn.common.util.CelebornHadoopUtils;
 import org.apache.celeborn.common.util.ExceptionMaker;
@@ -56,6 +59,10 @@ public abstract class ShuffleClient {
   private static LongAdder totalReadCounter = new LongAdder();
   private static LongAdder localShuffleReadCounter = new LongAdder();
 
+  private static volatile Optional<
+          BiFunction<Integer, byte[], 
ControlMessages.GetReducerFileGroupResponse>>
+      deserializeReducerFileGroupResponseFunction = Optional.empty();
+
   // for testing
   public static void reset() {
     _instance = null;
@@ -297,4 +304,21 @@ public abstract class ShuffleClient {
   public abstract TransportClientFactory getDataClientFactory();
 
   public abstract void excludeFailedFetchLocation(String hostAndFetchPort, 
Exception e);
+
+  public static void registerDeserializeReducerFileGroupResponseFunction(
+      BiFunction<Integer, byte[], ControlMessages.GetReducerFileGroupResponse> 
function) {
+    if (!deserializeReducerFileGroupResponseFunction.isPresent()) {
+      deserializeReducerFileGroupResponseFunction = 
Optional.ofNullable(function);
+    }
+  }
+
+  public static ControlMessages.GetReducerFileGroupResponse 
deserializeReducerFileGroupResponse(
+      int shuffleId, byte[] bytes) {
+    if (!deserializeReducerFileGroupResponseFunction.isPresent()) {
+      // Should never happen
+      logger.warn("DeserializeReducerFileGroupResponseFunction is not 
registered.");
+      return null;
+    }
+    return deserializeReducerFileGroupResponseFunction.get().apply(shuffleId, 
bytes);
+  }
 }
diff --git 
a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java 
b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index be6b5aa3f..e7f4a9083 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -1814,6 +1814,14 @@ public class ShuffleClientImpl extends ShuffleClient {
               ClassTag$.MODULE$.apply(GetReducerFileGroupResponse.class));
       switch (response.status()) {
         case SUCCESS:
+          if (response.broadcast() != null && response.broadcast().length > 0) 
{
+            response =
+                ShuffleClient.deserializeReducerFileGroupResponse(shuffleId, 
response.broadcast());
+            if (response == null) {
+              throw new CelebornIOException(
+                  "Failed to get GetReducerFileGroupResponse broadcast for 
shuffle: " + shuffleId);
+            }
+          }
           logger.info(
               "Shuffle {} request reducer file group success using {} ms, 
result partition size {}.",
               shuffleId,
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala 
b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
index aaeb4462f..ad0d66e3a 100644
--- a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
@@ -292,7 +292,8 @@ class CommitManager(appUniqueId: String, val conf: 
CelebornConf, lifecycleManage
               lifecycleManager.shuffleAllocatedWorkers,
               committedPartitionInfo,
               lifecycleManager.workerStatusTracker,
-              lifecycleManager.rpcSharedThreadPool)
+              lifecycleManager.rpcSharedThreadPool,
+              lifecycleManager)
           case PartitionType.MAP => new MapPartitionCommitHandler(
               appUniqueId,
               conf,
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala 
b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
index f706eeb90..35d1945e1 100644
--- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
@@ -1677,6 +1677,7 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
         } else {
           batchRemoveShuffleIds += shuffleId
         }
+        invalidatedBroadcastGetReducerFileGroupResponse(shuffleId)
       }
     }
     if (batchRemoveShuffleIds.nonEmpty) {
@@ -1848,6 +1849,22 @@ class LifecycleManager(val appUniqueId: String, val 
conf: CelebornConf) extends
     cancelShuffleCallback = Some(callback)
   }
 
+  @volatile private var broadcastGetReducerFileGroupResponseCallback
+      : Option[java.util.function.BiFunction[Integer, 
GetReducerFileGroupResponse, Array[Byte]]] =
+    None
+  def registerBroadcastGetReducerFileGroupResponseCallback(call: 
java.util.function.BiFunction[
+    Integer,
+    GetReducerFileGroupResponse,
+    Array[Byte]]): Unit = {
+    broadcastGetReducerFileGroupResponseCallback = Some(call)
+  }
+
+  @volatile private var invalidatedBroadcastCallback: 
Option[Consumer[Integer]] =
+    None
+  def registerInvalidatedBroadcastCallback(call: Consumer[Integer]): Unit = {
+    invalidatedBroadcastCallback = Some(call)
+  }
+
   def invalidateLatestMaxLocsCache(shuffleId: Int): Unit = {
     registerShuffleResponseRpcCache.invalidate(shuffleId)
   }
@@ -1889,4 +1906,19 @@ class LifecycleManager(val appUniqueId: String, val 
conf: CelebornConf) extends
     case _ =>
   }
 
+  def broadcastGetReducerFileGroupResponse(
+      shuffleId: Int,
+      response: GetReducerFileGroupResponse): Option[Array[Byte]] = {
+    broadcastGetReducerFileGroupResponseCallback match {
+      case Some(c) => Option(c.apply(shuffleId, response))
+      case _ => None
+    }
+  }
+
+  private def invalidatedBroadcastGetReducerFileGroupResponse(shuffleId: Int): 
Unit = {
+    invalidatedBroadcastCallback match {
+      case Some(c) => c.accept(shuffleId)
+      case _ =>
+    }
+  }
 }
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
 
b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
index 55639764c..98fe624fb 100644
--- 
a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
+++ 
b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
@@ -28,7 +28,7 @@ import scala.collection.mutable
 import com.google.common.cache.{Cache, CacheBuilder}
 import com.google.common.collect.Sets
 
-import org.apache.celeborn.client.{ClientUtils, ShuffleCommittedInfo, 
WorkerStatusTracker}
+import org.apache.celeborn.client.{ClientUtils, LifecycleManager, 
ShuffleCommittedInfo, WorkerStatusTracker}
 import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo
 import org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers, 
ShuffleFailedWorkers}
 import org.apache.celeborn.common.CelebornConf
@@ -55,7 +55,8 @@ class ReducePartitionCommitHandler(
     shuffleAllocatedWorkers: ShuffleAllocatedWorkers,
     committedPartitionInfo: CommittedPartitionInfo,
     workerStatusTracker: WorkerStatusTracker,
-    sharedRpcPool: ThreadPoolExecutor)
+    sharedRpcPool: ThreadPoolExecutor,
+    lifecycleManager: LifecycleManager)
   extends CommitHandler(
     appUniqueId,
     conf,
@@ -78,6 +79,10 @@ class ReducePartitionCommitHandler(
   private val rpcCacheConcurrencyLevel = conf.clientRpcCacheConcurrencyLevel
   private val rpcCacheExpireTime = conf.clientRpcCacheExpireTime
 
+  private val getReducerFileGroupResponseBroadcastEnabled = 
conf.getReducerFileGroupBroadcastEnabled
+  private val getReducerFileGroupResponseBroadcastMiniSize =
+    conf.getReducerFileGroupBroadcastMiniSize
+
   // noinspection UnstableApiUsage
   private val getReducerFileGroupRpcCache: Cache[Int, ByteBuffer] = 
CacheBuilder.newBuilder()
     .concurrencyLevel(rpcCacheConcurrencyLevel)
@@ -320,10 +325,17 @@ class ReducePartitionCommitHandler(
     } else {
       // LocalNettyRpcCallContext is for the UTs
       if (context.isInstanceOf[LocalNettyRpcCallContext]) {
-        context.reply(GetReducerFileGroupResponse(
+        var response = GetReducerFileGroupResponse(
           StatusCode.SUCCESS,
           reducerFileGroupsMap.getOrDefault(shuffleId, 
JavaUtils.newConcurrentHashMap()),
-          getMapperAttempts(shuffleId)))
+          getMapperAttempts(shuffleId))
+
+        // only check whether broadcast enabled for the UTs
+        if (getReducerFileGroupResponseBroadcastEnabled) {
+          response = broadcastGetReducerFileGroup(shuffleId, response)
+        }
+
+        context.reply(response)
       } else {
         val cachedMsg = getReducerFileGroupRpcCache.get(
           shuffleId,
@@ -337,7 +349,27 @@ class ReducePartitionCommitHandler(
                   shufflePushFailedBatches.getOrDefault(
                     shuffleId,
                     new util.HashMap[String, util.Set[PushFailedBatch]]()))
-              
context.asInstanceOf[RemoteNettyRpcCallContext].nettyEnv.serialize(returnedMsg)
+
+              val serializedMsg =
+                
context.asInstanceOf[RemoteNettyRpcCallContext].nettyEnv.serialize(returnedMsg)
+
+              if (getReducerFileGroupResponseBroadcastEnabled &&
+                serializedMsg.capacity() >= 
getReducerFileGroupResponseBroadcastMiniSize) {
+                val broadcastMsg = broadcastGetReducerFileGroup(shuffleId, 
returnedMsg)
+                if (broadcastMsg != returnedMsg) {
+                  val serializedBroadcastMsg =
+                    
context.asInstanceOf[RemoteNettyRpcCallContext].nettyEnv.serialize(broadcastMsg)
+                  logInfo(s"Shuffle $shuffleId GetReducerFileGroupResponse 
size" +
+                    s" ${serializedMsg.capacity()} reached the broadcast 
threshold" +
+                    s" $getReducerFileGroupResponseBroadcastMiniSize," +
+                    s" the broadcast response size is 
${serializedBroadcastMsg.capacity()}.")
+                  serializedBroadcastMsg
+                } else {
+                  serializedMsg
+                }
+              } else {
+                serializedMsg
+              }
             }
           })
         
context.asInstanceOf[RemoteNettyRpcCallContext].callback.onSuccess(cachedMsg)
@@ -345,6 +377,16 @@ class ReducePartitionCommitHandler(
     }
   }
 
+  private def broadcastGetReducerFileGroup(
+      shuffleId: Int,
+      response: GetReducerFileGroupResponse): GetReducerFileGroupResponse = {
+    lifecycleManager.broadcastGetReducerFileGroupResponse(shuffleId, response) 
match {
+      case Some(broadcastBytes) if broadcastBytes.nonEmpty =>
+        GetReducerFileGroupResponse(response.status, broadcast = 
broadcastBytes)
+      case _ => response
+    }
+  }
+
   override def handleGetReducerFileGroup(context: RpcCallContext, shuffleId: 
Int): Unit = {
     // Quick return for ended stage, avoid occupy sync lock.
     if (isStageEnd(shuffleId)) {
diff --git 
a/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java 
b/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java
index ade8b1cdb..a5076a59f 100644
--- a/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java
+++ b/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java
@@ -427,7 +427,8 @@ public class ShuffleClientSuiteJ {
                   locations,
                   new int[0],
                   Collections.emptySet(),
-                  Collections.emptyMap());
+                  Collections.emptyMap(),
+                  new byte[0]);
             });
 
     when(endpointRef.askSync(any(), any(), any(Integer.class), 
any(Long.class), any()))
@@ -439,7 +440,8 @@ public class ShuffleClientSuiteJ {
                   locations,
                   new int[0],
                   Collections.emptySet(),
-                  Collections.emptyMap());
+                  Collections.emptyMap(),
+                  new byte[0]);
             });
 
     shuffleClient =
@@ -482,7 +484,8 @@ public class ShuffleClientSuiteJ {
                   locations,
                   new int[0],
                   Collections.emptySet(),
-                  Collections.emptyMap());
+                  Collections.emptyMap(),
+                  new byte[0]);
             });
 
     when(endpointRef.askSync(any(), any(), any(Integer.class), 
any(Long.class), any()))
@@ -493,7 +496,8 @@ public class ShuffleClientSuiteJ {
                   locations,
                   new int[0],
                   Collections.emptySet(),
-                  Collections.emptyMap());
+                  Collections.emptyMap(),
+                  new byte[0]);
             });
 
     shuffleClient =
@@ -514,7 +518,8 @@ public class ShuffleClientSuiteJ {
                   locations,
                   new int[0],
                   Collections.emptySet(),
-                  Collections.emptyMap());
+                  Collections.emptyMap(),
+                  new byte[0]);
             });
 
     when(endpointRef.askSync(any(), any(), any(Integer.class), 
any(Long.class), any()))
@@ -525,7 +530,8 @@ public class ShuffleClientSuiteJ {
                   locations,
                   new int[0],
                   Collections.emptySet(),
-                  Collections.emptyMap());
+                  Collections.emptyMap(),
+                  new byte[0]);
             });
 
     shuffleClient =
@@ -546,7 +552,8 @@ public class ShuffleClientSuiteJ {
                   locations,
                   new int[0],
                   Collections.emptySet(),
-                  Collections.emptyMap());
+                  Collections.emptyMap(),
+                  new byte[0]);
             });
 
     when(endpointRef.askSync(any(), any(), any(Integer.class), 
any(Long.class), any()))
@@ -557,7 +564,8 @@ public class ShuffleClientSuiteJ {
                   locations,
                   new int[0],
                   Collections.emptySet(),
-                  Collections.emptyMap());
+                  Collections.emptyMap(),
+                  new byte[0]);
             });
 
     shuffleClient =
diff --git a/common/src/main/proto/TransportMessages.proto 
b/common/src/main/proto/TransportMessages.proto
index 553e95aa7..acf355756 100644
--- a/common/src/main/proto/TransportMessages.proto
+++ b/common/src/main/proto/TransportMessages.proto
@@ -390,6 +390,8 @@ message PbGetReducerFileGroupResponse {
   repeated int32 partitionIds = 4;
 
   map<string, PbPushFailedBatchSet> pushFailedBatches = 5;
+
+  bytes broadcast = 6;
 }
 
 message PbGetShuffleId {
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala 
b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
index f132cd189..b7698d4c4 100644
--- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
@@ -1057,6 +1057,10 @@ class CelebornConf(loadDefaults: Boolean) extends 
Cloneable with Logging with Se
     get(CLIENT_PUSH_DYNAMIC_WRITE_MODE_ENABLED)
   def dynamicWriteModePartitionNumThreshold =
     get(CLIENT_PUSH_DYNAMIC_WRITE_MODE_PARTITION_NUM_THRESHOLD)
+  def getReducerFileGroupBroadcastEnabled =
+    get(CLIENT_SHUFFLE_GET_REDUCER_FILE_GROUP_BROADCAST_ENABLED)
+  def getReducerFileGroupBroadcastMiniSize =
+    get(CLIENT_SHUFFLE_GET_REDUCER_FILE_GROUP_BROADCAST_MINI_SIZE)
   def shufflePartitionType: PartitionType = 
PartitionType.valueOf(get(SHUFFLE_PARTITION_TYPE))
   def shuffleRangeReadFilterEnabled: Boolean = 
get(SHUFFLE_RANGE_READ_FILTER_ENABLED)
   def shuffleForceFallbackEnabled: Boolean = 
get(SPARK_SHUFFLE_FORCE_FALLBACK_ENABLED)
@@ -5213,6 +5217,27 @@ object CelebornConf extends Logging {
       .intConf
       .createWithDefault(2000)
 
+  val CLIENT_SHUFFLE_GET_REDUCER_FILE_GROUP_BROADCAST_ENABLED =
+    
buildConf("celeborn.client.spark.shuffle.getReducerFileGroup.broadcast.enabled")
+      .categories("client")
+      .doc(
+        "Whether to leverage Spark broadcast mechanism to send the 
GetReducerFileGroupResponse. " +
+          "If the response size is large and Spark executor number is large, 
the Spark driver network " +
+          "may be exhausted because each executor will pull the response from 
the driver. With broadcasting " +
+          "GetReducerFileGroupResponse, it prevents the driver from being the 
bottleneck in sending out multiple " +
+          "copies of the GetReducerFileGroupResponse (one per executor).")
+      .version("0.6.0")
+      .booleanConf
+      .createWithDefault(false)
+
+  val CLIENT_SHUFFLE_GET_REDUCER_FILE_GROUP_BROADCAST_MINI_SIZE =
+    
buildConf("celeborn.client.spark.shuffle.getReducerFileGroup.broadcast.miniSize")
+      .categories("client")
+      .doc("The size at which we use Broadcast to send the 
GetReducerFileGroupResponse to the executors.")
+      .version("0.6.0")
+      .bytesConf(ByteUnit.BYTE)
+      .createWithDefaultString("512k")
+
   val SPARK_SHUFFLE_WRITER_MODE: ConfigEntry[String] =
     buildConf("celeborn.client.spark.shuffle.writer")
       .withAlternative("celeborn.shuffle.writer")
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
 
b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
index e730b9b7a..949b13322 100644
--- 
a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
+++ 
b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
@@ -22,6 +22,7 @@ import java.util.{Collections, UUID}
 
 import scala.collection.JavaConverters._
 
+import com.google.protobuf.ByteString
 import org.roaringbitmap.RoaringBitmap
 
 import org.apache.celeborn.common.identity.UserIdentifier
@@ -285,10 +286,11 @@ object ControlMessages extends Logging {
   // Path can't be serialized
   case class GetReducerFileGroupResponse(
       status: StatusCode,
-      fileGroup: util.Map[Integer, util.Set[PartitionLocation]],
-      attempts: Array[Int],
+      fileGroup: util.Map[Integer, util.Set[PartitionLocation]] = 
Collections.emptyMap(),
+      attempts: Array[Int] = Array.emptyIntArray,
       partitionIds: util.Set[Integer] = Collections.emptySet[Integer](),
-      pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]] = 
Collections.emptyMap())
+      pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]] = 
Collections.emptyMap(),
+      broadcast: Array[Byte] = Array.emptyByteArray)
     extends MasterMessage
 
   object WorkerExclude {
@@ -752,7 +754,13 @@ object ControlMessages extends Logging {
         .build().toByteArray
       new TransportMessage(MessageType.GET_REDUCER_FILE_GROUP, payload)
 
-    case GetReducerFileGroupResponse(status, fileGroup, attempts, 
partitionIds, failedBatches) =>
+    case GetReducerFileGroupResponse(
+          status,
+          fileGroup,
+          attempts,
+          partitionIds,
+          failedBatches,
+          broadcast) =>
       val builder = PbGetReducerFileGroupResponse
         .newBuilder()
         .setStatus(status.getValue)
@@ -770,6 +778,7 @@ object ControlMessages extends Logging {
           case (uniqueId, pushFailedBatchSet) =>
             (uniqueId, PbSerDeUtils.toPbPushFailedBatchSet(pushFailedBatchSet))
         }.asJava)
+      builder.setBroadcast(ByteString.copyFrom(broadcast))
       val payload = builder.build().toByteArray
       new TransportMessage(MessageType.GET_REDUCER_FILE_GROUP_RESPONSE, 
payload)
 
@@ -1198,12 +1207,14 @@ object ControlMessages extends Logging {
           case (uniqueId, pushFailedBatchSet) =>
             (uniqueId, 
PbSerDeUtils.fromPbPushFailedBatchSet(pushFailedBatchSet))
         }.toMap.asJava
+        val broadcast = pbGetReducerFileGroupResponse.getBroadcast.toByteArray
         GetReducerFileGroupResponse(
           Utils.toStatusCode(pbGetReducerFileGroupResponse.getStatus),
           fileGroup,
           attempts,
           partitionIds,
-          pushFailedBatches)
+          pushFailedBatches,
+          broadcast)
 
       case GET_SHUFFLE_ID_VALUE =>
         message.getParsedPayload()
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/util/KeyLock.scala 
b/common/src/main/scala/org/apache/celeborn/common/util/KeyLock.scala
new file mode 100644
index 000000000..97f7d2c65
--- /dev/null
+++ b/common/src/main/scala/org/apache/celeborn/common/util/KeyLock.scala
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.util
+
+import java.util.concurrent.{Callable, ConcurrentHashMap}
+
+/**
+ * This class is copied from Apache Spark.
+ * A special locking mechanism to provide locking with a given key. By 
providing the same key
+ * (identity is tested using the `equals` method), we ensure there is only one 
`func` running at
+ * the same time.
+ *
+ * @tparam K the type of key to identify a lock. This type must implement 
`equals` and `hashCode`
+ *           correctly as it will be the key type of an internal Map.
+ */
+class KeyLock[K] {
+
+  private val lockMap = new ConcurrentHashMap[K, AnyRef]()
+
+  private def acquireLock(key: K): Unit = {
+    while (true) {
+      val lock = lockMap.putIfAbsent(key, new Object)
+      if (lock == null) return
+      lock.synchronized {
+        while (lockMap.get(key) eq lock) {
+          lock.wait()
+        }
+      }
+    }
+  }
+
+  private def releaseLock(key: K): Unit = {
+    val lock = lockMap.remove(key)
+    lock.synchronized {
+      lock.notifyAll()
+    }
+  }
+
+  /**
+   * Run `func` under a lock identified by the given key. Multiple calls with 
the same key
+   * (identity is tested using the `equals` method) will be locked properly to 
ensure there is only
+   * one `func` running at the same time.
+   */
+  def withLock[T](key: K)(func: Callable[T]): T = {
+    if (key == null) {
+      throw new NullPointerException("key must not be null")
+    }
+    acquireLock(key)
+    try {
+      func.call()
+    } finally {
+      releaseLock(key)
+    }
+  }
+}
diff --git a/docs/configuration/client.md b/docs/configuration/client.md
index f42994cb3..5b3a25308 100644
--- a/docs/configuration/client.md
+++ b/docs/configuration/client.md
@@ -122,6 +122,8 @@ license: |
 | celeborn.client.spark.shuffle.fallback.numPartitionsThreshold | 2147483647 | 
false | Celeborn will only accept shuffle of partition number lower than this 
configuration value. This configuration only takes effect when 
`celeborn.client.spark.shuffle.fallback.policy` is `AUTO`. | 0.5.0 | 
celeborn.shuffle.forceFallback.numPartitionsThreshold,celeborn.client.spark.shuffle.forceFallback.numPartitionsThreshold
 | 
 | celeborn.client.spark.shuffle.fallback.policy | AUTO | false | Celeborn 
supports the following kind of fallback policies. 1. ALWAYS: always use spark 
built-in shuffle implementation; 2. AUTO: prefer to use celeborn shuffle 
implementation, and fallback to use spark built-in shuffle implementation based 
on certain factors, e.g. availability of enough workers and quota, shuffle 
partition number; 3. NEVER: always use celeborn shuffle implementation, and 
fail fast when it it is concluded th [...]
 | celeborn.client.spark.shuffle.forceFallback.enabled | false | false | Always 
use spark built-in shuffle implementation. This configuration is deprecated, 
consider configuring `celeborn.client.spark.shuffle.fallback.policy` instead. | 
0.3.0 | celeborn.shuffle.forceFallback.enabled | 
+| celeborn.client.spark.shuffle.getReducerFileGroup.broadcast.enabled | false 
| false | Whether to leverage Spark broadcast mechanism to send the 
GetReducerFileGroupResponse. If the response size is large and Spark executor 
number is large, the Spark driver network may be exhausted because each 
executor will pull the response from the driver. With broadcasting 
GetReducerFileGroupResponse, it prevents the driver from being the bottleneck 
in sending out multiple copies of the GetReducerFil [...]
+| celeborn.client.spark.shuffle.getReducerFileGroup.broadcast.miniSize | 512k 
| false | The size at which we use Broadcast to send the 
GetReducerFileGroupResponse to the executors. | 0.6.0 |  | 
 | celeborn.client.spark.shuffle.writer | HASH | false | Celeborn supports the 
following kind of shuffle writers. 1. hash: hash-based shuffle writer works 
fine when shuffle partition count is normal; 2. sort: sort-based shuffle writer 
works fine when memory pressure is high or shuffle partition count is huge. 
This configuration only takes effect when 
celeborn.client.spark.push.dynamicWriteMode.enabled is false. | 0.3.0 | 
celeborn.shuffle.writer | 
 | celeborn.client.spark.stageRerun.enabled | true | false | Whether to enable 
stage rerun. If true, client throws FetchFailedException instead of 
CelebornIOException. | 0.4.0 | celeborn.client.spark.fetch.throwsFetchFailure | 
 | celeborn.identity.provider | 
org.apache.celeborn.common.identity.DefaultIdentityProvider | false | 
IdentityProvider class name. Default class is 
`org.apache.celeborn.common.identity.DefaultIdentityProvider`. Optional values: 
org.apache.celeborn.common.identity.HadoopBasedIdentityProvider user name will 
be obtained by UserGroupInformation.getUserName; 
org.apache.celeborn.common.identity.DefaultIdentityProvider user name and 
tenant id are default values or user-specific values. | 0.6.0 | [...]
diff --git 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornHashSuite.scala
 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornHashSuite.scala
index 8912217ce..afc2956ac 100644
--- 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornHashSuite.scala
+++ 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornHashSuite.scala
@@ -18,11 +18,13 @@
 package org.apache.celeborn.tests.spark
 
 import org.apache.spark.SparkConf
+import org.apache.spark.shuffle.celeborn.SparkUtils
 import org.apache.spark.sql.SparkSession
 import org.scalatest.BeforeAndAfterEach
 import org.scalatest.funsuite.AnyFunSuite
 
 import org.apache.celeborn.client.ShuffleClient
+import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.protocol.ShuffleMode
 
 class CelebornHashSuite extends AnyFunSuite
@@ -64,4 +66,43 @@ class CelebornHashSuite extends AnyFunSuite
 
     celebornSparkSession.stop()
   }
+
+  test("celeborn spark integration test - GetReducerFileGroupResponse 
broadcast") {
+    SparkUtils.getReducerFileGroupResponseBroadcasts.clear()
+    SparkUtils.getReducerFileGroupResponseBroadcastNum.set(0)
+    val sparkConf = new 
SparkConf().setAppName("celeborn-demo").setMaster("local[2]")
+      .set(
+        
s"spark.${CelebornConf.CLIENT_SHUFFLE_GET_REDUCER_FILE_GROUP_BROADCAST_ENABLED.key}",
+        "true")
+      .set(
+        
s"spark.${CelebornConf.CLIENT_SHUFFLE_GET_REDUCER_FILE_GROUP_BROADCAST_MINI_SIZE.key}",
+        "0")
+    val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate()
+    val combineResult = combine(sparkSession)
+    val groupbyResult = groupBy(sparkSession)
+    val repartitionResult = repartition(sparkSession)
+    val sqlResult = runsql(sparkSession)
+
+    Thread.sleep(3000L)
+    sparkSession.stop()
+
+    val celebornSparkSession = SparkSession.builder()
+      .config(updateSparkConf(sparkConf, ShuffleMode.HASH))
+      .getOrCreate()
+    val celebornCombineResult = combine(celebornSparkSession)
+    val celebornGroupbyResult = groupBy(celebornSparkSession)
+    val celebornRepartitionResult = repartition(celebornSparkSession)
+    val celebornSqlResult = runsql(celebornSparkSession)
+
+    assert(combineResult.equals(celebornCombineResult))
+    assert(groupbyResult.equals(celebornGroupbyResult))
+    assert(repartitionResult.equals(celebornRepartitionResult))
+    assert(combineResult.equals(celebornCombineResult))
+    assert(sqlResult.equals(celebornSqlResult))
+    assert(SparkUtils.getReducerFileGroupResponseBroadcastNum.get() > 0)
+
+    celebornSparkSession.stop()
+    SparkUtils.getReducerFileGroupResponseBroadcasts.clear()
+    SparkUtils.getReducerFileGroupResponseBroadcastNum.set(0)
+  }
 }
diff --git 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornSortSuite.scala
 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornSortSuite.scala
index e4b2dd574..e4f6cc93b 100644
--- 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornSortSuite.scala
+++ 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornSortSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.celeborn.tests.spark
 
 import org.apache.spark.SparkConf
+import org.apache.spark.shuffle.celeborn.SparkUtils
 import org.apache.spark.sql.SparkSession
 import org.scalatest.BeforeAndAfterEach
 import org.scalatest.funsuite.AnyFunSuite
@@ -66,4 +67,45 @@ class CelebornSortSuite extends AnyFunSuite
 
     celebornSparkSession.stop()
   }
+
+  test("celeborn spark integration test - GetReducerFileGroupResponse 
broadcast") {
+    SparkUtils.getReducerFileGroupResponseBroadcasts.clear()
+    SparkUtils.getReducerFileGroupResponseBroadcastNum.set(0)
+    val sparkConf = new 
SparkConf().setAppName("celeborn-demo").setMaster("local[2]")
+      
.set(s"spark.${CelebornConf.CLIENT_PUSH_SORT_RANDOMIZE_PARTITION_ENABLED.key}", 
"false")
+      .set(
+        
s"spark.${CelebornConf.CLIENT_SHUFFLE_GET_REDUCER_FILE_GROUP_BROADCAST_ENABLED.key}",
+        "true")
+      .set(
+        
s"spark.${CelebornConf.CLIENT_SHUFFLE_GET_REDUCER_FILE_GROUP_BROADCAST_MINI_SIZE.key}",
+        "0")
+
+    val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate()
+    val combineResult = combine(sparkSession)
+    val groupbyResult = groupBy(sparkSession)
+    val repartitionResult = repartition(sparkSession)
+    val sqlResult = runsql(sparkSession)
+
+    Thread.sleep(3000L)
+    sparkSession.stop()
+
+    val celebornSparkSession = SparkSession.builder()
+      .config(updateSparkConf(sparkConf, ShuffleMode.SORT))
+      .getOrCreate()
+    val celebornCombineResult = combine(celebornSparkSession)
+    val celebornGroupbyResult = groupBy(celebornSparkSession)
+    val celebornRepartitionResult = repartition(celebornSparkSession)
+    val celebornSqlResult = runsql(celebornSparkSession)
+
+    assert(combineResult.equals(celebornCombineResult))
+    assert(groupbyResult.equals(celebornGroupbyResult))
+    assert(repartitionResult.equals(celebornRepartitionResult))
+    assert(combineResult.equals(celebornCombineResult))
+    assert(sqlResult.equals(celebornSqlResult))
+    assert(SparkUtils.getReducerFileGroupResponseBroadcastNum.get() > 0)
+
+    celebornSparkSession.stop()
+    SparkUtils.getReducerFileGroupResponseBroadcasts.clear()
+    SparkUtils.getReducerFileGroupResponseBroadcastNum.set(0)
+  }
 }
diff --git 
a/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala
 
b/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala
index 6b4bc13b8..83a5e12f6 100644
--- 
a/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala
+++ 
b/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala
@@ -29,7 +29,9 @@ import org.scalatest.funsuite.AnyFunSuite
 import org.scalatest.time.SpanSugar.convertIntToGrainOfTime
 
 import org.apache.celeborn.client.ShuffleClient
-import org.apache.celeborn.common.protocol.ShuffleMode
+import org.apache.celeborn.common.protocol.{PartitionLocation, ShuffleMode}
+import 
org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse
+import org.apache.celeborn.common.protocol.message.StatusCode
 import org.apache.celeborn.tests.spark.SparkTestBase
 
 class SparkUtilsSuite extends AnyFunSuite
@@ -157,4 +159,61 @@ class SparkUtilsSuite extends AnyFunSuite
       sparkSession.stop()
     }
   }
+
+  test("serialize/deserialize GetReducerFileGroupResponse with broadcast") {
+    val sparkConf = new 
SparkConf().setAppName("rss-demo").setMaster("local[2,3]")
+    val sparkSession = SparkSession.builder()
+      .config(updateSparkConf(sparkConf, ShuffleMode.HASH))
+      .config("spark.sql.shuffle.partitions", 2)
+      .config("spark.celeborn.shuffle.forceFallback.partition.enabled", false)
+      .config("spark.celeborn.client.spark.fetch.throwsFetchFailure", "true")
+      .config(
+        "spark.shuffle.manager",
+        "org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager")
+      .getOrCreate()
+
+    try {
+      val shuffleId = Integer.MAX_VALUE
+      val getReducerFileGroupResponse = GetReducerFileGroupResponse(
+        StatusCode.SUCCESS,
+        Map(Integer.valueOf(shuffleId) -> Set(new PartitionLocation(
+          0,
+          1,
+          "localhost",
+          1,
+          2,
+          3,
+          4,
+          PartitionLocation.Mode.REPLICA)).asJava).asJava,
+        Array(1),
+        Set(Integer.valueOf(shuffleId)).asJava)
+
+      val serializedBytes =
+        SparkUtils.serializeGetReducerFileGroupResponse(shuffleId, 
getReducerFileGroupResponse)
+      assert(serializedBytes != null && serializedBytes.length > 0)
+      val broadcast = 
SparkUtils.getReducerFileGroupResponseBroadcasts.values().asScala.head._1
+      assert(broadcast.isValid)
+
+      val deserializedGetReducerFileGroupResponse =
+        SparkUtils.deserializeGetReducerFileGroupResponse(shuffleId, 
serializedBytes)
+      assert(deserializedGetReducerFileGroupResponse.status == 
getReducerFileGroupResponse.status)
+      assert(
+        deserializedGetReducerFileGroupResponse.fileGroup == 
getReducerFileGroupResponse.fileGroup)
+      assert(java.util.Arrays.equals(
+        deserializedGetReducerFileGroupResponse.attempts,
+        getReducerFileGroupResponse.attempts))
+      assert(deserializedGetReducerFileGroupResponse.partitionIds == 
getReducerFileGroupResponse.partitionIds)
+      assert(
+        deserializedGetReducerFileGroupResponse.pushFailedBatches == 
getReducerFileGroupResponse.pushFailedBatches)
+
+      assert(!SparkUtils.getReducerFileGroupResponseBroadcasts.isEmpty)
+      SparkUtils.invalidateSerializedGetReducerFileGroupResponse(shuffleId)
+      assert(SparkUtils.getReducerFileGroupResponseBroadcasts.isEmpty)
+      assert(!broadcast.isValid)
+    } finally {
+      sparkSession.stop()
+      SparkUtils.getReducerFileGroupResponseBroadcasts.clear()
+      SparkUtils.getReducerFileGroupResponseBroadcastNum.set(0)
+    }
+  }
 }

Reply via email to