This is an automated email from the ASF dual-hosted git repository.

hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new afee97a657 [VL] Add some fixes following #8355 (#8373)
afee97a657 is described below

commit afee97a65769b1fb2e8ce8ae7cd19095a20c6fee
Author: Hongze Zhang <[email protected]>
AuthorDate: Tue Dec 31 10:42:54 2024 +0800

    [VL] Add some fixes following #8355 (#8373)
---
 .../apache/gluten/execution/VeloxTPCHSuite.scala   | 61 ++++++++++++++--------
 .../execution/ColumnarShuffleExchangeExec.scala    | 21 +-------
 .../org/apache/gluten/config/GlutenConfig.scala    |  2 +-
 3 files changed, 41 insertions(+), 43 deletions(-)

diff --git 
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxTPCHSuite.scala
 
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxTPCHSuite.scala
index e27c2aa23e..bb40b117a8 100644
--- 
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxTPCHSuite.scala
+++ 
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxTPCHSuite.scala
@@ -20,7 +20,7 @@ import org.apache.gluten.config.GlutenConfig
 
 import org.apache.spark.SparkConf
 import org.apache.spark.sql.{DataFrame, Row, TestUtils}
-import org.apache.spark.sql.execution.FormattedMode
+import org.apache.spark.sql.execution.{ColumnarShuffleExchangeExec, 
FormattedMode}
 
 import org.apache.commons.io.FileUtils
 
