turboFei commented on code in PR #3158:
URL: https://github.com/apache/celeborn/pull/3158#discussion_r2004744350


##########
client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java:
##########
@@ -462,4 +470,121 @@ public static void addSparkListener(SparkListener 
listener) {
       sparkContext.addSparkListener(listener);
     }
   }
+
+  /**
+   * A [[KeyLock]] whose key is a shuffle id to ensure there is only one 
thread accessing the
+   * broadcast belonging to the shuffle id at a time.
+   */
+  private static KeyLock<Integer> shuffleBroadcastLock = new KeyLock();
+
+  @VisibleForTesting
+  public static AtomicInteger getReducerFileGroupResponseBroadcastNum = new 
AtomicInteger();
+
+  @VisibleForTesting
+  public static Map<Integer, Tuple2<Broadcast<GetReducerFileGroupResponse>, 
byte[]>>
+      getReducerFileGroupResponseBroadcasts = JavaUtils.newConcurrentHashMap();
+
+  public static byte[] serializeGetReducerFileGroupResponse(
+      Integer shuffleId, GetReducerFileGroupResponse response) {
+    SparkContext sparkContext = 
SparkContext$.MODULE$.getActive().getOrElse(null);
+    if (sparkContext == null) {
+      LOG.error("Can not get active SparkContext.");
+      return null;
+    }
+
+    return shuffleBroadcastLock.withLock(
+        shuffleId,
+        () -> {
+          Tuple2<Broadcast<GetReducerFileGroupResponse>, byte[]>
+              cachedSerializeGetReducerFileGroupResponse =
+                  getReducerFileGroupResponseBroadcasts.get(shuffleId);
+          if (cachedSerializeGetReducerFileGroupResponse != null) {
+            return cachedSerializeGetReducerFileGroupResponse._2;
+          }
+
+          try {
+            LOG.info("Broadcasting GetReducerFileGroupResponse for shuffle: 
{}", shuffleId);
+            Broadcast<GetReducerFileGroupResponse> broadcast =
+                sparkContext.broadcast(
+                    response,
+                    scala.reflect.ClassManifestFactory.fromClass(
+                        GetReducerFileGroupResponse.class));
+
+            CompressionCodec codec = 
CompressionCodec.createCodec(sparkContext.conf());
+            // Using `org.apache.commons.io.output.ByteArrayOutputStream` 
instead of the standard
+            // one
+            // This implementation doesn't reallocate the whole memory block 
but allocates
+            // additional buffers. This way no buffers need to be garbage 
collected and
+            // the contents don't have to be copied to the new buffer.
+            org.apache.commons.io.output.ByteArrayOutputStream out =
+                new org.apache.commons.io.output.ByteArrayOutputStream();
+            try (ObjectOutputStream oos =
+                new ObjectOutputStream(codec.compressedOutputStream(out))) {
+              oos.writeObject(broadcast);
+            }
+            byte[] _serializeResult = out.toByteArray();
+            getReducerFileGroupResponseBroadcasts.put(
+                shuffleId, Tuple2.apply(broadcast, _serializeResult));
+            getReducerFileGroupResponseBroadcastNum.incrementAndGet();
+            return _serializeResult;
+          } catch (Throwable e) {
+            LOG.error(
+                "Failed to serialize GetReducerFileGroupResponse for shuffle: 
" + shuffleId, e);
+            return null;
+          }
+        });
+  }
+
+  public static GetReducerFileGroupResponse 
deserializeGetReducerFileGroupResponse(
+      Integer shuffleId, byte[] bytes) {
+    SparkContext sparkContext = 
SparkContext$.MODULE$.getActive().getOrElse(null);
+    if (sparkContext == null) {
+      LOG.error("Can not get active SparkContext.");
+      return null;
+    }
+
+    return shuffleBroadcastLock.withLock(
+        shuffleId,
+        () -> {
+          GetReducerFileGroupResponse response = null;
+          LOG.info(
+              "Deserializing GetReducerFileGroupResponse broadcast for 
shuffle: {}", shuffleId);
+
+          try {
+            CompressionCodec codec = 
CompressionCodec.createCodec(sparkContext.conf());
+            try (ObjectInputStream objIn =
+                new ObjectInputStream(
+                    codec.compressedInputStream(new 
ByteArrayInputStream(bytes)))) {
+              Broadcast<GetReducerFileGroupResponse> broadcast =
+                  (Broadcast<GetReducerFileGroupResponse>) objIn.readObject();
+              response = broadcast.value();
+            }
+          } catch (Throwable e) {
+            LOG.error(
+                "Failed to deserialize GetReducerFileGroupResponse for 
shuffle: " + shuffleId, e);
+          }
+          return response;
+        });
+  }
+
+  public static void invalidateSerializedGetReducerFileGroupResponse(Integer 
shuffleId) {
+    shuffleBroadcastLock.withLock(
+        shuffleId,
+        () -> {
+          try {
+            Tuple2<Broadcast<GetReducerFileGroupResponse>, byte[]>
+                cachedSerializeGetReducerFileGroupResponse =
+                    getReducerFileGroupResponseBroadcasts.remove(shuffleId);
+            if (cachedSerializeGetReducerFileGroupResponse != null) {
+              cachedSerializeGetReducerFileGroupResponse._1().destroy();
+            }
+          } catch (Throwable e) {
+            LOG.error(
+                "Failed to invalidate serialized GetReducerFileGroupResponse 
for shuffle: "
+                    + shuffleId,
+                e);
+          }
+          return null;
+        });
+  }

Review Comment:
   Refer Spark `MapOutputTracker` 
https://github.com/apache/spark/blob/8d260084b8a50ff59a127c7292c0cdb6737981b0/core/src/main/scala/org/apache/spark/MapOutputTracker.scala#L393
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to