This is an automated email from the ASF dual-hosted git repository.
zhangzc pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new a88d5d5b5f [GLUTEN-8784][CH] Coalesce union of multiple scan-projects
(#8785)
a88d5d5b5f is described below
commit a88d5d5b5fdf7bdad7c1f5b8ea42d41fd1fbbf1a
Author: lgbo <[email protected]>
AuthorDate: Wed Feb 26 16:39:13 2025 +0800
[GLUTEN-8784][CH] Coalesce union of multiple scan-projects (#8785)
[CH] Coalesce union of multiple scan-projects
---
.../gluten/backendsapi/clickhouse/CHBackend.scala | 2 +
.../gluten/backendsapi/clickhouse/CHRuleApi.scala | 1 +
.../extension/CoalesceAggregationUnion.scala | 1428 ++++++++++++--------
.../GlutenCoalesceAggregationUnionSuite.scala | 103 ++
.../apache/spark/sql/GlutenCTEInlineSuite.scala | 16 +-
.../sql/GlutenDataFrameSetOperationsSuite.scala | 10 +-
.../utils/clickhouse/ClickHouseTestSettings.scala | 2 +
.../apache/spark/sql/GlutenCTEInlineSuite.scala | 16 +-
.../sql/GlutenDataFrameSetOperationsSuite.scala | 10 +-
9 files changed, 986 insertions(+), 602 deletions(-)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
index 73e8bce3fe..c0380452de 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
@@ -158,6 +158,8 @@ object CHBackendSettings extends BackendSettingsApi with
Logging {
val GLUTEN_ENABLE_COALESCE_AGGREGATION_UNION: String =
CHConfig.prefixOf("enable.coalesce.aggregation.union")
+ val GLUTEN_ENABLE_COALESCE_PROJECT_UNION: String =
+ CHConfig.prefixOf("enable.coalesce.project.union")
def affinityMode: String = {
SparkEnv.get.conf
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
index ecd7e5a241..4573563cc9 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
@@ -61,6 +61,7 @@ object CHRuleApi {
injector.injectParser(
(spark, parserInterface) => new GlutenClickhouseSqlParser(spark,
parserInterface))
injector.injectResolutionRule(spark => new CoalesceAggregationUnion(spark))
+ injector.injectResolutionRule(spark => new CoalesceProjectionUnion(spark))
injector.injectResolutionRule(spark => new
RewriteToDateExpresstionRule(spark))
injector.injectResolutionRule(spark => new
RewriteDateTimestampComparisonRule(spark))
injector.injectResolutionRule(spark => new
CollapseGetJsonObjectExpressionRule(spark))
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala
index 4a83830e51..cfae1deeaa 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CoalesceAggregationUnion.scala
@@ -35,63 +35,73 @@ import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.util.{Failure, Success, Try}
-/*
- * Example:
- * Rewrite query
- * SELECT a, b, sum(c) FROM t WHERE d = 1 GROUP BY a,b
- * UNION ALL
- * SELECT a, b, sum(c) FROM t WHERE d = 2 GROUP BY a,b
- * into
- * SELECT a, b, sum(c) FROM (
- * SELECT s.a as a, s.b as b, s.c as c, s.id as group_id FROM (
- * SELECT explode(s) as s FROM (
- * SELECT array(
- * if(d = 1, named_struct('a', a, 'b', b, 'c', c, 'id', 0), null),
- * if(d = 2, named_struct('a', a, 'b', b, 'c', c, 'id', 1), null)) as s
- * FROM t WHERE d = 1 OR d = 2
- * )
- * ) WHERE s is not null
- * ) GROUP BY a,b, group_id
+/** Abstract class representing a plan analyzer. */
+abstract class AbstractPlanAnalyzer() {
+
+ /** Extract the common source subplan from the plan. */
+ def getExtractedSourcePlan(): Option[LogicalPlan] = None
+
+ /**
+ * Construct a plan with filter. The filter condition may be rewritten,
different from the
+ * original filter condition.
+ */
+ def getConstructedFilterPlan(): Option[LogicalPlan] = None
+
+ /**
+ * Construct a plan with aggregate . The aggregate expressions may be
rewritten, different from
+ * the original aggregate.
+ */
+ def getConstructedAggregatePlan(): Option[LogicalPlan] = None
+
+ /** Construct a plan with project. */
+ def getConstructedProjectPlan: Option[LogicalPlan] = None
+
+ /** If the rule cannot be applied, return false, otherwise return true. */
+ def doValidate(): Boolean = false
+}
+
+/**
+ * Case class representing an analyzed logical plan.
*
- * The first query need to scan `t` multiply, when the output of scan is
large, the query is
- * really slow. The rewritten query only scan `t` once, and the performance is
much better.
+ * @param plan
+ * The original plan which is analyzed.
+ * @param planAnalyzer
+ * An optional plan analyzer that provides additional analysis capabilities.
When the rule cannot
+ * apply to the plan, the planAnalyzer is None.
*/
+case class AnalyzedPlan(plan: LogicalPlan, planAnalyzer:
Option[AbstractPlanAnalyzer])
-class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan]
with Logging {
- def removeAlias(e: Expression): Expression = {
- e match {
- case alias: Alias => alias.child
- case _ => e
+object CoalesceUnionUtil extends Logging {
+
+ def isResolvedPlan(plan: LogicalPlan): Boolean = {
+ plan match {
+ case isnert: InsertIntoStatement => isnert.query.resolved
+ case _ => plan.resolved
}
}
- def hasAggregateExpression(e: Expression): Boolean = {
- if (e.children.isEmpty && !e.isInstanceOf[AggregateExpression]) {
- return false
- }
- e match {
- case _: AggregateExpression => true
- case _ => e.children.exists(hasAggregateExpression(_))
+ // unfold nested unions
+ def collectAllUnionClauses(union: Union): ArrayBuffer[LogicalPlan] = {
+ val unionClauses = ArrayBuffer[LogicalPlan]()
+ union.children.foreach {
+ case u: Union =>
+ unionClauses ++= collectAllUnionClauses(u)
+ case other =>
+ unionClauses += other
}
+ unionClauses
}
- def isAggregateExpression(e: Expression): Boolean = {
- e match {
- case cast: Cast => isAggregateExpression(cast.child)
- case alias: Alias => isAggregateExpression(alias.child)
- case agg: AggregateExpression => true
- case _ => false
- }
+ def isRelation(plan: LogicalPlan): Boolean = {
+ plan.isInstanceOf[MultiInstanceRelation]
}
- def hasAggregateExpressionsWithFilter(e: Expression): Boolean = {
- if (e.children.isEmpty && !e.isInstanceOf[AggregateExpression]) {
- return false
- }
- e match {
- case aggExpr: AggregateExpression =>
- aggExpr.filter.isDefined
- case _ => e.children.exists(hasAggregateExpressionsWithFilter(_))
+ def validateSource(plan: LogicalPlan): Boolean = {
+ plan match {
+ case relation if isRelation(relation) => true
+ case _: Project | _: Filter | _: SubqueryAlias =>
+ plan.children.forall(validateSource)
+ case _ => false
}
}
@@ -120,7 +130,216 @@ class CoalesceAggregationUnion(spark: SparkSession)
extends Rule[LogicalPlan] wi
}
}
- case class AggregateAnalzyInfo(originalAggregate: Aggregate) {
+ // Plans and expressions are the same.
+ def areStrictMatchedRelation(leftRelation: LogicalPlan, rightRelation:
LogicalPlan): Boolean = {
+ (leftRelation, rightRelation) match {
+ case (leftLogicalRelation: LogicalRelation, rightLogicalRelation:
LogicalRelation) =>
+ val leftTable =
+
leftLogicalRelation.catalogTable.map(_.identifier.unquotedString).getOrElse("")
+ val rightTable =
+
rightLogicalRelation.catalogTable.map(_.identifier.unquotedString).getOrElse("")
+ leftLogicalRelation.output.length ==
rightLogicalRelation.output.length &&
+ leftLogicalRelation.output.zip(rightLogicalRelation.output).forall {
+ case (leftAttr, rightAttr) =>
+ leftAttr.dataType.equals(rightAttr.dataType) &&
leftAttr.name.equals(rightAttr.name)
+ } &&
+ leftTable.equals(rightTable) && leftTable.nonEmpty
+ case (leftCTE: CTERelationRef, rightCTE: CTERelationRef) =>
+ leftCTE.cteId == rightCTE.cteId
+ case (leftHiveTable: HiveTableRelation, rightHiveTable:
HiveTableRelation) =>
+ leftHiveTable.tableMeta.identifier.unquotedString
+ .equals(rightHiveTable.tableMeta.identifier.unquotedString)
+ case (leftSubquery: SubqueryAlias, rightSubquery: SubqueryAlias) =>
+ areStrictMatchedRelation(leftSubquery.child, rightSubquery.child)
+ case (leftProject: Project, rightProject: Project) =>
+ leftProject.projectList.length == rightProject.projectList.length &&
+ leftProject.projectList.zip(rightProject.projectList).forall {
+ case (leftExpr, rightExpr) =>
+ areMatchedExpression(leftExpr, rightExpr)
+ } &&
+ areStrictMatchedRelation(leftProject.child, rightProject.child)
+ case (leftFilter: Filter, rightFilter: Filter) =>
+ areMatchedExpression(leftFilter.condition, rightFilter.condition) &&
+ areStrictMatchedRelation(leftFilter.child, rightFilter.child)
+ case (_, _) => false
+ }
+ }
+
+ // Two projects have the same output schema. Don't need the projectLists are
the same.
+ def areOutputMatchedProject(leftPlan: LogicalPlan, rightPlan: LogicalPlan):
Boolean = {
+ (leftPlan, rightPlan) match {
+ case (leftProject: Project, rightProject: Project) =>
+ val leftOutput = leftProject.output
+ val rightOutput = rightProject.output
+ leftOutput.length == rightOutput.length &&
+ leftOutput.zip(rightOutput).forall {
+ case (leftAttr, rightAttr) =>
+ leftAttr.dataType.equals(rightAttr.dataType) &&
leftAttr.name.equals(rightAttr.name)
+ }
+ case (_, _) =>
+ false
+ }
+ }
+
+ def areMatchedExpression(leftExpression: Expression, rightExpression:
Expression): Boolean = {
+ (leftExpression, rightExpression) match {
+ case (leftLiteral: Literal, rightLiteral: Literal) =>
+ leftLiteral.dataType.equals(rightLiteral.dataType) &&
+ leftLiteral.value == rightLiteral.value
+ case (leftAttr: Attribute, rightAttr: Attribute) =>
+ leftAttr.dataType.equals(rightAttr.dataType) &&
leftAttr.name.equals(rightAttr.name)
+ case (leftAgg: AggregateExpression, rightAgg: AggregateExpression) =>
+ leftAgg.isDistinct == rightAgg.isDistinct &&
+ areMatchedExpression(leftAgg.aggregateFunction,
rightAgg.aggregateFunction)
+ case (_, _) =>
+ leftExpression.getClass == rightExpression.getClass &&
+ leftExpression.children.length == rightExpression.children.length &&
+ leftExpression.children.zip(rightExpression.children).forall {
+ case (leftChild, rightChild) => areMatchedExpression(leftChild,
rightChild)
+ }
+ }
+ }
+
+ // Normalize all the filter conditions to make them fit with the first
plan's source
+ def normalizedClausesFilterCondition(analyzedPlans: Seq[AnalyzedPlan]):
Seq[Expression] = {
+ val valueAttributes =
analyzedPlans.head.planAnalyzer.get.getExtractedSourcePlan.get.output
+ analyzedPlans.map {
+ analyzedPlan =>
+ val keyAttributes =
analyzedPlan.planAnalyzer.get.getExtractedSourcePlan.get.output
+ val replaceMap = buildAttributesMap(keyAttributes, valueAttributes)
+ val filter =
analyzedPlan.planAnalyzer.get.getConstructedFilterPlan.get.asInstanceOf[Filter]
+ CoalesceUnionUtil.replaceAttributes(filter.condition, replaceMap)
+ }
+ }
+
+ def makeAlias(e: Expression, name: String): NamedExpression = {
+ Alias(e, name)(
+ NamedExpression.newExprId,
+ e match {
+ case ne: NamedExpression => ne.qualifier
+ case _ => Seq.empty
+ },
+ None,
+ Seq.empty)
+ }
+
+ def removeAlias(e: Expression): Expression = {
+ e match {
+ case alias: Alias => alias.child
+ case _ => e
+ }
+ }
+
+ def isAggregateExpression(e: Expression): Boolean = {
+ e match {
+ case cast: Cast => isAggregateExpression(cast.child)
+ case alias: Alias => isAggregateExpression(alias.child)
+ case agg: AggregateExpression => true
+ case _ => false
+ }
+ }
+
+ def hasAggregateExpression(e: Expression): Boolean = {
+ if (e.children.isEmpty && !e.isInstanceOf[AggregateExpression]) {
+ false
+ } else {
+ e match {
+ case _: AggregateExpression => true
+ case _ => e.children.exists(hasAggregateExpression(_))
+ }
+ }
+ }
+
+ def addArrayStep(plan: LogicalPlan): LogicalPlan = {
+ val array =
+
CoalesceUnionUtil.makeAlias(CreateArray(plan.output.map(_.asInstanceOf[Expression])),
"array")
+ Project(Seq(array), plan)
+ }
+
+ def addExplodeStep(plan: LogicalPlan): LogicalPlan = {
+ val arrayOutput = plan.output.head.asInstanceOf[Expression]
+ val explodeExpression = Explode(arrayOutput)
+ val explodeOutput = AttributeReference(
+ "generate_output",
+ arrayOutput.dataType.asInstanceOf[ArrayType].elementType)()
+ val generate = Generate(
+ explodeExpression,
+ unrequiredChildIndex = Seq(0),
+ outer = false,
+ qualifier = None,
+ generatorOutput = Seq(explodeOutput),
+ plan)
+ Filter(IsNotNull(generate.output.head), generate)
+ }
+
+ def addUnfoldStructStep(plan: LogicalPlan): LogicalPlan = {
+ assert(plan.output.length == 1)
+ val structExpression = plan.output.head
+ assert(structExpression.dataType.isInstanceOf[StructType])
+ val structType = structExpression.dataType.asInstanceOf[StructType]
+ val attributes = ArrayBuffer[NamedExpression]()
+ var fieldIndex = 0
+ structType.fields.foreach {
+ field =>
+ attributes += Alias(GetStructField(structExpression, fieldIndex),
field.name)()
+ fieldIndex += 1
+ }
+ Project(attributes.toSeq, plan)
+ }
+
+ def unionClauses(originalUnion: Union, clauses: Seq[LogicalPlan]):
LogicalPlan = {
+ val coalescePlan = if (clauses.length == 1) {
+ clauses.head
+ } else {
+ var firstUnionChild = clauses.head
+ for (i <- 1 until clauses.length - 1) {
+ firstUnionChild = Union(firstUnionChild, clauses(i))
+ }
+ Union(firstUnionChild, clauses.last)
+ }
+ val outputPairs = coalescePlan.output.zip(originalUnion.output)
+ if (outputPairs.forall(pair => pair._1.semanticEquals(pair._2))) {
+ coalescePlan
+ } else {
+ val reprojectOutputs = outputPairs.map {
+ case (newAttr, oldAttr) =>
+ if (newAttr.exprId == oldAttr.exprId) {
+ newAttr
+ } else {
+ Alias(newAttr, oldAttr.name)(oldAttr.exprId, oldAttr.qualifier,
None, Seq.empty)
+ }
+ }
+ Project(reprojectOutputs, coalescePlan)
+ }
+ }
+}
+
+/*
+ * Example:
+ * Rewrite query
+ * SELECT a, b, sum(c) FROM t WHERE d = 1 GROUP BY a,b
+ * UNION ALL
+ * SELECT a, b, sum(c) FROM t WHERE d = 2 GROUP BY a,b
+ * into
+ * SELECT a, b, sum(c) FROM (
+ * SELECT s.a as a, s.b as b, s.c as c, s.id as group_id FROM (
+ * SELECT explode(s) as s FROM (
+ * SELECT array(
+ * if(d = 1, named_struct('a', a, 'b', b, 'c', c, 'id', 0), null),
+ * if(d = 2, named_struct('a', a, 'b', b, 'c', c, 'id', 1), null)) as s
+ * FROM t WHERE d = 1 OR d = 2
+ * )
+ * ) WHERE s is not null
+ * ) GROUP BY a,b, group_id
+ *
+ * The first query need to scan `t` multiply, when the output of scan is
large, the query is
+ * really slow. The rewritten query only scan `t` once, and the performance is
much better.
+ */
+
+class CoalesceAggregationUnion(spark: SparkSession) extends Rule[LogicalPlan]
with Logging {
+
+ case class PlanAnalyzer(originalAggregate: Aggregate) extends
AbstractPlanAnalyzer {
+
protected def extractFilter(): Option[Filter] = {
originalAggregate.child match {
case filter: Filter => Some(filter)
@@ -129,9 +348,10 @@ class CoalesceAggregationUnion(spark: SparkSession)
extends Rule[LogicalPlan] wi
subquery.child match {
case filter: Filter => Some(filter)
case project @ Project(_, filter: Filter) => Some(filter)
- case relation if isRelation(relation) =>
+ case relation if CoalesceUnionUtil.isRelation(relation) =>
Some(Filter(Literal(true, BooleanType), subquery))
- case nestedRelation: SubqueryAlias if
(isRelation(nestedRelation.child)) =>
+ case nestedRelation: SubqueryAlias
+ if (CoalesceUnionUtil.isRelation(nestedRelation.child)) =>
Some(Filter(Literal(true, BooleanType), nestedRelation))
case _ => None
}
@@ -139,76 +359,66 @@ class CoalesceAggregationUnion(spark: SparkSession)
extends Rule[LogicalPlan] wi
}
}
- def isValidSource(plan: LogicalPlan): Boolean = {
- plan match {
- case relation if isRelation(relation) => true
- case _: Project | _: Filter | _: SubqueryAlias =>
- plan.children.forall(isValidSource)
- case _ => false
- }
- }
-
// Try to make the plan simple, contain only three steps, source, filter,
aggregate.
lazy val extractedSourcePlan = {
- val filter = extractFilter()
- if (!filter.isDefined) {
- None
- } else {
- filter.get.child match {
- case project: Project if isValidSource(project.child) =>
Some(project.child)
- case other if isValidSource(other) => Some(other)
- case _ => None
- }
+ extractFilter match {
+ case Some(filter) =>
+ filter.child match {
+ case project: Project if
CoalesceUnionUtil.validateSource(project.child) =>
+ Some(project.child)
+ case other if CoalesceUnionUtil.validateSource(other) =>
Some(other)
+ case _ => None
+ }
+ case None => None
}
}
+ override def getExtractedSourcePlan(): Option[LogicalPlan] =
extractedSourcePlan
lazy val constructedFilterPlan = {
- val filter = extractFilter()
- if (!filter.isDefined || !extractedSourcePlan.isDefined) {
- None
- } else {
- val project = filter.get.child match {
- case project: Project => Some(project)
- case other =>
- None
- }
- val newFilter = project match {
- case Some(project) =>
- val replaceMap = buildAttributesMap(
- project.output,
- project.child.output.map(_.asInstanceOf[Expression]))
- val newCondition = replaceAttributes(filter.get.condition,
replaceMap)
- Filter(newCondition, extractedSourcePlan.get)
- case None => filter.get.withNewChildren(Seq(extractedSourcePlan.get))
- }
- Some(newFilter)
+ extractedSourcePlan match {
+ case Some(sourcePlan) =>
+ val filter = extractFilter().get
+ val innerProject = filter.child match {
+ case project: Project => Some(project)
+ case _ => None
+ }
+
+ val newFilter = innerProject match {
+ case Some(project) =>
+ val replaceMap = CoalesceUnionUtil.buildAttributesMap(
+ project.output,
+ project.projectList.map(_.asInstanceOf[Expression]))
+ val newCondition =
CoalesceUnionUtil.replaceAttributes(filter.condition, replaceMap)
+ Filter(newCondition, sourcePlan)
+ case None => filter.withNewChildren(Seq(sourcePlan))
+ }
+ Some(newFilter)
+ case None => None
}
}
+ override def getConstructedFilterPlan: Option[LogicalPlan] =
constructedFilterPlan
lazy val constructedAggregatePlan = {
if (!constructedFilterPlan.isDefined) {
None
} else {
- val project = originalAggregate.child match {
- case p: Project => Some(p)
- case subquery: SubqueryAlias =>
- subquery.child match {
- case p: Project => Some(p)
- case _ => None
- }
+ val innerProject = originalAggregate.child match {
+ case project: Project => Some(project)
+ case subquery @ SubqueryAlias(_, project: Project) =>
+ Some(project)
case _ => None
}
- val newAggregate = project match {
- case Some(innerProject) =>
- val replaceMap = buildAttributesMap(
- innerProject.output,
- innerProject.projectList.map(_.asInstanceOf[Expression]))
+ val newAggregate = innerProject match {
+ case Some(project) =>
+ val replaceMap = CoalesceUnionUtil.buildAttributesMap(
+ project.output,
+ project.projectList.map(_.asInstanceOf[Expression]))
val newGroupExpressions =
originalAggregate.groupingExpressions.map {
- e => replaceAttributes(e, replaceMap)
+ e => CoalesceUnionUtil.replaceAttributes(e, replaceMap)
}
val newAggregateExpressions =
originalAggregate.aggregateExpressions.map {
- e => replaceAttributes(e,
replaceMap).asInstanceOf[NamedExpression]
+ e => CoalesceUnionUtil.replaceAttributes(e,
replaceMap).asInstanceOf[NamedExpression]
}
Aggregate(newGroupExpressions, newAggregateExpressions,
constructedFilterPlan.get)
case None =>
originalAggregate.withNewChildren(Seq(constructedFilterPlan.get))
@@ -216,75 +426,96 @@ class CoalesceAggregationUnion(spark: SparkSession)
extends Rule[LogicalPlan] wi
Some(newAggregate)
}
}
+ override def getConstructedAggregatePlan: Option[LogicalPlan] =
constructedAggregatePlan
+ def hasAggregateExpressionsWithFilter(e: Expression): Boolean = {
+ if (e.children.isEmpty && !e.isInstanceOf[AggregateExpression]) {
+ return false
+ }
+ e match {
+ case aggExpr: AggregateExpression =>
+ aggExpr.filter.isDefined
+ case _ => e.children.exists(hasAggregateExpressionsWithFilter(_))
+ }
+ }
lazy val hasAggregateWithFilter =
originalAggregate.aggregateExpressions.exists {
e => hasAggregateExpressionsWithFilter(e)
}
// The output results which are not aggregate expressions.
- lazy val resultGroupingExpressions = constructedAggregatePlan match {
+ lazy val resultRequiredGroupingExpressions = constructedAggregatePlan
match {
case Some(agg) =>
- agg.asInstanceOf[Aggregate].aggregateExpressions.filter(e =>
!hasAggregateExpression(e))
+ agg
+ .asInstanceOf[Aggregate]
+ .aggregateExpressions
+ .filter(e => !CoalesceUnionUtil.hasAggregateExpression(e))
case None => Seq.empty
}
- lazy val positionInGroupingKeys = {
- var i = 0
- // In most cases, the expressions which are not aggregate result could
be matched with one of
- // groupingk keys. There are some exceptions
- // 1. The expression is a literal. The grouping keys do not contain the
literal.
- // 2. The expression is an expression withs gruping keys. For example,
- // `select k1 + k2, count(1) from t group by k1, k2`.
- resultGroupingExpressions.map {
- e =>
- val aggregate = constructedAggregatePlan.get.asInstanceOf[Aggregate]
- e match {
- case literal @ Alias(_: Literal, _) =>
- var idx = aggregate.groupingExpressions.indexOf(e)
- if (idx == -1) {
- idx = aggregate.groupingExpressions.length + i
- i += 1
- }
- idx
- case _ =>
- var idx = aggregate.groupingExpressions.indexOf(removeAlias(e))
- idx = if (idx == -1) {
- aggregate.groupingExpressions.indexOf(e)
- } else {
- idx
- }
- idx
+ // For non-aggregate expressions in the output, they shoud be matched with
one of the grouping
+ // keys. There are some exceptions
+ // 1. The expression is a literal. The grouping keys do not contain the
literal.
+ // 2. The expression is an expression withs gruping keys. For example,
+ // `select k1 + k2, count(1) from t group by k1, k2`.
+ // This is used to judge whether two aggregate plans are matched.
+ lazy val aggregateResultMatchedGroupingKeysPositions = {
+ var extraPosition = 0
+ val aggregate = constructedAggregatePlan.get.asInstanceOf[Aggregate]
+ resultRequiredGroupingExpressions.map {
+ case literal @ Alias(_: Literal, _) =>
+ aggregate.groupingExpressions.indexOf(literal) match {
+ case -1 =>
+ extraPosition += 1
+ aggregate.groupingExpressions.length + extraPosition - 1
+ case position => position
+ }
+ case normalExpression =>
+ aggregate.groupingExpressions.indexOf(
+ CoalesceUnionUtil.removeAlias(normalExpression)) match {
+ case -1 => aggregate.groupingExpressions.indexOf(normalExpression)
+ case position => position
}
}
}
+
+ override def doValidate(): Boolean = {
+ !hasAggregateWithFilter &&
+ constructedAggregatePlan.isDefined &&
+ aggregateResultMatchedGroupingKeysPositions.forall(_ >= 0) &&
+ originalAggregate.aggregateExpressions.forall {
+ e =>
+ val innerExpr = CoalesceUnionUtil.removeAlias(e)
+ // `agg_fun1(x) + agg_fun2(y)` is supported, but `agg_fun1(x) + y`
is not supported.
+ if (CoalesceUnionUtil.hasAggregateExpression(innerExpr)) {
+ innerExpr.isInstanceOf[AggregateExpression] ||
+ innerExpr.children.forall(e =>
CoalesceUnionUtil.isAggregateExpression(e))
+ } else {
+ true
+ }
+ } &&
+ extractedSourcePlan.isDefined
+ }
}
/*
* Case class representing an analyzed plan.
*
* @param plan The logical plan that to be analyzed.
- * @param analyzedInfo Optional information about the aggregate analysis.
+ * @param planAnalyzer Optional information about the aggregate analysis.
*/
- case class AnalyzedPlan(plan: LogicalPlan, analyzedInfo:
Option[AggregateAnalzyInfo])
-
- def isResolvedPlan(plan: LogicalPlan): Boolean = {
- plan match {
- case isnert: InsertIntoStatement => isnert.query.resolved
- case _ => plan.resolved
- }
- }
-
override def apply(plan: LogicalPlan): LogicalPlan = {
if (
spark.conf
.get(CHBackendSettings.GLUTEN_ENABLE_COALESCE_AGGREGATION_UNION,
"true")
- .toBoolean && isResolvedPlan(plan)
+ .toBoolean && CoalesceUnionUtil.isResolvedPlan(plan)
) {
Try {
visitPlan(plan)
} match {
- case Success(res) => res
- case Failure(e) => plan
+ case Success(newPlan) => newPlan
+ case Failure(e) =>
+ logError(s"$e")
+ plan
}
} else {
plan
@@ -293,526 +524,531 @@ class CoalesceAggregationUnion(spark: SparkSession)
extends Rule[LogicalPlan] wi
def visitPlan(plan: LogicalPlan): LogicalPlan = {
plan match {
- case union: Union =>
- val planGroups = groupStructureMatchedAggregate(union)
- if (planGroups.forall(group => group.length == 1)) {
- plan.withNewChildren(plan.children.map(visitPlan))
- } else {
- val newUnionClauses = planGroups.map {
- groupedPlans =>
- if (groupedPlans.length == 1) {
- groupedPlans.head.plan
- } else {
- val firstAggregateAnalzyInfo =
groupedPlans.head.analyzedInfo.get
- val aggregates =
groupedPlans.map(_.analyzedInfo.get.constructedAggregatePlan.get)
- val filterConditions =
buildAggregateCasesConditions(groupedPlans)
- val firstAggregateFilter =
-
firstAggregateAnalzyInfo.constructedFilterPlan.get.asInstanceOf[Filter]
-
- // Add a filter step with condition `cond1 or cond2 or ...`,
`cond_i` comes from
- // each union clause. Apply this filter on the source plan.
- val unionFilter = Filter(
- buildUnionConditionForAggregateSource(filterConditions),
- firstAggregateAnalzyInfo.extractedSourcePlan.get)
-
- // Wrap all the attributes into a single structure attribute.
- val wrappedAttributesProject =
- buildProjectFoldIntoStruct(unionFilter, groupedPlans,
filterConditions)
-
- // Build an array which element are response to each union
clause.
- val arrayProject =
- buildProjectBranchArray(wrappedAttributesProject,
filterConditions)
+ case union @ Union(_, false, false) =>
+ val groupedUnionClauses = groupUnionClauses(union)
+ val newUnionClauses = groupedUnionClauses.map {
+ clauses => coalesceMatchedUnionClauses(clauses)
+ }
+ CoalesceUnionUtil.unionClauses(union, newUnionClauses)
+ case _ => plan.withNewChildren(plan.children.map(visitPlan))
+ }
+ }
- // Explode the array
- val explode = buildExplodeBranchArray(arrayProject)
+ def groupUnionClauses(union: Union): Seq[Seq[AnalyzedPlan]] = {
+ val unionClauses = CoalesceUnionUtil.collectAllUnionClauses(union)
+ val groups = ArrayBuffer[ArrayBuffer[AnalyzedPlan]]()
+ unionClauses.foreach {
+ clause =>
+ val innerClause = clause match {
+ case project @ Project(projectList, aggregate: Aggregate) =>
+ if (projectList.forall(_.isInstanceOf[Alias])) {
+ Some(aggregate)
+ } else {
+ None
+ }
+ case aggregate: Aggregate =>
+ Some(aggregate)
+ case _ => None
+ }
+ innerClause match {
+ case Some(aggregate) =>
+ val planAnalyzer = PlanAnalyzer(aggregate)
+ if (planAnalyzer.doValidate()) {
+ val analyzedPlan = AnalyzedPlan(clause, Some(planAnalyzer))
+ val matchedGroup = findMatchedGroup(analyzedPlan, groups)
+ if (matchedGroup != -1) {
+ groups(matchedGroup) += analyzedPlan
+ } else {
+ groups += ArrayBuffer(analyzedPlan)
+ }
+ } else {
+ val newClause = visitPlan(clause)
+ groups += ArrayBuffer(AnalyzedPlan(newClause, None))
+ }
+ case None =>
+ val newClause = visitPlan(clause)
+ groups += ArrayBuffer(AnalyzedPlan(newClause, None))
+ }
+ }
+ groups.map(_.toSeq).toSeq
+ }
- // Null value means that the union clause does not have the
corresponding data.
- val notNullFilter = Filter(IsNotNull(explode.output.head),
explode)
+ def findMatchedGroup(
+ analyedPlan: AnalyzedPlan,
+ groups: ArrayBuffer[ArrayBuffer[AnalyzedPlan]]): Int = {
+ groups.zipWithIndex.find {
+ case (group, groupIndex) =>
+ val checkedAnalyzedPlan = group.head
+ if (checkedAnalyzedPlan.planAnalyzer.isDefined &&
analyedPlan.planAnalyzer.isDefined) {
+ val leftPlanAnalyzer =
checkedAnalyzedPlan.planAnalyzer.get.asInstanceOf[PlanAnalyzer]
+ val rightPlanAnalyzer =
analyedPlan.planAnalyzer.get.asInstanceOf[PlanAnalyzer]
+ val leftAggregate =
+
leftPlanAnalyzer.getConstructedAggregatePlan.get.asInstanceOf[Aggregate]
+ val rightAggregate =
+
rightPlanAnalyzer.getConstructedAggregatePlan.get.asInstanceOf[Aggregate]
- // Destruct the struct attribute.
- val destructStructProject =
buildProjectUnfoldStruct(notNullFilter)
+ var isMatched = CoalesceUnionUtil
+ .areStrictMatchedRelation(
+ leftPlanAnalyzer.extractedSourcePlan.get,
+ rightPlanAnalyzer.extractedSourcePlan.get)
- buildAggregateWithGroupId(destructStructProject, groupedPlans)
- }
- }
- val coalesePlan = if (newUnionClauses.length == 1) {
- newUnionClauses.head
- } else {
- var firstUnionChild = newUnionClauses.head
- for (i <- 1 until newUnionClauses.length - 1) {
- firstUnionChild = Union(firstUnionChild, newUnionClauses(i))
+ isMatched = isMatched &&
+ leftAggregate.groupingExpressions.length ==
rightAggregate.groupingExpressions.length &&
+
leftAggregate.groupingExpressions.zip(rightAggregate.groupingExpressions).forall
{
+ case (leftExpr, rightExpr) =>
leftExpr.dataType.equals(rightExpr.dataType)
}
- Union(firstUnionChild, newUnionClauses.last)
- }
- // We need to keep the output atrributes same as the original plan.
- val outputAttrPairs = coalesePlan.output.zip(union.output)
- if (outputAttrPairs.forall(pair => pair._1.semanticEquals(pair._2)))
{
- coalesePlan
- } else {
- val reprejectOutputs = outputAttrPairs.map {
- case (newAttr, oldAttr) =>
- if (newAttr.exprId == oldAttr.exprId) {
- newAttr
+ isMatched = isMatched &&
+
leftPlanAnalyzer.aggregateResultMatchedGroupingKeysPositions.length ==
+
rightPlanAnalyzer.aggregateResultMatchedGroupingKeysPositions.length &&
+ leftPlanAnalyzer.aggregateResultMatchedGroupingKeysPositions
+
.zip(rightPlanAnalyzer.aggregateResultMatchedGroupingKeysPositions)
+ .forall { case (leftPos, rightPos) => leftPos == rightPos }
+
+ isMatched = isMatched && leftAggregate.aggregateExpressions.length ==
+ rightAggregate.aggregateExpressions.length &&
+
leftAggregate.aggregateExpressions.zip(rightAggregate.aggregateExpressions).forall
{
+ case (leftExpr, rightExpr) =>
+ if (leftExpr.dataType.equals(rightExpr.dataType)) {
+ (
+ CoalesceUnionUtil.hasAggregateExpression(leftExpr),
+ CoalesceUnionUtil.hasAggregateExpression(rightExpr)) match
{
+ case (true, true) =>
+ CoalesceUnionUtil.areMatchedExpression(leftExpr,
rightExpr)
+ case (false, true) => false
+ case (true, false) => false
+ case (false, false) => true
+ }
} else {
- Alias(newAttr, oldAttr.name)(oldAttr.exprId,
oldAttr.qualifier, None, Seq.empty)
+ false
}
}
- Project(reprejectOutputs, coalesePlan)
- }
+ isMatched
+ } else {
+ false
}
- case _ => plan.withNewChildren(plan.children.map(visitPlan))
+ } match {
+ case Some((_, i)) => i
+ case None => -1
}
}
- def isRelation(plan: LogicalPlan): Boolean = {
- plan.isInstanceOf[MultiInstanceRelation]
+ def coalesceMatchedUnionClauses(clauses: Seq[AnalyzedPlan]): LogicalPlan = {
+ if (clauses.length == 1) {
+ clauses.head.plan
+ } else {
+ val normalizedConditions =
CoalesceUnionUtil.normalizedClausesFilterCondition(clauses)
+ val newSource = clauses.head.planAnalyzer.get.getExtractedSourcePlan.get
+ val newFilter = Filter(normalizedConditions.reduce(Or), newSource)
+ val foldStructStep = addFoldStructStep(newFilter, normalizedConditions,
clauses)
+ val arrayStep = CoalesceUnionUtil.addArrayStep(foldStructStep)
+ val explodeStep = CoalesceUnionUtil.addExplodeStep(arrayStep)
+ val unfoldStructStep = CoalesceUnionUtil.addUnfoldStructStep(explodeStep)
+ addAggregateStep(unfoldStructStep,
clauses.head.planAnalyzer.get.asInstanceOf[PlanAnalyzer])
+ }
}
- def areSameRelation(l: LogicalPlan, r: LogicalPlan): Boolean = {
- (l, r) match {
- case (lRelation: LogicalRelation, rRelation: LogicalRelation) =>
- val lTable =
lRelation.catalogTable.map(_.identifier.unquotedString).getOrElse("")
- val rTable =
rRelation.catalogTable.map(_.identifier.unquotedString).getOrElse("")
- lRelation.output.length == rRelation.output.length &&
- lRelation.output.zip(rRelation.output).forall {
- case (lAttr, rAttr) =>
- lAttr.dataType.equals(rAttr.dataType) &&
lAttr.name.equals(rAttr.name)
- } &&
- lTable.equals(rTable) && lTable.nonEmpty
- case (lCTE: CTERelationRef, rCTE: CTERelationRef) =>
- lCTE.cteId == rCTE.cteId
- case (lHiveTable: HiveTableRelation, rHiveTable: HiveTableRelation) =>
- lHiveTable.tableMeta.identifier.unquotedString
- .equals(rHiveTable.tableMeta.identifier.unquotedString)
- case (_, _) =>
- logInfo(s"xxx unknow relation: ${l.getClass}, ${r.getClass}")
- false
+ def addFoldStructStep(
+ plan: LogicalPlan,
+ conditions: Seq[Expression],
+ analyzedPlans: Seq[AnalyzedPlan]): LogicalPlan = {
+ val replaceAttributes =
analyzedPlans.head.planAnalyzer.get.getExtractedSourcePlan.get.output
+
+ val structAttributes = analyzedPlans.zipWithIndex.map {
+ case (analyzedPlan, clauseIndex) =>
+ val attributeReplaceMap = CoalesceUnionUtil.buildAttributesMap(
+ analyzedPlan.planAnalyzer.get.getExtractedSourcePlan.get.output,
+ replaceAttributes)
+ val structFields = collectClauseStructFields(analyzedPlan,
clauseIndex, attributeReplaceMap)
+ CoalesceUnionUtil
+ .makeAlias(CreateNamedStruct(structFields), s"clause_$clauseIndex")
+ .asInstanceOf[NamedExpression]
+ }
+
+ val projectList = structAttributes.zip(conditions).map {
+ case (attribute, condition) =>
+ CoalesceUnionUtil
+ .makeAlias(If(condition, attribute, Literal(null,
attribute.dataType)), attribute.name)
+ .asInstanceOf[NamedExpression]
}
+ Project(projectList, plan)
}
- def isSupportedAggregate(info: AggregateAnalzyInfo): Boolean = {
+ def collectClauseStructFields(
+ analyzedPlan: AnalyzedPlan,
+ clauseIndex: Int,
+ attributeReplaceMap: Map[ExprId, Expression]): Seq[Expression] = {
- !info.hasAggregateWithFilter &&
- info.constructedAggregatePlan.isDefined &&
- info.positionInGroupingKeys.forall(_ >= 0) &&
- info.originalAggregate.aggregateExpressions.forall {
+ val planAnalyzer = analyzedPlan.planAnalyzer.get.asInstanceOf[PlanAnalyzer]
+ val aggregate =
planAnalyzer.constructedAggregatePlan.get.asInstanceOf[Aggregate]
+ val structFields = ArrayBuffer[Expression]()
+
+ aggregate.groupingExpressions.foreach {
e =>
- val innerExpr = removeAlias(e)
- // `agg_fun1(x) + agg_fun2(y)` is supported, but `agg_fun1(x) + y` is
not supported.
- if (hasAggregateExpression(innerExpr)) {
- innerExpr.isInstanceOf[AggregateExpression] ||
- innerExpr.children.forall(e => isAggregateExpression(e))
- } else {
- true
+ val fieldCounter = structFields.length / 2
+ structFields += Literal(UTF8String.fromString(s"f$fieldCounter"),
StringType)
+ structFields += CoalesceUnionUtil.replaceAttributes(e,
attributeReplaceMap)
+ }
+
+
planAnalyzer.aggregateResultMatchedGroupingKeysPositions.zipWithIndex.foreach {
+ case (position, index) =>
+ val fieldCounter = structFields.length / 2
+ if (position >= fieldCounter) {
+ val expression =
planAnalyzer.resultRequiredGroupingExpressions(index)
+ structFields += Literal(UTF8String.fromString(s"f$fieldCounter"),
StringType)
+ structFields += CoalesceUnionUtil.replaceAttributes(expression,
attributeReplaceMap)
}
- } &&
- info.extractedSourcePlan.isDefined
- }
+ }
- /**
- * Checks if two AggregateAnalzyInfo instances have the same structure.
- *
- * This method compares the aggregate expressions, grouping expressions, and
the source plans of
- * the two AggregateAnalzyInfo instances to determine if they have the same
structure.
- *
- * @param l
- * The first AggregateAnalzyInfo instance.
- * @param r
- * The second AggregateAnalzyInfo instance.
- * @return
- * True if the two instances have the same structure, false otherwise.
- */
- def areStructureMatchedAggregate(l: AggregateAnalzyInfo, r:
AggregateAnalzyInfo): Boolean = {
- val lAggregate = l.constructedAggregatePlan.get.asInstanceOf[Aggregate]
- val rAggregate = r.constructedAggregatePlan.get.asInstanceOf[Aggregate]
- lAggregate.aggregateExpressions.length ==
rAggregate.aggregateExpressions.length &&
-
lAggregate.aggregateExpressions.zip(rAggregate.aggregateExpressions).forall {
- case (lExpr, rExpr) =>
- if (!lExpr.dataType.equals(rExpr.dataType)) {
- false
- } else {
- (hasAggregateExpression(lExpr), hasAggregateExpression(rExpr)) match
{
- case (true, true) => areStructureMatchedExpressions(lExpr, rExpr)
- case (false, true) => false
- case (true, false) => false
- case (false, false) => true
+ aggregate.aggregateExpressions
+ .filter(e => CoalesceUnionUtil.hasAggregateExpression(e))
+ .foreach {
+ e =>
+ def visitAggregateExpression(expression: Expression): Unit = {
+ expression match {
+ case aggregateExpression: AggregateExpression =>
+ val fieldCounter = structFields.length / 2
+ val aggregateFunction = aggregateExpression.aggregateFunction
+ aggregateFunction.children.foreach {
+ argument =>
+ val fieldCounter = structFields.length / 2
+ structFields +=
Literal(UTF8String.fromString(s"f$fieldCounter"), StringType)
+ structFields += CoalesceUnionUtil.replaceAttributes(
+ argument,
+ attributeReplaceMap)
+ }
+ case combindedAggregateExpression
+ if
CoalesceUnionUtil.hasAggregateExpression(combindedAggregateExpression) =>
+
combindedAggregateExpression.children.foreach(visitAggregateExpression)
+ case other =>
+ val fieldCounter = structFields.length / 2
+ structFields +=
Literal(UTF8String.fromString(s"f$fieldCounter"), StringType)
+ structFields += CoalesceUnionUtil.replaceAttributes(other,
attributeReplaceMap)
+ }
}
- }
- } &&
- lAggregate.groupingExpressions.length ==
rAggregate.groupingExpressions.length &&
- l.positionInGroupingKeys.length == r.positionInGroupingKeys.length &&
- l.positionInGroupingKeys.zip(r.positionInGroupingKeys).forall {
- case (lPos, rPos) => lPos == rPos
- } &&
- areSameAggregateSource(l.extractedSourcePlan.get,
r.extractedSourcePlan.get)
+ visitAggregateExpression(e)
+ }
+
+ // Add the clause index to the struct.
+ val fieldCounter = structFields.length / 2
+ structFields += Literal(UTF8String.fromString(s"f$fieldCounter"),
StringType)
+ structFields += Literal(UTF8String.fromString(s"$clauseIndex"), StringType)
+
+ structFields.toSeq
}
- /*
- * Finds the index of the first group in `planGroups` that has the same
structure as the given
- * `analyzedInfo`.
- *
- * This method iterates over the `planGroups` and checks if the first
`AnalyzedPlan` in each group
- * has an `analyzedInfo` that matches the structure of the provided
`analyzedInfo`. If a match is
- * found, the index of the group is returned. If no match is found, -1 is
returned.
- *
- * @param planGroups
- * An ArrayBuffer of ArrayBuffers, where each inner ArrayBuffer contains
`AnalyzedPlan`
- * instances.
- * @param analyzedInfo
- * The `AggregateAnalzyInfo` to match against the groups in `planGroups`.
- * @return
- * The index of the first group with a matching structure, or -1 if no
match is found.
- */
- def findStructureMatchedAggregate(
- planGroups: ArrayBuffer[ArrayBuffer[AnalyzedPlan]],
- analyzedInfo: AggregateAnalzyInfo): Int = {
- planGroups.zipWithIndex.find(
- planWithIndex =>
- planWithIndex._1.head.analyzedInfo.isDefined &&
- areStructureMatchedAggregate(
- planWithIndex._1.head.analyzedInfo.get,
- analyzedInfo)) match {
- case Some((_, i)) => i
- case None => -1
- }
+ def addAggregateStep(plan: LogicalPlan, templatePlanAnalyzer: PlanAnalyzer):
LogicalPlan = {
+ val inputAttributes = plan.output
+ val templateAggregate =
+ templatePlanAnalyzer.constructedAggregatePlan.get.asInstanceOf[Aggregate]
+
+ val totalGroupingExpressionsCount = math.max(
+ templateAggregate.groupingExpressions.length,
+ templatePlanAnalyzer.aggregateResultMatchedGroupingKeysPositions.max + 1)
+
+ // inputAttributes.last is the clause index.
+ val groupingExpressions = inputAttributes
+ .slice(0, totalGroupingExpressionsCount)
+ .map(_.asInstanceOf[Expression]) :+ inputAttributes.last
+
+ var aggregateExpressionCount = totalGroupingExpressionsCount
+ var nonAggregateExpressionCount = 0
+ val aggregateExpressions = ArrayBuffer[NamedExpression]()
+ templateAggregate.aggregateExpressions.foreach {
+ e =>
+ CoalesceUnionUtil.removeAlias(e) match {
+ case aggregateExpression
+ if CoalesceUnionUtil.hasAggregateExpression(aggregateExpression)
=>
+ val (newAggregateExpression, count) = buildNewAggregateExpression(
+ aggregateExpression,
+ inputAttributes,
+ aggregateExpressionCount)
+
+ aggregateExpressionCount += count
+ aggregateExpressions += CoalesceUnionUtil
+ .makeAlias(newAggregateExpression, e.name)
+ .asInstanceOf[NamedExpression]
+ case nonAggregateExpression =>
+ val position =
templatePlanAnalyzer.aggregateResultMatchedGroupingKeysPositions(
+ nonAggregateExpressionCount)
+ val attribute = inputAttributes(position)
+ aggregateExpressions += CoalesceUnionUtil
+ .makeAlias(attribute, e.name)
+ .asInstanceOf[NamedExpression]
+ nonAggregateExpressionCount += 1
+ }
+ }
+ Aggregate(groupingExpressions.toSeq, aggregateExpressions.toSeq, plan)
}
- // Union only has two children. It's children may also be Union.
- def collectAllUnionClauses(union: Union): ArrayBuffer[LogicalPlan] = {
- val unionClauses = ArrayBuffer[LogicalPlan]()
- union.children.foreach {
- case u: Union =>
- unionClauses ++= collectAllUnionClauses(u)
- case other =>
- unionClauses += other
+ def buildNewAggregateExpression(
+ oldExpression: Expression,
+ inputAttributes: Seq[Attribute],
+ attributesOffset: Int): (Expression, Int) = {
+ oldExpression match {
+ case aggregateExpression: AggregateExpression =>
+ val aggregateFunction = aggregateExpression.aggregateFunction
+ val newArguments = aggregateFunction.children.zipWithIndex.map {
+ case (argument, i) => inputAttributes(i + attributesOffset)
+ }
+ val newAggregateFunction =
+
aggregateFunction.withNewChildren(newArguments).asInstanceOf[AggregateFunction]
+ val newAggregateExpression = AggregateExpression(
+ newAggregateFunction,
+ aggregateExpression.mode,
+ aggregateExpression.isDistinct,
+ aggregateExpression.filter,
+ aggregateExpression.resultId)
+ (newAggregateExpression, 1)
+ case combindedAggregateExpression
+ if
CoalesceUnionUtil.hasAggregateExpression(combindedAggregateExpression) =>
+ var count = 0
+ val newChildren = ArrayBuffer[Expression]()
+ combindedAggregateExpression.children.foreach {
+ case child =>
+ val (newChild, n) =
+ buildNewAggregateExpression(child, inputAttributes,
attributesOffset + count)
+ count += n
+ newChildren += newChild
+ }
+ val newExpression =
combindedAggregateExpression.withNewChildren(newChildren.toSeq)
+ (newExpression, count)
+ case _ => (inputAttributes(attributesOffset), 1)
}
- unionClauses
}
+}
- def groupStructureMatchedAggregate(union: Union):
ArrayBuffer[ArrayBuffer[AnalyzedPlan]] = {
+/**
+ * Rewrite following query select a,b, 1 as c from t where d = 1 union all
select a,b, 2 as c from t
+ * where d = 2 into select s.f0 as a, s.f1 as b, s.f2 as c from ( select
explode(s) as s from (
+ * select array(if(d=1, named_struct('f0', a, 'f1', b, 'f2', 1), null),
if(d=2, named_struct('f0',
+ * a, 'f1', b, 'f2', 2), null)) as s from t where d = 1 or d = 2 ) ) where s
is not null
+ */
+class CoalesceProjectionUnion(spark: SparkSession) extends Rule[LogicalPlan]
with Logging {
- def tryPutToGroup(
- groupResults: ArrayBuffer[ArrayBuffer[AnalyzedPlan]],
- agg: Aggregate): Unit = {
- val analyzedInfo = AggregateAnalzyInfo(agg)
- if (isSupportedAggregate(analyzedInfo)) {
- if (groupResults.isEmpty) {
- groupResults += ArrayBuffer(
- AnalyzedPlan(analyzedInfo.originalAggregate, Some(analyzedInfo)))
- } else {
- val idx = findStructureMatchedAggregate(groupResults, analyzedInfo)
- if (idx != -1) {
- groupResults(idx) += AnalyzedPlan(
- analyzedInfo.constructedAggregatePlan.get,
- Some(analyzedInfo))
- } else {
- groupResults += ArrayBuffer(
- AnalyzedPlan(analyzedInfo.constructedAggregatePlan.get,
Some(analyzedInfo)))
+ case class PlanAnalyzer(originalPlan: LogicalPlan) extends
AbstractPlanAnalyzer {
+ def extractFilter(): Option[Filter] = {
+ originalPlan match {
+ case project @ Project(_, filter: Filter) => Some(filter)
+ case _ => None
+ }
+ }
+
+ lazy val extractedSourcePlan = {
+ extractFilter match {
+ case Some(filter) =>
+ filter.child match {
+ case project: Project =>
+ if (CoalesceUnionUtil.validateSource(project.child))
Some(project.child)
+ else None
+ Some(project.child)
+ case subquery @ SubqueryAlias(_, project: Project) =>
+ if (CoalesceUnionUtil.validateSource(project.child))
Some(project.child)
+ else None
+ case _ => Some(filter.child)
}
- }
- } else {
- val rewrittenPlan = visitPlan(agg)
- groupResults += ArrayBuffer(AnalyzedPlan(rewrittenPlan, None))
+ case None => None
}
}
- val groupResults = ArrayBuffer[ArrayBuffer[AnalyzedPlan]]()
- collectAllUnionClauses(union).foreach {
- case project @ Project(projectList, agg: Aggregate) =>
- if (projectList.forall(e => e.isInstanceOf[Alias])) {
- tryPutToGroup(groupResults, agg)
- } else {
- val rewrittenPlan = visitPlan(project)
- groupResults += ArrayBuffer(AnalyzedPlan(rewrittenPlan, None))
- }
- case agg: Aggregate =>
- tryPutToGroup(groupResults, agg)
- case other =>
- val rewrittenPlan = visitPlan(other)
- groupResults += ArrayBuffer(AnalyzedPlan(rewrittenPlan, None))
- }
- groupResults
- }
-
- def areStructureMatchedExpressions(l: Expression, r: Expression): Boolean = {
- if (l.dataType.equals(r.dataType)) {
- (l, r) match {
- case (lAttr: Attribute, rAttr: Attribute) =>
- // The the qualifier may be overwritten by a subquery alias, and
make this check fail.
- lAttr.qualifiedName.equals(rAttr.qualifiedName)
- case (lLiteral: Literal, rLiteral: Literal) =>
- lLiteral.value == rLiteral.value
- case (lagg: AggregateExpression, ragg: AggregateExpression) =>
- lagg.isDistinct == ragg.isDistinct &&
- areStructureMatchedExpressions(lagg.aggregateFunction,
ragg.aggregateFunction)
- case _ =>
- l.children.length == r.children.length &&
- l.getClass == r.getClass &&
- l.children.zip(r.children).forall {
- case (lChild, rChild) => areStructureMatchedExpressions(lChild,
rChild)
+ lazy val constructedFilterPlan = {
+ extractedSourcePlan match {
+ case Some(source) =>
+ val filter = extractFilter().get
+ filter.child match {
+ case project: Project =>
+ val replaceMap =
+ CoalesceUnionUtil.buildAttributesMap(
+ project.output,
+ project.projectList.map(_.asInstanceOf[Expression]))
+ val newCondition =
CoalesceUnionUtil.replaceAttributes(filter.condition, replaceMap)
+ Some(Filter(newCondition, source))
+ case subquery @ SubqueryAlias(_, project: Project) =>
+ val replaceMap =
+ CoalesceUnionUtil.buildAttributesMap(
+ project.output,
+ project.projectList.map(_.asInstanceOf[Expression]))
+ val newCondition =
CoalesceUnionUtil.replaceAttributes(filter.condition, replaceMap)
+ Some(Filter(newCondition, source))
+ case _ => Some(filter)
}
+ case None => None
}
- } else {
- false
}
- }
- def areSameAggregateSource(lPlan: LogicalPlan, rPlan: LogicalPlan): Boolean
= {
- if (lPlan.children.length != rPlan.children.length || lPlan.getClass !=
rPlan.getClass) {
- false
- } else {
- lPlan.children.zip(rPlan.children).forall {
- case (lRelation, rRelation) if (isRelation(lRelation) &&
isRelation(rRelation)) =>
- areSameRelation(lRelation, rRelation)
- case (lSubQuery: SubqueryAlias, rSubQuery: SubqueryAlias) =>
- areSameAggregateSource(lSubQuery.child, rSubQuery.child)
- case (lproject: Project, rproject: Project) =>
- lproject.projectList.length == rproject.projectList.length &&
- lproject.projectList.zip(rproject.projectList).forall {
- case (lExpr, rExpr) => areStructureMatchedExpressions(lExpr, rExpr)
- } &&
- areSameAggregateSource(lproject.child, rproject.child)
- case (lFilter: Filter, rFilter: Filter) =>
- areStructureMatchedExpressions(lFilter.condition, rFilter.condition)
&&
- areSameAggregateSource(lFilter.child, rFilter.child)
- case (lChild, rChild) => false
+ lazy val constructedProjectPlan = {
+ constructedFilterPlan match {
+ case Some(filter) =>
+ val originalFilter = extractFilter().get
+ originalFilter.child match {
+ case project: Project =>
+ None
+ val replaceMap =
+ CoalesceUnionUtil.buildAttributesMap(
+ project.output,
+ project.projectList.map(_.asInstanceOf[Expression]))
+ val originalProject = originalPlan.asInstanceOf[Project]
+ val newProjectList =
+ originalProject.projectList
+ .map(e => CoalesceUnionUtil.replaceAttributes(e, replaceMap))
+ .map(_.asInstanceOf[NamedExpression])
+ val newProject = Project(newProjectList, filter)
+ Some(newProject)
+ case subquery @ SubqueryAlias(_, project: Project) =>
+ val replaceMap =
+ CoalesceUnionUtil.buildAttributesMap(
+ project.output,
+ project.projectList.map(_.asInstanceOf[Expression]))
+ val originalProject = originalPlan.asInstanceOf[Project]
+ val newProjectList =
+ originalProject.projectList
+ .map(e => CoalesceUnionUtil.replaceAttributes(e, replaceMap))
+ .map(_.asInstanceOf[NamedExpression])
+ val newProject = Project(newProjectList, filter)
+ Some(newProject)
+ case _ => Some(originalPlan)
+ }
+ case None => None
}
}
+
+ override def doValidate(): Boolean = {
+ constructedProjectPlan.isDefined
+ }
+
+ override def getExtractedSourcePlan: Option[LogicalPlan] =
extractedSourcePlan
+
+ override def getConstructedFilterPlan: Option[LogicalPlan] =
constructedFilterPlan
+
+ override def getConstructedProjectPlan: Option[LogicalPlan] =
constructedProjectPlan
}
- def buildAggregateCasesConditions(
- groupedPlans: ArrayBuffer[AnalyzedPlan]): ArrayBuffer[Expression] = {
- val firstPlanSourceOutputAttrs =
- groupedPlans.head.analyzedInfo.get.extractedSourcePlan.get.output
- groupedPlans.map {
- plan =>
- val attrsMap =
- buildAttributesMap(
- plan.analyzedInfo.get.extractedSourcePlan.get.output,
- firstPlanSourceOutputAttrs)
- val filter =
plan.analyzedInfo.get.constructedFilterPlan.get.asInstanceOf[Filter]
- replaceAttributes(filter.condition, attrsMap)
+ override def apply(plan: LogicalPlan): LogicalPlan = {
+ if (
+ spark.conf
+ .get(CHBackendSettings.GLUTEN_ENABLE_COALESCE_PROJECT_UNION, "true")
+ .toBoolean && CoalesceUnionUtil.isResolvedPlan(plan)
+ ) {
+ Try {
+ visitPlan(plan)
+ } match {
+ case Success(newPlan) => newPlan
+ case Failure(e) =>
+ logError(s"$e")
+ plan
+ }
+ } else {
+ plan
}
}
- def buildUnionConditionForAggregateSource(conditions:
ArrayBuffer[Expression]): Expression = {
- conditions.reduce(Or);
+ def visitPlan(plan: LogicalPlan): LogicalPlan = {
+ plan match {
+ case union @ Union(_, false, false) =>
+ val groupedUnionClauses = groupUnionClauses(union)
+ val newUnionClauses =
+ groupedUnionClauses.map(clauses =>
coalesceMatchedUnionClauses(clauses))
+ val newPlan = CoalesceUnionUtil.unionClauses(union, newUnionClauses)
+ newPlan
+ case other =>
+ other.withNewChildren(other.children.map(visitPlan))
+ }
}
- def wrapAggregatesAttributesInStructs(
- groupedPlans: ArrayBuffer[AnalyzedPlan]): Seq[NamedExpression] = {
- val structAttributes = ArrayBuffer[NamedExpression]()
- val casePrefix = "case_"
- val structPrefix = "field_"
- val firstSourceAttrs =
groupedPlans.head.analyzedInfo.get.extractedSourcePlan.get.output
- groupedPlans.zipWithIndex.foreach {
- case (aggregateCase, case_index) =>
- val analyzedInfo = aggregateCase.analyzedInfo.get
- val aggregate =
analyzedInfo.constructedAggregatePlan.get.asInstanceOf[Aggregate]
- val structFields = ArrayBuffer[Expression]()
- var fieldIndex: Int = 0
- val attrReplaceMap = buildAttributesMap(
- aggregateCase.analyzedInfo.get.extractedSourcePlan.get.output,
- firstSourceAttrs)
- aggregate.groupingExpressions.foreach {
- e =>
- structFields +=
Literal(UTF8String.fromString(s"$structPrefix$fieldIndex"), StringType)
- structFields += replaceAttributes(e, attrReplaceMap)
- fieldIndex += 1
- }
- for (i <- 0 until analyzedInfo.positionInGroupingKeys.length) {
- val position = analyzedInfo.positionInGroupingKeys(i)
- if (position >= fieldIndex) {
- val expr = analyzedInfo.resultGroupingExpressions(i)
- structFields +=
Literal(UTF8String.fromString(s"$structPrefix$fieldIndex"), StringType)
- structFields += replaceAttributes(
- analyzedInfo.resultGroupingExpressions(i),
- attrReplaceMap)
- fieldIndex += 1
+ def groupUnionClauses(union: Union): Seq[Seq[AnalyzedPlan]] = {
+ val unionClauses = CoalesceUnionUtil.collectAllUnionClauses(union)
+ val groups = ArrayBuffer[ArrayBuffer[AnalyzedPlan]]()
+ unionClauses.foreach {
+ clause =>
+ val planAnalyzer = PlanAnalyzer(clause)
+ if (planAnalyzer.doValidate()) {
+ val matchedGroup = findMatchedGroup(AnalyzedPlan(clause,
Some(planAnalyzer)), groups)
+ if (matchedGroup != -1) {
+ groups(matchedGroup) += AnalyzedPlan(clause, Some(planAnalyzer))
+ } else {
+ groups += ArrayBuffer(AnalyzedPlan(clause, Some(planAnalyzer)))
}
+ } else {
+ val newClause = visitPlan(clause)
+ groups += ArrayBuffer(AnalyzedPlan(newClause, None))
}
-
- aggregate.aggregateExpressions
- .filter(e => hasAggregateExpression(e))
- .foreach {
- e =>
- def collectExpressionsInAggregateExpression(aggExpr:
Expression): Unit = {
- aggExpr match {
- case aggExpr: AggregateExpression =>
- val aggFunction =
-
removeAlias(aggExpr).asInstanceOf[AggregateExpression].aggregateFunction
- aggFunction.children.foreach {
- child =>
- structFields += Literal(
- UTF8String.fromString(s"$structPrefix$fieldIndex"),
- StringType)
- structFields += replaceAttributes(child,
attrReplaceMap)
- fieldIndex += 1
- }
- case combineAgg if hasAggregateExpression(combineAgg) =>
- combineAgg.children.foreach {
- combindAggchild =>
collectExpressionsInAggregateExpression(combindAggchild)
- }
- case other =>
- structFields += Literal(
- UTF8String.fromString(s"$structPrefix$fieldIndex"),
- StringType)
- structFields += replaceAttributes(other, attrReplaceMap)
- fieldIndex += 1
- }
- }
- collectExpressionsInAggregateExpression(e)
- }
- structFields +=
Literal(UTF8String.fromString(s"$structPrefix$fieldIndex"), StringType)
- structFields += Literal(case_index, IntegerType)
- structAttributes += makeAlias(
- CreateNamedStruct(structFields.toSeq),
- s"$casePrefix$case_index")
- }
- structAttributes.toSeq
- }
-
- def buildProjectFoldIntoStruct(
- child: LogicalPlan,
- groupedPlans: ArrayBuffer[AnalyzedPlan],
- conditions: ArrayBuffer[Expression]): LogicalPlan = {
- val wrappedAttributes = wrapAggregatesAttributesInStructs(groupedPlans)
- val ifAttributes = wrappedAttributes.zip(conditions).map {
- case (attr, condition) =>
- makeAlias(If(condition, attr, Literal(null, attr.dataType)), attr.name)
- .asInstanceOf[NamedExpression]
}
- Project(ifAttributes, child)
+ groups.map(_.toSeq).toSeq
}
- def buildProjectBranchArray(
- child: LogicalPlan,
- conditions: ArrayBuffer[Expression]): LogicalPlan = {
- assert(
- child.output.length == conditions.length,
- s"Expected same length of output and conditions")
- val array = makeAlias(CreateArray(child.output), "array")
- Project(Seq(array), child)
- }
-
- def buildExplodeBranchArray(child: LogicalPlan): LogicalPlan = {
- assert(child.output.length == 1, s"Expected single output from $child")
- val array = child.output.head.asInstanceOf[Expression]
- assert(array.dataType.isInstanceOf[ArrayType], s"Expected ArrayType from
$array")
- val explodeExpr = Explode(array)
- val exploadOutput =
- AttributeReference("generate_output",
array.dataType.asInstanceOf[ArrayType].elementType)()
- Generate(
- explodeExpr,
- unrequiredChildIndex = Seq(0),
- outer = false,
- qualifier = None,
- generatorOutput = Seq(exploadOutput),
- child)
- }
+ def findMatchedGroup(
+ analyzedPlan: AnalyzedPlan,
+ groupedClauses: ArrayBuffer[ArrayBuffer[AnalyzedPlan]]): Int = {
+ groupedClauses.zipWithIndex.find {
+ groupWithIndex =>
+ if (groupWithIndex._1.head.planAnalyzer.isDefined) {
+ val checkedPlanAnalyzer = groupWithIndex._1.head.planAnalyzer.get
+ val checkPlanAnalyzer = analyzedPlan.planAnalyzer.get
+ CoalesceUnionUtil.areOutputMatchedProject(
+ checkedPlanAnalyzer.getConstructedProjectPlan.get,
+ checkPlanAnalyzer.getConstructedProjectPlan.get) &&
+ CoalesceUnionUtil.areStrictMatchedRelation(
+ checkedPlanAnalyzer.getExtractedSourcePlan.get,
+ checkPlanAnalyzer.getExtractedSourcePlan.get)
+ } else {
+ false
+ }
- def makeAlias(e: Expression, name: String): NamedExpression = {
- Alias(e, name)(
- NamedExpression.newExprId,
- e match {
- case ne: NamedExpression => ne.qualifier
- case _ => Seq.empty
- },
- None,
- Seq.empty)
+ } match {
+ case Some((_, i)) => i
+ case None => -1
+ }
}
- def buildProjectUnfoldStruct(child: LogicalPlan): LogicalPlan = {
- assert(child.output.length == 1, s"Expected single output from $child")
- val structedData = child.output.head
- assert(
- structedData.dataType.isInstanceOf[StructType],
- s"Expected StructType from $structedData")
- val structType = structedData.dataType.asInstanceOf[StructType]
- val attributes = ArrayBuffer[NamedExpression]()
- var index = 0
- structType.fields.foreach {
- field =>
- attributes += Alias(GetStructField(structedData, index), field.name)()
- index += 1
+ def coalesceMatchedUnionClauses(clauses: Seq[AnalyzedPlan]): LogicalPlan = {
+ if (clauses.length == 1) {
+ clauses.head.plan
+ } else {
+ val normalizedConditions =
CoalesceUnionUtil.normalizedClausesFilterCondition(clauses)
+ val newSource = clauses.head.planAnalyzer.get.getExtractedSourcePlan.get
+ val newFilter = Filter(normalizedConditions.reduce(Or), newSource)
+ val foldStructStep = addFoldStructStep(newFilter, normalizedConditions,
clauses)
+ val arrayStep = CoalesceUnionUtil.addArrayStep(foldStructStep)
+ val explodeStep = CoalesceUnionUtil.addExplodeStep(arrayStep)
+ CoalesceUnionUtil.addUnfoldStructStep(explodeStep)
}
- Project(attributes.toSeq, child)
}
- def buildAggregateWithGroupId(
- child: LogicalPlan,
- groupedPlans: ArrayBuffer[AnalyzedPlan]): LogicalPlan = {
- val attributes = child.output
- val firstAggregateAnalzyInfo = groupedPlans.head.analyzedInfo.get
- val aggregateTemplate =
-
firstAggregateAnalzyInfo.constructedAggregatePlan.get.asInstanceOf[Aggregate]
- val analyzedInfo = groupedPlans.head.analyzedInfo.get
-
- val totalGroupingExpressionsCount =
- math.max(
- aggregateTemplate.groupingExpressions.length,
- analyzedInfo.positionInGroupingKeys.max + 1)
-
- val groupingExpressions = attributes
- .slice(0, totalGroupingExpressionsCount)
- .map(_.asInstanceOf[Expression]) :+ attributes.last
+ def buildClausesStructs(analyzedPlans: Seq[AnalyzedPlan]):
Seq[NamedExpression] = {
+ val valueAttributes =
analyzedPlans.head.planAnalyzer.get.getExtractedSourcePlan.get.output
+ analyzedPlans.zipWithIndex.map {
+ case (analyzedPlan, clauseIndex) =>
+ val planAnalyzer = analyzedPlan.planAnalyzer.get
+ val keyAttributes = planAnalyzer.getExtractedSourcePlan.get.output
+ val replaceMap = CoalesceUnionUtil.buildAttributesMap(keyAttributes,
valueAttributes)
+ val projectPlan =
planAnalyzer.getConstructedProjectPlan.get.asInstanceOf[Project]
+ val newProjectList = projectPlan.projectList.map {
+ e => CoalesceUnionUtil.replaceAttributes(e, replaceMap)
+ }
- val normalExpressionPosition = analyzedInfo.positionInGroupingKeys
- var normalExpressionCount = 0
- var aggregateExpressionIndex = totalGroupingExpressionsCount
- val aggregateExpressions = ArrayBuffer[NamedExpression]()
- aggregateTemplate.aggregateExpressions.foreach {
- e =>
- removeAlias(e) match {
- case aggExpr if hasAggregateExpression(aggExpr) =>
- val (newAggExpr, count) =
- constructAggregateExpression(aggExpr, attributes,
aggregateExpressionIndex)
- aggregateExpressions += makeAlias(newAggExpr,
e.name).asInstanceOf[NamedExpression]
- aggregateExpressionIndex += count
- case other =>
- val position = normalExpressionPosition(normalExpressionCount)
- val attr = attributes(position)
- normalExpressionCount += 1
- aggregateExpressions += makeAlias(attr, e.name)
- .asInstanceOf[NamedExpression]
+ val structFields = ArrayBuffer[Expression]()
+ newProjectList.zipWithIndex.foreach {
+ case (e, fieldIndex) =>
+ structFields += Literal(UTF8String.fromString(s"f$fieldIndex"),
StringType)
+ structFields += e
}
+ CoalesceUnionUtil.makeAlias(CreateNamedStruct(structFields.toSeq),
s"clause_$clauseIndex")
}
- Aggregate(groupingExpressions.toSeq, aggregateExpressions.toSeq, child)
}
- def constructAggregateExpression(
- aggExpr: Expression,
- attributes: Seq[Attribute],
- index: Int): (Expression, Int) = {
- aggExpr match {
- case singleAggExpr: AggregateExpression =>
- val aggFunc = singleAggExpr.aggregateFunction
- val newAggFuncArgs = aggFunc.children.zipWithIndex.map {
- case (arg, i) =>
- attributes(index + i)
- }
- val newAggFunc =
-
aggFunc.withNewChildren(newAggFuncArgs).asInstanceOf[AggregateFunction]
- val res = AggregateExpression(
- newAggFunc,
- singleAggExpr.mode,
- singleAggExpr.isDistinct,
- singleAggExpr.filter,
- singleAggExpr.resultId)
- (res, 1)
- case combineAggExpr if hasAggregateExpression(combineAggExpr) =>
- val childrenExpressions = ArrayBuffer[Expression]()
- var totalCount = 0
- combineAggExpr.children.foreach {
- child =>
- val (expr, count) = constructAggregateExpression(child,
attributes, totalCount + index)
- childrenExpressions += expr
- totalCount += count
- }
- (combineAggExpr.withNewChildren(childrenExpressions.toSeq), totalCount)
- case _ => (attributes(index), 1)
+ def addFoldStructStep(
+ plan: LogicalPlan,
+ conditions: Seq[Expression],
+ analyzedPlans: Seq[AnalyzedPlan]): LogicalPlan = {
+ val structAttributes = buildClausesStructs(analyzedPlans)
+ assert(structAttributes.length == conditions.length)
+ val structAttributesWithCondition = structAttributes.zip(conditions).map {
+ case (struct, condition) =>
+ CoalesceUnionUtil
+ .makeAlias(If(condition, struct, Literal(null, struct.dataType)),
struct.name)
+ .asInstanceOf[NamedExpression]
}
+ Project(structAttributesWithCondition, plan)
}
}
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenCoalesceAggregationUnionSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenCoalesceAggregationUnionSuite.scala
index 23c6022727..8479d9f41e 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenCoalesceAggregationUnionSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenCoalesceAggregationUnionSuite.scala
@@ -47,6 +47,8 @@ class GlutenCoalesceAggregationUnionSuite extends
GlutenClickHouseWholeStageTran
.set("spark.io.compression.codec", "snappy")
.set("spark.sql.shuffle.partitions", "5")
.set("spark.sql.autoBroadcastJoinThreshold", "10MB")
+
.set("spark.gluten.sql.columnar.backend.ch.enable.coalesce.project.union",
"true")
+
.set("spark.gluten.sql.columnar.backend.ch.enable.coalesce.aggregation.union",
"true")
}
def createTestTable(tableName: String, data: DataFrame): Unit = {
@@ -392,4 +394,105 @@ class GlutenCoalesceAggregationUnionSuite extends
GlutenClickHouseWholeStageTran
compareResultsAgainstVanillaSpark(sql, true, checkHasUnion, true)
}
+ test("coalesce project union. case 1") {
+
+ val sql =
+ """
+ |select a, x, y from (
+ | select a, x, y from coalesce_union_t1 where b % 2 = 0
+ | union all
+ | select a, x, y from coalesce_union_t1 where b % 3 = 1
+ |) order by a, x, y
+ |""".stripMargin
+ compareResultsAgainstVanillaSpark(sql, true, checkNoUnion, true)
+ }
+
+ test("coalesce project union. case 2") {
+ val sql =
+ """
+ |select a, x, y from (
+ | select concat(a, 'x') as a , x, y from coalesce_union_t1 where b % 2
= 0
+ | union all
+ | select a, x, y + 2 as y from coalesce_union_t1 where b % 3 = 1
+ |) order by a, x, y
+ |""".stripMargin
+ compareResultsAgainstVanillaSpark(sql, true, checkNoUnion, true)
+ }
+
+ test("coalesce project union. case 3") {
+ val sql =
+ """
+ |select a, x, y from (
+ | select concat(a, 'x') as a , x, y, 1 as t from coalesce_union_t1
where b % 2 = 0
+ | union all
+ | select a, x, y + 2 as y, 2 as t from coalesce_union_t1 where b % 3 =
1
+ |) order by a, x, y, t
+ |""".stripMargin
+ compareResultsAgainstVanillaSpark(sql, true, checkNoUnion, true)
+ }
+
+ test("coalesce project union. case 4") {
+ val sql =
+ """
+ |select a, x, y from (
+ | select concat(a, 'x') as a , x, 1 as y from coalesce_union_t1 where
b % 2 = 0
+ | union all
+ | select a, x, y + 2 as y from coalesce_union_t1 where b % 3 = 1
+ |) order by a, x, y
+ |""".stripMargin
+ compareResultsAgainstVanillaSpark(sql, true, checkNoUnion, true)
+ }
+
+ test("coalesce project union. case 5") {
+ val sql =
+ """
+ |select a, x, y from (
+ | select a, x, y from (select a, x, y, b + 4 as b from
coalesce_union_t1) where b % 2 = 0
+ | union all
+ | select a, x, y + 2 as y from coalesce_union_t1 where b % 3 = 1
+ |) order by a, x, y
+ |""".stripMargin
+ compareResultsAgainstVanillaSpark(sql, true, checkNoUnion, true)
+ }
+
+ test("coalesce project union. case 6") {
+
+ val sql =
+ """
+ |select a, x, y, t from (
+ | select a, x, y, 1 as t from coalesce_union_t1 where b % 2 = 0
+ | union all
+ | select a, x, y, 2 as t from coalesce_union_t1 where b % 3 = 1
+ | union all
+ | select a, x, y, 3 as t from coalesce_union_t1 where b % 4 = 1
+ | union all
+ | select a, x, y, 4 as t from coalesce_union_t1 where b % 5 = 1
+ |) order by a, x, y, t
+ |""".stripMargin
+ compareResultsAgainstVanillaSpark(sql, true, checkNoUnion, true)
+ }
+
+ test("no coalesce project union. case 1") {
+ val sql =
+ """
+ |select a, x, y from (
+ | select a, x, y from coalesce_union_t1
+ | union all
+ | select a, x, y from coalesce_union_t1 where b % 3 = 1
+ |) order by a, x, y
+ |""".stripMargin
+ compareResultsAgainstVanillaSpark(sql, true, checkHasUnion, true)
+ }
+
+ test("no coalesce project union. case 2") {
+ val sql =
+ """
+ |select a, x, y from (
+ | select a , x, y from coalesce_union_t2 where b % 2 = 0
+ | union all
+ | select a, x, y from coalesce_union_t1 where b % 3 = 1
+ |) order by a, x, y
+ |""".stripMargin
+ compareResultsAgainstVanillaSpark(sql, true, checkHasUnion, true)
+ }
}
diff --git
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenCTEInlineSuite.scala
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenCTEInlineSuite.scala
index f5bdb254b2..3a05eda711 100644
---
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenCTEInlineSuite.scala
+++
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenCTEInlineSuite.scala
@@ -16,6 +16,18 @@
*/
package org.apache.spark.sql
-class GlutenCTEInlineSuiteAEOff extends CTEInlineSuiteAEOff with
GlutenSQLTestsTrait
+import org.apache.spark.SparkConf
-class GlutenCTEInlineSuiteAEOn extends CTEInlineSuiteAEOn with
GlutenSQLTestsTrait
+class GlutenCTEInlineSuiteAEOff extends CTEInlineSuiteAEOff with
GlutenSQLTestsTrait {
+ override def sparkConf: SparkConf =
+ super.sparkConf
+
.set("spark.gluten.sql.columnar.backend.ch.enable.coalesce.project.union",
"false")
+
+}
+
+class GlutenCTEInlineSuiteAEOn extends CTEInlineSuiteAEOn with
GlutenSQLTestsTrait {
+ override def sparkConf: SparkConf =
+ super.sparkConf
+
.set("spark.gluten.sql.columnar.backend.ch.enable.coalesce.project.union",
"false")
+
+}
diff --git
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenDataFrameSetOperationsSuite.scala
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenDataFrameSetOperationsSuite.scala
index d51d1034b0..fe7958b677 100644
---
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenDataFrameSetOperationsSuite.scala
+++
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenDataFrameSetOperationsSuite.scala
@@ -16,6 +16,14 @@
*/
package org.apache.spark.sql
+import org.apache.spark.SparkConf
+
class GlutenDataFrameSetOperationsSuite
extends DataFrameSetOperationsSuite
- with GlutenSQLTestsTrait {}
+ with GlutenSQLTestsTrait {
+ override def sparkConf: SparkConf =
+ super.sparkConf
+
.set("spark.gluten.sql.columnar.backend.ch.enable.coalesce.project.union",
"false")
+
.set("spark.gluten.sql.columnar.backend.ch.enable.coalesce.aggregation.union",
"false")
+
+}
diff --git
a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
index 3c1429a34f..75e97b48a9 100644
---
a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
+++
b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
@@ -622,6 +622,8 @@ class ClickHouseTestSettings extends BackendTestSettings {
.exclude("ordering and partitioning reporting")
enableSuite[GlutenDatasetAggregatorSuite]
enableSuite[GlutenDatasetCacheSuite]
+ // Disable this since coalesece union clauses rule will rewrite the query.
+ .exclude("SPARK-44653: non-trivial DataFrame unions should not break
caching")
enableSuite[GlutenDatasetOptimizationSuite]
enableSuite[GlutenDatasetPrimitiveSuite]
enableSuite[GlutenDatasetSerializerRegistratorSuite]
diff --git
a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenCTEInlineSuite.scala
b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenCTEInlineSuite.scala
index f5bdb254b2..3a05eda711 100644
---
a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenCTEInlineSuite.scala
+++
b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenCTEInlineSuite.scala
@@ -16,6 +16,18 @@
*/
package org.apache.spark.sql
-class GlutenCTEInlineSuiteAEOff extends CTEInlineSuiteAEOff with
GlutenSQLTestsTrait
+import org.apache.spark.SparkConf
-class GlutenCTEInlineSuiteAEOn extends CTEInlineSuiteAEOn with
GlutenSQLTestsTrait
+class GlutenCTEInlineSuiteAEOff extends CTEInlineSuiteAEOff with
GlutenSQLTestsTrait {
+ override def sparkConf: SparkConf =
+ super.sparkConf
+
.set("spark.gluten.sql.columnar.backend.ch.enable.coalesce.project.union",
"false")
+
+}
+
+class GlutenCTEInlineSuiteAEOn extends CTEInlineSuiteAEOn with
GlutenSQLTestsTrait {
+ override def sparkConf: SparkConf =
+ super.sparkConf
+
.set("spark.gluten.sql.columnar.backend.ch.enable.coalesce.project.union",
"false")
+
+}
diff --git
a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenDataFrameSetOperationsSuite.scala
b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenDataFrameSetOperationsSuite.scala
index d51d1034b0..fe7958b677 100644
---
a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenDataFrameSetOperationsSuite.scala
+++
b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenDataFrameSetOperationsSuite.scala
@@ -16,6 +16,14 @@
*/
package org.apache.spark.sql
+import org.apache.spark.SparkConf
+
class GlutenDataFrameSetOperationsSuite
extends DataFrameSetOperationsSuite
- with GlutenSQLTestsTrait {}
+ with GlutenSQLTestsTrait {
+ override def sparkConf: SparkConf =
+ super.sparkConf
+
.set("spark.gluten.sql.columnar.backend.ch.enable.coalesce.project.union",
"false")
+
.set("spark.gluten.sql.columnar.backend.ch.enable.coalesce.aggregation.union",
"false")
+
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]