This is an automated email from the ASF dual-hosted git repository.

marong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 2ae80afbc [CORE] Use SortShuffleManager instance in 
ColumnarShuffleManager (#6022)
2ae80afbc is described below

commit 2ae80afbca24e92fca5c9c5d0849a37a5b5c15fd
Author: Ankita Victor <[email protected]>
AuthorDate: Thu Jun 13 06:27:10 2024 +0530

    [CORE] Use SortShuffleManager instance in ColumnarShuffleManager (#6022)
---
 .../apache/gluten/execution/FallbackSuite.scala    |  26 ++++-
 .../shuffle/sort/ColumnarShuffleManager.scala      | 121 +++++++--------------
 2 files changed, 67 insertions(+), 80 deletions(-)

diff --git 
a/backends-velox/src/test/scala/org/apache/gluten/execution/FallbackSuite.scala 
b/backends-velox/src/test/scala/org/apache/gluten/execution/FallbackSuite.scala
index 15a71ceb5..27d191b9e 100644
--- 
a/backends-velox/src/test/scala/org/apache/gluten/execution/FallbackSuite.scala
+++ 
b/backends-velox/src/test/scala/org/apache/gluten/execution/FallbackSuite.scala
@@ -20,8 +20,9 @@ import org.apache.gluten.GlutenConfig
 import org.apache.gluten.extension.GlutenPlan
 
 import org.apache.spark.SparkConf
-import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.{ColumnarShuffleExchangeExec, SparkPlan}
 import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, 
AQEShuffleReadExec}
+import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
 
 class FallbackSuite extends VeloxWholeStageTransformerSuite with 
AdaptiveSparkPlanHelper {
   protected val rootPath: String = getClass.getResource("/").getPath
@@ -71,6 +72,29 @@ class FallbackSuite extends VeloxWholeStageTransformerSuite 
with AdaptiveSparkPl
     collect(plan) { case v: VeloxColumnarToRowExec => v }.size
   }
 
+  private def collectColumnarShuffleExchange(plan: SparkPlan): Int = {
+    collect(plan) { case c: ColumnarShuffleExchangeExec => c }.size
+  }
+
+  private def collectShuffleExchange(plan: SparkPlan): Int = {
+    collect(plan) { case c: ShuffleExchangeExec => c }.size
+  }
+
+  test("fallback with shuffle manager") {
+    withSQLConf(GlutenConfig.COLUMNAR_SHUFFLE_ENABLED.key -> "false") {
+      runQueryAndCompare("select c1, count(*) from tmp1 group by c1") {
+        df =>
+          val plan = df.queryExecution.executedPlan
+
+          assert(collectColumnarShuffleExchange(plan) == 0)
+          assert(collectShuffleExchange(plan) == 1)
+
+          val wholeQueryColumnarToRow = collectColumnarToRow(plan)
+          assert(wholeQueryColumnarToRow == 2)
+      }
+    }
+  }
+
   test("fallback with collect") {
     withSQLConf(GlutenConfig.COLUMNAR_WHOLESTAGE_FALLBACK_THRESHOLD.key -> 
"1") {
       runQueryAndCompare("SELECT count(*) FROM tmp1") {
diff --git 
a/gluten-core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala
 
b/gluten-core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala
index d8ba78cb9..06c6e6c0e 100644
--- 
a/gluten-core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala
+++ 
b/gluten-core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala
@@ -20,7 +20,6 @@ import org.apache.spark.{ShuffleDependency, SparkConf, 
SparkEnv, TaskContext}
 import org.apache.spark.internal.Logging
 import org.apache.spark.serializer.SerializerManager
 import org.apache.spark.shuffle._
-import org.apache.spark.shuffle.api.ShuffleExecutorComponents
 import org.apache.spark.shuffle.sort.SortShuffleManager.canUseBatchFetch
 import org.apache.spark.storage.BlockId
 import org.apache.spark.util.collection.OpenHashSet
@@ -28,13 +27,12 @@ import org.apache.spark.util.collection.OpenHashSet
 import java.io.InputStream
 import java.util.concurrent.ConcurrentHashMap
 
-import scala.collection.JavaConverters._
-
 class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with 
Logging {
 
   import ColumnarShuffleManager._
 
-  private lazy val shuffleExecutorComponents = 
loadShuffleExecutorComponents(conf)
+  private[this] lazy val sortShuffleManager: SortShuffleManager = new 
SortShuffleManager(conf)
+
   override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf)
 
   /** A mapping from shuffle ids to the number of mappers producing output for 
those shuffles. */
@@ -49,23 +47,9 @@ class ColumnarShuffleManager(conf: SparkConf) extends 
ShuffleManager with Loggin
       new ColumnarShuffleHandle[K, V](
         shuffleId,
         dependency.asInstanceOf[ColumnarShuffleDependency[K, V, V]])
-    } else if (SortShuffleWriter.shouldBypassMergeSort(conf, dependency)) {
-      // If there are fewer than spark.shuffle.sort.bypassMergeThreshold 
partitions and we don't
-      // need map-side aggregation, then write numPartitions files directly 
and just concatenate
-      // them at the end. This avoids doing serialization and deserialization 
twice to merge
-      // together the spilled files, which would happen with the normal code 
path. The downside is
-      // having multiple files open at a time and thus more memory allocated 
to buffers.
-      new BypassMergeSortShuffleHandle[K, V](
-        shuffleId,
-        dependency.asInstanceOf[ShuffleDependency[K, V, V]])
-    } else if (SortShuffleManager.canUseSerializedShuffle(dependency)) {
-      // Otherwise, try to buffer map outputs in a serialized form, since this 
is more efficient:
-      new SerializedShuffleHandle[K, V](
-        shuffleId,
-        dependency.asInstanceOf[ShuffleDependency[K, V, V]])
     } else {
-      // Otherwise, buffer map outputs in a deserialized form:
-      new BaseShuffleHandle(shuffleId, dependency)
+      // Otherwise call default SortShuffleManager
+      sortShuffleManager.registerShuffle(shuffleId, dependency)
     }
   }
 
@@ -75,39 +59,19 @@ class ColumnarShuffleManager(conf: SparkConf) extends 
ShuffleManager with Loggin
       mapId: Long,
       context: TaskContext,
       metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
-    val mapTaskIds =
-      taskIdMapsForShuffle.computeIfAbsent(handle.shuffleId, _ => new 
OpenHashSet[Long](16))
-    mapTaskIds.synchronized {
-      mapTaskIds.add(context.taskAttemptId())
-    }
-    val env = SparkEnv.get
     handle match {
       case columnarShuffleHandle: ColumnarShuffleHandle[K @unchecked, V 
@unchecked] =>
+        val mapTaskIds =
+          taskIdMapsForShuffle.computeIfAbsent(handle.shuffleId, _ => new 
OpenHashSet[Long](16))
+        mapTaskIds.synchronized {
+          mapTaskIds.add(context.taskAttemptId())
+        }
         GlutenShuffleWriterWrapper.genColumnarShuffleWriter(
           shuffleBlockResolver,
           columnarShuffleHandle,
           mapId,
           metrics)
-      case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V 
@unchecked] =>
-        new UnsafeShuffleWriter(
-          env.blockManager,
-          context.taskMemoryManager(),
-          unsafeShuffleHandle,
-          mapId,
-          context,
-          env.conf,
-          metrics,
-          shuffleExecutorComponents)
-      case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V 
@unchecked] =>
-        new BypassMergeSortShuffleWriter(
-          env.blockManager,
-          bypassMergeSortHandle,
-          mapId,
-          env.conf,
-          metrics,
-          shuffleExecutorComponents)
-      case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
-        new SortShuffleWriter(other, mapId, context, shuffleExecutorComponents)
+      case _ => sortShuffleManager.getWriter(handle, mapId, context, metrics)
     }
   }
 
