This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 81a641f  fix: CometExecRule should handle ShuffleQueryStage and 
ReusedExchange (#186)
81a641f is described below

commit 81a641f30844d76b417a600392a0cc97a74a919b
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Wed Mar 13 11:14:45 2024 -0700

    fix: CometExecRule should handle ShuffleQueryStage and ReusedExchange (#186)
    
    * fix: CometExecRule should handle ShuffleQueryStage and ReusedExchange
    
    * fix
    
    * Add comment and move tests
    
    * Remove unused table in test.
---
 .../apache/comet/CometSparkSessionExtensions.scala | 44 ++++++++++++++++----
 .../org/apache/comet/serde/QueryPlanSerde.scala    |  6 ++-
 .../comet/exec/CometColumnarShuffleSuite.scala     | 48 +++++++++++++++++++++-
 .../comet/exec/CometNativeShuffleSuite.scala       |  5 ++-
 4 files changed, 92 insertions(+), 11 deletions(-)

diff --git 
a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala 
b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
index 5720b69..39c83ae 100644
--- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
+++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
@@ -31,12 +31,13 @@ import org.apache.spark.sql.comet._
 import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, 
CometNativeShuffle}
 import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
 import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
 import org.apache.spark.sql.execution.aggregate.HashAggregateExec
 import org.apache.spark.sql.execution.datasources._
 import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
 import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
 import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
-import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, 
ShuffleExchangeExec}
+import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, 
ReusedExchangeExec, ShuffleExchangeExec}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 
@@ -221,12 +222,16 @@ class CometSparkSessionExtensions
      */
     // spotless:on
     private def transform(plan: SparkPlan): SparkPlan = {
-      def transform1(op: UnaryExecNode): Option[Operator] = {
-        op.child match {
-          case childNativeOp: CometNativeExec =>
-            QueryPlanSerde.operator2Proto(op, childNativeOp.nativeOp)
-          case _ =>
-            None
+      def transform1(op: SparkPlan): Option[Operator] = {
+        val allNativeExec = op.children.map {
+          case childNativeOp: CometNativeExec => Some(childNativeOp.nativeOp)
+          case _ => None
+        }
+
+        if (allNativeExec.forall(_.isDefined)) {
+          QueryPlanSerde.operator2Proto(op, allNativeExec.map(_.get): _*)
+        } else {
+          None
         }
       }
 
@@ -377,6 +382,31 @@ class CometSparkSessionExtensions
             case None => b
           }
 
+        // For AQE shuffle stage on a Comet shuffle exchange
+        case s @ ShuffleQueryStageExec(_, _: CometShuffleExchangeExec, _) =>
+          val newOp = transform1(s)
+          newOp match {
+            case Some(nativeOp) =>
+              CometSinkPlaceHolder(nativeOp, s, s)
+            case None =>
+              s
+          }
+
+        // For AQE shuffle stage on a reused Comet shuffle exchange
+        // Note that we don't need to handle `ReusedExchangeExec` for non-AQE 
case, because
+        // the query plan won't be re-optimized/planned in non-AQE mode.
+        case s @ ShuffleQueryStageExec(
+              _,
+              ReusedExchangeExec(_, _: CometShuffleExchangeExec),
+              _) =>
+          val newOp = transform1(s)
+          newOp match {
+            case Some(nativeOp) =>
+              CometSinkPlaceHolder(nativeOp, s, s)
+            case None =>
+              s
+          }
+
         // Native shuffle for Comet operators
         case s: ShuffleExchangeExec
             if isCometShuffleEnabled(conf) &&
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 902f703..5da926e 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -29,10 +29,12 @@ import 
org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
 import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, 
Partitioning, SinglePartition}
 import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
 import org.apache.spark.sql.comet.{CometSinkPlaceHolder, DecimalPrecision}
+import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
 import org.apache.spark.sql.execution
 import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
 import org.apache.spark.sql.execution.aggregate.HashAggregateExec
-import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, 
ShuffleExchangeExec}
+import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, 
ReusedExchangeExec, ShuffleExchangeExec}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
@@ -1883,6 +1885,8 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde {
       case _: CollectLimitExec => true
       case _: UnionExec => true
       case _: ShuffleExchangeExec => true
+      case ShuffleQueryStageExec(_, _: CometShuffleExchangeExec, _) => true
+      case ShuffleQueryStageExec(_, ReusedExchangeExec(_, _: 
CometShuffleExchangeExec), _) => true
       case _: TakeOrderedAndProjectExec => true
       case _: BroadcastExchangeExec => true
       case _ => false
diff --git 
a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala 
b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala
index 1a92f71..216b690 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala
@@ -26,7 +26,7 @@ import org.apache.hadoop.fs.Path
 import org.apache.spark.{Partitioner, SparkConf}
 import org.apache.spark.sql.{CometTestBase, DataFrame, Row}
 import org.apache.spark.sql.comet.execution.shuffle.{CometShuffleDependency, 
CometShuffleExchangeExec, CometShuffleManager}
-import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
+import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, 
AQEShuffleReadExec, ShuffleQueryStageExec}
 import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
 import org.apache.spark.sql.execution.joins.SortMergeJoinExec
 import org.apache.spark.sql.functions.col
