This is an automated email from the ASF dual-hosted git repository.
lgbo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 96045b1175 [GLUTEN-7778][CH] Make aggregation output schema same as CH
native (#7811)
96045b1175 is described below
commit 96045b117531a2a0db93bf4f9ea7bac5d7fba117
Author: lgbo <[email protected]>
AuthorDate: Thu Nov 7 09:51:19 2024 +0800
[GLUTEN-7778][CH] Make aggregation output schema same as CH native (#7811)
* unity agg output
* update
* update
* update
* fixed
---
.../clickhouse/CHSparkPlanExecApi.scala | 9 +++--
.../backendsapi/clickhouse/CHTransformerApi.scala | 11 +++---
.../execution/CHHashAggregateExecTransformer.scala | 39 ++++++++++++++++++++++
3 files changed, 53 insertions(+), 6 deletions(-)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
index ba165d936e..190fcb13ea 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
@@ -158,16 +158,21 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with
Logging {
aggregateAttributes: Seq[Attribute],
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
- child: SparkPlan): HashAggregateExecBaseTransformer =
+ child: SparkPlan): HashAggregateExecBaseTransformer = {
+ val replacedResultExpressions =
CHHashAggregateExecTransformer.getCHAggregateResultExpressions(
+ groupingExpressions,
+ aggregateExpressions,
+ resultExpressions)
CHHashAggregateExecTransformer(
requiredChildDistributionExpressions,
groupingExpressions.distinct,
aggregateExpressions,
aggregateAttributes,
initialInputBufferOffset,
- resultExpressions.distinct,
+ replacedResultExpressions.distinct,
child
)
+ }
/** Generate HashAggregateExecPullOutHelper */
override def genHashAggregateExecPullOutHelper(
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHTransformerApi.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHTransformerApi.scala
index e5b7182585..bf909c52ac 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHTransformerApi.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHTransformerApi.scala
@@ -177,10 +177,13 @@ class CHTransformerApi extends TransformerApi with
Logging {
// output name will be different from grouping expressions,
// so using output attribute instead of grouping expression
val groupingExpressions =
hash.output.splitAt(hash.groupingExpressions.size)._1
- val aggResultAttributes =
CHHashAggregateExecTransformer.getAggregateResultAttributes(
- groupingExpressions,
- hash.aggregateExpressions
- )
+ val aggResultAttributes = CHHashAggregateExecTransformer
+ .getCHAggregateResultExpressions(
+ groupingExpressions,
+ hash.aggregateExpressions,
+ hash.resultExpressions
+ )
+ .map(_.toAttribute)
if (aggResultAttributes.size == hash.output.size) {
aggResultAttributes
} else {
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala
index f5e64330cd..fcf6320f8e 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala
@@ -43,6 +43,45 @@ import scala.collection.JavaConverters._
import scala.collection.mutable.ListBuffer
object CHHashAggregateExecTransformer {
+ // The result attributes of aggregate expressions from vanilla may be
different from CH native.
+ // For example, the result attributes of `avg(x)` are `sum(x)` and
`count(x)`. This could bring
+ // some unexpected issues. So we need to make the result attributes
consistent with CH native.
+ def getCHAggregateResultExpressions(
+ groupingExpressions: Seq[NamedExpression],
+ aggregateExpressions: Seq[AggregateExpression],
+ resultExpressions: Seq[NamedExpression]): Seq[NamedExpression] = {
+ var adjustedResultExpressions = resultExpressions.slice(0,
groupingExpressions.length)
+ var resultExpressionIndex = groupingExpressions.length
+ adjustedResultExpressions ++ aggregateExpressions.flatMap {
+ aggExpr =>
+ aggExpr.mode match {
+ case Partial | PartialMerge =>
+ // For partial aggregate, the size of the result expressions of an
aggregate expression
+ // is the same as aggBufferAttributes' length
+ val aggBufferAttributesCount =
aggExpr.aggregateFunction.aggBufferAttributes.length
+ aggExpr.aggregateFunction match {
+ case avg: Average =>
+ val res = Seq(aggExpr.resultAttribute)
+ resultExpressionIndex += aggBufferAttributesCount
+ res
+ case sum: Sum if (sum.dataType.isInstanceOf[DecimalType]) =>
+ val res = Seq(resultExpressions(resultExpressionIndex))
+ resultExpressionIndex += aggBufferAttributesCount
+ res
+ case _ =>
+ val res = resultExpressions
+ .slice(resultExpressionIndex, resultExpressionIndex +
aggBufferAttributesCount)
+ resultExpressionIndex += aggBufferAttributesCount
+ res
+ }
+ case _ =>
+ val res = Seq(resultExpressions(resultExpressionIndex))
+ resultExpressionIndex += 1
+ res
+ }
+ }
+ }
+
def getAggregateResultAttributes(
groupingExpressions: Seq[NamedExpression],
aggregateExpressions: Seq[AggregateExpression]): Seq[Attribute] = {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]