This is an automated email from the ASF dual-hosted git repository.
feiwang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new 5e12b7d60 [CELEBORN-1921] Broadcast large GetReducerFileGroupResponse
to prevent Spark driver network exhausted
5e12b7d60 is described below
commit 5e12b7d6075787d2f70382674b5b552c31aa9080
Author: Wang, Fei <[email protected]>
AuthorDate: Tue Apr 1 08:29:21 2025 -0700
[CELEBORN-1921] Broadcast large GetReducerFileGroupResponse to prevent
Spark driver network exhausted
### What changes were proposed in this pull request?
For spark celeborn application, if the GetReducerFileGroupResponse is
larger than the threshold, Spark driver would broadcast the
GetReducerFileGroupResponse to the executors, it prevents the driver from being
the bottleneck in sending out multiple copies of the
GetReducerFileGroupResponse (one per executor).
### Why are the changes needed?
To prevent the driver from being the bottleneck in sending out multiple
copies of the GetReducerFileGroupResponse (one per executor).
### Does this PR introduce _any_ user-facing change?
No, the feature is not enabled by defaults.
### How was this patch tested?
UT.
Cluster testing with
`spark.celeborn.client.spark.shuffle.getReducerFileGroup.broadcast.enabled=true`.
The broadcast response size should be always about 1kb.


Application succeed.

