This is an automated email from the ASF dual-hosted git repository.

loneylee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 9c79bd664e [GLUTEN-9646][CH] Fix coalesce project union when has 
subquery (#9662)
9c79bd664e is described below

commit 9c79bd664e7c7590edebe3dade94da05ccb545cf
Author: Shuai li <[email protected]>
AuthorDate: Fri May 16 14:31:11 2025 +0800

    [GLUTEN-9646][CH] Fix coalesce project union when has subquery (#9662)
    
    * [GLUTEN-9646][CH] Fix coalesce project union when has subquery
    
    * fix ci
---
 .../extension/CoalesceAggregationUnion.scala       | 57 +++++++++++++---------
 .../GlutenCoalesceAggregationUnionSuite.scala      | 55 +++++++++++++++++++++
 2 files changed, 89 insertions(+), 23 deletions(-)

diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala
index 331dfdc2db..d5faa48d07 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala
@@ -17,7 +17,6 @@
 package org.apache.gluten.extension
 
 import org.apache.gluten.backendsapi.clickhouse.CHBackendSettings
-import org.apache.gluten.exception.GlutenNotSupportException
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.SparkSession
@@ -107,9 +106,9 @@ object CoalesceUnionUtil extends Logging {
 
   def buildAttributesMap(
       attributes: Seq[Attribute],
-      expressions: Seq[Expression]): Map[ExprId, Expression] = {
+      expressions: Seq[NamedExpression]): Map[ExprId, NamedExpression] = {
     assert(attributes.length == expressions.length)
-    val map = new mutable.HashMap[ExprId, Expression]()
+    val map = new mutable.HashMap[ExprId, NamedExpression]()
     attributes.zip(expressions).foreach {
       case (attr, expr) =>
         map.put(attr.exprId, expr)
@@ -117,19 +116,38 @@ object CoalesceUnionUtil extends Logging {
     map.toMap
   }
 
-  def replaceAttributes(e: Expression, replaceMap: Map[ExprId, Expression]): 
Expression = {
+  def replaceAttributes(e: Expression, replaceMap: Map[ExprId, 
NamedExpression]): Expression = {
     e match {
       case attr: Attribute =>
+        replaceMap.getOrElse(attr.exprId, attr)
+      case attr: OuterReference =>
         replaceMap.get(attr.exprId) match {
-          case Some(replaceAttr) => replaceAttr
-          case None =>
-            throw new GlutenNotSupportException(s"Not found attribute: $attr 
${attr.qualifiedName}")
+          case Some(replaceAttr) =>
+            OuterReference.apply(replaceAttr)
+          case _ => attr
         }
+      case subquery: ScalarSubquery =>
+        val plan = replaceSubqueryAttributes(subquery.plan, replaceMap)
+        val outerAttrs = subquery.outerAttrs.map(replaceAttributes(_, 
replaceMap))
+        subquery.copy(plan, outerAttrs)
       case _ =>
         e.withNewChildren(e.children.map(replaceAttributes(_, replaceMap)))
     }
   }
 
+  def replaceSubqueryAttributes(
+      plan: LogicalPlan,
+      replaceMap: Map[ExprId, NamedExpression]): LogicalPlan = {
+    plan match {
+      case filter: Filter =>
+        filter.copy(
+          replaceAttributes(filter.condition, replaceMap),
+          replaceSubqueryAttributes(filter.child, replaceMap))
+      case _ =>
+        plan.withNewChildren(plan.children.map(replaceSubqueryAttributes(_, 
replaceMap)))
+    }
+  }
+
   // Plans and expressions are the same.
   def areStrictMatchedRelation(leftRelation: LogicalPlan, rightRelation: 
LogicalPlan): Boolean = {
     (leftRelation, rightRelation) match {
@@ -396,9 +414,8 @@ case class CoalesceAggregationUnion(spark: SparkSession) 
extends Rule[LogicalPla
 
           val newFilter = innerProject match {
             case Some(project) =>
-              val replaceMap = CoalesceUnionUtil.buildAttributesMap(
-                project.output,
-                project.projectList.map(_.asInstanceOf[Expression]))
+              val replaceMap =
+                CoalesceUnionUtil.buildAttributesMap(project.output, 
project.projectList)
               val newCondition = 
CoalesceUnionUtil.replaceAttributes(filter.condition, replaceMap)
               Filter(newCondition, sourcePlan)
             case None => filter.withNewChildren(Seq(sourcePlan))
@@ -422,9 +439,8 @@ case class CoalesceAggregationUnion(spark: SparkSession) 
extends Rule[LogicalPla
 
         val newAggregate = innerProject match {
           case Some(project) =>
-            val replaceMap = CoalesceUnionUtil.buildAttributesMap(
-              project.output,
-              project.projectList.map(_.asInstanceOf[Expression]))
+            val replaceMap =
+              CoalesceUnionUtil.buildAttributesMap(project.output, 
project.projectList)
             val newGroupExpressions = 
originalAggregate.groupingExpressions.map {
               e => CoalesceUnionUtil.replaceAttributes(e, replaceMap)
             }
@@ -688,7 +704,7 @@ case class CoalesceAggregationUnion(spark: SparkSession) 
extends Rule[LogicalPla
   def collectClauseStructFields(
       analyzedPlan: AnalyzedPlan,
       clauseIndex: Int,
-      attributeReplaceMap: Map[ExprId, Expression]): Seq[Expression] = {
+      attributeReplaceMap: Map[ExprId, NamedExpression]): Seq[Expression] = {
 
     val planAnalyzer = analyzedPlan.planAnalyzer.get.asInstanceOf[PlanAnalyzer]
     val aggregate = 
planAnalyzer.constructedAggregatePlan.get.asInstanceOf[Aggregate]
@@ -876,14 +892,14 @@ case class CoalesceProjectionUnion(spark: SparkSession) 
extends Rule[LogicalPlan
               val replaceMap =
                 CoalesceUnionUtil.buildAttributesMap(
                   project.output,
-                  project.projectList.map(_.asInstanceOf[Expression]))
+                  project.projectList.map(_.asInstanceOf[NamedExpression]))
               val newCondition = 
CoalesceUnionUtil.replaceAttributes(filter.condition, replaceMap)
               Some(Filter(newCondition, source))
             case subquery @ SubqueryAlias(_, project: Project) =>
               val replaceMap =
                 CoalesceUnionUtil.buildAttributesMap(
                   project.output,
-                  project.projectList.map(_.asInstanceOf[Expression]))
+                  project.projectList.map(_.asInstanceOf[NamedExpression]))
               val newCondition = 
CoalesceUnionUtil.replaceAttributes(filter.condition, replaceMap)
               Some(Filter(newCondition, source))
             case _ => Some(filter)
@@ -898,11 +914,8 @@ case class CoalesceProjectionUnion(spark: SparkSession) 
extends Rule[LogicalPlan
           val originalFilter = extractFilter().get
           originalFilter.child match {
             case project: Project =>
-              None
               val replaceMap =
-                CoalesceUnionUtil.buildAttributesMap(
-                  project.output,
-                  project.projectList.map(_.asInstanceOf[Expression]))
+                CoalesceUnionUtil.buildAttributesMap(project.output, 
project.projectList)
               val originalProject = originalPlan.asInstanceOf[Project]
               val newProjectList =
                 originalProject.projectList
@@ -912,9 +925,7 @@ case class CoalesceProjectionUnion(spark: SparkSession) 
extends Rule[LogicalPlan
               Some(newProject)
             case subquery @ SubqueryAlias(_, project: Project) =>
               val replaceMap =
-                CoalesceUnionUtil.buildAttributesMap(
-                  project.output,
-                  project.projectList.map(_.asInstanceOf[Expression]))
+                CoalesceUnionUtil.buildAttributesMap(project.output, 
project.projectList)
               val originalProject = originalPlan.asInstanceOf[Project]
               val newProjectList =
                 originalProject.projectList
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenCoalesceAggregationUnionSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenCoalesceAggregationUnionSuite.scala
index e5374bebb4..6e73d12230 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenCoalesceAggregationUnionSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenCoalesceAggregationUnionSuite.scala
@@ -515,4 +515,59 @@ class GlutenCoalesceAggregationUnionSuite extends 
GlutenClickHouseWholeStageTran
         |""".stripMargin
     compareResultsAgainstVanillaSpark(sql, true, checkHasUnion, true)
   }
+
+  test("GLUTEN-9646: fix coalesce project union when has subquery") {
+    val schema_fact = StructType(
+      Array(
+        StructField("a", IntegerType, nullable = true),
+        StructField("b", IntegerType, nullable = true)
+      ))
+
+    val schema_order = StructType(
+      Array(
+        StructField("c", IntegerType, nullable = true),
+        StructField("b", IntegerType, nullable = true)
+      ))
+
+    val data_fact = sparkContext.parallelize(
+      Seq(
+        Row(2, 1),
+        Row(3, 2),
+        Row(4, 3),
+        Row(5, 4)
+      ))
+
+    val data_order = sparkContext.parallelize(
+      Seq(
+        Row(1, 1),
+        Row(2, 2),
+        Row(3, 3),
+        Row(4, 4)
+      ))
+
+    val dataFrame1 = spark.createDataFrame(data_fact, schema_fact)
+    val dataFrame2 = spark.createDataFrame(data_order, schema_order)
+    createTestTable("fact", dataFrame1)
+    createTestTable("order", dataFrame2)
+
+    val sql =
+      """
+        |SELECT a
+        |FROM fact
+        |WHERE a =
+        |    (SELECT sum(c) + 2
+        |     FROM order
+        |     WHERE order.b = fact.b
+        |     GROUP BY order.b)
+        |UNION ALL
+        |SELECT a
+        |FROM fact
+        |WHERE a =
+        |    (SELECT sum(c) + 1
+        |     FROM order
+        |     WHERE order.b = fact.b
+        |     GROUP BY order.b)
+        |""".stripMargin
+    compareResultsAgainstVanillaSpark(sql, compareResult = true, checkNoUnion)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to