Repository: spark Updated Branches: refs/heads/master f41c0a93f -> 18b75d465
[SPARK-22719][SQL] Refactor ConstantPropagation ## What changes were proposed in this pull request? The current time complexity of ConstantPropagation is O(n^2), which can be slow when the query is complex. Refactor the implementation with O( n ) time complexity, and some pruning to avoid traversing the whole `Condition` ## How was this patch tested? Unit test. Also simple benchmark test in ConstantPropagationSuite ``` val condition = (1 to 500).map{_ => Rand(0) === Rand(0)}.reduce(And) val query = testRelation .select(columnA) .where(condition) val start = System.currentTimeMillis() (1 to 40).foreach { _ => Optimize.execute(query.analyze) } val end = System.currentTimeMillis() println(end - start) ``` Run time before changes: 18989ms (474ms per loop) Run time after changes: 1275 ms (32ms per loop) Author: Wang Gengliang <ltn...@gmail.com> Closes #19912 from gengliangwang/ConstantPropagation. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/18b75d46 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/18b75d46 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/18b75d46 Branch: refs/heads/master Commit: 18b75d465b7563de926c5690094086a72a75c09f Parents: f41c0a9 Author: Wang Gengliang <ltn...@gmail.com> Authored: Thu Dec 7 10:24:49 2017 -0800 Committer: gatorsmile <gatorsm...@gmail.com> Committed: Thu Dec 7 10:24:49 2017 -0800 ---------------------------------------------------------------------- .../sql/catalyst/optimizer/expressions.scala | 106 +++++++++++++------ 1 file changed, 73 insertions(+), 33 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/18b75d46/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 785e815..6305b6c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -64,49 +64,89 @@ object ConstantFolding extends Rule[LogicalPlan] { * }}} * * Approach used: - * - Start from AND operator as the root - * - Get all the children conjunctive predicates which are EqualTo / EqualNullSafe such that they - * don't have a `NOT` or `OR` operator in them * - Populate a mapping of attribute => constant value by looking at all the equals predicates * - Using this mapping, replace occurrence of the attributes with the corresponding constant values * in the AND node. */ object ConstantPropagation extends Rule[LogicalPlan] with PredicateHelper { - private def containsNonConjunctionPredicates(expression: Expression): Boolean = expression.find { - case _: Not | _: Or => true - case _ => false - }.isDefined - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case f: Filter => f transformExpressionsUp { - case and: And => - val conjunctivePredicates = - splitConjunctivePredicates(and) - .filter(expr => expr.isInstanceOf[EqualTo] || expr.isInstanceOf[EqualNullSafe]) - .filterNot(expr => containsNonConjunctionPredicates(expr)) - - val equalityPredicates = conjunctivePredicates.collect { - case e @ EqualTo(left: AttributeReference, right: Literal) => ((left, right), e) - case e @ EqualTo(left: Literal, right: AttributeReference) => ((right, left), e) - case e @ EqualNullSafe(left: AttributeReference, right: Literal) => ((left, right), e) - case e @ EqualNullSafe(left: Literal, right: AttributeReference) => ((right, left), e) - } + case f: Filter => + val (newCondition, _) = traverse(f.condition, replaceChildren = true) + if (newCondition.isDefined) { + f.copy(condition = newCondition.get) + } else { + f + } + } - val constantsMap = AttributeMap(equalityPredicates.map(_._1)) - val predicates = equalityPredicates.map(_._2).toSet + type EqualityPredicates = Seq[((AttributeReference, Literal), BinaryComparison)] - def replaceConstants(expression: Expression) = expression transform { - case a: AttributeReference => - constantsMap.get(a) match { - case Some(literal) => literal - case None => a - } + /** + * Traverse a condition as a tree and replace attributes with constant values. + * - On matching [[And]], recursively traverse each children and get propagated mappings. + * If the current node is not child of another [[And]], replace all occurrences of the + * attributes with the corresponding constant values. + * - If a child of [[And]] is [[EqualTo]] or [[EqualNullSafe]], propagate the mapping + * of attribute => constant. + * - On matching [[Or]] or [[Not]], recursively traverse each children, propagate empty mapping. + * - Otherwise, stop traversal and propagate empty mapping. + * @param condition condition to be traversed + * @param replaceChildren whether to replace attributes with constant values in children + * @return A tuple including: + * 1. Option[Expression]: optional changed condition after traversal + * 2. EqualityPredicates: propagated mapping of attribute => constant + */ + private def traverse(condition: Expression, replaceChildren: Boolean) + : (Option[Expression], EqualityPredicates) = + condition match { + case e @ EqualTo(left: AttributeReference, right: Literal) => (None, Seq(((left, right), e))) + case e @ EqualTo(left: Literal, right: AttributeReference) => (None, Seq(((right, left), e))) + case e @ EqualNullSafe(left: AttributeReference, right: Literal) => + (None, Seq(((left, right), e))) + case e @ EqualNullSafe(left: Literal, right: AttributeReference) => + (None, Seq(((right, left), e))) + case a: And => + val (newLeft, equalityPredicatesLeft) = traverse(a.left, replaceChildren = false) + val (newRight, equalityPredicatesRight) = traverse(a.right, replaceChildren = false) + val equalityPredicates = equalityPredicatesLeft ++ equalityPredicatesRight + val newSelf = if (equalityPredicates.nonEmpty && replaceChildren) { + Some(And(replaceConstants(newLeft.getOrElse(a.left), equalityPredicates), + replaceConstants(newRight.getOrElse(a.right), equalityPredicates))) + } else { + if (newLeft.isDefined || newRight.isDefined) { + Some(And(newLeft.getOrElse(a.left), newRight.getOrElse(a.right))) + } else { + None + } } - - and transform { - case e @ EqualTo(_, _) if !predicates.contains(e) => replaceConstants(e) - case e @ EqualNullSafe(_, _) if !predicates.contains(e) => replaceConstants(e) + (newSelf, equalityPredicates) + case o: Or => + // Ignore the EqualityPredicates from children since they are only propagated through And. + val (newLeft, _) = traverse(o.left, replaceChildren = true) + val (newRight, _) = traverse(o.right, replaceChildren = true) + val newSelf = if (newLeft.isDefined || newRight.isDefined) { + Some(Or(left = newLeft.getOrElse(o.left), right = newRight.getOrElse((o.right)))) + } else { + None } + (newSelf, Seq.empty) + case n: Not => + // Ignore the EqualityPredicates from children since they are only propagated through And. + val (newChild, _) = traverse(n.child, replaceChildren = true) + (newChild.map(Not), Seq.empty) + case _ => (None, Seq.empty) + } + + private def replaceConstants(condition: Expression, equalityPredicates: EqualityPredicates) + : Expression = { + val constantsMap = AttributeMap(equalityPredicates.map(_._1)) + val predicates = equalityPredicates.map(_._2).toSet + def replaceConstants0(expression: Expression) = expression transform { + case a: AttributeReference => constantsMap.getOrElse(a, a) + } + condition transform { + case e @ EqualTo(_, _) if !predicates.contains(e) => replaceConstants0(e) + case e @ EqualNullSafe(_, _) if !predicates.contains(e) => replaceConstants0(e) } } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org