This is an automated email from the ASF dual-hosted git repository.
loneylee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 9c79bd664e [GLUTEN-9646][CH] Fix coalesce project union when has
subquery (#9662)
9c79bd664e is described below
commit 9c79bd664e7c7590edebe3dade94da05ccb545cf
Author: Shuai li <[email protected]>
AuthorDate: Fri May 16 14:31:11 2025 +0800
[GLUTEN-9646][CH] Fix coalesce project union when has subquery (#9662)
* [GLUTEN-9646][CH] Fix coalesce project union when has subquery
* fix ci
---
.../extension/CoalesceAggregationUnion.scala | 57 +++++++++++++---------
.../GlutenCoalesceAggregationUnionSuite.scala | 55 +++++++++++++++++++++
2 files changed, 89 insertions(+), 23 deletions(-)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala
index 331dfdc2db..d5faa48d07 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala
@@ -17,7 +17,6 @@
package org.apache.gluten.extension
import org.apache.gluten.backendsapi.clickhouse.CHBackendSettings
-import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
@@ -107,9 +106,9 @@ object CoalesceUnionUtil extends Logging {
def buildAttributesMap(
attributes: Seq[Attribute],
- expressions: Seq[Expression]): Map[ExprId, Expression] = {
+ expressions: Seq[NamedExpression]): Map[ExprId, NamedExpression] = {
assert(attributes.length == expressions.length)
- val map = new mutable.HashMap[ExprId, Expression]()
+ val map = new mutable.HashMap[ExprId, NamedExpression]()
attributes.zip(expressions).foreach {
case (attr, expr) =>
map.put(attr.exprId, expr)
@@ -117,19 +116,38 @@ object CoalesceUnionUtil extends Logging {
map.toMap
}
- def replaceAttributes(e: Expression, replaceMap: Map[ExprId, Expression]):
Expression = {
+ def replaceAttributes(e: Expression, replaceMap: Map[ExprId,
NamedExpression]): Expression = {
e match {
case attr: Attribute =>
+ replaceMap.getOrElse(attr.exprId, attr)
+ case attr: OuterReference =>
replaceMap.get(attr.exprId) match {
- case Some(replaceAttr) => replaceAttr
- case None =>
- throw new GlutenNotSupportException(s"Not found attribute: $attr
${attr.qualifiedName}")
+ case Some(replaceAttr) =>
+ OuterReference.apply(replaceAttr)
+ case _ => attr
}
+ case subquery: ScalarSubquery =>
+ val plan = replaceSubqueryAttributes(subquery.plan, replaceMap)
+ val outerAttrs = subquery.outerAttrs.map(replaceAttributes(_,
replaceMap))
+ subquery.copy(plan, outerAttrs)
case _ =>
e.withNewChildren(e.children.map(replaceAttributes(_, replaceMap)))
}
}
+ def replaceSubqueryAttributes(
+ plan: LogicalPlan,
+ replaceMap: Map[ExprId, NamedExpression]): LogicalPlan = {
+ plan match {
+ case filter: Filter =>
+ filter.copy(
+ replaceAttributes(filter.condition, replaceMap),
+ replaceSubqueryAttributes(filter.child, replaceMap))
+ case _ =>
+ plan.withNewChildren(plan.children.map(replaceSubqueryAttributes(_,
replaceMap)))
+ }
+ }
+
// Plans and expressions are the same.
def areStrictMatchedRelation(leftRelation: LogicalPlan, rightRelation:
LogicalPlan): Boolean = {
(leftRelation, rightRelation) match {
@@ -396,9 +414,8 @@ case class CoalesceAggregationUnion(spark: SparkSession)
extends Rule[LogicalPla
val newFilter = innerProject match {
case Some(project) =>
- val replaceMap = CoalesceUnionUtil.buildAttributesMap(
- project.output,
- project.projectList.map(_.asInstanceOf[Expression]))
+ val replaceMap =
+ CoalesceUnionUtil.buildAttributesMap(project.output,
project.projectList)
val newCondition =
CoalesceUnionUtil.replaceAttributes(filter.condition, replaceMap)
Filter(newCondition, sourcePlan)
case None => filter.withNewChildren(Seq(sourcePlan))
@@ -422,9 +439,8 @@ case class CoalesceAggregationUnion(spark: SparkSession)
extends Rule[LogicalPla
val newAggregate = innerProject match {
case Some(project) =>
- val replaceMap = CoalesceUnionUtil.buildAttributesMap(
- project.output,
- project.projectList.map(_.asInstanceOf[Expression]))
+ val replaceMap =
+ CoalesceUnionUtil.buildAttributesMap(project.output,
project.projectList)
val newGroupExpressions =
originalAggregate.groupingExpressions.map {
e => CoalesceUnionUtil.replaceAttributes(e, replaceMap)
}
@@ -688,7 +704,7 @@ case class CoalesceAggregationUnion(spark: SparkSession)
extends Rule[LogicalPla
def collectClauseStructFields(
analyzedPlan: AnalyzedPlan,
clauseIndex: Int,
- attributeReplaceMap: Map[ExprId, Expression]): Seq[Expression] = {
+ attributeReplaceMap: Map[ExprId, NamedExpression]): Seq[Expression] = {
val planAnalyzer = analyzedPlan.planAnalyzer.get.asInstanceOf[PlanAnalyzer]
val aggregate =
planAnalyzer.constructedAggregatePlan.get.asInstanceOf[Aggregate]
@@ -876,14 +892,14 @@ case class CoalesceProjectionUnion(spark: SparkSession)
extends Rule[LogicalPlan
val replaceMap =
CoalesceUnionUtil.buildAttributesMap(
project.output,
- project.projectList.map(_.asInstanceOf[Expression]))
+ project.projectList.map(_.asInstanceOf[NamedExpression]))
val newCondition =
CoalesceUnionUtil.replaceAttributes(filter.condition, replaceMap)
Some(Filter(newCondition, source))
case subquery @ SubqueryAlias(_, project: Project) =>
val replaceMap =
CoalesceUnionUtil.buildAttributesMap(
project.output,
- project.projectList.map(_.asInstanceOf[Expression]))
+ project.projectList.map(_.asInstanceOf[NamedExpression]))
val newCondition =
CoalesceUnionUtil.replaceAttributes(filter.condition, replaceMap)
Some(Filter(newCondition, source))
case _ => Some(filter)
@@ -898,11 +914,8 @@ case class CoalesceProjectionUnion(spark: SparkSession)
extends Rule[LogicalPlan
val originalFilter = extractFilter().get
originalFilter.child match {
case project: Project =>
- None
val replaceMap =
- CoalesceUnionUtil.buildAttributesMap(
- project.output,
- project.projectList.map(_.asInstanceOf[Expression]))
+ CoalesceUnionUtil.buildAttributesMap(project.output,
project.projectList)
val originalProject = originalPlan.asInstanceOf[Project]
val newProjectList =
originalProject.projectList
@@ -912,9 +925,7 @@ case class CoalesceProjectionUnion(spark: SparkSession)
extends Rule[LogicalPlan
Some(newProject)
case subquery @ SubqueryAlias(_, project: Project) =>
val replaceMap =
- CoalesceUnionUtil.buildAttributesMap(
- project.output,
- project.projectList.map(_.asInstanceOf[Expression]))
+ CoalesceUnionUtil.buildAttributesMap(project.output,
project.projectList)
val originalProject = originalPlan.asInstanceOf[Project]
val newProjectList =
originalProject.projectList
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenCoalesceAggregationUnionSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenCoalesceAggregationUnionSuite.scala
index e5374bebb4..6e73d12230 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenCoalesceAggregationUnionSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenCoalesceAggregationUnionSuite.scala
@@ -515,4 +515,59 @@ class GlutenCoalesceAggregationUnionSuite extends
GlutenClickHouseWholeStageTran
|""".stripMargin
compareResultsAgainstVanillaSpark(sql, true, checkHasUnion, true)
}
+
+ test("GLUTEN-9646: fix coalesce project union when has subquery") {
+ val schema_fact = StructType(
+ Array(
+ StructField("a", IntegerType, nullable = true),
+ StructField("b", IntegerType, nullable = true)
+ ))
+
+ val schema_order = StructType(
+ Array(
+ StructField("c", IntegerType, nullable = true),
+ StructField("b", IntegerType, nullable = true)
+ ))
+
+ val data_fact = sparkContext.parallelize(
+ Seq(
+ Row(2, 1),
+ Row(3, 2),
+ Row(4, 3),
+ Row(5, 4)
+ ))
+
+ val data_order = sparkContext.parallelize(
+ Seq(
+ Row(1, 1),
+ Row(2, 2),
+ Row(3, 3),
+ Row(4, 4)
+ ))
+
+ val dataFrame1 = spark.createDataFrame(data_fact, schema_fact)
+ val dataFrame2 = spark.createDataFrame(data_order, schema_order)
+ createTestTable("fact", dataFrame1)
+ createTestTable("order", dataFrame2)
+
+ val sql =
+ """
+ |SELECT a
+ |FROM fact
+ |WHERE a =
+ | (SELECT sum(c) + 2
+ | FROM order
+ | WHERE order.b = fact.b
+ | GROUP BY order.b)
+ |UNION ALL
+ |SELECT a
+ |FROM fact
+ |WHERE a =
+ | (SELECT sum(c) + 1
+ | FROM order
+ | WHERE order.b = fact.b
+ | GROUP BY order.b)
+ |""".stripMargin
+ compareResultsAgainstVanillaSpark(sql, compareResult = true, checkNoUnion)
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]