This is an automated email from the ASF dual-hosted git repository.

ulyssesyou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 716dd9a8b Revert SortShuffleManager changes in ColumnarShuffleManager 
(#6149)
716dd9a8b is described below

commit 716dd9a8ba92d46e70405483231c767f4ccf9259
Author: Ankita Victor <[email protected]>
AuthorDate: Thu Jun 20 07:37:21 2024 +0530

    Revert SortShuffleManager changes in ColumnarShuffleManager (#6149)
---
 .../shuffle/sort/ColumnarShuffleManager.scala      | 121 ++++++++++++++-------
 1 file changed, 79 insertions(+), 42 deletions(-)

diff --git 
a/gluten-core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala
 
b/gluten-core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala
index 06c6e6c0e..d8ba78cb9 100644
--- 
a/gluten-core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala
+++ 
b/gluten-core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala
@@ -20,6 +20,7 @@ import org.apache.spark.{ShuffleDependency, SparkConf, 
SparkEnv, TaskContext}
 import org.apache.spark.internal.Logging
 import org.apache.spark.serializer.SerializerManager
 import org.apache.spark.shuffle._
+import org.apache.spark.shuffle.api.ShuffleExecutorComponents
 import org.apache.spark.shuffle.sort.SortShuffleManager.canUseBatchFetch
 import org.apache.spark.storage.BlockId
 import org.apache.spark.util.collection.OpenHashSet
@@ -27,12 +28,13 @@ import org.apache.spark.util.collection.OpenHashSet
 import java.io.InputStream
 import java.util.concurrent.ConcurrentHashMap
 
+import scala.collection.JavaConverters._
+
 class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with 
Logging {
 
   import ColumnarShuffleManager._
 
-  private[this] lazy val sortShuffleManager: SortShuffleManager = new 
SortShuffleManager(conf)
-
+  private lazy val shuffleExecutorComponents = 
loadShuffleExecutorComponents(conf)
   override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf)
 
   /** A mapping from shuffle ids to the number of mappers producing output for 
those shuffles. */
@@ -47,9 +49,23 @@ class ColumnarShuffleManager(conf: SparkConf) extends 
ShuffleManager with Loggin
       new ColumnarShuffleHandle[K, V](
         shuffleId,
         dependency.asInstanceOf[ColumnarShuffleDependency[K, V, V]])
+    } else if (SortShuffleWriter.shouldBypassMergeSort(conf, dependency)) {
+      // If there are fewer than spark.shuffle.sort.bypassMergeThreshold 
partitions and we don't
+      // need map-side aggregation, then write numPartitions files directly 
and just concatenate
+      // them at the end. This avoids doing serialization and deserialization 
twice to merge
+      // together the spilled files, which would happen with the normal code 
path. The downside is
+      // having multiple files open at a time and thus more memory allocated 
to buffers.
+      new BypassMergeSortShuffleHandle[K, V](
+        shuffleId,
+        dependency.asInstanceOf[ShuffleDependency[K, V, V]])
+    } else if (SortShuffleManager.canUseSerializedShuffle(dependency)) {
+      // Otherwise, try to buffer map outputs in a serialized form, since this 
is more efficient:
+      new SerializedShuffleHandle[K, V](
+        shuffleId,
+        dependency.asInstanceOf[ShuffleDependency[K, V, V]])
     } else {
-      // Otherwise call default SortShuffleManager
-      sortShuffleManager.registerShuffle(shuffleId, dependency)
+      // Otherwise, buffer map outputs in a deserialized form:
+      new BaseShuffleHandle(shuffleId, dependency)
     }
   }
 
@@ -59,19 +75,39 @@ class ColumnarShuffleManager(conf: SparkConf) extends 
ShuffleManager with Loggin
       mapId: Long,
       context: TaskContext,
       metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
