This is an automated email from the ASF dual-hosted git repository.

gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 72d3266  [SPARK-35144][SQL] Migrate to transformWithPruning for object 
rules
72d3266 is described below

commit 72d32662d4744440e286a639783fed8dcf6c3948
Author: Yingyi Bu <yingyi...@databricks.com>
AuthorDate: Fri May 7 18:36:28 2021 +0800

    [SPARK-35144][SQL] Migrate to transformWithPruning for object rules
    
    ### What changes were proposed in this pull request?
    
    Added the following TreePattern enums:
    - APPEND_COLUMNS
    - DESERIALIZE_TO_OBJECT
    - LAMBDA_VARIABLE
    - MAP_OBJECTS
    - SERIALIZE_FROM_OBJECT
    - PROJECT
    - TYPED_FILTER
    
    Added tree traversal pruning to the following rules dealing with objects:
    - EliminateSerialization
    - CombineTypedFilters
    - EliminateMapObjects
    - ObjectSerializerPruning
    
    ### Why are the changes needed?
    
    Reduce the number of tree traversals and hence improve the query 
compilation latency.
    
    ### How was this patch tested?
    
    Existing tests.
    
    Closes #32451 from sigmod/object.
    
    Authored-by: Yingyi Bu <yingyi...@databricks.com>
    Signed-off-by: Gengliang Wang <ltn...@gmail.com>
---
 .../spark/sql/catalyst/expressions/objects/objects.scala  |  6 +++++-
 .../org/apache/spark/sql/catalyst/optimizer/objects.scala | 15 ++++++++++-----
 .../catalyst/plans/logical/basicLogicalOperators.scala    |  2 ++
 .../apache/spark/sql/catalyst/plans/logical/object.scala  |  8 ++++++++
 .../spark/sql/catalyst/rules/RuleIdCollection.scala       |  5 +++++
 .../apache/spark/sql/catalyst/trees/TreePatterns.scala    |  7 +++++++
 6 files changed, 37 insertions(+), 6 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 469c895..40378a3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.catalyst.trees.TernaryLike
-import org.apache.spark.sql.catalyst.trees.TreePattern.{NULL_CHECK, 
TreePattern}
+import org.apache.spark.sql.catalyst.trees.TreePattern._
 import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, 
