turboFei commented on code in PR #3158:
URL: https://github.com/apache/celeborn/pull/3158#discussion_r2014205885
##########
client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java:
##########
@@ -476,4 +489,121 @@ public static boolean
isCelebornSkewShuffleOrChildShuffle(int appShuffleId) {
return false;
}
}
+
+ /**
+ * A [[KeyLock]] whose key is a shuffle id to ensure there is only one
thread accessing the
+ * broadcast belonging to the shuffle id at a time.
+ */
+ private static KeyLock<Integer> shuffleBroadcastLock = new KeyLock();
Review Comment:
thanks, adopted for both spark3 and spark2
##########
client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java:
##########
@@ -476,4 +489,121 @@ public static boolean
isCelebornSkewShuffleOrChildShuffle(int appShuffleId) {
return false;
}
}
+
+ /**
+ * A [[KeyLock]] whose key is a shuffle id to ensure there is only one
thread accessing the
+ * broadcast belonging to the shuffle id at a time.
+ */
+ private static KeyLock<Integer> shuffleBroadcastLock = new KeyLock();
+
+ @VisibleForTesting
+ public static AtomicInteger getReducerFileGroupResponseBroadcastNum = new
AtomicInteger();
+
+ @VisibleForTesting
+ public static Map<Integer, Tuple2<Broadcast<TransportMessage>, byte[]>>
+ getReducerFileGroupResponseBroadcasts = JavaUtils.newConcurrentHashMap();
+
+ public static byte[] serializeGetReducerFileGroupResponse(
+ Integer shuffleId, GetReducerFileGroupResponse response) {
+ SparkContext sparkContext =
SparkContext$.MODULE$.getActive().getOrElse(null);
+ if (sparkContext == null) {
+ LOG.error("Can not get active SparkContext.");
+ return null;
+ }
+
+ return shuffleBroadcastLock.withLock(
+ shuffleId,
+ () -> {
+ Tuple2<Broadcast<TransportMessage>, byte[]>
cachedSerializeGetReducerFileGroupResponse =
+ getReducerFileGroupResponseBroadcasts.get(shuffleId);
+ if (cachedSerializeGetReducerFileGroupResponse != null) {
+ return cachedSerializeGetReducerFileGroupResponse._2;
+ }
+
+ try {
+ LOG.info("Broadcasting GetReducerFileGroupResponse for shuffle:
{}", shuffleId);
+ TransportMessage transportMessage =
+ (TransportMessage) Utils.toTransportMessage(response);
+ Broadcast<TransportMessage> broadcast =
+ sparkContext.broadcast(
+ transportMessage,
+
scala.reflect.ClassManifestFactory.fromClass(TransportMessage.class));
+
+ CompressionCodec codec =
CompressionCodec$.MODULE$.createCodec(sparkContext.conf());
+ // Using `org.apache.commons.io.output.ByteArrayOutputStream`
instead of the standard
+ // one
+ // This implementation doesn't reallocate the whole memory block
but allocates
+ // additional buffers. This way no buffers need to be garbage
collected and
+ // the contents don't have to be copied to the new buffer.
+ org.apache.commons.io.output.ByteArrayOutputStream out =
+ new org.apache.commons.io.output.ByteArrayOutputStream();
+ try (ObjectOutputStream oos =
+ new ObjectOutputStream(codec.compressedOutputStream(out))) {
+ oos.writeObject(broadcast);
+ }
+ byte[] _serializeResult = out.toByteArray();
+ getReducerFileGroupResponseBroadcasts.put(
+ shuffleId, Tuple2.apply(broadcast, _serializeResult));
+ getReducerFileGroupResponseBroadcastNum.incrementAndGet();
+ return _serializeResult;
+ } catch (Throwable e) {
+ LOG.error(
+ "Failed to serialize GetReducerFileGroupResponse for shuffle:
" + shuffleId, e);
Review Comment:
thanks, adopted for both spark3 and spark2
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]