zhztheplayer commented on code in PR #7451:
URL: https://github.com/apache/incubator-gluten/pull/7451#discussion_r1802267777


##########
backends-velox/src/main/scala/org/apache/gluten/extension/CollectRewriteRule.scala:
##########
@@ -37,70 +37,58 @@ import scala.reflect.{classTag, ClassTag}
 case class CollectRewriteRule(spark: SparkSession) extends Rule[LogicalPlan] {
   import CollectRewriteRule._
   override def apply(plan: LogicalPlan): LogicalPlan = 
LogicalPlanSelector.maybe(spark, plan) {
-    val out = plan.transformUp {
-      case node =>
-        val out = replaceCollectSet(replaceCollectList(node))
-        out
-    }
-    if (out.fastEquals(plan)) {
+    if (!has[VeloxCollectSet] && !has[VeloxCollectList]) {
       return plan
     }
-    out
-  }
 
-  private def replaceCollectList(node: LogicalPlan): LogicalPlan = {
-    node.transformExpressions {
-      case func @ AggregateExpression(l: CollectList, _, _, _, _) if 
has[VeloxCollectList] =>
-        func.copy(VeloxCollectList(l.child))
+    val newPlan = plan.transformUp {
+      case node =>
+        replaceAggCollect(node)
     }
+    if (newPlan.fastEquals(plan)) {
+      return plan
+    }
+    newPlan
   }
 
-  private def replaceCollectSet(node: LogicalPlan): LogicalPlan = {
-    // 1. Replace null result from VeloxCollectSet with empty array to align 
with
-    //    vanilla Spark.
-    // 2. Filter out null inputs from VeloxCollectSet to align with vanilla 
Spark.
-    //
-    // Since https://github.com/apache/incubator-gluten/pull/4805
+  private def replaceAggCollect(node: LogicalPlan): LogicalPlan = {
     node match {
       case agg: Aggregate =>
-        agg.transformExpressions {
-          case ToVeloxCollectSet(newAggFunc) =>
-            val out = ensureNonNull(newAggFunc)
-            out
+        
agg.transformExpressionsWithPruning(_.containsPattern(AGGREGATE_EXPRESSION)) {
+          case ToVeloxCollect(newAggExpr) =>
+            newAggExpr
         }
       case w: Window =>
-        w.transformExpressions {
-          case func @ WindowExpression(ToVeloxCollectSet(newAggFunc), _) =>
-            val out = ensureNonNull(func.copy(newAggFunc))
-            out
+        w.transformExpressionsWithPruning(
+          _.containsAllPatterns(AGGREGATE_EXPRESSION, WINDOW_EXPRESSION)) {
+          case windowExpr @ WindowExpression(ToVeloxCollect(newAggExpr), _) =>
+            windowExpr.copy(newAggExpr)
         }
       case other => other
     }
   }
 }
 
 object CollectRewriteRule {
-  private def ensureNonNull(expr: Expression): Expression = {
-    val out =
-      Coalesce(List(expr, Literal.create(Seq.empty, expr.dataType)))
-    assert(!out.nullable)
-    assert(!out.dataType.asInstanceOf[ArrayType].containsNull)
-    out
-  }

Review Comment:
   > Because VeloxCollect overrides the `defaultResult` with the default value. 
The Spark optimizer will ensure that it is not null.
   
   Didn't see this method is used much by Spark in regular aggregation planning 
routine: 
https://github.com/search?q=repo%3Aapache%2Fspark%20defaultResult&type=code. Do 
you see the `Coalesce` in plan after removing this logic anyway?
   
   BTW as Rui's Velox PR was merged recently 
https://github.com/facebookincubator/velox/pull/10737, @rui-mo did you make any 
corresponding changes in Gluten at that time?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to