GenericArrayData, MapData}
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.types._
@@ -669,6 +669,8 @@ case class LambdaVariable(
 
   private val accessor: (InternalRow, Int) => Any = 
InternalRow.getAccessor(dataType, nullable)
 
+  final override val nodePatterns: Seq[TreePattern] = Seq(LAMBDA_VARIABLE)
+
   // Interpreted execution of `LambdaVariable` always get the 0-index element 
from input row.
   override def eval(input: InternalRow): Any = {
     assert(input.numFields == 1,
@@ -781,6 +783,8 @@ case class MapObjects private(
   override def second: Expression = lambdaFunction
   override def third: Expression = inputData
 
+  final override val nodePatterns: Seq[TreePattern] = Seq(MAP_OBJECTS)
+
   // The data with UserDefinedType are actually stored with the data type of 
its sqlType.
   // When we want to apply MapObjects on it, we have to use it.
   lazy private val inputDataType = inputData.dataType match {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala
index 97712a0..52544ff 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.objects._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.catalyst.trees.TreePattern._
 import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType, 
UserDefinedType}
 
 /*
@@ -35,7 +36,8 @@ import org.apache.spark.sql.types.{ArrayType, DataType, 
MapType, StructType, Use
  * representation of data item.  For example back to back map operations.
  */
 object EliminateSerialization extends Rule[LogicalPlan] {
-  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
+    _.containsAnyPattern(DESERIALIZE_TO_OBJECT, APPEND_COLUMNS, TYPED_FILTER), 
ruleId) {
     case d @ DeserializeToObject(_, _, s: SerializeFromObject)
       if d.outputObjAttr.dataType == s.inputObjAttr.dataType =>
       // Adds an extra Project here, to preserve the output expr id of 
`DeserializeToObject`.
@@ -72,7 +74,8 @@ object EliminateSerialization extends Rule[LogicalPlan] {
  * merging the filter functions into one conjunctive function.
  */
 object CombineTypedFilters extends Rule[LogicalPlan] {
-  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
+    _.containsPattern(TYPED_FILTER), ruleId) {
     case t1 @ TypedFilter(_, _, _, _, t2 @ TypedFilter(_, _, _, _, child))
         if t1.deserializer.dataType == t2.deserializer.dataType =>
       TypedFilter(
@@ -108,7 +111,8 @@ object CombineTypedFilters extends Rule[LogicalPlan] {
  *   2. no custom collection class specified representation of data item.
  */
 object EliminateMapObjects extends Rule[LogicalPlan] {
-  def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+  def apply(plan: LogicalPlan): LogicalPlan = 
plan.transformAllExpressionsWithPruning(
+    _.containsAllPatterns(MAP_OBJECTS, LAMBDA_VARIABLE), ruleId) {
      case MapObjects(_, LambdaVariable(_, _, false, _), inputData, None) => 
inputData
   }
 }
@@ -207,7 +211,8 @@ object ObjectSerializerPruning extends Rule[LogicalPlan] {
     }
   }
 
-  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
+    _.containsAllPatterns(SERIALIZE_FROM_OBJECT, PROJECT), ruleId) {
     case p @ Project(_, s: SerializeFromObject) =>
       // Prunes individual serializer if it is not used at all by above 
projection.
       val usedRefs = p.references
@@ -252,7 +257,7 @@ object ReassignLambdaVariableID extends Rule[LogicalPlan] {
     var hasNegativeIds = false
     var hasPositiveIds = false
 
-    plan.transformAllExpressions {
+    
plan.transformAllExpressionsWithPruning(_.containsPattern(LAMBDA_VARIABLE), 
ruleId) {
       case lr: LambdaVariable if lr.id == 0 =>
         throw new IllegalStateException("LambdaVariable should never has 0 as 
its ID.")
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index f5e92fb..1bd1666 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -70,6 +70,8 @@ case class Project(projectList: Seq[NamedExpression], child: 
LogicalPlan)
   override def output: Seq[Attribute] = projectList.map(_.toAttribute)
   override def maxRows: Option[Long] = child.maxRows
 
+  final override val nodePatterns: Seq[TreePattern] = Seq(PROJECT)
+
   override lazy val resolved: Boolean = {
     val hasSpecialExpressions = projectList.exists ( _.collect {
         case agg: AggregateExpression => agg
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index 6d61a86..1f7eb67 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -24,6 +24,7 @@ import 
org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
 import org.apache.spark.sql.catalyst.encoders._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.objects.Invoke
+import org.apache.spark.sql.catalyst.trees.TreePattern._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
 import org.apache.spark.sql.types._
@@ -80,6 +81,7 @@ case class DeserializeToObject(
     deserializer: Expression,
     outputObjAttr: Attribute,
     child: LogicalPlan) extends UnaryNode with ObjectProducer {
+  final override val nodePatterns: Seq[TreePattern] = 
Seq(DESERIALIZE_TO_OBJECT)
   override protected def withNewChildInternal(newChild: LogicalPlan): 
DeserializeToObject =
     copy(child = newChild)
 }
@@ -94,6 +96,8 @@ case class SerializeFromObject(
 
   override def output: Seq[Attribute] = serializer.map(_.toAttribute)
 
+  final override val nodePatterns: Seq[TreePattern] = 
Seq(SERIALIZE_FROM_OBJECT)
+
   override protected def withNewChildInternal(newChild: LogicalPlan): 
SerializeFromObject =
     copy(child = newChild)
 }
@@ -256,6 +260,8 @@ case class TypedFilter(
 
   override def output: Seq[Attribute] = child.output
 
+  final override val nodePatterns: Seq[TreePattern] = Seq(TYPED_FILTER)
+
   def withObjectProducerChild(obj: LogicalPlan): Filter = {
     assert(obj.output.length == 1)
     Filter(typedCondition(obj.output.head), obj)
@@ -354,6 +360,8 @@ case class AppendColumns(
 
   override def output: Seq[Attribute] = child.output ++ newColumns
 
+  final override val nodePatterns: Seq[TreePattern] = Seq(APPEND_COLUMNS)
+
   def newColumns: Seq[Attribute] = serializer.map(_.toAttribute)
 
   override protected def withNewChildInternal(newChild: LogicalPlan): 
AppendColumns =
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
index 3a1fa6a..62f09d0 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
@@ -89,13 +89,17 @@ object RuleIdCollection {
       // Catalyst Optimizer rules
       "org.apache.spark.sql.catalyst.optimizer.BooleanSimplification" ::
       "org.apache.spark.sql.catalyst.optimizer.CombineConcats" ::
+      "org.apache.spark.sql.catalyst.optimizer.CombineTypedFilters" ::
       "org.apache.spark.sql.catalyst.optimizer.ConstantFolding" ::
       "org.apache.spark.sql.catalyst.optimizer.ConstantPropagation" ::
       "org.apache.spark.sql.catalyst.optimizer.CostBasedJoinReorder" ::
+      "org.apache.spark.sql.catalyst.optimizer.EliminateMapObjects" ::
       "org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin" ::
+      "org.apache.spark.sql.catalyst.optimizer.EliminateSerialization" ::
       "org.apache.spark.sql.catalyst.optimizer.LikeSimplification" ::
       "org.apache.spark.sql.catalyst.optimizer.LimitPushDownThroughWindow" ::
       "org.apache.spark.sql.catalyst.optimizer.NullPropagation" ::
+      "org.apache.spark.sql.catalyst.optimizer.ObjectSerializerPruning" ::
       "org.apache.spark.sql.catalyst.optimizer.OptimizeCsvJsonExprs" ::
       "org.apache.spark.sql.catalyst.optimizer.OptimizeIn" ::
       "org.apache.spark.sql.catalyst.optimizer.Optimizer$OptimizeSubqueries" ::
@@ -105,6 +109,7 @@ object RuleIdCollection {
       "org.apache.spark.sql.catalyst.optimizer.PushExtraPredicateThroughJoin" 
::
       "org.apache.spark.sql.catalyst.optimizer.PushFoldableIntoBranches" ::
       
"org.apache.spark.sql.catalyst.optimizer.PushLeftSemiLeftAntiThroughJoin" ::
+      "org.apache.spark.sql.catalyst.optimizer.ReassignLambdaVariableID" ::
       "org.apache.spark.sql.catalyst.optimizer.RemoveDispensableExpressions" ::
       "org.apache.spark.sql.catalyst.optimizer.ReorderAssociativeOperator" ::
       "org.apache.spark.sql.catalyst.optimizer.ReorderJoin" ::
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
index b44847c..fb384cd 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
@@ -25,6 +25,7 @@ object TreePattern extends Enumeration  {
   // Expression patterns (alphabetically ordered)
   val AND_OR: Value = Value(0)
   val ATTRIBUTE_REFERENCE: Value = Value
+  val APPEND_COLUMNS: Value = Value
   val BINARY_ARITHMETIC: Value = Value
   val BINARY_COMPARISON: Value = Value
   val CASE_WHEN: Value = Value
@@ -32,6 +33,7 @@ object TreePattern extends Enumeration  {
   val CONCAT: Value = Value
   val COUNT: Value = Value
   val CREATE_NAMED_STRUCT: Value = Value
+  val DESERIALIZE_TO_OBJECT: Value = Value
   val DYNAMIC_PRUNING_SUBQUERY: Value = Value
   val EXISTS_SUBQUERY = Value
   val EXPRESSION_WITH_RANDOM_SEED: Value = Value
@@ -41,12 +43,15 @@ object TreePattern extends Enumeration  {
   val IN_SUBQUERY: Value = Value
   val INSET: Value = Value
   val JSON_TO_STRUCT: Value = Value
+  val LAMBDA_VARIABLE: Value = Value
   val LIKE_FAMLIY: Value = Value
   val LIST_SUBQUERY: Value = Value
   val LITERAL: Value = Value
+  val MAP_OBJECTS: Value = Value
   val NOT: Value = Value
   val NULL_CHECK: Value = Value
   val NULL_LITERAL: Value = Value
+  val SERIALIZE_FROM_OBJECT: Value = Value
   val OUTER_REFERENCE: Value = Value
   val PLAN_EXPRESSION: Value = Value
   val SCALAR_SUBQUERY: Value = Value
@@ -66,5 +71,7 @@ object TreePattern extends Enumeration  {
   val LOCAL_RELATION: Value = Value
   val NATURAL_LIKE_JOIN: Value = Value
   val OUTER_JOIN: Value = Value
+  val PROJECT: Value = Value
+  val TYPED_FILTER: Value = Value
   val WINDOW: Value = Value
 }

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to