viirya commented on a change in pull request #29587:
URL: https://github.com/apache/spark/pull/29587#discussion_r493101256



##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala
##########
@@ -17,29 +17,202 @@
 
 package org.apache.spark.sql.catalyst.analysis
 
+import scala.collection.mutable
+
 import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
+import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.optimizer.CombineUnions
 import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, 
Union}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
 import org.apache.spark.sql.util.SchemaUtils
+import org.apache.spark.unsafe.types.UTF8String
 
 /**
  * Resolves different children of Union to a common set of columns.
  */
 object ResolveUnion extends Rule[LogicalPlan] {
-  private def unionTwoSides(
+  /**
+   * This method sorts recursively columns in a struct expression based on 
column names.
+   */
+  private def sortStructFields(expr: Expression): Expression = {
+    val existingExprs = 
expr.dataType.asInstanceOf[StructType].fieldNames.zipWithIndex.map {
+      case (name, i) =>
+        val fieldExpr = GetStructField(KnownNotNull(expr), i)
+        if (fieldExpr.dataType.isInstanceOf[StructType]) {
+          (name, sortStructFields(fieldExpr))
+        } else {
+          (name, fieldExpr)
+        }
+    }.sortBy(_._1).flatMap(pair => Seq(Literal(pair._1), pair._2))
+
+    val newExpr = CreateNamedStruct(existingExprs)
+    if (expr.nullable) {
+      If(IsNull(expr), Literal(null, newExpr.dataType), newExpr)
+    } else {
+      newExpr
+    }
+  }
+
+  /**
+   * Assumes input expressions are field expression of `CreateNamedStruct`. 
This method
+   * sorts the expressions based on field names.
+   */
+  private def sortFieldExprs(fieldExprs: Seq[Expression]): Seq[Expression] = {
+    fieldExprs.grouped(2).map { e =>
+      Seq(e.head, e.last)
+    }.toSeq.sortBy { pair =>
+      assert(pair.head.isInstanceOf[Literal])
+      pair.head.eval().asInstanceOf[UTF8String].toString
+    }.flatten
+  }
+
+  /**
+   * This helper method sorts fields in a `WithFields` expression by field 
name.
+   */
+  private def sortStructFieldsInWithFields(expr: Expression): Expression = 
expr transformUp {
+    case w: WithFields if w.resolved =>
+      w.evalExpr match {
+        case i @ If(IsNull(_), _, CreateNamedStruct(fieldExprs)) =>
+          val sorted = sortFieldExprs(fieldExprs)
+          val newStruct = CreateNamedStruct(sorted)
+          i.copy(trueValue = Literal(null, newStruct.dataType), falseValue = 
newStruct)
+        case CreateNamedStruct(fieldExprs) =>
+          val sorted = sortFieldExprs(fieldExprs)
+          val newStruct = CreateNamedStruct(sorted)
+          newStruct
+        case other =>
+          throw new AnalysisException(s"`WithFields` has incorrect eval 
expression: $other")
+      }
+  }
+
+  def simplifyWithFields(expr: Expression): Expression = {
+    expr.transformUp {
+      case WithFields(structExpr, names, values) if names.distinct.length != 
names.length =>

Review comment:
       When I was working on simplify expression tree for scale issue, I was 
running into this case. I changed other part of code later, so maybe this isn't 
hit now. Let me check.

##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala
##########
@@ -17,29 +17,202 @@
 
 package org.apache.spark.sql.catalyst.analysis
 
+import scala.collection.mutable
+
 import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
+import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.optimizer.CombineUnions
 import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, 
Union}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
 import org.apache.spark.sql.util.SchemaUtils
+import org.apache.spark.unsafe.types.UTF8String
 
 /**
  * Resolves different children of Union to a common set of columns.
  */
 object ResolveUnion extends Rule[LogicalPlan] {
-  private def unionTwoSides(
+  /**
+   * This method sorts recursively columns in a struct expression based on 
column names.
+   */
+  private def sortStructFields(expr: Expression): Expression = {
+    val existingExprs = 
expr.dataType.asInstanceOf[StructType].fieldNames.zipWithIndex.map {
+      case (name, i) =>
+        val fieldExpr = GetStructField(KnownNotNull(expr), i)
+        if (fieldExpr.dataType.isInstanceOf[StructType]) {
+          (name, sortStructFields(fieldExpr))
+        } else {
+          (name, fieldExpr)
+        }
+    }.sortBy(_._1).flatMap(pair => Seq(Literal(pair._1), pair._2))
+
+    val newExpr = CreateNamedStruct(existingExprs)
+    if (expr.nullable) {
+      If(IsNull(expr), Literal(null, newExpr.dataType), newExpr)
+    } else {
+      newExpr
+    }
+  }
+
+  /**
+   * Assumes input expressions are field expression of `CreateNamedStruct`. 
This method
+   * sorts the expressions based on field names.
+   */
+  private def sortFieldExprs(fieldExprs: Seq[Expression]): Seq[Expression] = {
+    fieldExprs.grouped(2).map { e =>
+      Seq(e.head, e.last)
+    }.toSeq.sortBy { pair =>
+      assert(pair.head.isInstanceOf[Literal])
+      pair.head.eval().asInstanceOf[UTF8String].toString
+    }.flatten
+  }
+
+  /**
+   * This helper method sorts fields in a `WithFields` expression by field 
name.
+   */
+  private def sortStructFieldsInWithFields(expr: Expression): Expression = 
expr transformUp {
+    case w: WithFields if w.resolved =>
+      w.evalExpr match {
+        case i @ If(IsNull(_), _, CreateNamedStruct(fieldExprs)) =>
+          val sorted = sortFieldExprs(fieldExprs)
+          val newStruct = CreateNamedStruct(sorted)
+          i.copy(trueValue = Literal(null, newStruct.dataType), falseValue = 
newStruct)
+        case CreateNamedStruct(fieldExprs) =>
+          val sorted = sortFieldExprs(fieldExprs)
+          val newStruct = CreateNamedStruct(sorted)
+          newStruct
+        case other =>
+          throw new AnalysisException(s"`WithFields` has incorrect eval 
expression: $other")
+      }
+  }
+
+  def simplifyWithFields(expr: Expression): Expression = {
+    expr.transformUp {
+      case WithFields(structExpr, names, values) if names.distinct.length != 
names.length =>

Review comment:
       Ok, seems not. However, this is still a useful to optimize inefficient 
`WithFields` expressions. I relied on these rules to simplify expression tree 
and debugging.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to