@@ -933,6 +933,52 @@ class CometShuffleSuite extends CometColumnarShuffleSuite {
   override protected val asyncShuffleEnable: Boolean = false
 
   protected val adaptiveExecutionEnabled: Boolean = true
+
+  import testImplicits._
+
+  test("Comet native operator after ShuffleQueryStage") {
+    withSQLConf(
+      SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+      CometConf.COMET_EXEC_ENABLED.key -> "true",
+      CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
+      CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") {
+      withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") {
+        val df = sql("SELECT * FROM tbl_a")
+        val shuffled = df
+          .select($"_1" + 1 as ("a"))
+          .filter($"a" > 4)
+          .repartition(10)
+          .sortWithinPartitions($"a")
+        checkSparkAnswerAndOperator(shuffled, classOf[ShuffleQueryStageExec])
+      }
+    }
+  }
+
+  test("Comet native operator after ShuffleQueryStage + ReusedExchange") {
+    withSQLConf(
+      SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+      CometConf.COMET_EXEC_ENABLED.key -> "true",
+      CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
+      CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") {
+      withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") {
+        withParquetTable((0 until 10).map(i => (i % 10, i + 2)), "tbl_b") {
+          val df = sql("SELECT * FROM tbl_a")
+          val left = df
+            .select($"_1" + 1 as ("a"))
+            .filter($"a" > 4)
+          val right = left.select($"a" as ("b"))
+          val join = left.join(right, $"a" === $"b")
+          checkSparkAnswerAndOperator(
+            join,
+            classOf[ShuffleQueryStageExec],
+            classOf[SortMergeJoinExec],
+            classOf[AQEShuffleReadExec])
+        }
+      }
+    }
+  }
 }
 
 class DisableAQECometShuffleSuite extends CometColumnarShuffleSuite {
diff --git 
a/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala 
b/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala
index c35763c..59e27fd 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala
@@ -64,8 +64,9 @@ class CometNativeShuffleSuite extends CometTestBase with 
AdaptiveSparkPlanHelper
           val path = new Path(dir.toURI.toString, "test.parquet")
           makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 
1000)
           var allTypes: Seq[Int] = (1 to 20)
-          if (isSpark34Plus) {
-            allTypes = allTypes.filterNot(Set(14, 17).contains)
+          if (!isSpark34Plus) {
+            // TODO: Remove this once after 
https://github.com/apache/arrow/issues/40038 is fixed
+            allTypes = allTypes.filterNot(Set(14).contains)
           }
           allTypes.map(i => s"_$i").foreach { c =>
             withSQLConf(

Reply via email to