This is an automated email from the ASF dual-hosted git repository. gengliang pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 72d3266 [SPARK-35144][SQL] Migrate to transformWithPruning for object rules 72d3266 is described below commit 72d32662d4744440e286a639783fed8dcf6c3948 Author: Yingyi Bu <yingyi...@databricks.com> AuthorDate: Fri May 7 18:36:28 2021 +0800 [SPARK-35144][SQL] Migrate to transformWithPruning for object rules ### What changes were proposed in this pull request? Added the following TreePattern enums: - APPEND_COLUMNS - DESERIALIZE_TO_OBJECT - LAMBDA_VARIABLE - MAP_OBJECTS - SERIALIZE_FROM_OBJECT - PROJECT - TYPED_FILTER Added tree traversal pruning to the following rules dealing with objects: - EliminateSerialization - CombineTypedFilters - EliminateMapObjects - ObjectSerializerPruning ### Why are the changes needed? Reduce the number of tree traversals and hence improve the query compilation latency. ### How was this patch tested? Existing tests. Closes #32451 from sigmod/object. Authored-by: Yingyi Bu <yingyi...@databricks.com> Signed-off-by: Gengliang Wang <ltn...@gmail.com> --- .../spark/sql/catalyst/expressions/objects/objects.scala | 6 +++++- .../org/apache/spark/sql/catalyst/optimizer/objects.scala | 15 ++++++++++----- .../catalyst/plans/logical/basicLogicalOperators.scala | 2 ++ .../apache/spark/sql/catalyst/plans/logical/object.scala | 8 ++++++++ .../spark/sql/catalyst/rules/RuleIdCollection.scala | 5 +++++ .../apache/spark/sql/catalyst/trees/TreePatterns.scala | 7 +++++++ 6 files changed, 37 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 469c895..40378a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.TernaryLike -import org.apache.spark.sql.catalyst.trees.TreePattern.{NULL_CHECK, TreePattern} +import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types._ @@ -669,6 +669,8 @@ case class LambdaVariable( private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType, nullable) + final override val nodePatterns: Seq[TreePattern] = Seq(LAMBDA_VARIABLE) + // Interpreted execution of `LambdaVariable` always get the 0-index element from input row. override def eval(input: InternalRow): Any = { assert(input.numFields == 1, @@ -781,6 +783,8 @@ case class MapObjects private( override def second: Expression = lambdaFunction override def third: Expression = inputData + final override val nodePatterns: Seq[TreePattern] = Seq(MAP_OBJECTS) + // The data with UserDefinedType are actually stored with the data type of its sqlType. // When we want to apply MapObjects on it, we have to use it. lazy private val inputDataType = inputData.dataType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala index 97712a0..52544ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType, UserDefinedType} /* @@ -35,7 +36,8 @@ import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType, Use * representation of data item. For example back to back map operations. */ object EliminateSerialization extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( + _.containsAnyPattern(DESERIALIZE_TO_OBJECT, APPEND_COLUMNS, TYPED_FILTER), ruleId) { case d @ DeserializeToObject(_, _, s: SerializeFromObject) if d.outputObjAttr.dataType == s.inputObjAttr.dataType => // Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`. @@ -72,7 +74,8 @@ object EliminateSerialization extends Rule[LogicalPlan] { * merging the filter functions into one conjunctive function. */ object CombineTypedFilters extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( + _.containsPattern(TYPED_FILTER), ruleId) { case t1 @ TypedFilter(_, _, _, _, t2 @ TypedFilter(_, _, _, _, child)) if t1.deserializer.dataType == t2.deserializer.dataType => TypedFilter( @@ -108,7 +111,8 @@ object CombineTypedFilters extends Rule[LogicalPlan] { * 2. no custom collection class specified representation of data item. */ object EliminateMapObjects extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning( + _.containsAllPatterns(MAP_OBJECTS, LAMBDA_VARIABLE), ruleId) { case MapObjects(_, LambdaVariable(_, _, false, _), inputData, None) => inputData } } @@ -207,7 +211,8 @@ object ObjectSerializerPruning extends Rule[LogicalPlan] { } } - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( + _.containsAllPatterns(SERIALIZE_FROM_OBJECT, PROJECT), ruleId) { case p @ Project(_, s: SerializeFromObject) => // Prunes individual serializer if it is not used at all by above projection. val usedRefs = p.references @@ -252,7 +257,7 @@ object ReassignLambdaVariableID extends Rule[LogicalPlan] { var hasNegativeIds = false var hasPositiveIds = false - plan.transformAllExpressions { + plan.transformAllExpressionsWithPruning(_.containsPattern(LAMBDA_VARIABLE), ruleId) { case lr: LambdaVariable if lr.id == 0 => throw new IllegalStateException("LambdaVariable should never has 0 as its ID.") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index f5e92fb..1bd1666 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -70,6 +70,8 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) override def output: Seq[Attribute] = projectList.map(_.toAttribute) override def maxRows: Option[Long] = child.maxRows + final override val nodePatterns: Seq[TreePattern] = Seq(PROJECT) + override lazy val resolved: Boolean = { val hasSpecialExpressions = projectList.exists ( _.collect { case agg: AggregateExpression => agg diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 6d61a86..1f7eb67 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke +import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} import org.apache.spark.sql.types._ @@ -80,6 +81,7 @@ case class DeserializeToObject( deserializer: Expression, outputObjAttr: Attribute, child: LogicalPlan) extends UnaryNode with ObjectProducer { + final override val nodePatterns: Seq[TreePattern] = Seq(DESERIALIZE_TO_OBJECT) override protected def withNewChildInternal(newChild: LogicalPlan): DeserializeToObject = copy(child = newChild) } @@ -94,6 +96,8 @@ case class SerializeFromObject( override def output: Seq[Attribute] = serializer.map(_.toAttribute) + final override val nodePatterns: Seq[TreePattern] = Seq(SERIALIZE_FROM_OBJECT) + override protected def withNewChildInternal(newChild: LogicalPlan): SerializeFromObject = copy(child = newChild) } @@ -256,6 +260,8 @@ case class TypedFilter( override def output: Seq[Attribute] = child.output + final override val nodePatterns: Seq[TreePattern] = Seq(TYPED_FILTER) + def withObjectProducerChild(obj: LogicalPlan): Filter = { assert(obj.output.length == 1) Filter(typedCondition(obj.output.head), obj) @@ -354,6 +360,8 @@ case class AppendColumns( override def output: Seq[Attribute] = child.output ++ newColumns + final override val nodePatterns: Seq[TreePattern] = Seq(APPEND_COLUMNS) + def newColumns: Seq[Attribute] = serializer.map(_.toAttribute) override protected def withNewChildInternal(newChild: LogicalPlan): AppendColumns = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala index 3a1fa6a..62f09d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala @@ -89,13 +89,17 @@ object RuleIdCollection { // Catalyst Optimizer rules "org.apache.spark.sql.catalyst.optimizer.BooleanSimplification" :: "org.apache.spark.sql.catalyst.optimizer.CombineConcats" :: + "org.apache.spark.sql.catalyst.optimizer.CombineTypedFilters" :: "org.apache.spark.sql.catalyst.optimizer.ConstantFolding" :: "org.apache.spark.sql.catalyst.optimizer.ConstantPropagation" :: "org.apache.spark.sql.catalyst.optimizer.CostBasedJoinReorder" :: + "org.apache.spark.sql.catalyst.optimizer.EliminateMapObjects" :: "org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin" :: + "org.apache.spark.sql.catalyst.optimizer.EliminateSerialization" :: "org.apache.spark.sql.catalyst.optimizer.LikeSimplification" :: "org.apache.spark.sql.catalyst.optimizer.LimitPushDownThroughWindow" :: "org.apache.spark.sql.catalyst.optimizer.NullPropagation" :: + "org.apache.spark.sql.catalyst.optimizer.ObjectSerializerPruning" :: "org.apache.spark.sql.catalyst.optimizer.OptimizeCsvJsonExprs" :: "org.apache.spark.sql.catalyst.optimizer.OptimizeIn" :: "org.apache.spark.sql.catalyst.optimizer.Optimizer$OptimizeSubqueries" :: @@ -105,6 +109,7 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.optimizer.PushExtraPredicateThroughJoin" :: "org.apache.spark.sql.catalyst.optimizer.PushFoldableIntoBranches" :: "org.apache.spark.sql.catalyst.optimizer.PushLeftSemiLeftAntiThroughJoin" :: + "org.apache.spark.sql.catalyst.optimizer.ReassignLambdaVariableID" :: "org.apache.spark.sql.catalyst.optimizer.RemoveDispensableExpressions" :: "org.apache.spark.sql.catalyst.optimizer.ReorderAssociativeOperator" :: "org.apache.spark.sql.catalyst.optimizer.ReorderJoin" :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index b44847c..fb384cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -25,6 +25,7 @@ object TreePattern extends Enumeration { // Expression patterns (alphabetically ordered) val AND_OR: Value = Value(0) val ATTRIBUTE_REFERENCE: Value = Value + val APPEND_COLUMNS: Value = Value val BINARY_ARITHMETIC: Value = Value val BINARY_COMPARISON: Value = Value val CASE_WHEN: Value = Value @@ -32,6 +33,7 @@ object TreePattern extends Enumeration { val CONCAT: Value = Value val COUNT: Value = Value val CREATE_NAMED_STRUCT: Value = Value + val DESERIALIZE_TO_OBJECT: Value = Value val DYNAMIC_PRUNING_SUBQUERY: Value = Value val EXISTS_SUBQUERY = Value val EXPRESSION_WITH_RANDOM_SEED: Value = Value @@ -41,12 +43,15 @@ object TreePattern extends Enumeration { val IN_SUBQUERY: Value = Value val INSET: Value = Value val JSON_TO_STRUCT: Value = Value + val LAMBDA_VARIABLE: Value = Value val LIKE_FAMLIY: Value = Value val LIST_SUBQUERY: Value = Value val LITERAL: Value = Value + val MAP_OBJECTS: Value = Value val NOT: Value = Value val NULL_CHECK: Value = Value val NULL_LITERAL: Value = Value + val SERIALIZE_FROM_OBJECT: Value = Value val OUTER_REFERENCE: Value = Value val PLAN_EXPRESSION: Value = Value val SCALAR_SUBQUERY: Value = Value @@ -66,5 +71,7 @@ object TreePattern extends Enumeration { val LOCAL_RELATION: Value = Value val NATURAL_LIKE_JOIN: Value = Value val OUTER_JOIN: Value = Value + val PROJECT: Value = Value + val TYPED_FILTER: Value = Value val WINDOW: Value = Value } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org