This is an automated email from the ASF dual-hosted git repository.

changchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 0702eb8f7 [GLUTEN-5016][CH] fix simple aggregation sql exchange 
fallback (#5042)
0702eb8f7 is described below

commit 0702eb8f73051407eb0dd9da495000360f459edd
Author: Wenzheng Liu <[email protected]>
AuthorDate: Fri Mar 22 13:03:43 2024 +0800

    [GLUTEN-5016][CH] fix simple aggregation sql exchange fallback (#5042)
---
 .../GlutenClickHouseTPCHNullableSuite.scala        |  15 ++
 .../extension/columnar/MiscColumnarRules.scala     | 263 ++++++++++++---------
 2 files changed, 163 insertions(+), 115 deletions(-)

diff --git 
a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHNullableSuite.scala
 
b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHNullableSuite.scala
index 72a04f02c..42783babd 100644
--- 
a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHNullableSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHNullableSuite.scala
@@ -195,4 +195,19 @@ class GlutenClickHouseTPCHNullableSuite extends 
GlutenClickHouseTPCHAbstractSuit
                           |""".stripMargin) { _ => }
     assert(result(0).getLong(0) == 227302L)
   }
+
+  test("test 'GLUTEN-5016'") {
+    withSQLConf(("spark.gluten.sql.columnar.preferColumnar", "false")) {
+      val sql =
+        """
+          |SELECT
+          |   sum(l_quantity) AS sum_qty
+          |FROM
+          |   lineitem
+          |WHERE
+          |   l_shipdate <= date'1998-09-02'
+          |""".stripMargin
+      runSql(sql, noFallBack = true) { _ => }
+    }
+  }
 }
diff --git 
a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/MiscColumnarRules.scala
 
b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/MiscColumnarRules.scala
index 4fbadb0b5..d21956bfc 100644
--- 
a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/MiscColumnarRules.scala
+++ 
b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/MiscColumnarRules.scala
@@ -49,9 +49,12 @@ object MiscColumnarRules {
 
     // Aggregation transformation.
     private case class AggregationTransformRule() extends Rule[SparkPlan] with 
LogLevelUtil {
-      override def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
+      override def apply(plan: SparkPlan): SparkPlan = plan match {
+        case plan if TransformHints.isNotTransformable(plan) =>
+          plan
         case agg: HashAggregateExec =>
           genHashAggregateExec(agg)
+        case other => other
       }
 
       /**
@@ -105,13 +108,144 @@ object MiscColumnarRules {
       }
     }
 
+    // Exchange transformation.
+    private case class ExchangeTransformRule() extends Rule[SparkPlan] with 
LogLevelUtil {
+      override def apply(plan: SparkPlan): SparkPlan = plan match {
+        case plan if TransformHints.isNotTransformable(plan) =>
+          plan
+        case plan: ShuffleExchangeExec =>
+          logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
+          val child = plan.child
+          if (
+            (child.supportsColumnar || 
GlutenConfig.getConf.enablePreferColumnar) &&
+            BackendsApiManager.getSettings.supportColumnarShuffleExec()
+          ) {
+            
BackendsApiManager.getSparkPlanExecApiInstance.genColumnarShuffleExchange(plan, 
child)
+          } else {
+            plan.withNewChildren(Seq(child))
+          }
+        case plan: BroadcastExchangeExec =>
+          val child = plan.child
+          logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
+          ColumnarBroadcastExchangeExec(plan.mode, child)
+        case other => other
+      }
+    }
+
+    // Join transformation.
+    private case class JoinTransformRule() extends Rule[SparkPlan] with 
LogLevelUtil {
+
+      /**
+       * Get the build side supported by the execution of vanilla Spark.
+       *
+       * @param plan
+       *   : shuffled hash join plan
+       * @return
+       *   the supported build side
+       */
+      private def getSparkSupportedBuildSide(plan: ShuffledHashJoinExec): 
BuildSide = {
+        plan.joinType match {
+          case LeftOuter | LeftSemi => BuildRight
+          case RightOuter => BuildLeft
+          case _ => plan.buildSide
+        }
+      }
+
+      override def apply(plan: SparkPlan): SparkPlan = {
+        if (TransformHints.isNotTransformable(plan)) {
+          logDebug(s"Columnar Processing for ${plan.getClass} is under row 
guard.")
+          plan match {
+            case shj: ShuffledHashJoinExec =>
+              if (BackendsApiManager.getSettings.recreateJoinExecOnFallback()) 
{
+                // Because we manually removed the build side limitation for 
LeftOuter, LeftSemi and
+                // RightOuter, need to change the build side back if this join 
fallback into vanilla
+                // Spark for execution.
+                return ShuffledHashJoinExec(
+                  shj.leftKeys,
+                  shj.rightKeys,
+                  shj.joinType,
+                  getSparkSupportedBuildSide(shj),
+                  shj.condition,
+                  shj.left,
+                  shj.right,
+                  shj.isSkewJoin
+                )
+              } else {
+                return shj
+              }
+            case p =>
+              return p
+          }
+        }
+        plan match {
+          case plan: ShuffledHashJoinExec =>
+            val left = plan.left
+            val right = plan.right
+            logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
+            BackendsApiManager.getSparkPlanExecApiInstance
+              .genShuffledHashJoinExecTransformer(
+                plan.leftKeys,
+                plan.rightKeys,
+                plan.joinType,
+                plan.buildSide,
+                plan.condition,
+                left,
+                right,
+                plan.isSkewJoin)
+          case plan: SortMergeJoinExec =>
+            val left = plan.left
+            val right = plan.right
+            logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
+            SortMergeJoinExecTransformer(
+              plan.leftKeys,
+              plan.rightKeys,
+              plan.joinType,
+              plan.condition,
+              left,
+              right,
+              plan.isSkewJoin)
+          case plan: BroadcastHashJoinExec =>
+            val left = plan.left
+            val right = plan.right
+            BackendsApiManager.getSparkPlanExecApiInstance
+              .genBroadcastHashJoinExecTransformer(
+                plan.leftKeys,
+                plan.rightKeys,
+                plan.joinType,
+                plan.buildSide,
+                plan.condition,
+                left,
+                right,
+                isNullAwareAntiJoin = plan.isNullAwareAntiJoin)
+          case plan: CartesianProductExec =>
+            val left = plan.left
+            val right = plan.right
+            BackendsApiManager.getSparkPlanExecApiInstance
+              .genCartesianProductExecTransformer(left, right, plan.condition)
+          case plan: BroadcastNestedLoopJoinExec =>
+            val left = plan.left
+            val right = plan.right
+            BackendsApiManager.getSparkPlanExecApiInstance
+              .genBroadcastNestedLoopJoinExecTransformer(
+                left,
+                right,
+                plan.buildSide,
+                plan.joinType,
+                plan.condition)
+          case other => other
+        }
+      }
+
+    }
+
     // Filter transformation.
     private case class FilterTransformRule() extends Rule[SparkPlan] with 
LogLevelUtil {
       private val replace = new ReplaceSingleNode()
 
-      override def apply(plan: SparkPlan): SparkPlan = plan.transformDown {
+      override def apply(plan: SparkPlan): SparkPlan = plan match {
         case filter: FilterExec =>
           genFilterExec(filter)
+        case other => other
       }
 
       /**
@@ -155,39 +289,18 @@ object MiscColumnarRules {
     private case class RegularTransformRule() extends Rule[SparkPlan] with 
LogLevelUtil {
       private val replace = new ReplaceSingleNode()
 
-      override def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
-        case plan => replace.replaceWithTransformerPlan(plan)
-      }
+      override def apply(plan: SparkPlan): SparkPlan = 
replace.replaceWithTransformerPlan(plan)
     }
 
     // Utility to replace single node within transformed Gluten node.
     // Children will be preserved as they are as children of the output node.
     class ReplaceSingleNode() extends LogLevelUtil with Logging {
-      private val columnarConf: GlutenConfig = GlutenConfig.getConf
 
       def replaceWithTransformerPlan(p: SparkPlan): SparkPlan = {
         val plan = p
         if (TransformHints.isNotTransformable(plan)) {
           logDebug(s"Columnar Processing for ${plan.getClass} is under row 
guard.")
           plan match {
-            case shj: ShuffledHashJoinExec =>
-              if (BackendsApiManager.getSettings.recreateJoinExecOnFallback()) 
{
-                // Because we manually removed the build side limitation for 
LeftOuter, LeftSemi and
-                // RightOuter, need to change the build side back if this join 
fallback into vanilla
-                // Spark for execution.
-                return ShuffledHashJoinExec(
-                  shj.leftKeys,
-                  shj.rightKeys,
-                  shj.joinType,
-                  getSparkSupportedBuildSide(shj),
-                  shj.condition,
-                  shj.left,
-                  shj.right,
-                  shj.isSkewJoin
-                )
-              } else {
-                return shj
-              }
             case plan: BatchScanExec =>
               return applyScanNotTransformable(plan)
             case plan: FileSourceScanExec =>
@@ -283,75 +396,6 @@ object MiscColumnarRules {
               plan.projectList,
               child,
               offset)
-          case plan: ShuffleExchangeExec =>
-            logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
-            val child = plan.child
-            if (
-              (child.supportsColumnar || columnarConf.enablePreferColumnar) &&
-              BackendsApiManager.getSettings.supportColumnarShuffleExec()
-            ) {
-              
BackendsApiManager.getSparkPlanExecApiInstance.genColumnarShuffleExchange(plan, 
child)
-            } else {
-              plan.withNewChildren(Seq(child))
-            }
-          case plan: ShuffledHashJoinExec =>
-            val left = plan.left
-            val right = plan.right
-            logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
-            BackendsApiManager.getSparkPlanExecApiInstance
-              .genShuffledHashJoinExecTransformer(
-                plan.leftKeys,
-                plan.rightKeys,
-                plan.joinType,
-                plan.buildSide,
-                plan.condition,
-                left,
-                right,
-                plan.isSkewJoin)
-          case plan: SortMergeJoinExec =>
-            val left = plan.left
-            val right = plan.right
-            logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
-            SortMergeJoinExecTransformer(
-              plan.leftKeys,
-              plan.rightKeys,
-              plan.joinType,
-              plan.condition,
-              left,
-              right,
-              plan.isSkewJoin)
-          case plan: BroadcastExchangeExec =>
-            val child = plan.child
-            logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
-            ColumnarBroadcastExchangeExec(plan.mode, child)
-          case plan: BroadcastHashJoinExec =>
-            val left = plan.left
-            val right = plan.right
-            BackendsApiManager.getSparkPlanExecApiInstance
-              .genBroadcastHashJoinExecTransformer(
-                plan.leftKeys,
-                plan.rightKeys,
-                plan.joinType,
-                plan.buildSide,
-                plan.condition,
-                left,
-                right,
-                isNullAwareAntiJoin = plan.isNullAwareAntiJoin)
-          case plan: CartesianProductExec =>
-            val left = plan.left
-            val right = plan.right
-            BackendsApiManager.getSparkPlanExecApiInstance
-              .genCartesianProductExecTransformer(left, right, plan.condition)
-          case plan: BroadcastNestedLoopJoinExec =>
-            val left = plan.left
-            val right = plan.right
-            BackendsApiManager.getSparkPlanExecApiInstance
-              .genBroadcastNestedLoopJoinExecTransformer(
-                left,
-                right,
-                plan.buildSide,
-                plan.joinType,
-                plan.condition)
           case plan: WindowExec =>
             WindowExecTransformer(
               plan.windowExpression,
@@ -389,22 +433,6 @@ object MiscColumnarRules {
         }
       }
 
-      /**
-       * Get the build side supported by the execution of vanilla Spark.
-       *
-       * @param plan
-       *   : shuffled hash join plan
-       * @return
-       *   the supported build side
-       */
-      private def getSparkSupportedBuildSide(plan: ShuffledHashJoinExec): 
BuildSide = {
-        plan.joinType match {
-          case LeftOuter | LeftSemi => BuildRight
-          case RightOuter => BuildLeft
-          case _ => plan.buildSide
-        }
-      }
-
       private def applyScanNotTransformable(plan: SparkPlan): SparkPlan = plan 
match {
         case plan: FileSourceScanExec =>
           val newPartitionFilters =
@@ -489,18 +517,23 @@ object MiscColumnarRules {
   case class TransformPreOverrides() extends Rule[SparkPlan] with LogLevelUtil 
{
     import TransformPreOverrides._
 
-    private val subRules = List(
-      FilterTransformRule(),
+    private val topdownRules = List(
+      FilterTransformRule()
+    )
+    private val bottomupRules = List(
       RegularTransformRule(),
-      AggregationTransformRule()
+      AggregationTransformRule(),
+      ExchangeTransformRule(),
+      JoinTransformRule()
     )
 
     @transient private val planChangeLogger = new PlanChangeLogger[SparkPlan]()
 
     def apply(plan: SparkPlan): SparkPlan = {
-      val newPlan = subRules.foldLeft(plan)((p, rule) => rule.apply(p))
-      planChangeLogger.logRule(ruleName, plan, newPlan)
-      newPlan
+      val plan0 = topdownRules.foldLeft(plan)((p, rule) => p.transformDown { 
case p => rule(p) })
+      val plan1 = bottomupRules.foldLeft(plan0)((p, rule) => p.transformUp { 
case p => rule(p) })
+      planChangeLogger.logRule(ruleName, plan, plan1)
+      plan1
     }
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to