@@ -123,17 +87,17 @@ class ColumnarShuffleManager(conf: SparkConf) extends 
ShuffleManager with Loggin
       endPartition: Int,
       context: TaskContext,
       metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
-    val (blocksByAddress, canEnableBatchFetch) = {
-      GlutenShuffleUtils.getReaderParam(
-        handle,
-        startMapIndex,
-        endMapIndex,
-        startPartition,
-        endPartition)
-    }
-    val shouldBatchFetch =
-      canEnableBatchFetch && canUseBatchFetch(startPartition, endPartition, 
context)
     if (handle.isInstanceOf[ColumnarShuffleHandle[_, _]]) {
+      val (blocksByAddress, canEnableBatchFetch) = {
+        GlutenShuffleUtils.getReaderParam(
+          handle,
+          startMapIndex,
+          endMapIndex,
+          startPartition,
+          endPartition)
+      }
+      val shouldBatchFetch =
+        canEnableBatchFetch && canUseBatchFetch(startPartition, endPartition, 
context)
       new BlockStoreShuffleReader(
         handle.asInstanceOf[BaseShuffleHandle[K, _, C]],
         blocksByAddress,
@@ -143,44 +107,43 @@ class ColumnarShuffleManager(conf: SparkConf) extends 
ShuffleManager with Loggin
         shouldBatchFetch = shouldBatchFetch
       )
     } else {
-      new BlockStoreShuffleReader(
-        handle.asInstanceOf[BaseShuffleHandle[K, _, C]],
-        blocksByAddress,
+      sortShuffleManager.getReader(
+        handle,
+        startMapIndex,
+        endMapIndex,
+        startPartition,
+        endPartition,
         context,
-        metrics,
-        shouldBatchFetch = shouldBatchFetch
-      )
+        metrics)
     }
   }
 
   /** Remove a shuffle's metadata from the ShuffleManager. */
   override def unregisterShuffle(shuffleId: Int): Boolean = {
-    Option(taskIdMapsForShuffle.remove(shuffleId)).foreach {
-      mapTaskIds =>
-        mapTaskIds.iterator.foreach {
-          mapId => shuffleBlockResolver.removeDataByMap(shuffleId, mapId)
-        }
+    if (taskIdMapsForShuffle.contains(shuffleId)) {
+      Option(taskIdMapsForShuffle.remove(shuffleId)).foreach {
+        mapTaskIds =>
+          mapTaskIds.iterator.foreach {
+            mapId => shuffleBlockResolver.removeDataByMap(shuffleId, mapId)
+          }
+      }
+      true
+    } else {
+      sortShuffleManager.unregisterShuffle(shuffleId)
     }
-    true
   }
 
   /** Shut down this ShuffleManager. */
   override def stop(): Unit = {
-    shuffleBlockResolver.stop()
+    if (!taskIdMapsForShuffle.isEmpty) {
+      shuffleBlockResolver.stop()
+    } else {
+      sortShuffleManager.stop
+    }
   }
 }
 
 object ColumnarShuffleManager extends Logging {
-  private def loadShuffleExecutorComponents(conf: SparkConf): 
ShuffleExecutorComponents = {
-    val executorComponents = 
ShuffleDataIOUtils.loadShuffleDataIO(conf).executor()
-    val extraConfigs = 
conf.getAllWithPrefix(ShuffleDataIOUtils.SHUFFLE_SPARK_CONF_PREFIX).toMap
-    executorComponents.initializeExecutor(
-      conf.getAppId,
-      SparkEnv.get.executorId,
-      extraConfigs.asJava)
-    executorComponents
-  }
-
   private def bypassDecompressionSerializerManger =
     new SerializerManager(
       SparkEnv.get.serializer,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to