This is an automated email from the ASF dual-hosted git repository.

lgbo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new c2a578c337 enable lazy expand for avg and sum(decimal) (#7840)
c2a578c337 is described below

commit c2a578c3379adafb79da9c06df3df4e41eb9139a
Author: lgbo <[email protected]>
AuthorDate: Fri Nov 8 17:48:33 2024 +0800

    enable lazy expand for avg and sum(decimal) (#7840)
---
 .../gluten/extension/LazyAggregateExpandRule.scala | 15 +++++++---
 .../GlutenClickHouseTPCHSaltNullParquetSuite.scala | 35 ++++++++++++++++++++++
 2 files changed, 46 insertions(+), 4 deletions(-)

diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/LazyAggregateExpandRule.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/LazyAggregateExpandRule.scala
index e06503a5e1..86b28ab1f7 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/LazyAggregateExpandRule.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/LazyAggregateExpandRule.scala
@@ -190,10 +190,12 @@ case class LazyAggregateExpandRule(session: SparkSession) 
extends Rule[SparkPlan
     // 2. if any aggregate function uses attributes which is not from expand's 
child, we don't
     // enable this
     if (
-      !aggregate.aggregateExpressions.forall(
+      !aggregate.aggregateExpressions.forall {
         e =>
           isValidAggregateFunction(e) &&
-            
e.aggregateFunction.references.forall(expandOutputAttributes.contains(_)))
+          e.aggregateFunction.references.forall(
+            attr => 
expandOutputAttributes.find(_.semanticEquals(attr)).isDefined)
+      }
     ) {
       logDebug(s"xxx Some aggregate functions are not supported")
       return false
@@ -267,7 +269,8 @@ case class LazyAggregateExpandRule(session: SparkSession) 
extends Rule[SparkPlan
       case _: Count => true
       case _: Max => true
       case _: Min => true
-      case sum: Sum => !sum.dataType.isInstanceOf[DecimalType]
+      case _: Average => true
+      case _: Sum => true
       case _ => false
     }
   }
@@ -275,7 +278,11 @@ case class LazyAggregateExpandRule(session: SparkSession) 
extends Rule[SparkPlan
   def getReplaceAttribute(
       toReplace: Attribute,
       attributesToReplace: Map[Attribute, Attribute]): Attribute = {
-    attributesToReplace.getOrElse(toReplace, toReplace)
+    val kv = attributesToReplace.find(kv => kv._1.semanticEquals(toReplace))
+    kv match {
+      case Some((_, v)) => v
+      case None => toReplace
+    }
   }
 
   def buildReplaceAttributeMap(expand: ExpandExecTransformer): Map[Attribute, 
Attribute] = {
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
index 12047b300c..9affdeb7f7 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
@@ -3068,6 +3068,41 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends 
GlutenClickHouseTPCHAbstr
     compareResultsAgainstVanillaSpark(sql, true, checkLazyExpand)
   }
 
+  test("GLUTEN-7647 lazy expand for avg and sum") {
+    val create_table_sql =
+      """
+        |create table test_7647(x bigint, y bigint, z bigint, v decimal(10, 
2)) using parquet
+        |""".stripMargin
+    spark.sql(create_table_sql)
+    val insert_data_sql =
+      """
+        |insert into test_7647 values
+        |(1, 1, 1, 1.0),
+        |(2, 2, 2, 2.0),
+        |(3, 3, 3, 3.0),
+        |(2,2,1, 4.0)
+        |""".stripMargin
+    spark.sql(insert_data_sql)
+
+    def checkLazyExpand(df: DataFrame): Unit = {
+      val expands = collectWithSubqueries(df.queryExecution.executedPlan) {
+        case e: ExpandExecTransformer if 
(e.child.isInstanceOf[HashAggregateExecBaseTransformer]) =>
+          e
+      }
+      assert(expands.size == 1)
+    }
+
+    var sql = "select x, y, avg(z), sum(v) from test_7647 group by x, y with 
cube order by x, y"
+    compareResultsAgainstVanillaSpark(sql, true, checkLazyExpand)
+    sql =
+      "select x, y, count(distinct z), avg(v) from test_7647 group by x, y 
with cube order by x, y"
+    compareResultsAgainstVanillaSpark(sql, true, checkLazyExpand)
+    sql =
+      "select x, y, count(distinct z), sum(v) from test_7647 group by x, y 
with cube order by x, y"
+    compareResultsAgainstVanillaSpark(sql, true, checkLazyExpand)
+    spark.sql("drop table if exists test_7647")
+  }
+
   test("GLUTEN-7759: Fix bug of agg pre-project push down") {
     val table_create_sql =
       "create table test_tbl_7759(id bigint, name string, day string) using 
parquet"


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to