This is an automated email from the ASF dual-hosted git repository.

zhangzc pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new a88d5d5b5f [GLUTEN-8784][CH] Coalesce union of multiple scan-projects 
(#8785)
a88d5d5b5f is described below

commit a88d5d5b5fdf7bdad7c1f5b8ea42d41fd1fbbf1a
Author: lgbo <[email protected]>
AuthorDate: Wed Feb 26 16:39:13 2025 +0800

    [GLUTEN-8784][CH] Coalesce union of multiple scan-projects (#8785)
    
    [CH] Coalesce union of multiple scan-projects
---
 .../gluten/backendsapi/clickhouse/CHBackend.scala  |    2 +
 .../gluten/backendsapi/clickhouse/CHRuleApi.scala  |    1 +
 .../extension/CoalesceAggregationUnion.scala       | 1428 ++++++++++++--------
 .../GlutenCoalesceAggregationUnionSuite.scala      |  103 ++
 .../apache/spark/sql/GlutenCTEInlineSuite.scala    |   16 +-
 .../sql/GlutenDataFrameSetOperationsSuite.scala    |   10 +-
 .../utils/clickhouse/ClickHouseTestSettings.scala  |    2 +
 .../apache/spark/sql/GlutenCTEInlineSuite.scala    |   16 +-
 .../sql/GlutenDataFrameSetOperationsSuite.scala    |   10 +-
 9 files changed, 986 insertions(+), 602 deletions(-)

diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
index 73e8bce3fe..c0380452de 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
@@ -158,6 +158,8 @@ object CHBackendSettings extends BackendSettingsApi with 
Logging {
 
   val GLUTEN_ENABLE_COALESCE_AGGREGATION_UNION: String =
     CHConfig.prefixOf("enable.coalesce.aggregation.union")
+  val GLUTEN_ENABLE_COALESCE_PROJECT_UNION: String =
+    CHConfig.prefixOf("enable.coalesce.project.union")
 
   def affinityMode: String = {
     SparkEnv.get.conf
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
index ecd7e5a241..4573563cc9 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
@@ -61,6 +61,7 @@ object CHRuleApi {
     injector.injectParser(
       (spark, parserInterface) => new GlutenClickhouseSqlParser(spark, 
parserInterface))
     injector.injectResolutionRule(spark => new CoalesceAggregationUnion(spark))
+    injector.injectResolutionRule(spark => new CoalesceProjectionUnion(spark))
     injector.injectResolutionRule(spark => new 
RewriteToDateExpresstionRule(spark))
     injector.injectResolutionRule(spark => new 
RewriteDateTimestampComparisonRule(spark))
     injector.injectResolutionRule(spark => new 
CollapseGetJsonObjectExpressionRule(spark))
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala
index 4a83830e51..cfae1deeaa 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala
@@ -35,63 +35,73 @@ import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 import scala.util.{Failure, Success, Try}
 
-/*
- * Example:
- * Rewrite query
- *  SELECT a, b, sum(c) FROM t WHERE d = 1 GROUP BY a,b
- *  UNION ALL
- *  SELECT a, b, sum(c) FROM t WHERE d = 2 GROUP BY a,b
- * into
- *  SELECT a, b, sum(c) FROM (
- *   SELECT s.a as a, s.b as b, s.c as c, s.id as group_id FROM (
- *    SELECT explode(s) as s FROM (
- *      SELECT array(
- *        if(d = 1, named_struct('a', a, 'b', b, 'c', c, 'id', 0), null),
- *        if(d = 2, named_struct('a', a, 'b', b, 'c', c, 'id', 1), null)) as s
- *      FROM t WHERE d = 1 OR d = 2
- *    )
- *   ) WHERE s is not null
- *  ) GROUP BY a,b, group_id
+/** Abstract class representing a plan analyzer. */
+abstract class AbstractPlanAnalyzer() {
+
+  /** Extract the common source subplan from the plan. */
+  def getExtractedSourcePlan(): Option[LogicalPlan] = None
+
+  /**
+   * Construct a plan with filter. The filter condition may be rewritten, 
different from the
+   * original filter condition.
+   */
+  def getConstructedFilterPlan(): Option[LogicalPlan] = None
+
+  /**
+   * Construct a plan with aggregate . The aggregate expressions may be 
rewritten, different from
+   * the original aggregate.
+   */
+  def getConstructedAggregatePlan(): Option[LogicalPlan] = None
+
+  /** Construct a plan with project. */
+  def getConstructedProjectPlan: Option[LogicalPlan] = None
+
+  /** If the rule cannot be applied, return false, otherwise return true. */
+  def doValidate(): Boolean = false
+}
+
+/**
+ * Case class representing an analyzed logical plan.
  *
- * The first query need to scan `t` multiply, when the output of scan is 
large, the query is
- * really slow. The rewritten query only scan `t` once, and the performance is 
much better.
+ * @param plan
+ *   The original plan which is analyzed.
+ * @param planAnalyzer
+ *   An optional plan analyzer that provides additional analysis capabilities. 
When the rule cannot
+ *   apply to the plan, the planAnalyzer is None.
  */
+case class AnalyzedPlan(plan: LogicalPlan, planAnalyzer: 
Option[AbstractPlanAnalyzer])
 
-class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] 
with Logging {
-  def removeAlias(e: Expression): Expression = {
-    e match {
-      case alias: Alias => alias.child
-      case _ => e
+object CoalesceUnionUtil extends Logging {
+
+  def isResolvedPlan(plan: LogicalPlan): Boolean = {
+    plan match {
+      case isnert: InsertIntoStatement => isnert.query.resolved
+      case _ => plan.resolved
     }
   }
 
-  def hasAggregateExpression(e: Expression): Boolean = {
-    if (e.children.isEmpty && !e.isInstanceOf[AggregateExpression]) {
-      return false
-    }
-    e match {
-      case _: AggregateExpression => true
-      case _ => e.children.exists(hasAggregateExpression(_))
+  // unfold nested unions
+  def collectAllUnionClauses(union: Union): ArrayBuffer[LogicalPlan] = {
+    val unionClauses = ArrayBuffer[LogicalPlan]()
+    union.children.foreach {
+      case u: Union =>
+        unionClauses ++= collectAllUnionClauses(u)
+      case other =>
+        unionClauses += other
     }
+    unionClauses
   }
 
-  def isAggregateExpression(e: Expression): Boolean = {
-    e match {
-      case cast: Cast => isAggregateExpression(cast.child)
-      case alias: Alias => isAggregateExpression(alias.child)
-      case agg: AggregateExpression => true
-      case _ => false
-    }
+  def isRelation(plan: LogicalPlan): Boolean = {
+    plan.isInstanceOf[MultiInstanceRelation]
   }
 
-  def hasAggregateExpressionsWithFilter(e: Expression): Boolean = {
-    if (e.children.isEmpty && !e.isInstanceOf[AggregateExpression]) {
-      return false
-    }
-    e match {
-      case aggExpr: AggregateExpression =>
-        aggExpr.filter.isDefined
-      case _ => e.children.exists(hasAggregateExpressionsWithFilter(_))
+  def validateSource(plan: LogicalPlan): Boolean = {
+    plan match {
+      case relation if isRelation(relation) => true
+      case _: Project | _: Filter | _: SubqueryAlias =>
+        plan.children.forall(validateSource)
+      case _ => false
     }
   }
 
@@ -120,7 +130,216 @@ class CoalesceAggregationUnion(spark: SparkSession) 
extends Rule[LogicalPlan] wi
     }
   }
 
-  case class AggregateAnalzyInfo(originalAggregate: Aggregate) {
+  // Plans and expressions are the same.
+  def areStrictMatchedRelation(leftRelation: LogicalPlan, rightRelation: 
LogicalPlan): Boolean = {
+    (leftRelation, rightRelation) match {
+      case (leftLogicalRelation: LogicalRelation, rightLogicalRelation: 
LogicalRelation) =>
+        val leftTable =
+          
leftLogicalRelation.catalogTable.map(_.identifier.unquotedString).getOrElse("")
+        val rightTable =
+          
rightLogicalRelation.catalogTable.map(_.identifier.unquotedString).getOrElse("")
+        leftLogicalRelation.output.length == 
rightLogicalRelation.output.length &&
+        leftLogicalRelation.output.zip(rightLogicalRelation.output).forall {
+          case (leftAttr, rightAttr) =>
+            leftAttr.dataType.equals(rightAttr.dataType) && 
leftAttr.name.equals(rightAttr.name)
+        } &&
+        leftTable.equals(rightTable) && leftTable.nonEmpty
+      case (leftCTE: CTERelationRef, rightCTE: CTERelationRef) =>
+        leftCTE.cteId == rightCTE.cteId
+      case (leftHiveTable: HiveTableRelation, rightHiveTable: 
HiveTableRelation) =>
+        leftHiveTable.tableMeta.identifier.unquotedString
+          .equals(rightHiveTable.tableMeta.identifier.unquotedString)
+      case (leftSubquery: SubqueryAlias, rightSubquery: SubqueryAlias) =>
+        areStrictMatchedRelation(leftSubquery.child, rightSubquery.child)
+      case (leftProject: Project, rightProject: Project) =>
+        leftProject.projectList.length == rightProject.projectList.length &&
+        leftProject.projectList.zip(rightProject.projectList).forall {
+          case (leftExpr, rightExpr) =>
+            areMatchedExpression(leftExpr, rightExpr)
+        } &&
+        areStrictMatchedRelation(leftProject.child, rightProject.child)
+      case (leftFilter: Filter, rightFilter: Filter) =>
+        areMatchedExpression(leftFilter.condition, rightFilter.condition) &&
+        areStrictMatchedRelation(leftFilter.child, rightFilter.child)
+      case (_, _) => false
+    }
+  }
+
+  // Two projects have the same output schema. Don't need the projectLists are 
the same.
+  def areOutputMatchedProject(leftPlan: LogicalPlan, rightPlan: LogicalPlan): 
Boolean = {
+    (leftPlan, rightPlan) match {
+      case (leftProject: Project, rightProject: Project) =>
+        val leftOutput = leftProject.output
+        val rightOutput = rightProject.output
+        leftOutput.length == rightOutput.length &&
+        leftOutput.zip(rightOutput).forall {
+          case (leftAttr, rightAttr) =>
+            leftAttr.dataType.equals(rightAttr.dataType) && 
leftAttr.name.equals(rightAttr.name)
+        }
+      case (_, _) =>
+        false
+    }
+  }
+
+  def areMatchedExpression(leftExpression: Expression, rightExpression: 
Expression): Boolean = {
+    (leftExpression, rightExpression) match {
+      case (leftLiteral: Literal, rightLiteral: Literal) =>
+        leftLiteral.dataType.equals(rightLiteral.dataType) &&
+        leftLiteral.value == rightLiteral.value
+      case (leftAttr: Attribute, rightAttr: Attribute) =>
+        leftAttr.dataType.equals(rightAttr.dataType) && 
leftAttr.name.equals(rightAttr.name)
+      case (leftAgg: AggregateExpression, rightAgg: AggregateExpression) =>
+        leftAgg.isDistinct == rightAgg.isDistinct &&
+        areMatchedExpression(leftAgg.aggregateFunction, 
rightAgg.aggregateFunction)
+      case (_, _) =>
+        leftExpression.getClass == rightExpression.getClass &&
+        leftExpression.children.length == rightExpression.children.length &&
+        leftExpression.children.zip(rightExpression.children).forall {
+          case (leftChild, rightChild) => areMatchedExpression(leftChild, 
rightChild)
+        }
+    }
+  }
+
+  // Normalize all the filter conditions to make them fit with the first 
plan's source
+  def normalizedClausesFilterCondition(analyzedPlans: Seq[AnalyzedPlan]): 
Seq[Expression] = {
+    val valueAttributes = 
analyzedPlans.head.planAnalyzer.get.getExtractedSourcePlan.get.output
+    analyzedPlans.map {
+      analyzedPlan =>
+        val keyAttributes = 
analyzedPlan.planAnalyzer.get.getExtractedSourcePlan.get.output
+        val replaceMap = buildAttributesMap(keyAttributes, valueAttributes)
+        val filter = 
analyzedPlan.planAnalyzer.get.getConstructedFilterPlan.get.asInstanceOf[Filter]
+        CoalesceUnionUtil.replaceAttributes(filter.condition, replaceMap)
+    }
+  }
+
+  def makeAlias(e: Expression, name: String): NamedExpression = {
+    Alias(e, name)(
+      NamedExpression.newExprId,
+      e match {
+        case ne: NamedExpression => ne.qualifier
+        case _ => Seq.empty
+      },
+      None,
+      Seq.empty)
+  }
+
+  def removeAlias(e: Expression): Expression = {
+    e match {
+      case alias: Alias => alias.child
+      case _ => e
+    }
+  }
+
+  def isAggregateExpression(e: Expression): Boolean = {
+    e match {
+      case cast: Cast => isAggregateExpression(cast.child)
+      case alias: Alias => isAggregateExpression(alias.child)
+      case agg: AggregateExpression => true
+      case _ => false
+    }
+  }
+
+  def hasAggregateExpression(e: Expression): Boolean = {
+    if (e.children.isEmpty && !e.isInstanceOf[AggregateExpression]) {
+      false
+    } else {
+      e match {
+        case _: AggregateExpression => true
+        case _ => e.children.exists(hasAggregateExpression(_))
+      }
+    }
+  }
+
+  def addArrayStep(plan: LogicalPlan): LogicalPlan = {
+    val array =
+      
CoalesceUnionUtil.makeAlias(CreateArray(plan.output.map(_.asInstanceOf[Expression])),
 "array")
+    Project(Seq(array), plan)
+  }
+
+  def addExplodeStep(plan: LogicalPlan): LogicalPlan = {
+    val arrayOutput = plan.output.head.asInstanceOf[Expression]
+    val explodeExpression = Explode(arrayOutput)
+    val explodeOutput = AttributeReference(
+      "generate_output",
+      arrayOutput.dataType.asInstanceOf[ArrayType].elementType)()
+    val generate = Generate(
+      explodeExpression,
+      unrequiredChildIndex = Seq(0),
+      outer = false,
+      qualifier = None,
+      generatorOutput = Seq(explodeOutput),
+      plan)
+    Filter(IsNotNull(generate.output.head), generate)
+  }
+
+  def addUnfoldStructStep(plan: LogicalPlan): LogicalPlan = {
+    assert(plan.output.length == 1)
+    val structExpression = plan.output.head
+    assert(structExpression.dataType.isInstanceOf[StructType])
+    val structType = structExpression.dataType.asInstanceOf[StructType]
+    val attributes = ArrayBuffer[NamedExpression]()
+    var fieldIndex = 0
+    structType.fields.foreach {
+      field =>
+        attributes += Alias(GetStructField(structExpression, fieldIndex), 
field.name)()
+        fieldIndex += 1
+    }
+    Project(attributes.toSeq, plan)
+  }
+
+  def unionClauses(originalUnion: Union, clauses: Seq[LogicalPlan]): 
LogicalPlan = {
+    val coalescePlan = if (clauses.length == 1) {
+      clauses.head
+    } else {
+      var firstUnionChild = clauses.head
+      for (i <- 1 until clauses.length - 1) {
+        firstUnionChild = Union(firstUnionChild, clauses(i))
+      }
+      Union(firstUnionChild, clauses.last)
+    }
+    val outputPairs = coalescePlan.output.zip(originalUnion.output)
+    if (outputPairs.forall(pair => pair._1.semanticEquals(pair._2))) {
+      coalescePlan
+    } else {
+      val reprojectOutputs = outputPairs.map {
+        case (newAttr, oldAttr) =>
+          if (newAttr.exprId == oldAttr.exprId) {
+            newAttr
+          } else {
+            Alias(newAttr, oldAttr.name)(oldAttr.exprId, oldAttr.qualifier, 
None, Seq.empty)
+          }
+      }
+      Project(reprojectOutputs, coalescePlan)
+    }
+  }
+}
+
+/*
+ * Example:
+ * Rewrite query
+ *  SELECT a, b, sum(c) FROM t WHERE d = 1 GROUP BY a,b
+ *  UNION ALL
+ *  SELECT a, b, sum(c) FROM t WHERE d = 2 GROUP BY a,b
+ * into
+ *  SELECT a, b, sum(c) FROM (
+ *   SELECT s.a as a, s.b as b, s.c as c, s.id as group_id FROM (
+ *    SELECT explode(s) as s FROM (
+ *      SELECT array(
+ *        if(d = 1, named_struct('a', a, 'b', b, 'c', c, 'id', 0), null),
+ *        if(d = 2, named_struct('a', a, 'b', b, 'c', c, 'id', 1), null)) as s
+ *      FROM t WHERE d = 1 OR d = 2
+ *    )
+ *   ) WHERE s is not null
+ *  ) GROUP BY a,b, group_id
+ *
+ * The first query need to scan `t` multiply, when the output of scan is 
large, the query is
+ * really slow. The rewritten query only scan `t` once, and the performance is 
much better.
+ */
+
+class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan] 
with Logging {
+
+  case class PlanAnalyzer(originalAggregate: Aggregate) extends 
AbstractPlanAnalyzer {
+
     protected def extractFilter(): Option[Filter] = {
       originalAggregate.child match {
         case filter: Filter => Some(filter)
@@ -129,9 +348,10 @@ class CoalesceAggregationUnion(spark: SparkSession) 
extends Rule[LogicalPlan] wi
           subquery.child match {
             case filter: Filter => Some(filter)
             case project @ Project(_, filter: Filter) => Some(filter)
-            case relation if isRelation(relation) =>
+            case relation if CoalesceUnionUtil.isRelation(relation) =>
               Some(Filter(Literal(true, BooleanType), subquery))
-            case nestedRelation: SubqueryAlias if 
(isRelation(nestedRelation.child)) =>
+            case nestedRelation: SubqueryAlias
+                if (CoalesceUnionUtil.isRelation(nestedRelation.child)) =>
               Some(Filter(Literal(true, BooleanType), nestedRelation))
             case _ => None
           }
@@ -139,76 +359,66 @@ class CoalesceAggregationUnion(spark: SparkSession) 
extends Rule[LogicalPlan] wi
       }
     }
 
-    def isValidSource(plan: LogicalPlan): Boolean = {
-      plan match {
-        case relation if isRelation(relation) => true
-        case _: Project | _: Filter | _: SubqueryAlias =>
-          plan.children.forall(isValidSource)
-        case _ => false
-      }
-    }
-
     // Try to make the plan simple, contain only three steps, source, filter, 
aggregate.
     lazy val extractedSourcePlan = {
-      val filter = extractFilter()
-      if (!filter.isDefined) {
-        None
-      } else {
-        filter.get.child match {
-          case project: Project if isValidSource(project.child) => 
Some(project.child)
-          case other if isValidSource(other) => Some(other)
-          case _ => None
-        }
+      extractFilter match {
+        case Some(filter) =>
+          filter.child match {
+            case project: Project if 
CoalesceUnionUtil.validateSource(project.child) =>
+              Some(project.child)
+            case other if CoalesceUnionUtil.validateSource(other) => 
Some(other)
+            case _ => None
+          }
+        case None => None
       }
     }
+    override def getExtractedSourcePlan(): Option[LogicalPlan] = 
extractedSourcePlan
 
     lazy val constructedFilterPlan = {
-      val filter = extractFilter()
-      if (!filter.isDefined || !extractedSourcePlan.isDefined) {
-        None
-      } else {
-        val project = filter.get.child match {
-          case project: Project => Some(project)
-          case other =>
-            None
-        }
-        val newFilter = project match {
-          case Some(project) =>
-            val replaceMap = buildAttributesMap(
-              project.output,
-              project.child.output.map(_.asInstanceOf[Expression]))
-            val newCondition = replaceAttributes(filter.get.condition, 
replaceMap)
-            Filter(newCondition, extractedSourcePlan.get)
-          case None => filter.get.withNewChildren(Seq(extractedSourcePlan.get))
-        }
-        Some(newFilter)
+      extractedSourcePlan match {
+        case Some(sourcePlan) =>
+          val filter = extractFilter().get
+          val innerProject = filter.child match {
+            case project: Project => Some(project)
+            case _ => None
+          }
+
+          val newFilter = innerProject match {
+            case Some(project) =>
+              val replaceMap = CoalesceUnionUtil.buildAttributesMap(
+                project.output,
+                project.projectList.map(_.asInstanceOf[Expression]))
+              val newCondition = 
CoalesceUnionUtil.replaceAttributes(filter.condition, replaceMap)
+              Filter(newCondition, sourcePlan)
+            case None => filter.withNewChildren(Seq(sourcePlan))
+          }
+          Some(newFilter)
+        case None => None
       }
     }
+    override def getConstructedFilterPlan: Option[LogicalPlan] = 
constructedFilterPlan
 
     lazy val constructedAggregatePlan = {
       if (!constructedFilterPlan.isDefined) {
         None
       } else {
-        val project = originalAggregate.child match {
-          case p: Project => Some(p)
-          case subquery: SubqueryAlias =>
-            subquery.child match {
-              case p: Project => Some(p)
-              case _ => None
-            }
+        val innerProject = originalAggregate.child match {
+          case project: Project => Some(project)
+          case subquery @ SubqueryAlias(_, project: Project) =>
+            Some(project)
           case _ => None
         }
 
-        val newAggregate = project match {
-          case Some(innerProject) =>
-            val replaceMap = buildAttributesMap(
-              innerProject.output,
-              innerProject.projectList.map(_.asInstanceOf[Expression]))
+        val newAggregate = innerProject match {
+          case Some(project) =>
+            val replaceMap = CoalesceUnionUtil.buildAttributesMap(
+              project.output,
+              project.projectList.map(_.asInstanceOf[Expression]))
             val newGroupExpressions = 
originalAggregate.groupingExpressions.map {
-              e => replaceAttributes(e, replaceMap)
+              e => CoalesceUnionUtil.replaceAttributes(e, replaceMap)
             }
             val newAggregateExpressions = 
originalAggregate.aggregateExpressions.map {
-              e => replaceAttributes(e, 
replaceMap).asInstanceOf[NamedExpression]
+              e => CoalesceUnionUtil.replaceAttributes(e, 
replaceMap).asInstanceOf[NamedExpression]
             }
             Aggregate(newGroupExpressions, newAggregateExpressions, 
constructedFilterPlan.get)
           case None => 
originalAggregate.withNewChildren(Seq(constructedFilterPlan.get))
@@ -216,75 +426,96 @@ class CoalesceAggregationUnion(spark: SparkSession) 
extends Rule[LogicalPlan] wi
         Some(newAggregate)
       }
     }
+    override def getConstructedAggregatePlan: Option[LogicalPlan] = 
constructedAggregatePlan
 
+    def hasAggregateExpressionsWithFilter(e: Expression): Boolean = {
+      if (e.children.isEmpty && !e.isInstanceOf[AggregateExpression]) {
+        return false
+      }
+      e match {
+        case aggExpr: AggregateExpression =>
+          aggExpr.filter.isDefined
+        case _ => e.children.exists(hasAggregateExpressionsWithFilter(_))
+      }
+    }
     lazy val hasAggregateWithFilter = 
originalAggregate.aggregateExpressions.exists {
       e => hasAggregateExpressionsWithFilter(e)
     }
 
     // The output results which are not aggregate expressions.
-    lazy val resultGroupingExpressions = constructedAggregatePlan match {
+    lazy val resultRequiredGroupingExpressions = constructedAggregatePlan 
match {
       case Some(agg) =>
-        agg.asInstanceOf[Aggregate].aggregateExpressions.filter(e => 
!hasAggregateExpression(e))
+        agg
+          .asInstanceOf[Aggregate]
+          .aggregateExpressions
+          .filter(e => !CoalesceUnionUtil.hasAggregateExpression(e))
       case None => Seq.empty
     }
 
-    lazy val positionInGroupingKeys = {
-      var i = 0
-      // In most cases, the expressions which are not aggregate result could 
be matched with one of
-      // groupingk keys. There are some exceptions
-      // 1. The expression is a literal. The grouping keys do not contain the 
literal.
-      // 2. The expression is an expression withs gruping keys. For example,
-      //    `select k1 + k2, count(1) from t group by k1, k2`.
-      resultGroupingExpressions.map {
-        e =>
-          val aggregate = constructedAggregatePlan.get.asInstanceOf[Aggregate]
-          e match {
-            case literal @ Alias(_: Literal, _) =>
-              var idx = aggregate.groupingExpressions.indexOf(e)
-              if (idx == -1) {
-                idx = aggregate.groupingExpressions.length + i
-                i += 1
-              }
-              idx
-            case _ =>
-              var idx = aggregate.groupingExpressions.indexOf(removeAlias(e))
-              idx = if (idx == -1) {
-                aggregate.groupingExpressions.indexOf(e)
-              } else {
-                idx
-              }
-              idx
+    // For non-aggregate expressions in the output, they shoud be matched with 
one of the grouping
+    // keys. There are some exceptions
+    // 1. The expression is a literal. The grouping keys do not contain the 
literal.
+    // 2. The expression is an expression withs gruping keys. For example,
+    //    `select k1 + k2, count(1) from t group by k1, k2`.
+    // This is used to judge whether two aggregate plans are matched.
+    lazy val aggregateResultMatchedGroupingKeysPositions = {
+      var extraPosition = 0
+      val aggregate = constructedAggregatePlan.get.asInstanceOf[Aggregate]
+      resultRequiredGroupingExpressions.map {
+        case literal @ Alias(_: Literal, _) =>
+          aggregate.groupingExpressions.indexOf(literal) match {
+            case -1 =>
+              extraPosition += 1
+              aggregate.groupingExpressions.length + extraPosition - 1
+            case position => position
+          }
+        case normalExpression =>
+          aggregate.groupingExpressions.indexOf(
+            CoalesceUnionUtil.removeAlias(normalExpression)) match {
+            case -1 => aggregate.groupingExpressions.indexOf(normalExpression)
+            case position => position
           }
       }
     }
+
+    override def doValidate(): Boolean = {
+      !hasAggregateWithFilter &&
+      constructedAggregatePlan.isDefined &&
+      aggregateResultMatchedGroupingKeysPositions.forall(_ >= 0) &&
+      originalAggregate.aggregateExpressions.forall {
+        e =>
+          val innerExpr = CoalesceUnionUtil.removeAlias(e)
+          // `agg_fun1(x) + agg_fun2(y)` is supported, but `agg_fun1(x) + y` 
is not supported.
+          if (CoalesceUnionUtil.hasAggregateExpression(innerExpr)) {
+            innerExpr.isInstanceOf[AggregateExpression] ||
+            innerExpr.children.forall(e => 
CoalesceUnionUtil.isAggregateExpression(e))
+          } else {
+            true
+          }
+      } &&
+      extractedSourcePlan.isDefined
+    }
   }
 
   /*
    * Case class representing an analyzed plan.
    *
    * @param plan The logical plan that to be analyzed.
-   * @param analyzedInfo Optional information about the aggregate analysis.
+   * @param planAnalyzer Optional information about the aggregate analysis.
    */
-  case class AnalyzedPlan(plan: LogicalPlan, analyzedInfo: 
Option[AggregateAnalzyInfo])
-
-  def isResolvedPlan(plan: LogicalPlan): Boolean = {
-    plan match {
-      case isnert: InsertIntoStatement => isnert.query.resolved
-      case _ => plan.resolved
-    }
-  }
-
   override def apply(plan: LogicalPlan): LogicalPlan = {
     if (
       spark.conf
         .get(CHBackendSettings.GLUTEN_ENABLE_COALESCE_AGGREGATION_UNION, 
"true")
-        .toBoolean && isResolvedPlan(plan)
+        .toBoolean && CoalesceUnionUtil.isResolvedPlan(plan)
     ) {
       Try {
         visitPlan(plan)
       } match {
-        case Success(res) => res
-        case Failure(e) => plan
+        case Success(newPlan) => newPlan
+        case Failure(e) =>
+          logError(s"$e")
+          plan
       }
     } else {
       plan
@@ -293,526 +524,531 @@ class CoalesceAggregationUnion(spark: SparkSession) 
extends Rule[LogicalPlan] wi
 
   def visitPlan(plan: LogicalPlan): LogicalPlan = {
     plan match {
-      case union: Union =>
-        val planGroups = groupStructureMatchedAggregate(union)
-        if (planGroups.forall(group => group.length == 1)) {
-          plan.withNewChildren(plan.children.map(visitPlan))
-        } else {
-          val newUnionClauses = planGroups.map {
-            groupedPlans =>
-              if (groupedPlans.length == 1) {
-                groupedPlans.head.plan
-              } else {
-                val firstAggregateAnalzyInfo = 
groupedPlans.head.analyzedInfo.get
-                val aggregates = 
groupedPlans.map(_.analyzedInfo.get.constructedAggregatePlan.get)
-                val filterConditions = 
buildAggregateCasesConditions(groupedPlans)
-                val firstAggregateFilter =
-                  
firstAggregateAnalzyInfo.constructedFilterPlan.get.asInstanceOf[Filter]
-
-                // Add a filter step with condition `cond1 or cond2 or ...`, 
`cond_i` comes from
-                // each union clause. Apply this filter on the source plan.
-                val unionFilter = Filter(
-                  buildUnionConditionForAggregateSource(filterConditions),
-                  firstAggregateAnalzyInfo.extractedSourcePlan.get)
-
-                // Wrap all the attributes into a single structure attribute.
-                val wrappedAttributesProject =
-                  buildProjectFoldIntoStruct(unionFilter, groupedPlans, 
filterConditions)
-
-                // Build an array which element are response to each union 
clause.
-                val arrayProject =
-                  buildProjectBranchArray(wrappedAttributesProject, 
filterConditions)
+      case union @ Union(_, false, false) =>
+        val groupedUnionClauses = groupUnionClauses(union)
+        val newUnionClauses = groupedUnionClauses.map {
+          clauses => coalesceMatchedUnionClauses(clauses)
+        }
+        CoalesceUnionUtil.unionClauses(union, newUnionClauses)
+      case _ => plan.withNewChildren(plan.children.map(visitPlan))
+    }
+  }
 
-                // Explode the array
-                val explode = buildExplodeBranchArray(arrayProject)
+  def groupUnionClauses(union: Union): Seq[Seq[AnalyzedPlan]] = {
+    val unionClauses = CoalesceUnionUtil.collectAllUnionClauses(union)
+    val groups = ArrayBuffer[ArrayBuffer[AnalyzedPlan]]()
+    unionClauses.foreach {
+      clause =>
+        val innerClause = clause match {
+          case project @ Project(projectList, aggregate: Aggregate) =>
+            if (projectList.forall(_.isInstanceOf[Alias])) {
+              Some(aggregate)
+            } else {
+              None
+            }
+          case aggregate: Aggregate =>
+            Some(aggregate)
+          case _ => None
+        }
+        innerClause match {
+          case Some(aggregate) =>
+            val planAnalyzer = PlanAnalyzer(aggregate)
+            if (planAnalyzer.doValidate()) {
+              val analyzedPlan = AnalyzedPlan(clause, Some(planAnalyzer))
+              val matchedGroup = findMatchedGroup(analyzedPlan, groups)
+              if (matchedGroup != -1) {
+                groups(matchedGroup) += analyzedPlan
+              } else {
+                groups += ArrayBuffer(analyzedPlan)
+              }
+            } else {
+              val newClause = visitPlan(clause)
+              groups += ArrayBuffer(AnalyzedPlan(newClause, None))
+            }
+          case None =>
+            val newClause = visitPlan(clause)
+            groups += ArrayBuffer(AnalyzedPlan(newClause, None))
+        }
+    }
+    groups.map(_.toSeq).toSeq
+  }
 
-                // Null value means that the union clause does not have the 
corresponding data.
-                val notNullFilter = Filter(IsNotNull(explode.output.head), 
explode)
+  def findMatchedGroup(
+      analyedPlan: AnalyzedPlan,
+      groups: ArrayBuffer[ArrayBuffer[AnalyzedPlan]]): Int = {
+    groups.zipWithIndex.find {
+      case (group, groupIndex) =>
+        val checkedAnalyzedPlan = group.head
+        if (checkedAnalyzedPlan.planAnalyzer.isDefined && 
analyedPlan.planAnalyzer.isDefined) {
+          val leftPlanAnalyzer = 
checkedAnalyzedPlan.planAnalyzer.get.asInstanceOf[PlanAnalyzer]
+          val rightPlanAnalyzer = 
analyedPlan.planAnalyzer.get.asInstanceOf[PlanAnalyzer]
+          val leftAggregate =
+            
leftPlanAnalyzer.getConstructedAggregatePlan.get.asInstanceOf[Aggregate]
+          val rightAggregate =
+            
rightPlanAnalyzer.getConstructedAggregatePlan.get.asInstanceOf[Aggregate]
 
-                // Destruct the struct attribute.
-                val destructStructProject = 
buildProjectUnfoldStruct(notNullFilter)
+          var isMatched = CoalesceUnionUtil
+            .areStrictMatchedRelation(
+              leftPlanAnalyzer.extractedSourcePlan.get,
+              rightPlanAnalyzer.extractedSourcePlan.get)
 
-                buildAggregateWithGroupId(destructStructProject, groupedPlans)
-              }
-          }
-          val coalesePlan = if (newUnionClauses.length == 1) {
-            newUnionClauses.head
-          } else {
-            var firstUnionChild = newUnionClauses.head
-            for (i <- 1 until newUnionClauses.length - 1) {
-              firstUnionChild = Union(firstUnionChild, newUnionClauses(i))
+          isMatched = isMatched &&
+            leftAggregate.groupingExpressions.length == 
rightAggregate.groupingExpressions.length &&
+            
leftAggregate.groupingExpressions.zip(rightAggregate.groupingExpressions).forall
 {
+              case (leftExpr, rightExpr) => 
leftExpr.dataType.equals(rightExpr.dataType)
             }
-            Union(firstUnionChild, newUnionClauses.last)
-          }
 
-          // We need to keep the output atrributes same as the original plan.
-          val outputAttrPairs = coalesePlan.output.zip(union.output)
-          if (outputAttrPairs.forall(pair => pair._1.semanticEquals(pair._2))) 
{
-            coalesePlan
-          } else {
-            val reprejectOutputs = outputAttrPairs.map {
-              case (newAttr, oldAttr) =>
-                if (newAttr.exprId == oldAttr.exprId) {
-                  newAttr
+          isMatched = isMatched &&
+            
leftPlanAnalyzer.aggregateResultMatchedGroupingKeysPositions.length ==
+            
rightPlanAnalyzer.aggregateResultMatchedGroupingKeysPositions.length &&
+            leftPlanAnalyzer.aggregateResultMatchedGroupingKeysPositions
+              
.zip(rightPlanAnalyzer.aggregateResultMatchedGroupingKeysPositions)
+              .forall { case (leftPos, rightPos) => leftPos == rightPos }
+
+          isMatched = isMatched && leftAggregate.aggregateExpressions.length ==
+            rightAggregate.aggregateExpressions.length &&
+            
leftAggregate.aggregateExpressions.zip(rightAggregate.aggregateExpressions).forall
 {
+              case (leftExpr, rightExpr) =>
+                if (leftExpr.dataType.equals(rightExpr.dataType)) {
+                  (
+                    CoalesceUnionUtil.hasAggregateExpression(leftExpr),
+                    CoalesceUnionUtil.hasAggregateExpression(rightExpr)) match 
{
+                    case (true, true) =>
+                      CoalesceUnionUtil.areMatchedExpression(leftExpr, 
rightExpr)
+                    case (false, true) => false
+                    case (true, false) => false
+                    case (false, false) => true
+                  }
                 } else {
-                  Alias(newAttr, oldAttr.name)(oldAttr.exprId, 
oldAttr.qualifier, None, Seq.empty)
+                  false
                 }
             }
-            Project(reprejectOutputs, coalesePlan)
-          }
+          isMatched
+        } else {
+          false
         }
-      case _ => plan.withNewChildren(plan.children.map(visitPlan))
+    } match {
+      case Some((_, i)) => i
+      case None => -1
     }
   }
 
-  def isRelation(plan: LogicalPlan): Boolean = {
-    plan.isInstanceOf[MultiInstanceRelation]
+  def coalesceMatchedUnionClauses(clauses: Seq[AnalyzedPlan]): LogicalPlan = {
+    if (clauses.length == 1) {
+      clauses.head.plan
+    } else {
+      val normalizedConditions = 
CoalesceUnionUtil.normalizedClausesFilterCondition(clauses)
+      val newSource = clauses.head.planAnalyzer.get.getExtractedSourcePlan.get
+      val newFilter = Filter(normalizedConditions.reduce(Or), newSource)
+      val foldStructStep = addFoldStructStep(newFilter, normalizedConditions, 
clauses)
+      val arrayStep = CoalesceUnionUtil.addArrayStep(foldStructStep)
+      val explodeStep = CoalesceUnionUtil.addExplodeStep(arrayStep)
+      val unfoldStructStep = CoalesceUnionUtil.addUnfoldStructStep(explodeStep)
+      addAggregateStep(unfoldStructStep, 
clauses.head.planAnalyzer.get.asInstanceOf[PlanAnalyzer])
+    }
   }
 
-  def areSameRelation(l: LogicalPlan, r: LogicalPlan): Boolean = {
-    (l, r) match {
-      case (lRelation: LogicalRelation, rRelation: LogicalRelation) =>
-        val lTable = 
lRelation.catalogTable.map(_.identifier.unquotedString).getOrElse("")
-        val rTable = 
rRelation.catalogTable.map(_.identifier.unquotedString).getOrElse("")
-        lRelation.output.length == rRelation.output.length &&
-        lRelation.output.zip(rRelation.output).forall {
-          case (lAttr, rAttr) =>
-            lAttr.dataType.equals(rAttr.dataType) && 
lAttr.name.equals(rAttr.name)
-        } &&
-        lTable.equals(rTable) && lTable.nonEmpty
-      case (lCTE: CTERelationRef, rCTE: CTERelationRef) =>
-        lCTE.cteId == rCTE.cteId
-      case (lHiveTable: HiveTableRelation, rHiveTable: HiveTableRelation) =>
-        lHiveTable.tableMeta.identifier.unquotedString
-          .equals(rHiveTable.tableMeta.identifier.unquotedString)
-      case (_, _) =>
-        logInfo(s"xxx unknow relation: ${l.getClass}, ${r.getClass}")
-        false
+  def addFoldStructStep(
+      plan: LogicalPlan,
+      conditions: Seq[Expression],
+      analyzedPlans: Seq[AnalyzedPlan]): LogicalPlan = {
+    val replaceAttributes = 
analyzedPlans.head.planAnalyzer.get.getExtractedSourcePlan.get.output
+
+    val structAttributes = analyzedPlans.zipWithIndex.map {
+      case (analyzedPlan, clauseIndex) =>
+        val attributeReplaceMap = CoalesceUnionUtil.buildAttributesMap(
+          analyzedPlan.planAnalyzer.get.getExtractedSourcePlan.get.output,
+          replaceAttributes)
+        val structFields = collectClauseStructFields(analyzedPlan, 
clauseIndex, attributeReplaceMap)
+        CoalesceUnionUtil
+          .makeAlias(CreateNamedStruct(structFields), s"clause_$clauseIndex")
+          .asInstanceOf[NamedExpression]
+    }
+
+    val projectList = structAttributes.zip(conditions).map {
+      case (attribute, condition) =>
+        CoalesceUnionUtil
+          .makeAlias(If(condition, attribute, Literal(null, 
attribute.dataType)), attribute.name)
+          .asInstanceOf[NamedExpression]
     }
+    Project(projectList, plan)
   }
 
-  def isSupportedAggregate(info: AggregateAnalzyInfo): Boolean = {
+  def collectClauseStructFields(
+      analyzedPlan: AnalyzedPlan,
+      clauseIndex: Int,
+      attributeReplaceMap: Map[ExprId, Expression]): Seq[Expression] = {
 
-    !info.hasAggregateWithFilter &&
-    info.constructedAggregatePlan.isDefined &&
-    info.positionInGroupingKeys.forall(_ >= 0) &&
-    info.originalAggregate.aggregateExpressions.forall {
+    val planAnalyzer = analyzedPlan.planAnalyzer.get.asInstanceOf[PlanAnalyzer]
+    val aggregate = 
planAnalyzer.constructedAggregatePlan.get.asInstanceOf[Aggregate]
+    val structFields = ArrayBuffer[Expression]()
+
+    aggregate.groupingExpressions.foreach {
       e =>
-        val innerExpr = removeAlias(e)
-        // `agg_fun1(x) + agg_fun2(y)` is supported, but `agg_fun1(x) + y` is 
not supported.
-        if (hasAggregateExpression(innerExpr)) {
-          innerExpr.isInstanceOf[AggregateExpression] ||
-          innerExpr.children.forall(e => isAggregateExpression(e))
-        } else {
-          true
+        val fieldCounter = structFields.length / 2
+        structFields += Literal(UTF8String.fromString(s"f$fieldCounter"), 
StringType)
+        structFields += CoalesceUnionUtil.replaceAttributes(e, 
attributeReplaceMap)
+    }
+
+    
planAnalyzer.aggregateResultMatchedGroupingKeysPositions.zipWithIndex.foreach {
+      case (position, index) =>
+        val fieldCounter = structFields.length / 2
+        if (position >= fieldCounter) {
+          val expression = 
planAnalyzer.resultRequiredGroupingExpressions(index)
+          structFields += Literal(UTF8String.fromString(s"f$fieldCounter"), 
StringType)
+          structFields += CoalesceUnionUtil.replaceAttributes(expression, 
attributeReplaceMap)
         }
-    } &&
-    info.extractedSourcePlan.isDefined
-  }
+    }
 
-  /**
-   * Checks if two AggregateAnalzyInfo instances have the same structure.
-   *
-   * This method compares the aggregate expressions, grouping expressions, and 
the source plans of
-   * the two AggregateAnalzyInfo instances to determine if they have the same 
structure.
-   *
-   * @param l
-   *   The first AggregateAnalzyInfo instance.
-   * @param r
-   *   The second AggregateAnalzyInfo instance.
-   * @return
-   *   True if the two instances have the same structure, false otherwise.
-   */
-  def areStructureMatchedAggregate(l: AggregateAnalzyInfo, r: 
AggregateAnalzyInfo): Boolean = {
-    val lAggregate = l.constructedAggregatePlan.get.asInstanceOf[Aggregate]
-    val rAggregate = r.constructedAggregatePlan.get.asInstanceOf[Aggregate]
-    lAggregate.aggregateExpressions.length == 
rAggregate.aggregateExpressions.length &&
-    
lAggregate.aggregateExpressions.zip(rAggregate.aggregateExpressions).forall {
-      case (lExpr, rExpr) =>
-        if (!lExpr.dataType.equals(rExpr.dataType)) {
-          false
-        } else {
-          (hasAggregateExpression(lExpr), hasAggregateExpression(rExpr)) match 
{
-            case (true, true) => areStructureMatchedExpressions(lExpr, rExpr)
-            case (false, true) => false
-            case (true, false) => false
-            case (false, false) => true
+    aggregate.aggregateExpressions
+      .filter(e => CoalesceUnionUtil.hasAggregateExpression(e))
+      .foreach {
+        e =>
+          def visitAggregateExpression(expression: Expression): Unit = {
+            expression match {
+              case aggregateExpression: AggregateExpression =>
+                val fieldCounter = structFields.length / 2
+                val aggregateFunction = aggregateExpression.aggregateFunction
+                aggregateFunction.children.foreach {
+                  argument =>
+                    val fieldCounter = structFields.length / 2
+                    structFields += 
Literal(UTF8String.fromString(s"f$fieldCounter"), StringType)
+                    structFields += CoalesceUnionUtil.replaceAttributes(
+                      argument,
+                      attributeReplaceMap)
+                }
+              case combindedAggregateExpression
+                  if 
CoalesceUnionUtil.hasAggregateExpression(combindedAggregateExpression) =>
+                
combindedAggregateExpression.children.foreach(visitAggregateExpression)
+              case other =>
+                val fieldCounter = structFields.length / 2
+                structFields += 
Literal(UTF8String.fromString(s"f$fieldCounter"), StringType)
+                structFields += CoalesceUnionUtil.replaceAttributes(other, 
attributeReplaceMap)
+            }
           }
-        }
-    } &&
-    lAggregate.groupingExpressions.length == 
rAggregate.groupingExpressions.length &&
-    l.positionInGroupingKeys.length == r.positionInGroupingKeys.length &&
-    l.positionInGroupingKeys.zip(r.positionInGroupingKeys).forall {
-      case (lPos, rPos) => lPos == rPos
-    } &&
-    areSameAggregateSource(l.extractedSourcePlan.get, 
r.extractedSourcePlan.get)
+          visitAggregateExpression(e)
+      }
+
+    // Add the clause index to the struct.
+    val fieldCounter = structFields.length / 2
+    structFields += Literal(UTF8String.fromString(s"f$fieldCounter"), 
StringType)
+    structFields += Literal(UTF8String.fromString(s"$clauseIndex"), StringType)
+
+    structFields.toSeq
   }
 
-  /*
-   * Finds the index of the first group in `planGroups` that has the same 
structure as the given
-   * `analyzedInfo`.
-   *
-   * This method iterates over the `planGroups` and checks if the first 
`AnalyzedPlan` in each group
-   * has an `analyzedInfo` that matches the structure of the provided 
`analyzedInfo`. If a match is
-   * found, the index of the group is returned. If no match is found, -1 is 
returned.
-   *
-   * @param planGroups
-   *   An ArrayBuffer of ArrayBuffers, where each inner ArrayBuffer contains 
`AnalyzedPlan`
-   *   instances.
-   * @param analyzedInfo
-   *   The `AggregateAnalzyInfo` to match against the groups in `planGroups`.
-   * @return
-   *   The index of the first group with a matching structure, or -1 if no 
match is found.
-   */
-  def findStructureMatchedAggregate(
-      planGroups: ArrayBuffer[ArrayBuffer[AnalyzedPlan]],
-      analyzedInfo: AggregateAnalzyInfo): Int = {
-    planGroups.zipWithIndex.find(
-      planWithIndex =>
-        planWithIndex._1.head.analyzedInfo.isDefined &&
-          areStructureMatchedAggregate(
-            planWithIndex._1.head.analyzedInfo.get,
-            analyzedInfo)) match {
-      case Some((_, i)) => i
-      case None => -1
-    }
+  def addAggregateStep(plan: LogicalPlan, templatePlanAnalyzer: PlanAnalyzer): 
LogicalPlan = {
+    val inputAttributes = plan.output
+    val templateAggregate =
+      templatePlanAnalyzer.constructedAggregatePlan.get.asInstanceOf[Aggregate]
+
+    val totalGroupingExpressionsCount = math.max(
+      templateAggregate.groupingExpressions.length,
+      templatePlanAnalyzer.aggregateResultMatchedGroupingKeysPositions.max + 1)
+
+    // inputAttributes.last is the clause index.
+    val groupingExpressions = inputAttributes
+      .slice(0, totalGroupingExpressionsCount)
+      .map(_.asInstanceOf[Expression]) :+ inputAttributes.last
+
+    var aggregateExpressionCount = totalGroupingExpressionsCount
+    var nonAggregateExpressionCount = 0
+    val aggregateExpressions = ArrayBuffer[NamedExpression]()
+    templateAggregate.aggregateExpressions.foreach {
+      e =>
+        CoalesceUnionUtil.removeAlias(e) match {
+          case aggregateExpression
+              if CoalesceUnionUtil.hasAggregateExpression(aggregateExpression) 
=>
+            val (newAggregateExpression, count) = buildNewAggregateExpression(
+              aggregateExpression,
+              inputAttributes,
+              aggregateExpressionCount)
+
+            aggregateExpressionCount += count
+            aggregateExpressions += CoalesceUnionUtil
+              .makeAlias(newAggregateExpression, e.name)
+              .asInstanceOf[NamedExpression]
+          case nonAggregateExpression =>
+            val position = 
templatePlanAnalyzer.aggregateResultMatchedGroupingKeysPositions(
+              nonAggregateExpressionCount)
+            val attribute = inputAttributes(position)
+            aggregateExpressions += CoalesceUnionUtil
+              .makeAlias(attribute, e.name)
+              .asInstanceOf[NamedExpression]
+            nonAggregateExpressionCount += 1
 
+        }
+    }
+    Aggregate(groupingExpressions.toSeq, aggregateExpressions.toSeq, plan)
   }
 
-  // Union only has two children. It's children may also be Union.
-  def collectAllUnionClauses(union: Union): ArrayBuffer[LogicalPlan] = {
-    val unionClauses = ArrayBuffer[LogicalPlan]()
-    union.children.foreach {
-      case u: Union =>
-        unionClauses ++= collectAllUnionClauses(u)
-      case other =>
-        unionClauses += other
+  def buildNewAggregateExpression(
+      oldExpression: Expression,
+      inputAttributes: Seq[Attribute],
+      attributesOffset: Int): (Expression, Int) = {
+    oldExpression match {
+      case aggregateExpression: AggregateExpression =>
+        val aggregateFunction = aggregateExpression.aggregateFunction
+        val newArguments = aggregateFunction.children.zipWithIndex.map {
+          case (argument, i) => inputAttributes(i + attributesOffset)
+        }
+        val newAggregateFunction =
+          
aggregateFunction.withNewChildren(newArguments).asInstanceOf[AggregateFunction]
+        val newAggregateExpression = AggregateExpression(
+          newAggregateFunction,
+          aggregateExpression.mode,
+          aggregateExpression.isDistinct,
+          aggregateExpression.filter,
+          aggregateExpression.resultId)
+        (newAggregateExpression, 1)
+      case combindedAggregateExpression
+          if 
CoalesceUnionUtil.hasAggregateExpression(combindedAggregateExpression) =>
+        var count = 0
+        val newChildren = ArrayBuffer[Expression]()
+        combindedAggregateExpression.children.foreach {
+          case child =>
+            val (newChild, n) =
+              buildNewAggregateExpression(child, inputAttributes, 
attributesOffset + count)
+            count += n
+            newChildren += newChild
+        }
+        val newExpression = 
combindedAggregateExpression.withNewChildren(newChildren.toSeq)
+        (newExpression, count)
+      case _ => (inputAttributes(attributesOffset), 1)
     }
-    unionClauses
   }
+}
 
-  def groupStructureMatchedAggregate(union: Union): 
ArrayBuffer[ArrayBuffer[AnalyzedPlan]] = {
+/**
+ * Rewrite following query select a,b, 1 as c from t where d = 1 union all 
select a,b, 2 as c from t
+ * where d = 2 into select s.f0 as a, s.f1 as b, s.f2 as c from ( select 
explode(s) as s from (
+ * select array(if(d=1, named_struct('f0', a, 'f1', b, 'f2', 1), null), 
if(d=2, named_struct('f0',
+ * a, 'f1', b, 'f2', 2), null)) as s from t where d = 1 or d = 2 ) ) where s 
is not null
+ */
+class CoalesceProjectionUnion(spark: SparkSession) extends Rule[LogicalPlan] 
with Logging {
 
-    def tryPutToGroup(
-        groupResults: ArrayBuffer[ArrayBuffer[AnalyzedPlan]],
-        agg: Aggregate): Unit = {
-      val analyzedInfo = AggregateAnalzyInfo(agg)
-      if (isSupportedAggregate(analyzedInfo)) {
-        if (groupResults.isEmpty) {
-          groupResults += ArrayBuffer(
-            AnalyzedPlan(analyzedInfo.originalAggregate, Some(analyzedInfo)))
-        } else {
-          val idx = findStructureMatchedAggregate(groupResults, analyzedInfo)
-          if (idx != -1) {
-            groupResults(idx) += AnalyzedPlan(
-              analyzedInfo.constructedAggregatePlan.get,
-              Some(analyzedInfo))
-          } else {
-            groupResults += ArrayBuffer(
-              AnalyzedPlan(analyzedInfo.constructedAggregatePlan.get, 
Some(analyzedInfo)))
+  case class PlanAnalyzer(originalPlan: LogicalPlan) extends 
AbstractPlanAnalyzer {
+    def extractFilter(): Option[Filter] = {
+      originalPlan match {
+        case project @ Project(_, filter: Filter) => Some(filter)
+        case _ => None
+      }
+    }
+
+    lazy val extractedSourcePlan = {
+      extractFilter match {
+        case Some(filter) =>
+          filter.child match {
+            case project: Project =>
+              if (CoalesceUnionUtil.validateSource(project.child)) 
Some(project.child)
+              else None
+              Some(project.child)
+            case subquery @ SubqueryAlias(_, project: Project) =>
+              if (CoalesceUnionUtil.validateSource(project.child)) 
Some(project.child)
+              else None
+            case _ => Some(filter.child)
           }
-        }
-      } else {
-        val rewrittenPlan = visitPlan(agg)
-        groupResults += ArrayBuffer(AnalyzedPlan(rewrittenPlan, None))
+        case None => None
       }
     }
 
-    val groupResults = ArrayBuffer[ArrayBuffer[AnalyzedPlan]]()
-    collectAllUnionClauses(union).foreach {
-      case project @ Project(projectList, agg: Aggregate) =>
-        if (projectList.forall(e => e.isInstanceOf[Alias])) {
-          tryPutToGroup(groupResults, agg)
-        } else {
-          val rewrittenPlan = visitPlan(project)
-          groupResults += ArrayBuffer(AnalyzedPlan(rewrittenPlan, None))
-        }
-      case agg: Aggregate =>
-        tryPutToGroup(groupResults, agg)
-      case other =>
-        val rewrittenPlan = visitPlan(other)
-        groupResults += ArrayBuffer(AnalyzedPlan(rewrittenPlan, None))
-    }
-    groupResults
-  }
-
-  def areStructureMatchedExpressions(l: Expression, r: Expression): Boolean = {
-    if (l.dataType.equals(r.dataType)) {
-      (l, r) match {
-        case (lAttr: Attribute, rAttr: Attribute) =>
-          // The the qualifier may be overwritten by a subquery alias, and 
make this check fail.
-          lAttr.qualifiedName.equals(rAttr.qualifiedName)
-        case (lLiteral: Literal, rLiteral: Literal) =>
-          lLiteral.value == rLiteral.value
-        case (lagg: AggregateExpression, ragg: AggregateExpression) =>
-          lagg.isDistinct == ragg.isDistinct &&
-          areStructureMatchedExpressions(lagg.aggregateFunction, 
ragg.aggregateFunction)
-        case _ =>
-          l.children.length == r.children.length &&
-          l.getClass == r.getClass &&
-          l.children.zip(r.children).forall {
-            case (lChild, rChild) => areStructureMatchedExpressions(lChild, 
rChild)
+    lazy val constructedFilterPlan = {
+      extractedSourcePlan match {
+        case Some(source) =>
+          val filter = extractFilter().get
+          filter.child match {
+            case project: Project =>
+              val replaceMap =
+                CoalesceUnionUtil.buildAttributesMap(
+                  project.output,
+                  project.projectList.map(_.asInstanceOf[Expression]))
+              val newCondition = 
CoalesceUnionUtil.replaceAttributes(filter.condition, replaceMap)
+              Some(Filter(newCondition, source))
+            case subquery @ SubqueryAlias(_, project: Project) =>
+              val replaceMap =
+                CoalesceUnionUtil.buildAttributesMap(
+                  project.output,
+                  project.projectList.map(_.asInstanceOf[Expression]))
+              val newCondition = 
CoalesceUnionUtil.replaceAttributes(filter.condition, replaceMap)
+              Some(Filter(newCondition, source))
+            case _ => Some(filter)
           }
+        case None => None
       }
-    } else {
-      false
     }
-  }
 
-  def areSameAggregateSource(lPlan: LogicalPlan, rPlan: LogicalPlan): Boolean 
= {
-    if (lPlan.children.length != rPlan.children.length || lPlan.getClass != 
rPlan.getClass) {
-      false
-    } else {
-      lPlan.children.zip(rPlan.children).forall {
-        case (lRelation, rRelation) if (isRelation(lRelation) && 
isRelation(rRelation)) =>
-          areSameRelation(lRelation, rRelation)
-        case (lSubQuery: SubqueryAlias, rSubQuery: SubqueryAlias) =>
-          areSameAggregateSource(lSubQuery.child, rSubQuery.child)
-        case (lproject: Project, rproject: Project) =>
-          lproject.projectList.length == rproject.projectList.length &&
-          lproject.projectList.zip(rproject.projectList).forall {
-            case (lExpr, rExpr) => areStructureMatchedExpressions(lExpr, rExpr)
-          } &&
-          areSameAggregateSource(lproject.child, rproject.child)
-        case (lFilter: Filter, rFilter: Filter) =>
-          areStructureMatchedExpressions(lFilter.condition, rFilter.condition) 
&&
-          areSameAggregateSource(lFilter.child, rFilter.child)
-        case (lChild, rChild) => false
+    lazy val constructedProjectPlan = {
+      constructedFilterPlan match {
+        case Some(filter) =>
+          val originalFilter = extractFilter().get
+          originalFilter.child match {
+            case project: Project =>
+              None
+              val replaceMap =
+                CoalesceUnionUtil.buildAttributesMap(
+                  project.output,
+                  project.projectList.map(_.asInstanceOf[Expression]))
+              val originalProject = originalPlan.asInstanceOf[Project]
+              val newProjectList =
+                originalProject.projectList
+                  .map(e => CoalesceUnionUtil.replaceAttributes(e, replaceMap))
+                  .map(_.asInstanceOf[NamedExpression])
+              val newProject = Project(newProjectList, filter)
+              Some(newProject)
+            case subquery @ SubqueryAlias(_, project: Project) =>
+              val replaceMap =
+                CoalesceUnionUtil.buildAttributesMap(
+                  project.output,
+                  project.projectList.map(_.asInstanceOf[Expression]))
+              val originalProject = originalPlan.asInstanceOf[Project]
+              val newProjectList =
+                originalProject.projectList
+                  .map(e => CoalesceUnionUtil.replaceAttributes(e, replaceMap))
+                  .map(_.asInstanceOf[NamedExpression])
+              val newProject = Project(newProjectList, filter)
+              Some(newProject)
+            case _ => Some(originalPlan)
+          }
+        case None => None
       }
     }
+
+    override def doValidate(): Boolean = {
+      constructedProjectPlan.isDefined
+    }
+
+    override def getExtractedSourcePlan: Option[LogicalPlan] = 
extractedSourcePlan
+
+    override def getConstructedFilterPlan: Option[LogicalPlan] = 
constructedFilterPlan
+
+    override def getConstructedProjectPlan: Option[LogicalPlan] = 
constructedProjectPlan
   }
 
-  def buildAggregateCasesConditions(
-      groupedPlans: ArrayBuffer[AnalyzedPlan]): ArrayBuffer[Expression] = {
-    val firstPlanSourceOutputAttrs =
-      groupedPlans.head.analyzedInfo.get.extractedSourcePlan.get.output
-    groupedPlans.map {
-      plan =>
-        val attrsMap =
-          buildAttributesMap(
-            plan.analyzedInfo.get.extractedSourcePlan.get.output,
-            firstPlanSourceOutputAttrs)
-        val filter = 
plan.analyzedInfo.get.constructedFilterPlan.get.asInstanceOf[Filter]
-        replaceAttributes(filter.condition, attrsMap)
+  override def apply(plan: LogicalPlan): LogicalPlan = {
+    if (
+      spark.conf
+        .get(CHBackendSettings.GLUTEN_ENABLE_COALESCE_PROJECT_UNION, "true")
+        .toBoolean && CoalesceUnionUtil.isResolvedPlan(plan)
+    ) {
+      Try {
+        visitPlan(plan)
+      } match {
+        case Success(newPlan) => newPlan
+        case Failure(e) =>
+          logError(s"$e")
+          plan
+      }
+    } else {
+      plan
     }
   }
 
-  def buildUnionConditionForAggregateSource(conditions: 
ArrayBuffer[Expression]): Expression = {
-    conditions.reduce(Or);
+  def visitPlan(plan: LogicalPlan): LogicalPlan = {
+    plan match {
+      case union @ Union(_, false, false) =>
+        val groupedUnionClauses = groupUnionClauses(union)
+        val newUnionClauses =
+          groupedUnionClauses.map(clauses => 
coalesceMatchedUnionClauses(clauses))
+        val newPlan = CoalesceUnionUtil.unionClauses(union, newUnionClauses)
+        newPlan
+      case other =>
+        other.withNewChildren(other.children.map(visitPlan))
+    }
   }
 
-  def wrapAggregatesAttributesInStructs(
-      groupedPlans: ArrayBuffer[AnalyzedPlan]): Seq[NamedExpression] = {
-    val structAttributes = ArrayBuffer[NamedExpression]()
-    val casePrefix = "case_"
-    val structPrefix = "field_"
-    val firstSourceAttrs = 
groupedPlans.head.analyzedInfo.get.extractedSourcePlan.get.output
-    groupedPlans.zipWithIndex.foreach {
-      case (aggregateCase, case_index) =>
-        val analyzedInfo = aggregateCase.analyzedInfo.get
-        val aggregate = 
analyzedInfo.constructedAggregatePlan.get.asInstanceOf[Aggregate]
-        val structFields = ArrayBuffer[Expression]()
-        var fieldIndex: Int = 0
-        val attrReplaceMap = buildAttributesMap(
-          aggregateCase.analyzedInfo.get.extractedSourcePlan.get.output,
-          firstSourceAttrs)
-        aggregate.groupingExpressions.foreach {
-          e =>
-            structFields += 
Literal(UTF8String.fromString(s"$structPrefix$fieldIndex"), StringType)
-            structFields += replaceAttributes(e, attrReplaceMap)
-            fieldIndex += 1
-        }
-        for (i <- 0 until analyzedInfo.positionInGroupingKeys.length) {
-          val position = analyzedInfo.positionInGroupingKeys(i)
-          if (position >= fieldIndex) {
-            val expr = analyzedInfo.resultGroupingExpressions(i)
-            structFields += 
Literal(UTF8String.fromString(s"$structPrefix$fieldIndex"), StringType)
-            structFields += replaceAttributes(
-              analyzedInfo.resultGroupingExpressions(i),
-              attrReplaceMap)
-            fieldIndex += 1
+  def groupUnionClauses(union: Union): Seq[Seq[AnalyzedPlan]] = {
+    val unionClauses = CoalesceUnionUtil.collectAllUnionClauses(union)
+    val groups = ArrayBuffer[ArrayBuffer[AnalyzedPlan]]()
+    unionClauses.foreach {
+      clause =>
+        val planAnalyzer = PlanAnalyzer(clause)
+        if (planAnalyzer.doValidate()) {
+          val matchedGroup = findMatchedGroup(AnalyzedPlan(clause, 
Some(planAnalyzer)), groups)
+          if (matchedGroup != -1) {
+            groups(matchedGroup) += AnalyzedPlan(clause, Some(planAnalyzer))
+          } else {
+            groups += ArrayBuffer(AnalyzedPlan(clause, Some(planAnalyzer)))
           }
+        } else {
+          val newClause = visitPlan(clause)
+          groups += ArrayBuffer(AnalyzedPlan(newClause, None))
         }
-
-        aggregate.aggregateExpressions
-          .filter(e => hasAggregateExpression(e))
-          .foreach {
-            e =>
-              def collectExpressionsInAggregateExpression(aggExpr: 
Expression): Unit = {
-                aggExpr match {
-                  case aggExpr: AggregateExpression =>
-                    val aggFunction =
-                      
removeAlias(aggExpr).asInstanceOf[AggregateExpression].aggregateFunction
-                    aggFunction.children.foreach {
-                      child =>
-                        structFields += Literal(
-                          UTF8String.fromString(s"$structPrefix$fieldIndex"),
-                          StringType)
-                        structFields += replaceAttributes(child, 
attrReplaceMap)
-                        fieldIndex += 1
-                    }
-                  case combineAgg if hasAggregateExpression(combineAgg) =>
-                    combineAgg.children.foreach {
-                      combindAggchild => 
collectExpressionsInAggregateExpression(combindAggchild)
-                    }
-                  case other =>
-                    structFields += Literal(
-                      UTF8String.fromString(s"$structPrefix$fieldIndex"),
-                      StringType)
-                    structFields += replaceAttributes(other, attrReplaceMap)
-                    fieldIndex += 1
-                }
-              }
-              collectExpressionsInAggregateExpression(e)
-          }
-        structFields += 
Literal(UTF8String.fromString(s"$structPrefix$fieldIndex"), StringType)
-        structFields += Literal(case_index, IntegerType)
-        structAttributes += makeAlias(
-          CreateNamedStruct(structFields.toSeq),
-          s"$casePrefix$case_index")
-    }
-    structAttributes.toSeq
-  }
-
-  def buildProjectFoldIntoStruct(
-      child: LogicalPlan,
-      groupedPlans: ArrayBuffer[AnalyzedPlan],
-      conditions: ArrayBuffer[Expression]): LogicalPlan = {
-    val wrappedAttributes = wrapAggregatesAttributesInStructs(groupedPlans)
-    val ifAttributes = wrappedAttributes.zip(conditions).map {
-      case (attr, condition) =>
-        makeAlias(If(condition, attr, Literal(null, attr.dataType)), attr.name)
-          .asInstanceOf[NamedExpression]
     }
-    Project(ifAttributes, child)
+    groups.map(_.toSeq).toSeq
   }
 
-  def buildProjectBranchArray(
-      child: LogicalPlan,
-      conditions: ArrayBuffer[Expression]): LogicalPlan = {
-    assert(
-      child.output.length == conditions.length,
-      s"Expected same length of output and conditions")
-    val array = makeAlias(CreateArray(child.output), "array")
-    Project(Seq(array), child)
-  }
-
-  def buildExplodeBranchArray(child: LogicalPlan): LogicalPlan = {
-    assert(child.output.length == 1, s"Expected single output from $child")
-    val array = child.output.head.asInstanceOf[Expression]
-    assert(array.dataType.isInstanceOf[ArrayType], s"Expected ArrayType from 
$array")
-    val explodeExpr = Explode(array)
-    val exploadOutput =
-      AttributeReference("generate_output", 
array.dataType.asInstanceOf[ArrayType].elementType)()
-    Generate(
-      explodeExpr,
-      unrequiredChildIndex = Seq(0),
-      outer = false,
-      qualifier = None,
-      generatorOutput = Seq(exploadOutput),
-      child)
-  }
+  def findMatchedGroup(
+      analyzedPlan: AnalyzedPlan,
+      groupedClauses: ArrayBuffer[ArrayBuffer[AnalyzedPlan]]): Int = {
+    groupedClauses.zipWithIndex.find {
+      groupWithIndex =>
+        if (groupWithIndex._1.head.planAnalyzer.isDefined) {
+          val checkedPlanAnalyzer = groupWithIndex._1.head.planAnalyzer.get
+          val checkPlanAnalyzer = analyzedPlan.planAnalyzer.get
+          CoalesceUnionUtil.areOutputMatchedProject(
+            checkedPlanAnalyzer.getConstructedProjectPlan.get,
+            checkPlanAnalyzer.getConstructedProjectPlan.get) &&
+          CoalesceUnionUtil.areStrictMatchedRelation(
+            checkedPlanAnalyzer.getExtractedSourcePlan.get,
+            checkPlanAnalyzer.getExtractedSourcePlan.get)
+        } else {
+          false
+        }
 
-  def makeAlias(e: Expression, name: String): NamedExpression = {
-    Alias(e, name)(
-      NamedExpression.newExprId,
-      e match {
-        case ne: NamedExpression => ne.qualifier
-        case _ => Seq.empty
-      },
-      None,
-      Seq.empty)
+    } match {
+      case Some((_, i)) => i
+      case None => -1
+    }
   }
 
-  def buildProjectUnfoldStruct(child: LogicalPlan): LogicalPlan = {
-    assert(child.output.length == 1, s"Expected single output from $child")
-    val structedData = child.output.head
-    assert(
-      structedData.dataType.isInstanceOf[StructType],
-      s"Expected StructType from $structedData")
-    val structType = structedData.dataType.asInstanceOf[StructType]
-    val attributes = ArrayBuffer[NamedExpression]()
-    var index = 0
-    structType.fields.foreach {
-      field =>
-        attributes += Alias(GetStructField(structedData, index), field.name)()
-        index += 1
+  def coalesceMatchedUnionClauses(clauses: Seq[AnalyzedPlan]): LogicalPlan = {
+    if (clauses.length == 1) {
+      clauses.head.plan
+    } else {
+      val normalizedConditions = 
CoalesceUnionUtil.normalizedClausesFilterCondition(clauses)
+      val newSource = clauses.head.planAnalyzer.get.getExtractedSourcePlan.get
+      val newFilter = Filter(normalizedConditions.reduce(Or), newSource)
+      val foldStructStep = addFoldStructStep(newFilter, normalizedConditions, 
clauses)
+      val arrayStep = CoalesceUnionUtil.addArrayStep(foldStructStep)
+      val explodeStep = CoalesceUnionUtil.addExplodeStep(arrayStep)
+      CoalesceUnionUtil.addUnfoldStructStep(explodeStep)
     }
-    Project(attributes.toSeq, child)
   }
 
-  def buildAggregateWithGroupId(
-      child: LogicalPlan,
-      groupedPlans: ArrayBuffer[AnalyzedPlan]): LogicalPlan = {
-    val attributes = child.output
-    val firstAggregateAnalzyInfo = groupedPlans.head.analyzedInfo.get
-    val aggregateTemplate =
-      
firstAggregateAnalzyInfo.constructedAggregatePlan.get.asInstanceOf[Aggregate]
-    val analyzedInfo = groupedPlans.head.analyzedInfo.get
-
-    val totalGroupingExpressionsCount =
-      math.max(
-        aggregateTemplate.groupingExpressions.length,
-        analyzedInfo.positionInGroupingKeys.max + 1)
-
-    val groupingExpressions = attributes
-      .slice(0, totalGroupingExpressionsCount)
-      .map(_.asInstanceOf[Expression]) :+ attributes.last
+  def buildClausesStructs(analyzedPlans: Seq[AnalyzedPlan]): 
Seq[NamedExpression] = {
+    val valueAttributes = 
analyzedPlans.head.planAnalyzer.get.getExtractedSourcePlan.get.output
+    analyzedPlans.zipWithIndex.map {
+      case (analyzedPlan, clauseIndex) =>
+        val planAnalyzer = analyzedPlan.planAnalyzer.get
+        val keyAttributes = planAnalyzer.getExtractedSourcePlan.get.output
+        val replaceMap = CoalesceUnionUtil.buildAttributesMap(keyAttributes, 
valueAttributes)
+        val projectPlan = 
planAnalyzer.getConstructedProjectPlan.get.asInstanceOf[Project]
+        val newProjectList = projectPlan.projectList.map {
+          e => CoalesceUnionUtil.replaceAttributes(e, replaceMap)
+        }
 
-    val normalExpressionPosition = analyzedInfo.positionInGroupingKeys
-    var normalExpressionCount = 0
-    var aggregateExpressionIndex = totalGroupingExpressionsCount
-    val aggregateExpressions = ArrayBuffer[NamedExpression]()
-    aggregateTemplate.aggregateExpressions.foreach {
-      e =>
-        removeAlias(e) match {
-          case aggExpr if hasAggregateExpression(aggExpr) =>
-            val (newAggExpr, count) =
-              constructAggregateExpression(aggExpr, attributes, 
aggregateExpressionIndex)
-            aggregateExpressions += makeAlias(newAggExpr, 
e.name).asInstanceOf[NamedExpression]
-            aggregateExpressionIndex += count
-          case other =>
-            val position = normalExpressionPosition(normalExpressionCount)
-            val attr = attributes(position)
-            normalExpressionCount += 1
-            aggregateExpressions += makeAlias(attr, e.name)
-              .asInstanceOf[NamedExpression]
+        val structFields = ArrayBuffer[Expression]()
+        newProjectList.zipWithIndex.foreach {
+          case (e, fieldIndex) =>
+            structFields += Literal(UTF8String.fromString(s"f$fieldIndex"), 
StringType)
+            structFields += e
         }
+        CoalesceUnionUtil.makeAlias(CreateNamedStruct(structFields.toSeq), 
s"clause_$clauseIndex")
     }
-    Aggregate(groupingExpressions.toSeq, aggregateExpressions.toSeq, child)
   }
 
-  def constructAggregateExpression(
-      aggExpr: Expression,
-      attributes: Seq[Attribute],
-      index: Int): (Expression, Int) = {
-    aggExpr match {
-      case singleAggExpr: AggregateExpression =>
-        val aggFunc = singleAggExpr.aggregateFunction
-        val newAggFuncArgs = aggFunc.children.zipWithIndex.map {
-          case (arg, i) =>
-            attributes(index + i)
-        }
-        val newAggFunc =
-          
aggFunc.withNewChildren(newAggFuncArgs).asInstanceOf[AggregateFunction]
-        val res = AggregateExpression(
-          newAggFunc,
-          singleAggExpr.mode,
-          singleAggExpr.isDistinct,
-          singleAggExpr.filter,
-          singleAggExpr.resultId)
-        (res, 1)
-      case combineAggExpr if hasAggregateExpression(combineAggExpr) =>
-        val childrenExpressions = ArrayBuffer[Expression]()
-        var totalCount = 0
-        combineAggExpr.children.foreach {
-          child =>
-            val (expr, count) = constructAggregateExpression(child, 
attributes, totalCount + index)
-            childrenExpressions += expr
-            totalCount += count
-        }
-        (combineAggExpr.withNewChildren(childrenExpressions.toSeq), totalCount)
-      case _ => (attributes(index), 1)
+  def addFoldStructStep(
+      plan: LogicalPlan,
+      conditions: Seq[Expression],
+      analyzedPlans: Seq[AnalyzedPlan]): LogicalPlan = {
+    val structAttributes = buildClausesStructs(analyzedPlans)
+    assert(structAttributes.length == conditions.length)
+    val structAttributesWithCondition = structAttributes.zip(conditions).map {
+      case (struct, condition) =>
+        CoalesceUnionUtil
+          .makeAlias(If(condition, struct, Literal(null, struct.dataType)), 
struct.name)
+          .asInstanceOf[NamedExpression]
     }
+    Project(structAttributesWithCondition, plan)
   }
 }
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenCoalesceAggregationUnionSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenCoalesceAggregationUnionSuite.scala
index 23c6022727..8479d9f41e 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenCoalesceAggregationUnionSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenCoalesceAggregationUnionSuite.scala
@@ -47,6 +47,8 @@ class GlutenCoalesceAggregationUnionSuite extends 
GlutenClickHouseWholeStageTran
       .set("spark.io.compression.codec", "snappy")
       .set("spark.sql.shuffle.partitions", "5")
       .set("spark.sql.autoBroadcastJoinThreshold", "10MB")
+      
.set("spark.gluten.sql.columnar.backend.ch.enable.coalesce.project.union", 
"true")
+      
.set("spark.gluten.sql.columnar.backend.ch.enable.coalesce.aggregation.union", 
"true")
   }
 
   def createTestTable(tableName: String, data: DataFrame): Unit = {
@@ -392,4 +394,105 @@ class GlutenCoalesceAggregationUnionSuite extends 
GlutenClickHouseWholeStageTran
     compareResultsAgainstVanillaSpark(sql, true, checkHasUnion, true)
   }
 
+  test("coalesce project union. case 1") {
+
+    val sql =
+      """
+        |select a, x, y from (
+        | select a, x, y from coalesce_union_t1 where b % 2 = 0
+        | union all
+        | select a, x, y from coalesce_union_t1 where b % 3 = 1
+        |) order by a, x, y
+        |""".stripMargin
+    compareResultsAgainstVanillaSpark(sql, true, checkNoUnion, true)
+  }
+
+  test("coalesce project union. case 2") {
+    val sql =
+      """
+        |select a, x, y from (
+        | select concat(a, 'x') as a , x, y from coalesce_union_t1 where b % 2 
= 0
+        | union all
+        | select a, x, y + 2 as y from coalesce_union_t1 where b % 3 = 1
+        |) order by a, x, y
+        |""".stripMargin
+    compareResultsAgainstVanillaSpark(sql, true, checkNoUnion, true)
+  }
+
+  test("coalesce project union. case 3") {
+    val sql =
+      """
+        |select a, x, y from (
+        | select concat(a, 'x') as a , x, y, 1 as t from coalesce_union_t1 
where b % 2 = 0
+        | union all
+        | select a, x, y + 2 as y, 2 as t from coalesce_union_t1 where b % 3 = 
1
+        |) order by a, x, y, t
+        |""".stripMargin
+    compareResultsAgainstVanillaSpark(sql, true, checkNoUnion, true)
+  }
+
+  test("coalesce project union. case 4") {
+    val sql =
+      """
+        |select a, x, y from (
+        | select concat(a, 'x') as a , x, 1 as y from coalesce_union_t1 where 
b % 2 = 0
+        | union all
+        | select a, x, y + 2 as y from coalesce_union_t1 where b % 3 = 1
+        |) order by a, x, y
+        |""".stripMargin
+    compareResultsAgainstVanillaSpark(sql, true, checkNoUnion, true)
+  }
+
+  test("coalesce project union. case 5") {
+    val sql =
+      """
+        |select a, x, y from (
+        | select a, x, y from (select a, x, y, b + 4 as b from 
coalesce_union_t1) where b % 2 = 0
+        | union all
+        | select a, x, y + 2 as y from coalesce_union_t1 where b % 3 = 1
+        |) order by a, x, y
+        |""".stripMargin
+    compareResultsAgainstVanillaSpark(sql, true, checkNoUnion, true)
+  }
+
+  test("coalesce project union. case 6") {
+
+    val sql =
+      """
+        |select a, x, y, t from (
+        | select a, x, y, 1 as t from coalesce_union_t1 where b % 2 = 0
+        | union all
+        | select a, x, y, 2 as t from coalesce_union_t1 where b % 3 = 1
+        | union all
+        | select a, x, y, 3 as t from coalesce_union_t1 where b % 4 = 1
+        | union all
+        | select a, x, y, 4 as t from coalesce_union_t1 where b % 5 = 1
+        |) order by a, x, y, t
+        |""".stripMargin
+    compareResultsAgainstVanillaSpark(sql, true, checkNoUnion, true)
+  }
+
+  test("no coalesce project union. case 1") {
+    val sql =
+      """
+        |select a, x, y from (
+        | select a, x, y from coalesce_union_t1
+        | union all
+        | select a, x, y from coalesce_union_t1 where b % 3 = 1
+        |) order by a, x, y
+        |""".stripMargin
+    compareResultsAgainstVanillaSpark(sql, true, checkHasUnion, true)
+  }
+
+  test("no coalesce project union. case 2") {
+    val sql =
+      """
+        |select a, x, y from (
+        | select a , x, y from coalesce_union_t2 where b % 2 = 0
+        | union all
+        | select a, x, y from coalesce_union_t1 where b % 3 = 1
+        |) order by a, x, y
+        |""".stripMargin
+    compareResultsAgainstVanillaSpark(sql, true, checkHasUnion, true)
+  }
 }
diff --git 
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenCTEInlineSuite.scala
 
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenCTEInlineSuite.scala
index f5bdb254b2..3a05eda711 100644
--- 
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenCTEInlineSuite.scala
+++ 
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenCTEInlineSuite.scala
@@ -16,6 +16,18 @@
  */
 package org.apache.spark.sql
 
-class GlutenCTEInlineSuiteAEOff extends CTEInlineSuiteAEOff with 
GlutenSQLTestsTrait
+import org.apache.spark.SparkConf
 
-class GlutenCTEInlineSuiteAEOn extends CTEInlineSuiteAEOn with 
GlutenSQLTestsTrait
+class GlutenCTEInlineSuiteAEOff extends CTEInlineSuiteAEOff with 
GlutenSQLTestsTrait {
+  override def sparkConf: SparkConf =
+    super.sparkConf
+      
.set("spark.gluten.sql.columnar.backend.ch.enable.coalesce.project.union", 
"false")
+
+}
+
+class GlutenCTEInlineSuiteAEOn extends CTEInlineSuiteAEOn with 
GlutenSQLTestsTrait {
+  override def sparkConf: SparkConf =
+    super.sparkConf
+      
.set("spark.gluten.sql.columnar.backend.ch.enable.coalesce.project.union", 
"false")
+
+}
diff --git 
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenDataFrameSetOperationsSuite.scala
 
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenDataFrameSetOperationsSuite.scala
index d51d1034b0..fe7958b677 100644
--- 
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenDataFrameSetOperationsSuite.scala
+++ 
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenDataFrameSetOperationsSuite.scala
@@ -16,6 +16,14 @@
  */
 package org.apache.spark.sql
 
+import org.apache.spark.SparkConf
+
 class GlutenDataFrameSetOperationsSuite
   extends DataFrameSetOperationsSuite
-  with GlutenSQLTestsTrait {}
+  with GlutenSQLTestsTrait {
+  override def sparkConf: SparkConf =
+    super.sparkConf
+      
.set("spark.gluten.sql.columnar.backend.ch.enable.coalesce.project.union", 
"false")
+      
.set("spark.gluten.sql.columnar.backend.ch.enable.coalesce.aggregation.union", 
"false")
+
+}
diff --git 
a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
 
b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
index 3c1429a34f..75e97b48a9 100644
--- 
a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
+++ 
b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
@@ -622,6 +622,8 @@ class ClickHouseTestSettings extends BackendTestSettings {
     .exclude("ordering and partitioning reporting")
   enableSuite[GlutenDatasetAggregatorSuite]
   enableSuite[GlutenDatasetCacheSuite]
+    // Disable this since coalesece union clauses rule will rewrite the query.
+    .exclude("SPARK-44653: non-trivial DataFrame unions should not break 
caching")
   enableSuite[GlutenDatasetOptimizationSuite]
   enableSuite[GlutenDatasetPrimitiveSuite]
   enableSuite[GlutenDatasetSerializerRegistratorSuite]
diff --git 
a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenCTEInlineSuite.scala
 
b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenCTEInlineSuite.scala
index f5bdb254b2..3a05eda711 100644
--- 
a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenCTEInlineSuite.scala
+++ 
b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenCTEInlineSuite.scala
@@ -16,6 +16,18 @@
  */
 package org.apache.spark.sql
 
-class GlutenCTEInlineSuiteAEOff extends CTEInlineSuiteAEOff with 
GlutenSQLTestsTrait
+import org.apache.spark.SparkConf
 
-class GlutenCTEInlineSuiteAEOn extends CTEInlineSuiteAEOn with 
GlutenSQLTestsTrait
+class GlutenCTEInlineSuiteAEOff extends CTEInlineSuiteAEOff with 
GlutenSQLTestsTrait {
+  override def sparkConf: SparkConf =
+    super.sparkConf
+      
.set("spark.gluten.sql.columnar.backend.ch.enable.coalesce.project.union", 
"false")
+
+}
+
+class GlutenCTEInlineSuiteAEOn extends CTEInlineSuiteAEOn with 
GlutenSQLTestsTrait {
+  override def sparkConf: SparkConf =
+    super.sparkConf
+      
.set("spark.gluten.sql.columnar.backend.ch.enable.coalesce.project.union", 
"false")
+
+}
diff --git 
a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenDataFrameSetOperationsSuite.scala
 
b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenDataFrameSetOperationsSuite.scala
index d51d1034b0..fe7958b677 100644
--- 
a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenDataFrameSetOperationsSuite.scala
+++ 
b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenDataFrameSetOperationsSuite.scala
@@ -16,6 +16,14 @@
  */
 package org.apache.spark.sql
 
+import org.apache.spark.SparkConf
+
 class GlutenDataFrameSetOperationsSuite
   extends DataFrameSetOperationsSuite
-  with GlutenSQLTestsTrait {}
+  with GlutenSQLTestsTrait {
+  override def sparkConf: SparkConf =
+    super.sparkConf
+      
.set("spark.gluten.sql.columnar.backend.ch.enable.coalesce.project.union", 
"false")
+      
.set("spark.gluten.sql.columnar.backend.ch.enable.coalesce.aggregation.union", 
"false")
+
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to