Closes #3158 from turboFei/broadcast_rgf.
Authored-by: Wang, Fei <[email protected]>
Signed-off-by: Wang, Fei <[email protected]>
---
LICENSE | 1 +
client-spark/spark-2-shaded/pom.xml | 1 +
.../shuffle/celeborn/SparkShuffleManager.java | 9 ++
.../apache/spark/shuffle/celeborn/SparkUtils.java | 129 ++++++++++++++++++++
.../shuffle/celeborn/CelebornShuffleReader.scala | 11 ++
client-spark/spark-3-shaded/pom.xml | 1 +
.../shuffle/celeborn/SparkShuffleManager.java | 9 ++
.../apache/spark/shuffle/celeborn/SparkUtils.java | 130 +++++++++++++++++++++
.../shuffle/celeborn/CelebornShuffleReader.scala | 13 ++-
.../org/apache/celeborn/client/ShuffleClient.java | 24 ++++
.../apache/celeborn/client/ShuffleClientImpl.java | 8 ++
.../org/apache/celeborn/client/CommitManager.scala | 3 +-
.../apache/celeborn/client/LifecycleManager.scala | 32 +++++
.../commit/ReducePartitionCommitHandler.scala | 52 ++++++++-
.../celeborn/client/ShuffleClientSuiteJ.java | 24 ++--
common/src/main/proto/TransportMessages.proto | 2 +
.../org/apache/celeborn/common/CelebornConf.scala | 25 ++++
.../common/protocol/message/ControlMessages.scala | 21 +++-
.../org/apache/celeborn/common/util/KeyLock.scala | 70 +++++++++++
docs/configuration/client.md | 2 +
.../celeborn/tests/spark/CelebornHashSuite.scala | 41 +++++++
.../celeborn/tests/spark/CelebornSortSuite.scala | 42 +++++++
.../spark/shuffle/celeborn/SparkUtilsSuite.scala | 61 +++++++++-
23 files changed, 690 insertions(+), 21 deletions(-)
diff --git a/LICENSE b/LICENSE
index 402c728a7..f757490f9 100644
--- a/LICENSE
+++ b/LICENSE
@@ -223,6 +223,7 @@ Apache Spark
./common/src/test/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeaderSuiteJ.java
./common/src/test/java/org/apache/celeborn/common/network/ssl/ReloadingX509TrustManagerSuiteJ.java
./common/src/test/java/org/apache/celeborn/common/network/ssl/SslSampleConfigs.java
+./common/src/main/scala/org/apache/celeborn/common/util/KeyLock.scala
./worker/src/main/java/org/apache/celeborn/service/deploy/worker/shuffledb/DB.java
./worker/src/main/java/org/apache/celeborn/service/deploy/worker/shuffledb/DBIterator.java
./worker/src/main/java/org/apache/celeborn/service/deploy/worker/shuffledb/LevelDB.java
diff --git a/client-spark/spark-2-shaded/pom.xml
b/client-spark/spark-2-shaded/pom.xml
index 3551b1caa..cfbc02079 100644
--- a/client-spark/spark-2-shaded/pom.xml
+++ b/client-spark/spark-2-shaded/pom.xml
@@ -73,6 +73,7 @@
<include>io.netty:*</include>
<include>org.apache.commons:commons-lang3</include>
<include>org.roaringbitmap:RoaringBitmap</include>
+ <include>commons-io:commons-io</include>
</includes>
</artifactSet>
<filters>
diff --git
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
index 67f29eced..41fe777b7 100644
---
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
+++
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
@@ -109,6 +109,15 @@ public class SparkShuffleManager implements ShuffleManager
{
lifecycleManager.registerShuffleTrackerCallback(
shuffleId ->
mapOutputTracker.unregisterAllMapOutput(shuffleId));
}
+
+ if (celebornConf.getReducerFileGroupBroadcastEnabled()) {
+
lifecycleManager.registerBroadcastGetReducerFileGroupResponseCallback(
+ (shuffleId, getReducerFileGroupResponse) ->
+ SparkUtils.serializeGetReducerFileGroupResponse(
+ shuffleId, getReducerFileGroupResponse));
+ lifecycleManager.registerInvalidatedBroadcastCallback(
+ shuffleId ->
SparkUtils.invalidateSerializedGetReducerFileGroupResponse(shuffleId));
+ }
}
}
}
diff --git
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
index 7ac0f6583..c708c1e84 100644
---
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
+++
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
@@ -17,7 +17,10 @@
package org.apache.spark.shuffle.celeborn;
+import java.io.ByteArrayInputStream;
import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.HashSet;
@@ -25,6 +28,7 @@ import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.LongAdder;
import java.util.stream.Collectors;
@@ -37,7 +41,12 @@ import org.apache.spark.BarrierTaskContext;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.SparkContext$;
+import org.apache.spark.SparkEnv;
+import org.apache.spark.SparkEnv$;
import org.apache.spark.TaskContext;
+import org.apache.spark.broadcast.Broadcast;
+import org.apache.spark.io.CompressionCodec;
+import org.apache.spark.io.CompressionCodec$;
import org.apache.spark.scheduler.DAGScheduler;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
@@ -54,7 +63,10 @@ import org.slf4j.LoggerFactory;
import org.apache.celeborn.client.ShuffleClient;
import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.network.protocol.TransportMessage;
+import
org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse;
import org.apache.celeborn.common.util.JavaUtils;
+import org.apache.celeborn.common.util.KeyLock;
import org.apache.celeborn.common.util.Utils;
import org.apache.celeborn.reflect.DynFields;
@@ -346,4 +358,121 @@ public class SparkUtils {
sparkContext.addSparkListener(listener);
}
}
+
+ /**
+ * A [[KeyLock]] whose key is a shuffle id to ensure there is only one
thread accessing the
+ * broadcast belonging to the shuffle id at a time.
+ */
+ private static final KeyLock<Integer> shuffleBroadcastLock = new KeyLock<>();
+
+ @VisibleForTesting
+ public static AtomicInteger getReducerFileGroupResponseBroadcastNum = new
AtomicInteger();
+
+ @VisibleForTesting
+ public static Map<Integer, Tuple2<Broadcast<TransportMessage>, byte[]>>
+ getReducerFileGroupResponseBroadcasts = JavaUtils.newConcurrentHashMap();
+
+ public static byte[] serializeGetReducerFileGroupResponse(
+ Integer shuffleId, GetReducerFileGroupResponse response) {
+ SparkContext sparkContext =
SparkContext$.MODULE$.getActive().getOrElse(null);
+ if (sparkContext == null) {
+ logger.error("Can not get active SparkContext.");
+ return null;
+ }
+
+ return shuffleBroadcastLock.withLock(
+ shuffleId,
+ () -> {
+ Tuple2<Broadcast<TransportMessage>, byte[]>
cachedSerializeGetReducerFileGroupResponse =
+ getReducerFileGroupResponseBroadcasts.get(shuffleId);
+ if (cachedSerializeGetReducerFileGroupResponse != null) {
+ return cachedSerializeGetReducerFileGroupResponse._2;
+ }
+
+ try {
+ logger.info("Broadcasting GetReducerFileGroupResponse for shuffle:
{}", shuffleId);
+ TransportMessage transportMessage =
+ (TransportMessage) Utils.toTransportMessage(response);
+ Broadcast<TransportMessage> broadcast =
+ sparkContext.broadcast(
+ transportMessage,
+
scala.reflect.ClassManifestFactory.fromClass(TransportMessage.class));
+
+ CompressionCodec codec =
CompressionCodec$.MODULE$.createCodec(sparkContext.conf());
+ // Using `org.apache.commons.io.output.ByteArrayOutputStream`
instead of the standard
+ // one
+ // This implementation doesn't reallocate the whole memory block
but allocates
+ // additional buffers. This way no buffers need to be garbage
collected and
+ // the contents don't have to be copied to the new buffer.
+ org.apache.commons.io.output.ByteArrayOutputStream out =
+ new org.apache.commons.io.output.ByteArrayOutputStream();
+ try (ObjectOutputStream oos =
+ new ObjectOutputStream(codec.compressedOutputStream(out))) {
+ oos.writeObject(broadcast);
+ }
+ byte[] _serializeResult = out.toByteArray();
+ getReducerFileGroupResponseBroadcasts.put(
+ shuffleId, Tuple2.apply(broadcast, _serializeResult));
+ getReducerFileGroupResponseBroadcastNum.incrementAndGet();
+ return _serializeResult;
+ } catch (Throwable e) {
+ logger.error(
+ "Failed to serialize GetReducerFileGroupResponse for shuffle:
{}", shuffleId, e);
+ return null;
+ }
+ });
+ }
+
+ public static GetReducerFileGroupResponse
deserializeGetReducerFileGroupResponse(
+ Integer shuffleId, byte[] bytes) {
+ SparkEnv sparkEnv = SparkEnv$.MODULE$.get();
+ if (sparkEnv == null) {
+ logger.error("Can not get SparkEnv.");
+ return null;
+ }
+
+ return shuffleBroadcastLock.withLock(
+ shuffleId,
+ () -> {
+ GetReducerFileGroupResponse response = null;
+ logger.info(
+ "Deserializing GetReducerFileGroupResponse broadcast for
shuffle: {}", shuffleId);
+
+ try {
+ CompressionCodec codec =
CompressionCodec$.MODULE$.createCodec(sparkEnv.conf());
+ try (ObjectInputStream objIn =
+ new ObjectInputStream(
+ codec.compressedInputStream(new
ByteArrayInputStream(bytes)))) {
+ Broadcast<TransportMessage> broadcast =
+ (Broadcast<TransportMessage>) objIn.readObject();
+ response =
+ (GetReducerFileGroupResponse)
Utils.fromTransportMessage(broadcast.value());
+ }
+ } catch (Throwable e) {
+ logger.error(
+ "Failed to deserialize GetReducerFileGroupResponse for
shuffle: " + shuffleId, e);
+ }
+ return response;
+ });
+ }
+
+ public static void invalidateSerializedGetReducerFileGroupResponse(Integer
shuffleId) {
+ shuffleBroadcastLock.withLock(
+ shuffleId,
+ () -> {
+ try {
+ Tuple2<Broadcast<TransportMessage>, byte[]>
cachedSerializeGetReducerFileGroupResponse =
+ getReducerFileGroupResponseBroadcasts.remove(shuffleId);
+ if (cachedSerializeGetReducerFileGroupResponse != null) {
+ cachedSerializeGetReducerFileGroupResponse._1().destroy();
+ }
+ } catch (Throwable e) {
+ logger.error(
+ "Failed to invalidate serialized GetReducerFileGroupResponse
for shuffle: "
+ + shuffleId,
+ e);
+ }
+ return null;
+ });
+ }
}
diff --git
a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
index a60c68f3d..2269aaf68 100644
---
a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
+++
b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
@@ -20,6 +20,7 @@ package org.apache.spark.shuffle.celeborn
import java.io.IOException
import java.util.concurrent.{ThreadPoolExecutor, TimeUnit}
import java.util.concurrent.atomic.AtomicReference
+import java.util.function.BiFunction
import org.apache.spark.{Aggregator, InterruptibleIterator, TaskContext}
import org.apache.spark.internal.Logging
@@ -33,6 +34,7 @@ import org.apache.celeborn.client.read.CelebornInputStream
import org.apache.celeborn.client.read.MetricsCallback
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.exception.{CelebornIOException,
PartitionUnRetryAbleException}
+import
org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse
import org.apache.celeborn.common.util.{JavaUtils, ThreadUtils}
class CelebornShuffleReader[K, C](
@@ -254,4 +256,13 @@ class CelebornShuffleReader[K, C](
object CelebornShuffleReader {
var streamCreatorPool: ThreadPoolExecutor = null
+ // Register the deserializer for GetReducerFileGroupResponse broadcast
+ ShuffleClient.registerDeserializeReducerFileGroupResponseFunction(new
BiFunction[
+ Integer,
+ Array[Byte],
+ GetReducerFileGroupResponse] {
+ override def apply(shuffleId: Integer, broadcast: Array[Byte]):
GetReducerFileGroupResponse = {
+ SparkUtils.deserializeGetReducerFileGroupResponse(shuffleId, broadcast)
+ }
+ })
}
diff --git a/client-spark/spark-3-shaded/pom.xml
b/client-spark/spark-3-shaded/pom.xml
index d3d59cb87..8cce8577a 100644
--- a/client-spark/spark-3-shaded/pom.xml
+++ b/client-spark/spark-3-shaded/pom.xml
@@ -77,6 +77,7 @@
<include>io.netty:*</include>
<include>org.apache.commons:commons-lang3</include>
<include>org.roaringbitmap:RoaringBitmap</include>
+ <include>commons-io:commons-io</include>
</includes>
</artifactSet>
<filters>
diff --git
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
index 234fba1fe..a3e75cd10 100644
---
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
+++
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
@@ -156,6 +156,15 @@ public class SparkShuffleManager implements ShuffleManager
{
SparkUtils::isCelebornSkewShuffleOrChildShuffle);
}
}
+
+ if (celebornConf.getReducerFileGroupBroadcastEnabled()) {
+
lifecycleManager.registerBroadcastGetReducerFileGroupResponseCallback(
+ (shuffleId, getReducerFileGroupResponse) ->
+ SparkUtils.serializeGetReducerFileGroupResponse(
+ shuffleId, getReducerFileGroupResponse));
+ lifecycleManager.registerInvalidatedBroadcastCallback(
+ shuffleId ->
SparkUtils.invalidateSerializedGetReducerFileGroupResponse(shuffleId));
+ }
}
}
}
diff --git
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
index 6443c2163..4efc29a90 100644
---
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
+++
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
@@ -17,11 +17,15 @@
package org.apache.spark.shuffle.celeborn;
+import java.io.ByteArrayInputStream;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.LongAdder;
import java.util.stream.Collectors;
@@ -35,7 +39,12 @@ import org.apache.spark.MapOutputTrackerMaster;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.SparkContext$;
+import org.apache.spark.SparkEnv;
+import org.apache.spark.SparkEnv$;
import org.apache.spark.TaskContext;
+import org.apache.spark.broadcast.Broadcast;
+import org.apache.spark.io.CompressionCodec;
+import org.apache.spark.io.CompressionCodec$;
import org.apache.spark.scheduler.DAGScheduler;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
@@ -57,7 +66,11 @@ import org.slf4j.LoggerFactory;
import org.apache.celeborn.client.ShuffleClient;
import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.network.protocol.TransportMessage;
+import
org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse;
import org.apache.celeborn.common.util.JavaUtils;
+import org.apache.celeborn.common.util.KeyLock;
+import org.apache.celeborn.common.util.Utils;
import org.apache.celeborn.reflect.DynConstructors;
import org.apache.celeborn.reflect.DynFields;
import org.apache.celeborn.reflect.DynMethods;
@@ -476,4 +489,121 @@ public class SparkUtils {
return false;
}
}
+
+ /**
+ * A [[KeyLock]] whose key is a shuffle id to ensure there is only one
thread accessing the
+ * broadcast belonging to the shuffle id at a time.
+ */
+ private static final KeyLock<Integer> shuffleBroadcastLock = new KeyLock<>();
+
+ @VisibleForTesting
+ public static AtomicInteger getReducerFileGroupResponseBroadcastNum = new
AtomicInteger();
+
+ @VisibleForTesting
+ public static Map<Integer, Tuple2<Broadcast<TransportMessage>, byte[]>>
+ getReducerFileGroupResponseBroadcasts = JavaUtils.newConcurrentHashMap();
+
+ public static byte[] serializeGetReducerFileGroupResponse(
+ Integer shuffleId, GetReducerFileGroupResponse response) {
+ SparkContext sparkContext =
SparkContext$.MODULE$.getActive().getOrElse(null);
+ if (sparkContext == null) {
+ LOG.error("Can not get active SparkContext.");
+ return null;
+ }
+
+ return shuffleBroadcastLock.withLock(
+ shuffleId,
+ () -> {
+ Tuple2<Broadcast<TransportMessage>, byte[]>
cachedSerializeGetReducerFileGroupResponse =
+ getReducerFileGroupResponseBroadcasts.get(shuffleId);
+ if (cachedSerializeGetReducerFileGroupResponse != null) {
+ return cachedSerializeGetReducerFileGroupResponse._2;
+ }
+
+ try {
+ LOG.info("Broadcasting GetReducerFileGroupResponse for shuffle:
{}", shuffleId);
+ TransportMessage transportMessage =
+ (TransportMessage) Utils.toTransportMessage(response);
+ Broadcast<TransportMessage> broadcast =
+ sparkContext.broadcast(
+ transportMessage,
+
scala.reflect.ClassManifestFactory.fromClass(TransportMessage.class));
+
+ CompressionCodec codec =
CompressionCodec$.MODULE$.createCodec(sparkContext.conf());
+ // Using `org.apache.commons.io.output.ByteArrayOutputStream`
instead of the standard
+ // one
+ // This implementation doesn't reallocate the whole memory block
but allocates
+ // additional buffers. This way no buffers need to be garbage
collected and
+ // the contents don't have to be copied to the new buffer.
+ org.apache.commons.io.output.ByteArrayOutputStream out =
+ new org.apache.commons.io.output.ByteArrayOutputStream();
+ try (ObjectOutputStream oos =
+ new ObjectOutputStream(codec.compressedOutputStream(out))) {
+ oos.writeObject(broadcast);
+ }
+ byte[] _serializeResult = out.toByteArray();
+ getReducerFileGroupResponseBroadcasts.put(
+ shuffleId, Tuple2.apply(broadcast, _serializeResult));
+ getReducerFileGroupResponseBroadcastNum.incrementAndGet();
+ return _serializeResult;
+ } catch (Throwable e) {
+ LOG.error(
+ "Failed to serialize GetReducerFileGroupResponse for shuffle:
{}", shuffleId, e);
+ return null;
+ }
+ });
+ }
+
+ public static GetReducerFileGroupResponse
deserializeGetReducerFileGroupResponse(
+ Integer shuffleId, byte[] bytes) {
+ SparkEnv sparkEnv = SparkEnv$.MODULE$.get();
+ if (sparkEnv == null) {
+ LOG.error("Can not get SparkEnv.");
+ return null;
+ }
+
+ return shuffleBroadcastLock.withLock(
+ shuffleId,
+ () -> {
+ GetReducerFileGroupResponse response = null;
+ LOG.info(
+ "Deserializing GetReducerFileGroupResponse broadcast for
shuffle: {}", shuffleId);
+
+ try {
+ CompressionCodec codec =
CompressionCodec$.MODULE$.createCodec(sparkEnv.conf());
+ try (ObjectInputStream objIn =
+ new ObjectInputStream(
+ codec.compressedInputStream(new
ByteArrayInputStream(bytes)))) {
+ Broadcast<TransportMessage> broadcast =
+ (Broadcast<TransportMessage>) objIn.readObject();
+ response =
+ (GetReducerFileGroupResponse)
Utils.fromTransportMessage(broadcast.value());
+ }
+ } catch (Throwable e) {
+ LOG.error(
+ "Failed to deserialize GetReducerFileGroupResponse for
shuffle: " + shuffleId, e);
+ }
+ return response;
+ });
+ }
+
+ public static void invalidateSerializedGetReducerFileGroupResponse(Integer
shuffleId) {
+ shuffleBroadcastLock.withLock(
+ shuffleId,
+ () -> {
+ try {
+ Tuple2<Broadcast<TransportMessage>, byte[]>
cachedSerializeGetReducerFileGroupResponse =
+ getReducerFileGroupResponseBroadcasts.remove(shuffleId);
+ if (cachedSerializeGetReducerFileGroupResponse != null) {
+ cachedSerializeGetReducerFileGroupResponse._1().destroy();
+ }
+ } catch (Throwable e) {
+ LOG.error(
+ "Failed to invalidate serialized GetReducerFileGroupResponse
for shuffle: "
+ + shuffleId,
+ e);
+ }
+ return null;
+ });
+ }
}
diff --git
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
index af29b57dd..3e296b310 100644
---
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
+++
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
@@ -21,6 +21,7 @@ import java.io.IOException
import java.util.{ArrayList => JArrayList, HashMap => JHashMap, Map => JMap,
Set => JSet}
import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor,
TimeoutException, TimeUnit}
import java.util.concurrent.atomic.AtomicReference
+import java.util.function.BiFunction
import scala.collection.JavaConverters._
@@ -43,7 +44,8 @@ import
org.apache.celeborn.common.exception.{CelebornIOException, PartitionUnRet
import org.apache.celeborn.common.network.client.TransportClient
import org.apache.celeborn.common.network.protocol.TransportMessage
import org.apache.celeborn.common.protocol._
-import org.apache.celeborn.common.protocol.message.StatusCode
+import org.apache.celeborn.common.protocol.message.{ControlMessages,
StatusCode}
+import
org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse
import org.apache.celeborn.common.util.{JavaUtils, ThreadUtils, Utils}
class CelebornShuffleReader[K, C](
@@ -465,4 +467,13 @@ class CelebornShuffleReader[K, C](
object CelebornShuffleReader {
var streamCreatorPool: ThreadPoolExecutor = null
+ // Register the deserializer for GetReducerFileGroupResponse broadcast
+ ShuffleClient.registerDeserializeReducerFileGroupResponseFunction(new
BiFunction[
+ Integer,
+ Array[Byte],
+ GetReducerFileGroupResponse] {
+ override def apply(shuffleId: Integer, broadcast: Array[Byte]):
GetReducerFileGroupResponse = {
+ SparkUtils.deserializeGetReducerFileGroupResponse(shuffleId, broadcast)
+ }
+ })
}
diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
index 8690f5ceb..dde2b36c4 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
@@ -20,9 +20,11 @@ package org.apache.celeborn.client;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Map;
+import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.LongAdder;
+import java.util.function.BiFunction;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.hadoop.fs.FileSystem;
@@ -38,6 +40,7 @@ import
org.apache.celeborn.common.network.client.TransportClientFactory;
import org.apache.celeborn.common.protocol.PartitionLocation;
import org.apache.celeborn.common.protocol.PbStreamHandler;
import org.apache.celeborn.common.protocol.StorageInfo;
+import org.apache.celeborn.common.protocol.message.ControlMessages;
import org.apache.celeborn.common.rpc.RpcEndpointRef;
import org.apache.celeborn.common.util.CelebornHadoopUtils;
import org.apache.celeborn.common.util.ExceptionMaker;
@@ -56,6 +59,10 @@ public abstract class ShuffleClient {
private static LongAdder totalReadCounter = new LongAdder();
private static LongAdder localShuffleReadCounter = new LongAdder();
+ private static volatile Optional<
+ BiFunction<Integer, byte[],
ControlMessages.GetReducerFileGroupResponse>>
+ deserializeReducerFileGroupResponseFunction = Optional.empty();
+
// for testing
public static void reset() {
_instance = null;
@@ -297,4 +304,21 @@ public abstract class ShuffleClient {
public abstract TransportClientFactory getDataClientFactory();
public abstract void excludeFailedFetchLocation(String hostAndFetchPort,
Exception e);
+
+ public static void registerDeserializeReducerFileGroupResponseFunction(
+ BiFunction<Integer, byte[], ControlMessages.GetReducerFileGroupResponse>
function) {
+ if (!deserializeReducerFileGroupResponseFunction.isPresent()) {
+ deserializeReducerFileGroupResponseFunction =
Optional.ofNullable(function);
+ }
+ }
+
+ public static ControlMessages.GetReducerFileGroupResponse
deserializeReducerFileGroupResponse(
+ int shuffleId, byte[] bytes) {
+ if (!deserializeReducerFileGroupResponseFunction.isPresent()) {
+ // Should never happen
+ logger.warn("DeserializeReducerFileGroupResponseFunction is not
registered.");
+ return null;
+ }
+ return deserializeReducerFileGroupResponseFunction.get().apply(shuffleId,
bytes);
+ }
}
diff --git
a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index be6b5aa3f..e7f4a9083 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -1814,6 +1814,14 @@ public class ShuffleClientImpl extends ShuffleClient {
ClassTag$.MODULE$.apply(GetReducerFileGroupResponse.class));
switch (response.status()) {
case SUCCESS:
+ if (response.broadcast() != null && response.broadcast().length > 0)
{
+ response =
+ ShuffleClient.deserializeReducerFileGroupResponse(shuffleId,
response.broadcast());
+ if (response == null) {
+ throw new CelebornIOException(
+ "Failed to get GetReducerFileGroupResponse broadcast for
shuffle: " + shuffleId);
+ }
+ }
logger.info(
"Shuffle {} request reducer file group success using {} ms,
result partition size {}.",
shuffleId,
diff --git
a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
index aaeb4462f..ad0d66e3a 100644
--- a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
@@ -292,7 +292,8 @@ class CommitManager(appUniqueId: String, val conf:
CelebornConf, lifecycleManage
lifecycleManager.shuffleAllocatedWorkers,
committedPartitionInfo,
lifecycleManager.workerStatusTracker,
- lifecycleManager.rpcSharedThreadPool)
+ lifecycleManager.rpcSharedThreadPool,
+ lifecycleManager)
case PartitionType.MAP => new MapPartitionCommitHandler(
appUniqueId,
conf,
diff --git
a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
index f706eeb90..35d1945e1 100644
--- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
@@ -1677,6 +1677,7 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
} else {
batchRemoveShuffleIds += shuffleId
}
+ invalidatedBroadcastGetReducerFileGroupResponse(shuffleId)
}
}
if (batchRemoveShuffleIds.nonEmpty) {
@@ -1848,6 +1849,22 @@ class LifecycleManager(val appUniqueId: String, val
conf: CelebornConf) extends
cancelShuffleCallback = Some(callback)
}
+ @volatile private var broadcastGetReducerFileGroupResponseCallback
+ : Option[java.util.function.BiFunction[Integer,
GetReducerFileGroupResponse, Array[Byte]]] =
+ None
+ def registerBroadcastGetReducerFileGroupResponseCallback(call:
java.util.function.BiFunction[
+ Integer,
+ GetReducerFileGroupResponse,
+ Array[Byte]]): Unit = {
+ broadcastGetReducerFileGroupResponseCallback = Some(call)
+ }
+
+ @volatile private var invalidatedBroadcastCallback:
Option[Consumer[Integer]] =
+ None
+ def registerInvalidatedBroadcastCallback(call: Consumer[Integer]): Unit = {
+ invalidatedBroadcastCallback = Some(call)
+ }
+
def invalidateLatestMaxLocsCache(shuffleId: Int): Unit = {
registerShuffleResponseRpcCache.invalidate(shuffleId)
}
@@ -1889,4 +1906,19 @@ class LifecycleManager(val appUniqueId: String, val
conf: CelebornConf) extends
case _ =>
}
+ def broadcastGetReducerFileGroupResponse(
+ shuffleId: Int,
+ response: GetReducerFileGroupResponse): Option[Array[Byte]] = {
+ broadcastGetReducerFileGroupResponseCallback match {
+ case Some(c) => Option(c.apply(shuffleId, response))
+ case _ => None
+ }
+ }
+
+ private def invalidatedBroadcastGetReducerFileGroupResponse(shuffleId: Int):
Unit = {
+ invalidatedBroadcastCallback match {
+ case Some(c) => c.accept(shuffleId)
+ case _ =>
+ }
+ }
}
diff --git
a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
index 55639764c..98fe624fb 100644
---
a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
+++
b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
@@ -28,7 +28,7 @@ import scala.collection.mutable
import com.google.common.cache.{Cache, CacheBuilder}
import com.google.common.collect.Sets
-import org.apache.celeborn.client.{ClientUtils, ShuffleCommittedInfo,
WorkerStatusTracker}
+import org.apache.celeborn.client.{ClientUtils, LifecycleManager,
ShuffleCommittedInfo, WorkerStatusTracker}
import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo
import org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers,
ShuffleFailedWorkers}
import org.apache.celeborn.common.CelebornConf
@@ -55,7 +55,8 @@ class ReducePartitionCommitHandler(
shuffleAllocatedWorkers: ShuffleAllocatedWorkers,
committedPartitionInfo: CommittedPartitionInfo,
workerStatusTracker: WorkerStatusTracker,
- sharedRpcPool: ThreadPoolExecutor)
+ sharedRpcPool: ThreadPoolExecutor,
+ lifecycleManager: LifecycleManager)
extends CommitHandler(
appUniqueId,
conf,
@@ -78,6 +79,10 @@ class ReducePartitionCommitHandler(
private val rpcCacheConcurrencyLevel = conf.clientRpcCacheConcurrencyLevel
private val rpcCacheExpireTime = conf.clientRpcCacheExpireTime
+ private val getReducerFileGroupResponseBroadcastEnabled =
conf.getReducerFileGroupBroadcastEnabled
+ private val getReducerFileGroupResponseBroadcastMiniSize =
+ conf.getReducerFileGroupBroadcastMiniSize
+
// noinspection UnstableApiUsage
private val getReducerFileGroupRpcCache: Cache[Int, ByteBuffer] =
CacheBuilder.newBuilder()
.concurrencyLevel(rpcCacheConcurrencyLevel)
@@ -320,10 +325,17 @@ class ReducePartitionCommitHandler(
} else {
// LocalNettyRpcCallContext is for the UTs
if (context.isInstanceOf[LocalNettyRpcCallContext]) {
- context.reply(GetReducerFileGroupResponse(
+ var response = GetReducerFileGroupResponse(
StatusCode.SUCCESS,
reducerFileGroupsMap.getOrDefault(shuffleId,
JavaUtils.newConcurrentHashMap()),
- getMapperAttempts(shuffleId)))
+ getMapperAttempts(shuffleId))
+
+ // only check whether broadcast enabled for the UTs
+ if (getReducerFileGroupResponseBroadcastEnabled) {
+ response = broadcastGetReducerFileGroup(shuffleId, response)
+ }
+
+ context.reply(response)
} else {
val cachedMsg = getReducerFileGroupRpcCache.get(
shuffleId,
@@ -337,7 +349,27 @@ class ReducePartitionCommitHandler(
shufflePushFailedBatches.getOrDefault(
shuffleId,
new util.HashMap[String, util.Set[PushFailedBatch]]()))
-
context.asInstanceOf[RemoteNettyRpcCallContext].nettyEnv.serialize(returnedMsg)
+
+ val serializedMsg =
+
context.asInstanceOf[RemoteNettyRpcCallContext].nettyEnv.serialize(returnedMsg)
+
+ if (getReducerFileGroupResponseBroadcastEnabled &&
+ serializedMsg.capacity() >=
getReducerFileGroupResponseBroadcastMiniSize) {
+ val broadcastMsg = broadcastGetReducerFileGroup(shuffleId,
returnedMsg)
+ if (broadcastMsg != returnedMsg) {
+ val serializedBroadcastMsg =
+
context.asInstanceOf[RemoteNettyRpcCallContext].nettyEnv.serialize(broadcastMsg)
+ logInfo(s"Shuffle $shuffleId GetReducerFileGroupResponse
size" +
+ s" ${serializedMsg.capacity()} reached the broadcast
threshold" +
+ s" $getReducerFileGroupResponseBroadcastMiniSize," +
+ s" the broadcast response size is
${serializedBroadcastMsg.capacity()}.")
+ serializedBroadcastMsg
+ } else {
+ serializedMsg
+ }
+ } else {
+ serializedMsg
+ }
}
})
context.asInstanceOf[RemoteNettyRpcCallContext].callback.onSuccess(cachedMsg)
@@ -345,6 +377,16 @@ class ReducePartitionCommitHandler(
}
}
+ private def broadcastGetReducerFileGroup(
+ shuffleId: Int,
+ response: GetReducerFileGroupResponse): GetReducerFileGroupResponse = {
+ lifecycleManager.broadcastGetReducerFileGroupResponse(shuffleId, response)
match {
+ case Some(broadcastBytes) if broadcastBytes.nonEmpty =>
+ GetReducerFileGroupResponse(response.status, broadcast =
broadcastBytes)
+ case _ => response
+ }
+ }
+
override def handleGetReducerFileGroup(context: RpcCallContext, shuffleId:
Int): Unit = {
// Quick return for ended stage, avoid occupy sync lock.
if (isStageEnd(shuffleId)) {
diff --git
a/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java
b/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java
index ade8b1cdb..a5076a59f 100644
--- a/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java
+++ b/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java
@@ -427,7 +427,8 @@ public class ShuffleClientSuiteJ {
locations,
new int[0],
Collections.emptySet(),
- Collections.emptyMap());
+ Collections.emptyMap(),
+ new byte[0]);
});
when(endpointRef.askSync(any(), any(), any(Integer.class),
any(Long.class), any()))
@@ -439,7 +440,8 @@ public class ShuffleClientSuiteJ {
locations,
new int[0],
Collections.emptySet(),
- Collections.emptyMap());
+ Collections.emptyMap(),
+ new byte[0]);
});
shuffleClient =
@@ -482,7 +484,8 @@ public class ShuffleClientSuiteJ {
locations,
new int[0],
Collections.emptySet(),
- Collections.emptyMap());
+ Collections.emptyMap(),
+ new byte[0]);
});
when(endpointRef.askSync(any(), any(), any(Integer.class),
any(Long.class), any()))
@@ -493,7 +496,8 @@ public class ShuffleClientSuiteJ {
locations,
new int[0],
Collections.emptySet(),
- Collections.emptyMap());
+ Collections.emptyMap(),
+ new byte[0]);
});
shuffleClient =
@@ -514,7 +518,8 @@ public class ShuffleClientSuiteJ {
locations,
new int[0],
Collections.emptySet(),
- Collections.emptyMap());
+ Collections.emptyMap(),
+ new byte[0]);
});
when(endpointRef.askSync(any(), any(), any(Integer.class),
any(Long.class), any()))
@@ -525,7 +530,8 @@ public class ShuffleClientSuiteJ {
locations,
new int[0],
Collections.emptySet(),
- Collections.emptyMap());
+ Collections.emptyMap(),
+ new byte[0]);
});
shuffleClient =
@@ -546,7 +552,8 @@ public class ShuffleClientSuiteJ {
locations,
new int[0],
Collections.emptySet(),
- Collections.emptyMap());
+ Collections.emptyMap(),
+ new byte[0]);
});
when(endpointRef.askSync(any(), any(), any(Integer.class),
any(Long.class), any()))
@@ -557,7 +564,8 @@ public class ShuffleClientSuiteJ {
locations,
new int[0],
Collections.emptySet(),
- Collections.emptyMap());
+ Collections.emptyMap(),
+ new byte[0]);
});
shuffleClient =
diff --git a/common/src/main/proto/TransportMessages.proto
b/common/src/main/proto/TransportMessages.proto
index 553e95aa7..acf355756 100644
--- a/common/src/main/proto/TransportMessages.proto
+++ b/common/src/main/proto/TransportMessages.proto
@@ -390,6 +390,8 @@ message PbGetReducerFileGroupResponse {
repeated int32 partitionIds = 4;
map<string, PbPushFailedBatchSet> pushFailedBatches = 5;
+
+ bytes broadcast = 6;
}
message PbGetShuffleId {
diff --git
a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
index f132cd189..b7698d4c4 100644
--- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
@@ -1057,6 +1057,10 @@ class CelebornConf(loadDefaults: Boolean) extends
Cloneable with Logging with Se
get(CLIENT_PUSH_DYNAMIC_WRITE_MODE_ENABLED)
def dynamicWriteModePartitionNumThreshold =
get(CLIENT_PUSH_DYNAMIC_WRITE_MODE_PARTITION_NUM_THRESHOLD)
+ def getReducerFileGroupBroadcastEnabled =
+ get(CLIENT_SHUFFLE_GET_REDUCER_FILE_GROUP_BROADCAST_ENABLED)
+ def getReducerFileGroupBroadcastMiniSize =
+ get(CLIENT_SHUFFLE_GET_REDUCER_FILE_GROUP_BROADCAST_MINI_SIZE)
def shufflePartitionType: PartitionType =
PartitionType.valueOf(get(SHUFFLE_PARTITION_TYPE))
def shuffleRangeReadFilterEnabled: Boolean =
get(SHUFFLE_RANGE_READ_FILTER_ENABLED)
def shuffleForceFallbackEnabled: Boolean =
get(SPARK_SHUFFLE_FORCE_FALLBACK_ENABLED)
@@ -5213,6 +5217,27 @@ object CelebornConf extends Logging {
.intConf
.createWithDefault(2000)
+ val CLIENT_SHUFFLE_GET_REDUCER_FILE_GROUP_BROADCAST_ENABLED =
+
buildConf("celeborn.client.spark.shuffle.getReducerFileGroup.broadcast.enabled")
+ .categories("client")
+ .doc(
+ "Whether to leverage Spark broadcast mechanism to send the
GetReducerFileGroupResponse. " +
+ "If the response size is large and Spark executor number is large,
the Spark driver network " +
+ "may be exhausted because each executor will pull the response from
the driver. With broadcasting " +
+ "GetReducerFileGroupResponse, it prevents the driver from being the
bottleneck in sending out multiple " +
+ "copies of the GetReducerFileGroupResponse (one per executor).")
+ .version("0.6.0")
+ .booleanConf
+ .createWithDefault(false)
+
+ val CLIENT_SHUFFLE_GET_REDUCER_FILE_GROUP_BROADCAST_MINI_SIZE =
+
buildConf("celeborn.client.spark.shuffle.getReducerFileGroup.broadcast.miniSize")
+ .categories("client")
+ .doc("The size at which we use Broadcast to send the
GetReducerFileGroupResponse to the executors.")
+ .version("0.6.0")
+ .bytesConf(ByteUnit.BYTE)
+ .createWithDefaultString("512k")
+
val SPARK_SHUFFLE_WRITER_MODE: ConfigEntry[String] =
buildConf("celeborn.client.spark.shuffle.writer")
.withAlternative("celeborn.shuffle.writer")
diff --git
a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
index e730b9b7a..949b13322 100644
---
a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
+++
b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
@@ -22,6 +22,7 @@ import java.util.{Collections, UUID}
import scala.collection.JavaConverters._
+import com.google.protobuf.ByteString
import org.roaringbitmap.RoaringBitmap
import org.apache.celeborn.common.identity.UserIdentifier
@@ -285,10 +286,11 @@ object ControlMessages extends Logging {
// Path can't be serialized
case class GetReducerFileGroupResponse(
status: StatusCode,
- fileGroup: util.Map[Integer, util.Set[PartitionLocation]],
- attempts: Array[Int],
+ fileGroup: util.Map[Integer, util.Set[PartitionLocation]] =
Collections.emptyMap(),
+ attempts: Array[Int] = Array.emptyIntArray,
partitionIds: util.Set[Integer] = Collections.emptySet[Integer](),
- pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]] =
Collections.emptyMap())
+ pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]] =
Collections.emptyMap(),
+ broadcast: Array[Byte] = Array.emptyByteArray)
extends MasterMessage
object WorkerExclude {
@@ -752,7 +754,13 @@ object ControlMessages extends Logging {
.build().toByteArray
new TransportMessage(MessageType.GET_REDUCER_FILE_GROUP, payload)
- case GetReducerFileGroupResponse(status, fileGroup, attempts,
partitionIds, failedBatches) =>
+ case GetReducerFileGroupResponse(
+ status,
+ fileGroup,
+ attempts,
+ partitionIds,
+ failedBatches,
+ broadcast) =>
val builder = PbGetReducerFileGroupResponse
.newBuilder()
.setStatus(status.getValue)
@@ -770,6 +778,7 @@ object ControlMessages extends Logging {
case (uniqueId, pushFailedBatchSet) =>
(uniqueId, PbSerDeUtils.toPbPushFailedBatchSet(pushFailedBatchSet))
}.asJava)
+ builder.setBroadcast(ByteString.copyFrom(broadcast))
val payload = builder.build().toByteArray
new TransportMessage(MessageType.GET_REDUCER_FILE_GROUP_RESPONSE,
payload)
@@ -1198,12 +1207,14 @@ object ControlMessages extends Logging {
case (uniqueId, pushFailedBatchSet) =>
(uniqueId,
PbSerDeUtils.fromPbPushFailedBatchSet(pushFailedBatchSet))
}.toMap.asJava
+ val broadcast = pbGetReducerFileGroupResponse.getBroadcast.toByteArray
GetReducerFileGroupResponse(
Utils.toStatusCode(pbGetReducerFileGroupResponse.getStatus),
fileGroup,
attempts,
partitionIds,
- pushFailedBatches)
+ pushFailedBatches,
+ broadcast)
case GET_SHUFFLE_ID_VALUE =>
message.getParsedPayload()
diff --git
a/common/src/main/scala/org/apache/celeborn/common/util/KeyLock.scala
b/common/src/main/scala/org/apache/celeborn/common/util/KeyLock.scala
new file mode 100644
index 000000000..97f7d2c65
--- /dev/null
+++ b/common/src/main/scala/org/apache/celeborn/common/util/KeyLock.scala
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.util
+
+import java.util.concurrent.{Callable, ConcurrentHashMap}
+
+/**
+ * This class is copied from Apache Spark.
+ * A special locking mechanism to provide locking with a given key. By
providing the same key
+ * (identity is tested using the `equals` method), we ensure there is only one
`func` running at
+ * the same time.
+ *
+ * @tparam K the type of key to identify a lock. This type must implement
`equals` and `hashCode`
+ * correctly as it will be the key type of an internal Map.
+ */
+class KeyLock[K] {
+
+ private val lockMap = new ConcurrentHashMap[K, AnyRef]()
+
+ private def acquireLock(key: K): Unit = {
+ while (true) {
+ val lock = lockMap.putIfAbsent(key, new Object)
+ if (lock == null) return
+ lock.synchronized {
+ while (lockMap.get(key) eq lock) {
+ lock.wait()
+ }
+ }
+ }
+ }
+
+ private def releaseLock(key: K): Unit = {
+ val lock = lockMap.remove(key)
+ lock.synchronized {
+ lock.notifyAll()
+ }
+ }
+
+ /**
+ * Run `func` under a lock identified by the given key. Multiple calls with
the same key
+ * (identity is tested using the `equals` method) will be locked properly to
ensure there is only
+ * one `func` running at the same time.
+ */
+ def withLock[T](key: K)(func: Callable[T]): T = {
+ if (key == null) {
+ throw new NullPointerException("key must not be null")
+ }
+ acquireLock(key)
+ try {
+ func.call()
+ } finally {
+ releaseLock(key)
+ }
+ }
+}
diff --git a/docs/configuration/client.md b/docs/configuration/client.md
index f42994cb3..5b3a25308 100644
--- a/docs/configuration/client.md
+++ b/docs/configuration/client.md
@@ -122,6 +122,8 @@ license: |
| celeborn.client.spark.shuffle.fallback.numPartitionsThreshold | 2147483647 |
false | Celeborn will only accept shuffle of partition number lower than this
configuration value. This configuration only takes effect when
`celeborn.client.spark.shuffle.fallback.policy` is `AUTO`. | 0.5.0 |
celeborn.shuffle.forceFallback.numPartitionsThreshold,celeborn.client.spark.shuffle.forceFallback.numPartitionsThreshold
|
| celeborn.client.spark.shuffle.fallback.policy | AUTO | false | Celeborn
supports the following kind of fallback policies. 1. ALWAYS: always use spark
built-in shuffle implementation; 2. AUTO: prefer to use celeborn shuffle
implementation, and fallback to use spark built-in shuffle implementation based
on certain factors, e.g. availability of enough workers and quota, shuffle
partition number; 3. NEVER: always use celeborn shuffle implementation, and
fail fast when it it is concluded th [...]
| celeborn.client.spark.shuffle.forceFallback.enabled | false | false | Always
use spark built-in shuffle implementation. This configuration is deprecated,
consider configuring `celeborn.client.spark.shuffle.fallback.policy` instead. |
0.3.0 | celeborn.shuffle.forceFallback.enabled |
+| celeborn.client.spark.shuffle.getReducerFileGroup.broadcast.enabled | false
| false | Whether to leverage Spark broadcast mechanism to send the
GetReducerFileGroupResponse. If the response size is large and Spark executor
number is large, the Spark driver network may be exhausted because each
executor will pull the response from the driver. With broadcasting
GetReducerFileGroupResponse, it prevents the driver from being the bottleneck
in sending out multiple copies of the GetReducerFil [...]
+| celeborn.client.spark.shuffle.getReducerFileGroup.broadcast.miniSize | 512k
| false | The size at which we use Broadcast to send the
GetReducerFileGroupResponse to the executors. | 0.6.0 | |
| celeborn.client.spark.shuffle.writer | HASH | false | Celeborn supports the
following kind of shuffle writers. 1. hash: hash-based shuffle writer works
fine when shuffle partition count is normal; 2. sort: sort-based shuffle writer
works fine when memory pressure is high or shuffle partition count is huge.
This configuration only takes effect when
celeborn.client.spark.push.dynamicWriteMode.enabled is false. | 0.3.0 |
celeborn.shuffle.writer |
| celeborn.client.spark.stageRerun.enabled | true | false | Whether to enable
stage rerun. If true, client throws FetchFailedException instead of
CelebornIOException. | 0.4.0 | celeborn.client.spark.fetch.throwsFetchFailure |
| celeborn.identity.provider |
org.apache.celeborn.common.identity.DefaultIdentityProvider | false |
IdentityProvider class name. Default class is
`org.apache.celeborn.common.identity.DefaultIdentityProvider`. Optional values:
org.apache.celeborn.common.identity.HadoopBasedIdentityProvider user name will
be obtained by UserGroupInformation.getUserName;
org.apache.celeborn.common.identity.DefaultIdentityProvider user name and
tenant id are default values or user-specific values. | 0.6.0 | [...]
diff --git
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornHashSuite.scala
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornHashSuite.scala
index 8912217ce..afc2956ac 100644
---
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornHashSuite.scala
+++
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornHashSuite.scala
@@ -18,11 +18,13 @@
package org.apache.celeborn.tests.spark
import org.apache.spark.SparkConf
+import org.apache.spark.shuffle.celeborn.SparkUtils
import org.apache.spark.sql.SparkSession
import org.scalatest.BeforeAndAfterEach
import org.scalatest.funsuite.AnyFunSuite
import org.apache.celeborn.client.ShuffleClient
+import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.protocol.ShuffleMode
class CelebornHashSuite extends AnyFunSuite
@@ -64,4 +66,43 @@ class CelebornHashSuite extends AnyFunSuite
celebornSparkSession.stop()
}
+
+ test("celeborn spark integration test - GetReducerFileGroupResponse
broadcast") {
+ SparkUtils.getReducerFileGroupResponseBroadcasts.clear()
+ SparkUtils.getReducerFileGroupResponseBroadcastNum.set(0)
+ val sparkConf = new
SparkConf().setAppName("celeborn-demo").setMaster("local[2]")
+ .set(
+
s"spark.${CelebornConf.CLIENT_SHUFFLE_GET_REDUCER_FILE_GROUP_BROADCAST_ENABLED.key}",
+ "true")
+ .set(
+
s"spark.${CelebornConf.CLIENT_SHUFFLE_GET_REDUCER_FILE_GROUP_BROADCAST_MINI_SIZE.key}",
+ "0")
+ val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate()
+ val combineResult = combine(sparkSession)
+ val groupbyResult = groupBy(sparkSession)
+ val repartitionResult = repartition(sparkSession)
+ val sqlResult = runsql(sparkSession)
+
+ Thread.sleep(3000L)
+ sparkSession.stop()
+
+ val celebornSparkSession = SparkSession.builder()
+ .config(updateSparkConf(sparkConf, ShuffleMode.HASH))
+ .getOrCreate()
+ val celebornCombineResult = combine(celebornSparkSession)
+ val celebornGroupbyResult = groupBy(celebornSparkSession)
+ val celebornRepartitionResult = repartition(celebornSparkSession)
+ val celebornSqlResult = runsql(celebornSparkSession)
+
+ assert(combineResult.equals(celebornCombineResult))
+ assert(groupbyResult.equals(celebornGroupbyResult))
+ assert(repartitionResult.equals(celebornRepartitionResult))
+ assert(combineResult.equals(celebornCombineResult))
+ assert(sqlResult.equals(celebornSqlResult))
+ assert(SparkUtils.getReducerFileGroupResponseBroadcastNum.get() > 0)
+
+ celebornSparkSession.stop()
+ SparkUtils.getReducerFileGroupResponseBroadcasts.clear()
+ SparkUtils.getReducerFileGroupResponseBroadcastNum.set(0)
+ }
}
diff --git
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornSortSuite.scala
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornSortSuite.scala
index e4b2dd574..e4f6cc93b 100644
---
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornSortSuite.scala
+++
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornSortSuite.scala
@@ -18,6 +18,7 @@
package org.apache.celeborn.tests.spark
import org.apache.spark.SparkConf
+import org.apache.spark.shuffle.celeborn.SparkUtils
import org.apache.spark.sql.SparkSession
import org.scalatest.BeforeAndAfterEach
import org.scalatest.funsuite.AnyFunSuite
@@ -66,4 +67,45 @@ class CelebornSortSuite extends AnyFunSuite
celebornSparkSession.stop()
}
+
+ test("celeborn spark integration test - GetReducerFileGroupResponse
broadcast") {
+ SparkUtils.getReducerFileGroupResponseBroadcasts.clear()
+ SparkUtils.getReducerFileGroupResponseBroadcastNum.set(0)
+ val sparkConf = new
SparkConf().setAppName("celeborn-demo").setMaster("local[2]")
+
.set(s"spark.${CelebornConf.CLIENT_PUSH_SORT_RANDOMIZE_PARTITION_ENABLED.key}",
"false")
+ .set(
+
s"spark.${CelebornConf.CLIENT_SHUFFLE_GET_REDUCER_FILE_GROUP_BROADCAST_ENABLED.key}",
+ "true")
+ .set(
+
s"spark.${CelebornConf.CLIENT_SHUFFLE_GET_REDUCER_FILE_GROUP_BROADCAST_MINI_SIZE.key}",
+ "0")
+
+ val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate()
+ val combineResult = combine(sparkSession)
+ val groupbyResult = groupBy(sparkSession)
+ val repartitionResult = repartition(sparkSession)
+ val sqlResult = runsql(sparkSession)
+
+ Thread.sleep(3000L)
+ sparkSession.stop()
+
+ val celebornSparkSession = SparkSession.builder()
+ .config(updateSparkConf(sparkConf, ShuffleMode.SORT))
+ .getOrCreate()
+ val celebornCombineResult = combine(celebornSparkSession)
+ val celebornGroupbyResult = groupBy(celebornSparkSession)
+ val celebornRepartitionResult = repartition(celebornSparkSession)
+ val celebornSqlResult = runsql(celebornSparkSession)
+
+ assert(combineResult.equals(celebornCombineResult))
+ assert(groupbyResult.equals(celebornGroupbyResult))
+ assert(repartitionResult.equals(celebornRepartitionResult))
+ assert(combineResult.equals(celebornCombineResult))
+ assert(sqlResult.equals(celebornSqlResult))
+ assert(SparkUtils.getReducerFileGroupResponseBroadcastNum.get() > 0)
+
+ celebornSparkSession.stop()
+ SparkUtils.getReducerFileGroupResponseBroadcasts.clear()
+ SparkUtils.getReducerFileGroupResponseBroadcastNum.set(0)
+ }
}
diff --git
a/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala
b/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala
index 6b4bc13b8..83a5e12f6 100644
---
a/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala
+++
b/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala
@@ -29,7 +29,9 @@ import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.time.SpanSugar.convertIntToGrainOfTime
import org.apache.celeborn.client.ShuffleClient
-import org.apache.celeborn.common.protocol.ShuffleMode
+import org.apache.celeborn.common.protocol.{PartitionLocation, ShuffleMode}
+import
org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse
+import org.apache.celeborn.common.protocol.message.StatusCode
import org.apache.celeborn.tests.spark.SparkTestBase
class SparkUtilsSuite extends AnyFunSuite
@@ -157,4 +159,61 @@ class SparkUtilsSuite extends AnyFunSuite
sparkSession.stop()
}
}
+
+ test("serialize/deserialize GetReducerFileGroupResponse with broadcast") {
+ val sparkConf = new
SparkConf().setAppName("rss-demo").setMaster("local[2,3]")
+ val sparkSession = SparkSession.builder()
+ .config(updateSparkConf(sparkConf, ShuffleMode.HASH))
+ .config("spark.sql.shuffle.partitions", 2)
+ .config("spark.celeborn.shuffle.forceFallback.partition.enabled", false)
+ .config("spark.celeborn.client.spark.fetch.throwsFetchFailure", "true")
+ .config(
+ "spark.shuffle.manager",
+ "org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager")
+ .getOrCreate()
+
+ try {
+ val shuffleId = Integer.MAX_VALUE
+ val getReducerFileGroupResponse = GetReducerFileGroupResponse(
+ StatusCode.SUCCESS,
+ Map(Integer.valueOf(shuffleId) -> Set(new PartitionLocation(
+ 0,
+ 1,
+ "localhost",
+ 1,
+ 2,
+ 3,
+ 4,
+ PartitionLocation.Mode.REPLICA)).asJava).asJava,
+ Array(1),
+ Set(Integer.valueOf(shuffleId)).asJava)
+
+ val serializedBytes =
+ SparkUtils.serializeGetReducerFileGroupResponse(shuffleId,
getReducerFileGroupResponse)
+ assert(serializedBytes != null && serializedBytes.length > 0)
+ val broadcast =
SparkUtils.getReducerFileGroupResponseBroadcasts.values().asScala.head._1
+ assert(broadcast.isValid)
+
+ val deserializedGetReducerFileGroupResponse =
+ SparkUtils.deserializeGetReducerFileGroupResponse(shuffleId,
serializedBytes)
+ assert(deserializedGetReducerFileGroupResponse.status ==
getReducerFileGroupResponse.status)
+ assert(
+ deserializedGetReducerFileGroupResponse.fileGroup ==
getReducerFileGroupResponse.fileGroup)
+ assert(java.util.Arrays.equals(
+ deserializedGetReducerFileGroupResponse.attempts,
+ getReducerFileGroupResponse.attempts))
+ assert(deserializedGetReducerFileGroupResponse.partitionIds ==
getReducerFileGroupResponse.partitionIds)
+ assert(
+ deserializedGetReducerFileGroupResponse.pushFailedBatches ==
getReducerFileGroupResponse.pushFailedBatches)
+
+ assert(!SparkUtils.getReducerFileGroupResponseBroadcasts.isEmpty)
+ SparkUtils.invalidateSerializedGetReducerFileGroupResponse(shuffleId)
+ assert(SparkUtils.getReducerFileGroupResponseBroadcasts.isEmpty)
+ assert(!broadcast.isValid)
+ } finally {
+ sparkSession.stop()
+ SparkUtils.getReducerFileGroupResponseBroadcasts.clear()
+ SparkUtils.getReducerFileGroupResponseBroadcastNum.set(0)
+ }
+ }
}