peter-toth commented on code in PR #52238:
URL: https://github.com/apache/spark/pull/52238#discussion_r2338014380


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala:
##########
@@ -1188,6 +1237,117 @@ object CollapseProject extends Rule[LogicalPlan] with 
AliasHelper {
         r.copy(child = p.copy(projectList = buildCleanedProjectList(l1, 
p.projectList)))
       case Project(l1, s @ Sample(_, _, _, _, p2 @ Project(l2, _))) if 
isRenaming(l1, l2) =>
         s.copy(child = p2.copy(projectList = buildCleanedProjectList(l1, 
p2.projectList)))
+      case o => o
+    }
+  }
+
+  private def cheapToInlineProducer(
+      producer: NamedExpression,
+      relatedConsumers: Iterable[Expression]) = trimAliases(producer) match {
+    // These collection creation functions are not cheap as a producer, but we 
have
+    // optimizer rules that can optimize them out if they are only consumed by
+    // ExtractValue (See SimplifyExtractValueOps), so we need to allow to 
inline them to
+    // avoid perf regression. As an example:
+    //   Project(s.a, s.b, Project(create_struct(a, b, c) as s, child))
+    // We should collapse these two projects and eventually get Project(a, b, 
child)
+    case e @ (_: CreateNamedStruct | _: UpdateFields | _: CreateMap | _: 
CreateArray) =>
+      // We can inline the collection creation producer if at most one of its 
access
+      // is non-cheap. Cheap access here means the access can be optimized by
+      // `SimplifyExtractValueOps` and become a cheap expression. For example,
+      // `create_struct(a, b, c).a` is a cheap access as it can be optimized 
to `a`.
+      // For a query:
+      //   Project(s.a, s, Project(create_struct(a, b, c) as s, child))
+      // We should collapse these two projects and eventually get
+      //   Project(a, create_struct(a, b, c) as s, child)
+      var nonCheapAccessSeen = false
+      def nonCheapAccessVisitor(): Boolean = {
+        // Returns true for all calls after the first.
+        try {
+          nonCheapAccessSeen
+        } finally {
+          nonCheapAccessSeen = true
+        }
+      }
+
+      !relatedConsumers
+        .exists(findNonCheapAccesses(_, producer.toAttribute, e, 
nonCheapAccessVisitor))
+
+    case other => isCheap(other)
+  }
+
+  private def mergeProjectExpressions(
+      consumers: Seq[NamedExpression],
+      producers: Seq[NamedExpression],
+      alwaysInline: Boolean,
+      pythonUDFEvalTypesInUpperProjects: Set[Int],
+      pythonUDFArrowFallbackOnUDT: Boolean): (Seq[NamedExpression], 
Seq[NamedExpression]) = {
+    lazy val producerAttributes = AttributeSet(producers.collect { case a: 
Alias => a.toAttribute })
+
+    // A map from producer attributes to tuples of:
+    // - how many times the producer is referenced from consumers and
+    // - the set of consumers that reference the producer.
+    lazy val producerReferences = AttributeMap(consumers
+      .flatMap(e => 
collectReferences(e).filter(producerAttributes.contains).map(_ -> e))
+      .groupMap(_._1)(_._2)
+      .view.mapValues(v => v.size -> ExpressionSet(v)))
+
+    // Split the producers from the lower node to 4 categories:
+    // - `neverInlines` contains producer expressions that shouldn't be 
inlined.
+    //    These include non-deterministic expressions or expensive ones that 
are referenced multiple
+    //    times.
+    // - `mustInlines` contains expressions with Python UDFs that must be 
inlined into the upper
+    //    node to avoid performance issues.
+    // - `maybeInlines` contains expressions that might make sense to inline, 
such as expressions
+    //    that are used only once, or are cheap to inline.
+    //    But we need to take into account the side effect of adding new 
pass-through attributes to
+    //    the lower node, which can make the node much wider than it was 
originally.
+    val neverInlines = ListBuffer.empty[NamedExpression]
+    val mustInlines = ListBuffer.empty[NamedExpression]
+    val maybeInlines = ListBuffer.empty[NamedExpression]
+    val others = ListBuffer.empty[NamedExpression]
+    producers.foreach {
+      case a: Alias =>
+        producerReferences.get(a.toAttribute).foreach { case (count, 
relatedConsumers) =>
+          lazy val containsUDF = a.child.exists {
+            case udf: PythonUDF =>
+              isScalarPythonUDF(udf) &&
+                pythonUDFEvalTypesInUpperProjects.contains(
+                  correctEvalType(udf, pythonUDFArrowFallbackOnUDT))
+            case _ => false
+          }
+
+          if (!a.child.deterministic) {
+            neverInlines += a
+          } else if (alwaysInline || containsUDF) {
+            mustInlines += a
+          } else if (count == 1 || cheapToInlineProducer(a, relatedConsumers)) 
{
+            maybeInlines += a
+          } else {
+            neverInlines += a
+          }
+        }
+
+      case o => others += o

Review Comment:
   Fixed in 
https://github.com/apache/spark/pull/52238/commits/bc9df4f50ff83865c62a439938426d05c42d9aed



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala:
##########
@@ -1188,6 +1237,117 @@ object CollapseProject extends Rule[LogicalPlan] with 
AliasHelper {
         r.copy(child = p.copy(projectList = buildCleanedProjectList(l1, 
p.projectList)))
       case Project(l1, s @ Sample(_, _, _, _, p2 @ Project(l2, _))) if 
isRenaming(l1, l2) =>
         s.copy(child = p2.copy(projectList = buildCleanedProjectList(l1, 
p2.projectList)))
+      case o => o
+    }
+  }
+
+  private def cheapToInlineProducer(
+      producer: NamedExpression,
+      relatedConsumers: Iterable[Expression]) = trimAliases(producer) match {
+    // These collection creation functions are not cheap as a producer, but we 
have
+    // optimizer rules that can optimize them out if they are only consumed by
+    // ExtractValue (See SimplifyExtractValueOps), so we need to allow to 
inline them to
+    // avoid perf regression. As an example:
+    //   Project(s.a, s.b, Project(create_struct(a, b, c) as s, child))
+    // We should collapse these two projects and eventually get Project(a, b, 
child)
+    case e @ (_: CreateNamedStruct | _: UpdateFields | _: CreateMap | _: 
CreateArray) =>
+      // We can inline the collection creation producer if at most one of its 
access
+      // is non-cheap. Cheap access here means the access can be optimized by
+      // `SimplifyExtractValueOps` and become a cheap expression. For example,
+      // `create_struct(a, b, c).a` is a cheap access as it can be optimized 
to `a`.
+      // For a query:
+      //   Project(s.a, s, Project(create_struct(a, b, c) as s, child))
+      // We should collapse these two projects and eventually get
+      //   Project(a, create_struct(a, b, c) as s, child)
+      var nonCheapAccessSeen = false
+      def nonCheapAccessVisitor(): Boolean = {
+        // Returns true for all calls after the first.
+        try {
+          nonCheapAccessSeen
+        } finally {
+          nonCheapAccessSeen = true
+        }
+      }
+
+      !relatedConsumers
+        .exists(findNonCheapAccesses(_, producer.toAttribute, e, 
nonCheapAccessVisitor))
+
+    case other => isCheap(other)
+  }
+
+  private def mergeProjectExpressions(
+      consumers: Seq[NamedExpression],
+      producers: Seq[NamedExpression],
+      alwaysInline: Boolean,
+      pythonUDFEvalTypesInUpperProjects: Set[Int],
+      pythonUDFArrowFallbackOnUDT: Boolean): (Seq[NamedExpression], 
Seq[NamedExpression]) = {
+    lazy val producerAttributes = AttributeSet(producers.collect { case a: 
Alias => a.toAttribute })
+
+    // A map from producer attributes to tuples of:
+    // - how many times the producer is referenced from consumers and
+    // - the set of consumers that reference the producer.
+    lazy val producerReferences = AttributeMap(consumers
+      .flatMap(e => 
collectReferences(e).filter(producerAttributes.contains).map(_ -> e))
+      .groupMap(_._1)(_._2)
+      .view.mapValues(v => v.size -> ExpressionSet(v)))
+
+    // Split the producers from the lower node to 4 categories:
+    // - `neverInlines` contains producer expressions that shouldn't be 
inlined.
+    //    These include non-deterministic expressions or expensive ones that 
are referenced multiple
+    //    times.
+    // - `mustInlines` contains expressions with Python UDFs that must be 
inlined into the upper
+    //    node to avoid performance issues.
+    // - `maybeInlines` contains expressions that might make sense to inline, 
such as expressions
+    //    that are used only once, or are cheap to inline.
+    //    But we need to take into account the side effect of adding new 
pass-through attributes to
+    //    the lower node, which can make the node much wider than it was 
originally.
+    val neverInlines = ListBuffer.empty[NamedExpression]
+    val mustInlines = ListBuffer.empty[NamedExpression]
+    val maybeInlines = ListBuffer.empty[NamedExpression]
+    val others = ListBuffer.empty[NamedExpression]

Review Comment:
   Added comment in 
https://github.com/apache/spark/pull/52238/commits/bc9df4f50ff83865c62a439938426d05c42d9aed.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to