fqaiser94 commented on a change in pull request #29812:
URL: https://github.com/apache/spark/pull/29812#discussion_r491929282
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala
##########
@@ -17,16 +17,29 @@
package org.apache.spark.sql.catalyst.optimizer
-import org.apache.spark.sql.catalyst.expressions.WithFields
+import scala.collection.mutable
+
+import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField,
WithFields}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
/**
- * Combines all adjacent [[WithFields]] expression into a single
[[WithFields]] expression.
+ * Optimizes [[WithFields]] expression chains.
*/
-object CombineWithFields extends Rule[LogicalPlan] {
+object OptimizeWithFields extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ case WithFields(structExpr, names, values) if names.distinct.length !=
names.length =>
+ val newNames = mutable.ArrayBuffer.empty[String]
+ val newValues = mutable.ArrayBuffer.empty[Expression]
+ names.zip(values).reverse.foreach { case (name, value) =>
+ if (!newNames.contains(name)) {
Review comment:
should use `resolver` here otherwise I think we will have correct-ness
issues.
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala
##########
@@ -17,16 +17,29 @@
package org.apache.spark.sql.catalyst.optimizer
-import org.apache.spark.sql.catalyst.expressions.WithFields
+import scala.collection.mutable
+
+import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField,
WithFields}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
/**
- * Combines all adjacent [[WithFields]] expression into a single
[[WithFields]] expression.
+ * Optimizes [[WithFields]] expression chains.
*/
-object CombineWithFields extends Rule[LogicalPlan] {
+object OptimizeWithFields extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ case WithFields(structExpr, names, values) if names.distinct.length !=
names.length =>
+ val newNames = mutable.ArrayBuffer.empty[String]
+ val newValues = mutable.ArrayBuffer.empty[Expression]
+ names.zip(values).reverse.foreach { case (name, value) =>
+ if (!newNames.contains(name)) {
+ newNames += name
+ newValues += value
+ }
+ }
+ WithFields(structExpr, names = newNames.reverse.toSeq, valExprs =
newValues.reverse.toSeq)
Review comment:
For my understanding, can you explain how we expect to benefit from this
optimization?
I ask because we do this kind of deduplication inside of `WithFields`
already as part of the `foldLeft` operation
[here](https://github.com/apache/spark/blob/d01594e8d186e63a6c3ce361e756565e830d5237/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala#L578).
It will only keep the last `valExpr` for each `name`. So I think the optimized
logical plan will be the same with or without this optimization in all
scenarios? CMIIW
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala
##########
@@ -17,16 +17,29 @@
package org.apache.spark.sql.catalyst.optimizer
-import org.apache.spark.sql.catalyst.expressions.WithFields
+import scala.collection.mutable
+
+import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField,
WithFields}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
/**
- * Combines all adjacent [[WithFields]] expression into a single
[[WithFields]] expression.
+ * Optimizes [[WithFields]] expression chains.
*/
-object CombineWithFields extends Rule[LogicalPlan] {
+object OptimizeWithFields extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ case WithFields(structExpr, names, values) if names.distinct.length !=
names.length =>
Review comment:
could this `case` statement be after the next `case` statement? So that
we combine the chains first before deduplicating?
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]