@@ -117,133 +117,133 @@ abstract class VeloxTPCHSuite extends 
VeloxTPCHTableSupport {
   }
 
   test("TPC-H q1") {
-    runTPCHQuery(1, tpchQueries, queriesResults, compareResult = false, 
noFallBack = false) {
+    runTPCHQuery(1, tpchQueries, queriesResults, compareResult = false, 
noFallBack = true) {
       checkGoldenFile(_, 1)
     }
   }
 
   test("TPC-H q2") {
-    runTPCHQuery(2, tpchQueries, queriesResults, compareResult = false, 
noFallBack = false) {
+    runTPCHQuery(2, tpchQueries, queriesResults, compareResult = false, 
noFallBack = true) {
       _ => // due to tpc-h q2 will generate multiple plans, skip checking 
golden file for now
     }
   }
 
   test("TPC-H q3") {
-    runTPCHQuery(3, tpchQueries, queriesResults, compareResult = false, 
noFallBack = false) {
+    runTPCHQuery(3, tpchQueries, queriesResults, compareResult = false, 
noFallBack = true) {
       checkGoldenFile(_, 3)
     }
   }
 
   test("TPC-H q4") {
-    runTPCHQuery(4, tpchQueries, queriesResults, compareResult = false, 
noFallBack = false) {
+    runTPCHQuery(4, tpchQueries, queriesResults, compareResult = false, 
noFallBack = true) {
       checkGoldenFile(_, 4)
     }
   }
 
   test("TPC-H q5") {
-    runTPCHQuery(5, tpchQueries, queriesResults, compareResult = false, 
noFallBack = false) {
+    runTPCHQuery(5, tpchQueries, queriesResults, compareResult = false, 
noFallBack = true) {
       checkGoldenFile(_, 5)
     }
   }
 
   test("TPC-H q6") {
-    runTPCHQuery(6, tpchQueries, queriesResults, compareResult = false, 
noFallBack = false) {
+    runTPCHQuery(6, tpchQueries, queriesResults, compareResult = false, 
noFallBack = true) {
       checkGoldenFile(_, 6)
     }
   }
 
   test("TPC-H q7") {
-    runTPCHQuery(7, tpchQueries, queriesResults, compareResult = false, 
noFallBack = false) {
+    runTPCHQuery(7, tpchQueries, queriesResults, compareResult = false, 
noFallBack = true) {
       checkGoldenFile(_, 7)
     }
   }
 
   test("TPC-H q8") {
-    runTPCHQuery(8, tpchQueries, queriesResults, compareResult = false, 
noFallBack = false) {
+    runTPCHQuery(8, tpchQueries, queriesResults, compareResult = false, 
noFallBack = true) {
       checkGoldenFile(_, 8)
     }
   }
 
   test("TPC-H q9") {
-    runTPCHQuery(9, tpchQueries, queriesResults, compareResult = false, 
noFallBack = false) {
+    runTPCHQuery(9, tpchQueries, queriesResults, compareResult = false, 
noFallBack = true) {
       checkGoldenFile(_, 9)
     }
   }
 
   test("TPC-H q10") {
-    runTPCHQuery(10, tpchQueries, queriesResults, compareResult = false, 
noFallBack = false) {
+    runTPCHQuery(10, tpchQueries, queriesResults, compareResult = false, 
noFallBack = true) {
       checkGoldenFile(_, 10)
     }
   }
 
   test("TPC-H q11") {
-    runTPCHQuery(11, tpchQueries, queriesResults, compareResult = false, 
noFallBack = false) {
+    runTPCHQuery(11, tpchQueries, queriesResults, compareResult = false, 
noFallBack = true) {
       checkGoldenFile(_, 11)
     }
   }
 
   test("TPC-H q12") {
-    runTPCHQuery(12, tpchQueries, queriesResults, compareResult = false, 
noFallBack = false) {
+    runTPCHQuery(12, tpchQueries, queriesResults, compareResult = false, 
noFallBack = true) {
       checkGoldenFile(_, 12)
     }
   }
 
   test("TPC-H q13") {
-    runTPCHQuery(13, tpchQueries, queriesResults, compareResult = false, 
noFallBack = false) {
+    runTPCHQuery(13, tpchQueries, queriesResults, compareResult = false, 
noFallBack = true) {
       checkGoldenFile(_, 13)
     }
   }
 
   test("TPC-H q14") {
-    runTPCHQuery(14, tpchQueries, queriesResults, compareResult = false, 
noFallBack = false) {
+    runTPCHQuery(14, tpchQueries, queriesResults, compareResult = false, 
noFallBack = true) {
       checkGoldenFile(_, 14)
     }
   }
 
   test("TPC-H q15") {
-    runTPCHQuery(15, tpchQueries, queriesResults, compareResult = false, 
noFallBack = false) {
+    runTPCHQuery(15, tpchQueries, queriesResults, compareResult = false, 
noFallBack = true) {
       checkGoldenFile(_, 15)
     }
   }
 
   test("TPC-H q16") {
-    runTPCHQuery(16, tpchQueries, queriesResults, compareResult = false, 
noFallBack = false) {
+    runTPCHQuery(16, tpchQueries, queriesResults, compareResult = false, 
noFallBack = true) {
       checkGoldenFile(_, 16)
     }
   }
 
   test("TPC-H q17") {
-    runTPCHQuery(17, tpchQueries, queriesResults, compareResult = false, 
noFallBack = false) {
+    runTPCHQuery(17, tpchQueries, queriesResults, compareResult = false, 
noFallBack = true) {
       checkGoldenFile(_, 17)
     }
   }
 
   test("TPC-H q18") {
-    runTPCHQuery(18, tpchQueries, queriesResults, compareResult = false, 
noFallBack = false) {
+    runTPCHQuery(18, tpchQueries, queriesResults, compareResult = false, 
noFallBack = true) {
       checkGoldenFile(_, 18)
     }
   }
 
   test("TPC-H q19") {
-    runTPCHQuery(19, tpchQueries, queriesResults, compareResult = false, 
noFallBack = false) {
+    runTPCHQuery(19, tpchQueries, queriesResults, compareResult = false, 
noFallBack = true) {
       checkGoldenFile(_, 19)
     }
   }
 
   test("TPC-H q20") {
-    runTPCHQuery(20, tpchQueries, queriesResults, compareResult = false, 
noFallBack = false) {
+    runTPCHQuery(20, tpchQueries, queriesResults, compareResult = false, 
noFallBack = true) {
       checkGoldenFile(_, 20)
     }
   }
 
   test("TPC-H q21") {
-    runTPCHQuery(21, tpchQueries, queriesResults, compareResult = false, 
noFallBack = false) {
+    runTPCHQuery(21, tpchQueries, queriesResults, compareResult = false, 
noFallBack = true) {
       checkGoldenFile(_, 21)
     }
   }
 
   test("TPC-H q22") {
-    runTPCHQuery(22, tpchQueries, queriesResults, compareResult = false, 
noFallBack = false) {
+    runTPCHQuery(22, tpchQueries, queriesResults, compareResult = false, 
noFallBack = true) {
       checkGoldenFile(_, 22)
     }
   }
@@ -304,6 +304,21 @@ class VeloxTPCHV1GlutenShuffleManagerSuite extends 
VeloxTPCHSuite {
       .set("spark.sql.autoBroadcastJoinThreshold", "-1")
       .set("spark.shuffle.manager", 
"org.apache.spark.shuffle.GlutenShuffleManager")
   }
+
+  override protected def runQueryAndCompare(
+      sqlStr: String,
+      compareResult: Boolean,
+      noFallBack: Boolean,
+      cache: Boolean)(customCheck: DataFrame => Unit): DataFrame = {
+    assert(noFallBack)
+    super.runQueryAndCompare(sqlStr, compareResult, noFallBack, cache) {
+      df =>
+        assert(df.queryExecution.executedPlan.collect {
+          case p if p.isInstanceOf[ColumnarShuffleExchangeExec] => p
+        }.nonEmpty)
+        customCheck(df)
+    }
+  }
 }
 
 class VeloxTPCHV1BhjSuite extends VeloxTPCHSuite {
diff --git 
a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala
 
b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala
index 70cc764754..66ca97277e 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala
@@ -27,7 +27,6 @@ import org.apache.spark._
 import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.RDD
 import org.apache.spark.serializer.Serializer
-import org.apache.spark.shuffle.ShuffleHandle
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.Attribute
 import org.apache.spark.sql.catalyst.plans.logical.Statistics
@@ -93,22 +92,6 @@ case class ColumnarShuffleExchangeExec(
       useSortBasedShuffle)
   }
 
-  // 'shuffleDependency' is only needed when enable AQE.
-  // Columnar shuffle will use 'columnarShuffleDependency'
-  @transient
-  lazy val shuffleDependency: ShuffleDependency[Int, InternalRow, InternalRow] 
=
-    new ShuffleDependency[Int, InternalRow, InternalRow](
-      _rdd = new ColumnarShuffleExchangeExec.DummyPairRDDWithPartitions(
-        sparkContext,
-        inputColumnarRDD.getNumPartitions),
-      partitioner = columnarShuffleDependency.partitioner
-    ) {
-
-      override val shuffleId: Int = columnarShuffleDependency.shuffleId
-
-      override val shuffleHandle: ShuffleHandle = 
columnarShuffleDependency.shuffleHandle
-    }
-
   // super.stringArgs ++ Iterator(output.map(o => 
s"${o}#${o.dataType.simpleString}"))
   val serializer: Serializer = BackendsApiManager.getSparkPlanExecApiInstance
     .createColumnarBatchSerializer(schema, metrics, useSortBasedShuffle)
@@ -128,9 +111,9 @@ case class ColumnarShuffleExchangeExec(
 
   override def nodeName: String = "ColumnarExchange"
 
-  override def numMappers: Int = shuffleDependency.rdd.getNumPartitions
+  override def numMappers: Int = inputColumnarRDD.getNumPartitions
 
-  override def numPartitions: Int = shuffleDependency.partitioner.numPartitions
+  override def numPartitions: Int = 
columnarShuffleDependency.partitioner.numPartitions
 
   override def runtimeStatistics: Statistics = {
     val dataSize = metrics("dataSize").value
diff --git 
a/shims/common/src/main/scala/org/apache/gluten/config/GlutenConfig.scala 
b/shims/common/src/main/scala/org/apache/gluten/config/GlutenConfig.scala
index ce752b8fb9..98f8e89bf5 100644
--- a/shims/common/src/main/scala/org/apache/gluten/config/GlutenConfig.scala
+++ b/shims/common/src/main/scala/org/apache/gluten/config/GlutenConfig.scala
@@ -127,7 +127,7 @@ class GlutenConfig(conf: SQLConf) extends Logging {
   def isUseGlutenShuffleManager: Boolean =
     conf
       .getConfString("spark.shuffle.manager", "sort")
-      .equals("org.apache.spark.shuffle.sort.GlutenShuffleManager")
+      .equals("org.apache.spark.shuffle.GlutenShuffleManager")
 
   // Whether to use ColumnarShuffleManager.
   def isUseColumnarShuffleManager: Boolean =


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to