Github user gatorsmile commented on a diff in the pull request: https://github.com/apache/spark/pull/19201#discussion_r138393814 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala --- @@ -106,91 +106,48 @@ trait QueryPlanConstraints { self: LogicalPlan => * Infers an additional set of constraints from a given set of equality constraints. * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an * additional constraint of the form `b = 5`. - * - * [SPARK-17733] We explicitly prevent producing recursive constraints of the form `a = f(a, b)` - * as they are often useless and can lead to a non-converging set of constraints. */ private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { - val constraintClasses = generateEquivalentConstraintClasses(constraints) - + val aliasedConstraints = eliminateAliasedExpressionInConstraints(constraints) var inferredConstraints = Set.empty[Expression] - constraints.foreach { + aliasedConstraints.foreach { case eq @ EqualTo(l: Attribute, r: Attribute) => - val candidateConstraints = constraints - eq - inferredConstraints ++= candidateConstraints.map(_ transform { - case a: Attribute if a.semanticEquals(l) && - !isRecursiveDeduction(r, constraintClasses) => r - }) - inferredConstraints ++= candidateConstraints.map(_ transform { - case a: Attribute if a.semanticEquals(r) && - !isRecursiveDeduction(l, constraintClasses) => l - }) + val candidateConstraints = aliasedConstraints - eq + inferredConstraints ++= replaceConstraints(candidateConstraints, l, r) + inferredConstraints ++= replaceConstraints(candidateConstraints, r, l) case _ => // No inference } inferredConstraints -- constraints } /** - * Generate a sequence of expression sets from constraints, where each set stores an equivalence - * class of expressions. For example, Set(`a = b`, `b = c`, `e = f`) will generate the following - * expression sets: (Set(a, b, c), Set(e, f)). This will be used to search all expressions equal - * to an selected attribute. + * Replace the aliased expression in [[Alias]] with the alias name if both exist in constraints. + * Thus non-converging inference can be prevented. + * E.g. `a = f(a, b)`, `a = f(b, c) && c = g(a, b)`. + * Also, the size of constraints is reduced without losing any information. + * When the inferred filters are pushed down the operators that generate the alias, + * the alias names used in filters are replaced by the aliased expressions. */ - private def generateEquivalentConstraintClasses( - constraints: Set[Expression]): Seq[Set[Expression]] = { - var constraintClasses = Seq.empty[Set[Expression]] - constraints.foreach { - case eq @ EqualTo(l: Attribute, r: Attribute) => - // Transform [[Alias]] to its child. - val left = aliasMap.getOrElse(l, l) - val right = aliasMap.getOrElse(r, r) - // Get the expression set for an equivalence constraint class. - val leftConstraintClass = getConstraintClass(left, constraintClasses) - val rightConstraintClass = getConstraintClass(right, constraintClasses) - if (leftConstraintClass.nonEmpty && rightConstraintClass.nonEmpty) { - // Combine the two sets. - constraintClasses = constraintClasses - .diff(leftConstraintClass :: rightConstraintClass :: Nil) :+ - (leftConstraintClass ++ rightConstraintClass) - } else if (leftConstraintClass.nonEmpty) { // && rightConstraintClass.isEmpty - // Update equivalence class of `left` expression. - constraintClasses = constraintClasses - .diff(leftConstraintClass :: Nil) :+ (leftConstraintClass + right) - } else if (rightConstraintClass.nonEmpty) { // && leftConstraintClass.isEmpty - // Update equivalence class of `right` expression. - constraintClasses = constraintClasses - .diff(rightConstraintClass :: Nil) :+ (rightConstraintClass + left) - } else { // leftConstraintClass.isEmpty && rightConstraintClass.isEmpty - // Create new equivalence constraint class since neither expression presents - // in any classes. - constraintClasses = constraintClasses :+ Set(left, right) - } - case _ => // Skip + private def eliminateAliasedExpressionInConstraints(constraints: Set[Expression]) + : Set[Expression] = { + val attributesInEqualTo = constraints.flatMap { + case EqualTo(l: Attribute, r: Attribute) => l :: r :: Nil + case _ => Nil } - - constraintClasses - } - - /** - * Get all expressions equivalent to the selected expression. - */ - private def getConstraintClass( - expr: Expression, - constraintClasses: Seq[Set[Expression]]): Set[Expression] = - constraintClasses.find(_.contains(expr)).getOrElse(Set.empty[Expression]) - - /** - * Check whether replace by an [[Attribute]] will cause a recursive deduction. Generally it - * has the form like: `a -> f(a, b)`, where `a` and `b` are expressions and `f` is a function. - * Here we first get all expressions equal to `attr` and then check whether at least one of them - * is a child of the referenced expression. - */ - private def isRecursiveDeduction( - attr: Attribute, - constraintClasses: Seq[Set[Expression]]): Boolean = { - val expr = aliasMap.getOrElse(attr, attr) - getConstraintClass(expr, constraintClasses).exists { e => - expr.children.exists(_.semanticEquals(e)) + var aliasedConstraints = constraints + attributesInEqualTo.foreach { a => + if (aliasMap.contains(a)) { + val child = aliasMap.get(a).get + aliasedConstraints = replaceConstraints(aliasedConstraints, child, a) + } } + aliasedConstraints } + + private def replaceConstraints( + constraints: Set[Expression], + source: Expression, + destination: Attribute): Set[Expression] = constraints.map(_ transform { --- End diff -- please use four line indents. https://github.com/databricks/scala-style-guide#spacing-and-indentation
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org