This is an automated email from the ASF dual-hosted git repository.

mingliang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new f470973243  [GLUTEN-8229][VL] Don't rewrite collect_list/collect_set 
in window (#8230)
f470973243 is described below

commit f470973243c7ee541d75ea97ad760ac48bfd08e3
Author: Mingliang Zhu <[email protected]>
AuthorDate: Fri Dec 13 18:42:53 2024 +0800

     [GLUTEN-8229][VL] Don't rewrite collect_list/collect_set in window (#8230)
---
 .../gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala   |  4 +++-
 .../org/apache/gluten/extension/CollectRewriteRule.scala   | 14 ++++----------
 2 files changed, 7 insertions(+), 11 deletions(-)

diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
index 8984a9551b..5975a20a26 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
@@ -35,7 +35,7 @@ import org.apache.spark.shuffle.utils.ShuffleUtil
 import org.apache.spark.sql.catalyst.catalog.BucketSpec
 import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
+import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
CollectList, CollectSet}
 import org.apache.spark.sql.catalyst.optimizer.BuildSide
 import org.apache.spark.sql.catalyst.plans.JoinType
 import org.apache.spark.sql.catalyst.plans.physical._
@@ -759,7 +759,9 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
       Sig[UserDefinedAggregateFunction](ExpressionNames.UDAF_PLACEHOLDER),
       Sig[NaNvl](ExpressionNames.NANVL),
       Sig[VeloxCollectList](ExpressionNames.COLLECT_LIST),
+      Sig[CollectList](ExpressionNames.COLLECT_LIST),
       Sig[VeloxCollectSet](ExpressionNames.COLLECT_SET),
+      Sig[CollectSet](ExpressionNames.COLLECT_SET),
       Sig[VeloxBloomFilterMightContain](ExpressionNames.MIGHT_CONTAIN),
       Sig[VeloxBloomFilterAggregate](ExpressionNames.BLOOM_FILTER_AGG),
       // For test purpose.
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/extension/CollectRewriteRule.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/extension/CollectRewriteRule.scala
index 48541b234e..e76de56374 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/extension/CollectRewriteRule.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/extension/CollectRewriteRule.scala
@@ -20,11 +20,11 @@ import org.apache.gluten.expression.ExpressionMappings
 import org.apache.gluten.expression.aggregate.{VeloxCollectList, 
VeloxCollectSet}
 
 import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.catalyst.expressions.{Expression, WindowExpression}
+import org.apache.spark.sql.catalyst.expressions.Expression
 import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, 
Window}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
 import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.catalyst.trees.TreePattern.{AGGREGATE, 
AGGREGATE_EXPRESSION, WINDOW, WINDOW_EXPRESSION}
+import org.apache.spark.sql.catalyst.trees.TreePattern.{AGGREGATE, 
AGGREGATE_EXPRESSION}
 
 import scala.reflect.{classTag, ClassTag}
 
@@ -40,7 +40,7 @@ case class CollectRewriteRule(spark: SparkSession) extends 
Rule[LogicalPlan] {
       return plan
     }
 
-    val newPlan = plan.transformUpWithPruning(_.containsAnyPattern(WINDOW, 
AGGREGATE)) {
+    val newPlan = plan.transformUpWithPruning(_.containsPattern(AGGREGATE)) {
       case node =>
         replaceAggCollect(node)
     }
@@ -57,12 +57,6 @@ case class CollectRewriteRule(spark: SparkSession) extends 
Rule[LogicalPlan] {
           case ToVeloxCollect(newAggExpr) =>
             newAggExpr
         }
-      case w: Window =>
-        w.transformExpressionsWithPruning(
-          _.containsAllPatterns(AGGREGATE_EXPRESSION, WINDOW_EXPRESSION)) {
-          case windowExpr @ WindowExpression(ToVeloxCollect(newAggExpr), _) =>
-            windowExpr.copy(newAggExpr)
-        }
       case other => other
     }
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to