This is an automated email from the ASF dual-hosted git repository.

philo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 781eb5d19d [GLUTEN-8060][CORE][VL] Various of fixes for the 
experimental `GlutenShuffleManager` (#8355)
781eb5d19d is described below

commit 781eb5d19d8512b976eb825136fa00e352cb62f4
Author: Hongze Zhang <[email protected]>
AuthorDate: Fri Dec 27 16:22:00 2024 +0800

    [GLUTEN-8060][CORE][VL] Various of fixes for the experimental 
`GlutenShuffleManager` (#8355)
---
 .../gluten/backendsapi/velox/VeloxBackend.scala    |  8 +++--
 .../backendsapi/velox/VeloxListenerApi.scala       | 17 +++++++++-
 .../apache/gluten/execution/VeloxTPCHSuite.scala   | 11 +++++++
 .../spark/shuffle/ShuffleManagerLookup.scala       |  4 +--
 .../spark/shuffle/ShuffleManagerRegistry.scala     | 22 ++++++++++---
 .../spark/shuffle/ShuffleManagerRouter.scala       | 36 ++++++++++++++++++++--
 .../spark/shuffle/GlutenShuffleManagerSuite.scala  |  5 +++
 .../scala/org/apache/gluten/GlutenConfig.scala     | 11 +++++--
 8 files changed, 98 insertions(+), 16 deletions(-)

diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
index daa398d370..dba41b1f8d 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
@@ -456,9 +456,11 @@ object VeloxBackendSettings extends BackendSettingsApi {
   }
 
   override def supportColumnarShuffleExec(): Boolean = {
-    GlutenConfig.getConf.enableColumnarShuffle && 
(GlutenConfig.getConf.isUseColumnarShuffleManager
-      || GlutenConfig.getConf.isUseCelebornShuffleManager
-      || GlutenConfig.getConf.isUseUniffleShuffleManager)
+    val conf = GlutenConfig.getConf
+    conf.enableColumnarShuffle && (conf.isUseGlutenShuffleManager
+      || conf.isUseColumnarShuffleManager
+      || conf.isUseCelebornShuffleManager
+      || conf.isUseUniffleShuffleManager)
   }
 
   override def enableJoinKeysRewrite(): Boolean = false
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala
index 175e34177a..75ce66f18d 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala
@@ -27,10 +27,12 @@ import org.apache.gluten.jni.{JniLibLoader, JniWorkspace}
 import org.apache.gluten.udf.UdfJniWrapper
 import org.apache.gluten.utils._
 
-import org.apache.spark.{HdfsConfGenerator, SparkConf, SparkContext}
+import org.apache.spark.{HdfsConfGenerator, ShuffleDependency, SparkConf, 
SparkContext}
 import org.apache.spark.api.plugin.PluginContext
 import org.apache.spark.internal.Logging
 import org.apache.spark.network.util.ByteUnit
+import org.apache.spark.shuffle.{ColumnarShuffleDependency, LookupKey, 
ShuffleManagerRegistry}
+import org.apache.spark.shuffle.sort.ColumnarShuffleManager
 import org.apache.spark.sql.execution.ColumnarCachedBatchSerializer
 import org.apache.spark.sql.execution.datasources.GlutenWriterColumnarRules
 import 
org.apache.spark.sql.execution.datasources.velox.{VeloxParquetWriterInjects, 
VeloxRowSplitter}
@@ -128,6 +130,19 @@ class VeloxListenerApi extends ListenerApi with Logging {
     ArrowJavaBatch.ensureRegistered()
     ArrowNativeBatch.ensureRegistered()
 
+    // Register columnar shuffle so can be considered when
+    // `org.apache.spark.shuffle.GlutenShuffleManager` is set as Spark shuffle 
manager.
+    ShuffleManagerRegistry
+      .get()
+      .register(
+        new LookupKey {
+          override def accepts[K, V, C](dependency: ShuffleDependency[K, V, 
C]): Boolean = {
+            dependency.getClass == classOf[ColumnarShuffleDependency[_, _, _]]
+          }
+        },
+        classOf[ColumnarShuffleManager].getName
+      )
+
     // Sets this configuration only once, since not undoable.
     if (conf.getBoolean(GlutenConfig.GLUTEN_DEBUG_KEEP_JNI_WORKSPACE, 
defaultValue = false)) {
       val debugDir = conf.get(GlutenConfig.GLUTEN_DEBUG_KEEP_JNI_WORKSPACE_DIR)
diff --git 
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxTPCHSuite.scala
 
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxTPCHSuite.scala
index 44ffae45ad..0e5af8b875 100644
--- 
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxTPCHSuite.scala
+++ 
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxTPCHSuite.scala
@@ -295,6 +295,17 @@ class VeloxTPCHV1Suite extends VeloxTPCHSuite {
   }
 }
 
+class VeloxTPCHV1GlutenShuffleManagerSuite extends VeloxTPCHSuite {
+  override def subType(): String = "v1-gluten-shuffle-manager"
+
+  override protected def sparkConf: SparkConf = {
+    super.sparkConf
+      .set("spark.sql.sources.useV1SourceList", "parquet")
+      .set("spark.sql.autoBroadcastJoinThreshold", "-1")
+      .set("spark.shuffle.manager", 
"org.apache.spark.shuffle.GlutenShuffleManager")
+  }
+}
+
 class VeloxTPCHV1BhjSuite extends VeloxTPCHSuite {
   override def subType(): String = "v1-bhj"
 
diff --git 
a/gluten-core/src/main/scala/org/apache/spark/shuffle/ShuffleManagerLookup.scala
 
b/gluten-core/src/main/scala/org/apache/spark/shuffle/ShuffleManagerLookup.scala
index 8b060c9818..68d3aa6b5d 100644
--- 
a/gluten-core/src/main/scala/org/apache/spark/shuffle/ShuffleManagerLookup.scala
+++ 
b/gluten-core/src/main/scala/org/apache/spark/shuffle/ShuffleManagerLookup.scala
@@ -19,12 +19,10 @@ package org.apache.spark.shuffle
 import org.apache.spark.ShuffleDependency
 
 private class ShuffleManagerLookup(all: Seq[(LookupKey, ShuffleManager)]) {
-  private val allReversed = all.reverse
-
   def findShuffleManager[K, V, C](dependency: ShuffleDependency[K, V, C]): 
ShuffleManager = {
     this.synchronized {
       // The latest shuffle manager registered will be looked up earlier.
-      allReversed.find(_._1.accepts(dependency)).map(_._2).getOrElse {
+      all.find(_._1.accepts(dependency)).map(_._2).getOrElse {
         throw new IllegalStateException(s"No ShuffleManager found for 
$dependency")
       }
     }
diff --git 
a/gluten-core/src/main/scala/org/apache/spark/shuffle/ShuffleManagerRegistry.scala
 
b/gluten-core/src/main/scala/org/apache/spark/shuffle/ShuffleManagerRegistry.scala
index 5b621a755d..0e2381bba4 100644
--- 
a/gluten-core/src/main/scala/org/apache/spark/shuffle/ShuffleManagerRegistry.scala
+++ 
b/gluten-core/src/main/scala/org/apache/spark/shuffle/ShuffleManagerRegistry.scala
@@ -16,17 +16,20 @@
  */
 package org.apache.spark.shuffle
 
-import org.apache.spark.SparkConf
+import org.apache.spark.{ShuffleDependency, SparkConf}
+import org.apache.spark.shuffle.sort.SortShuffleManager
 import org.apache.spark.util.{SparkTestUtil, Utils}
 
 import scala.collection.mutable
 
 class ShuffleManagerRegistry private[ShuffleManagerRegistry] {
   import ShuffleManagerRegistry._
-  private val all: mutable.Buffer[(LookupKey, String)] = mutable.Buffer()
+  private val all: mutable.ListBuffer[(LookupKey, String)] = 
mutable.ListBuffer()
   private val routerBuilders: mutable.Buffer[RouterBuilder] = mutable.Buffer()
   private val classDeDup: mutable.Set[String] = mutable.Set()
 
+  // The shuffle manager class registered through this API later
+  // will take higher precedence during lookup.
   def register(lookupKey: LookupKey, shuffleManagerClass: String): Unit = {
     val clazz = Utils.classForName(shuffleManagerClass)
     require(
@@ -42,7 +45,7 @@ class ShuffleManagerRegistry private[ShuffleManagerRegistry] {
       s"Shuffle manager class already registered: $shuffleManagerClass")
     this.synchronized {
       classDeDup += shuffleManagerClass
-      all += lookupKey -> shuffleManagerClass
+      (lookupKey -> shuffleManagerClass) +=: all
       // Invalidate all shuffle managers cached in each alive router builder 
instances.
       // Then, once the router builder is accessed, a new router will be 
forced to create.
       routerBuilders.foreach(_.invalidateCache())
@@ -68,7 +71,18 @@ class ShuffleManagerRegistry private[ShuffleManagerRegistry] 
{
 }
 
 object ShuffleManagerRegistry {
-  private val instance = new ShuffleManagerRegistry()
+  private val instance = {
+    val r = new ShuffleManagerRegistry()
+    r.register(
+      new LookupKey {
+        override def accepts[K, V, C](dependency: ShuffleDependency[K, V, C]): 
Boolean = {
+          dependency.getClass == classOf[ShuffleDependency[_, _, _]]
+        }
+      },
+      classOf[SortShuffleManager].getName
+    )
+    r
+  }
 
   def get(): ShuffleManagerRegistry = instance
 
diff --git 
a/gluten-core/src/main/scala/org/apache/spark/shuffle/ShuffleManagerRouter.scala
 
b/gluten-core/src/main/scala/org/apache/spark/shuffle/ShuffleManagerRouter.scala
index 80aa9d8047..522cf01eee 100644
--- 
a/gluten-core/src/main/scala/org/apache/spark/shuffle/ShuffleManagerRouter.scala
+++ 
b/gluten-core/src/main/scala/org/apache/spark/shuffle/ShuffleManagerRouter.scala
@@ -16,12 +16,15 @@
  */
 package org.apache.spark.shuffle
 import org.apache.spark.{ShuffleDependency, TaskContext}
+import org.apache.spark.internal.Logging
 import org.apache.spark.network.buffer.ManagedBuffer
 import org.apache.spark.network.shuffle.MergedBlockMeta
 import org.apache.spark.storage.{BlockId, ShuffleBlockBatchId, ShuffleBlockId, 
ShuffleMergedBlockId}
 
 /** The internal shuffle manager instance used by GlutenShuffleManager. */
-private class ShuffleManagerRouter(lookup: ShuffleManagerLookup) extends 
ShuffleManager {
+private class ShuffleManagerRouter(lookup: ShuffleManagerLookup)
+  extends ShuffleManager
+  with Logging {
   import ShuffleManagerRouter._
   private val cache = new Cache()
   private val resolver = new BlockResolver(cache)
@@ -38,6 +41,7 @@ private class ShuffleManagerRouter(lookup: 
ShuffleManagerLookup) extends Shuffle
       mapId: Long,
       context: TaskContext,
       metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
+    ensureShuffleManagerRegistered(handle)
     cache.get(handle.shuffleId).getWriter(handle, mapId, context, metrics)
   }
 
@@ -49,6 +53,7 @@ private class ShuffleManagerRouter(lookup: 
ShuffleManagerLookup) extends Shuffle
       endPartition: Int,
       context: TaskContext,
       metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
+    ensureShuffleManagerRegistered(handle)
     cache
       .get(handle.shuffleId)
       .getReader(handle, startMapIndex, endMapIndex, startPartition, 
endPartition, context, metrics)
@@ -61,8 +66,29 @@ private class ShuffleManagerRouter(lookup: 
ShuffleManagerLookup) extends Shuffle
   override def shuffleBlockResolver: ShuffleBlockResolver = resolver
 
   override def stop(): Unit = {
-    assert(cache.size() == 0)
-    lookup.all().reverse.foreach(_.stop())
+    if (!(cache.size() == 0)) {
+      logWarning(
+        s"Shuffle router cache is not empty when being stopped. This might be 
because the " +
+          s"shuffle is not unregistered.")
+    }
+    lookup.all().foreach(_.stop())
+  }
+
+  private def ensureShuffleManagerRegistered(handle: ShuffleHandle): Unit = {
+    val baseShuffleHandle = handle match {
+      case b: BaseShuffleHandle[_, _, _] => b
+      case _ =>
+        throw new UnsupportedOperationException(
+          s"${handle.getClass} is not a BaseShuffleHandle so is not supported 
by " +
+            s"GlutenShuffleManager")
+    }
+    val shuffleId = baseShuffleHandle.shuffleId
+    if (cache.has(shuffleId)) {
+      return
+    }
+    val dependency = baseShuffleHandle.dependency
+    val manager = lookup.findShuffleManager(dependency)
+    cache.store(shuffleId, manager)
   }
 }
 
@@ -71,6 +97,10 @@ private object ShuffleManagerRouter {
     private val cache: java.util.Map[Int, ShuffleManager] =
       new java.util.concurrent.ConcurrentHashMap()
 
+    def has(shuffleId: Int): Boolean = {
+      cache.containsKey(shuffleId)
+    }
+
     def store(shuffleId: Int, manager: ShuffleManager): ShuffleManager = {
       cache.compute(
         shuffleId,
diff --git 
a/gluten-core/src/test/scala/org/apache/spark/shuffle/GlutenShuffleManagerSuite.scala
 
b/gluten-core/src/test/scala/org/apache/spark/shuffle/GlutenShuffleManagerSuite.scala
index 640fc0ab07..4ab81711ef 100644
--- 
a/gluten-core/src/test/scala/org/apache/spark/shuffle/GlutenShuffleManagerSuite.scala
+++ 
b/gluten-core/src/test/scala/org/apache/spark/shuffle/GlutenShuffleManagerSuite.scala
@@ -33,6 +33,11 @@ class GlutenShuffleManagerSuite extends SharedSparkSession {
       .set(SHUFFLE_MANAGER.key, classOf[GlutenShuffleManager].getName)
   }
 
+  override protected def beforeEach(): Unit = {
+    val registry = ShuffleManagerRegistry.get()
+    registry.clear()
+  }
+
   override protected def afterEach(): Unit = {
     val registry = ShuffleManagerRegistry.get()
     registry.clear()
diff --git a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala 
b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
index 296153346e..20ebf4c7aa 100644
--- a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
+++ b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
@@ -129,18 +129,25 @@ class GlutenConfig(conf: SQLConf) extends Logging {
   def scanFileSchemeValidationEnabled: Boolean =
     conf.getConf(VELOX_SCAN_FILE_SCHEME_VALIDATION_ENABLED)
 
-  // whether to use ColumnarShuffleManager
+  // Whether to use GlutenShuffleManager (experimental).
+  def isUseGlutenShuffleManager: Boolean =
+    conf
+      .getConfString("spark.shuffle.manager", "sort")
+      .equals("org.apache.spark.shuffle.sort.GlutenShuffleManager")
+
+  // Whether to use ColumnarShuffleManager.
   def isUseColumnarShuffleManager: Boolean =
     conf
       .getConfString("spark.shuffle.manager", "sort")
       .equals("org.apache.spark.shuffle.sort.ColumnarShuffleManager")
 
-  // whether to use CelebornShuffleManager
+  // Whether to use CelebornShuffleManager.
   def isUseCelebornShuffleManager: Boolean =
     conf
       .getConfString("spark.shuffle.manager", "sort")
       .contains("celeborn")
 
+  // Whether to use UniffleShuffleManager.
   def isUseUniffleShuffleManager: Boolean =
     conf
       .getConfString("spark.shuffle.manager", "sort")


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to