+    val mapTaskIds =
+      taskIdMapsForShuffle.computeIfAbsent(handle.shuffleId, _ => new 
OpenHashSet[Long](16))
+    mapTaskIds.synchronized {
+      mapTaskIds.add(context.taskAttemptId())
+    }
+    val env = SparkEnv.get
     handle match {
       case columnarShuffleHandle: ColumnarShuffleHandle[K @unchecked, V 
@unchecked] =>
-        val mapTaskIds =
-          taskIdMapsForShuffle.computeIfAbsent(handle.shuffleId, _ => new 
OpenHashSet[Long](16))
-        mapTaskIds.synchronized {
-          mapTaskIds.add(context.taskAttemptId())
-        }
         GlutenShuffleWriterWrapper.genColumnarShuffleWriter(
           shuffleBlockResolver,
           columnarShuffleHandle,
           mapId,
           metrics)
-      case _ => sortShuffleManager.getWriter(handle, mapId, context, metrics)
+      case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V 
@unchecked] =>
+        new UnsafeShuffleWriter(
+          env.blockManager,
+          context.taskMemoryManager(),
+          unsafeShuffleHandle,
+          mapId,
+          context,
+          env.conf,
+          metrics,
+          shuffleExecutorComponents)
+      case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V 
@unchecked] =>
+        new BypassMergeSortShuffleWriter(
+          env.blockManager,
+          bypassMergeSortHandle,
+          mapId,
+          env.conf,
+          metrics,
+          shuffleExecutorComponents)
+      case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
+        new SortShuffleWriter(other, mapId, context, shuffleExecutorComponents)
     }
   }
 
@@ -87,17 +123,17 @@ class ColumnarShuffleManager(conf: SparkConf) extends 
ShuffleManager with Loggin
       endPartition: Int,
       context: TaskContext,
       metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
+    val (blocksByAddress, canEnableBatchFetch) = {
+      GlutenShuffleUtils.getReaderParam(
+        handle,
+        startMapIndex,
+        endMapIndex,
+        startPartition,
+        endPartition)
+    }
+    val shouldBatchFetch =
+      canEnableBatchFetch && canUseBatchFetch(startPartition, endPartition, 
context)
     if (handle.isInstanceOf[ColumnarShuffleHandle[_, _]]) {
-      val (blocksByAddress, canEnableBatchFetch) = {
-        GlutenShuffleUtils.getReaderParam(
-          handle,
-          startMapIndex,
-          endMapIndex,
-          startPartition,
-          endPartition)
-      }
-      val shouldBatchFetch =
-        canEnableBatchFetch && canUseBatchFetch(startPartition, endPartition, 
context)
       new BlockStoreShuffleReader(
         handle.asInstanceOf[BaseShuffleHandle[K, _, C]],
         blocksByAddress,
@@ -107,43 +143,44 @@ class ColumnarShuffleManager(conf: SparkConf) extends 
ShuffleManager with Loggin
         shouldBatchFetch = shouldBatchFetch
       )
     } else {
-      sortShuffleManager.getReader(
-        handle,
-        startMapIndex,
-        endMapIndex,
-        startPartition,
-        endPartition,
+      new BlockStoreShuffleReader(
+        handle.asInstanceOf[BaseShuffleHandle[K, _, C]],
+        blocksByAddress,
         context,
-        metrics)
+        metrics,
+        shouldBatchFetch = shouldBatchFetch
+      )
     }
   }
 
   /** Remove a shuffle's metadata from the ShuffleManager. */
   override def unregisterShuffle(shuffleId: Int): Boolean = {
-    if (taskIdMapsForShuffle.contains(shuffleId)) {
-      Option(taskIdMapsForShuffle.remove(shuffleId)).foreach {
-        mapTaskIds =>
-          mapTaskIds.iterator.foreach {
-            mapId => shuffleBlockResolver.removeDataByMap(shuffleId, mapId)
-          }
-      }
-      true
-    } else {
-      sortShuffleManager.unregisterShuffle(shuffleId)
+    Option(taskIdMapsForShuffle.remove(shuffleId)).foreach {
+      mapTaskIds =>
+        mapTaskIds.iterator.foreach {
+          mapId => shuffleBlockResolver.removeDataByMap(shuffleId, mapId)
+        }
     }
+    true
   }
 
   /** Shut down this ShuffleManager. */
   override def stop(): Unit = {
-    if (!taskIdMapsForShuffle.isEmpty) {
-      shuffleBlockResolver.stop()
-    } else {
-      sortShuffleManager.stop
-    }
+    shuffleBlockResolver.stop()
   }
 }
 
 object ColumnarShuffleManager extends Logging {
+  private def loadShuffleExecutorComponents(conf: SparkConf): 
ShuffleExecutorComponents = {
+    val executorComponents = 
ShuffleDataIOUtils.loadShuffleDataIO(conf).executor()
+    val extraConfigs = 
conf.getAllWithPrefix(ShuffleDataIOUtils.SHUFFLE_SPARK_CONF_PREFIX).toMap
+    executorComponents.initializeExecutor(
+      conf.getAppId,
+      SparkEnv.get.executorId,
+      extraConfigs.asJava)
+    executorComponents
+  }
+
   private def bypassDecompressionSerializerManger =
     new SerializerManager(
       SparkEnv.get.serializer,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to