ulysses-you commented on code in PR #38034:
URL: https://github.com/apache/spark/pull/38034#discussion_r1071674624


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala:
##########
@@ -618,6 +618,212 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] 
extends Product with Tre
     }
   }
 
+  /**
+   * Returns alternative copies of this node where `rule` has been recursively 
applied to the tree.
+   *
+   * Users should not expect a specific directionality. If a specific 
directionality is needed,
+   * multiTransformDownWithPruning or multiTransformUpWithPruning should be 
used.
+   *
+   * @param rule a function used to generate transformed alternatives for a 
node
+   * @return     the stream of alternatives
+   */
+  def multiTransformDown(
+      rule: PartialFunction[BaseType, Stream[BaseType]]): Stream[BaseType] = {
+    multiTransformDownWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule)
+  }
+
+  /**
+   * Returns alternative copies of this node where `rule` has been recursively 
applied to the tree.
+   *
+   * Users should not expect a specific directionality. If a specific 
directionality is needed,
+   * multiTransformDownWithPruning or multiTransformUpWithPruning should be 
used.
+   *
+   * @param rule a function used to generate transformed alternatives for a 
node and the
+   *             `autoContinue` flag
+   * @return the stream of alternatives
+   */
+  def multiTransformDownWithContinuation(
+      rule: PartialFunction[BaseType, (Stream[BaseType], Boolean)]): 
Stream[BaseType] = {
+    multiTransformDownWithContinuationAndPruning(AlwaysProcess.fn, 
UnknownRuleId)(rule)
+  }
+
+  /**
+   * Returns alternative copies of this node where `rule` has been recursively 
applied to the tree.
+   *
+   * Users should not expect a specific directionality. If a specific 
directionality is needed,
+   * multiTransformDownWithPruning or multiTransformUpWithPruning should be 
used.
+   *
+   * @param rule   a function used to generate transformed alternatives for a 
node
+   * @param cond   a Lambda expression to prune tree traversals. If 
`cond.apply` returns false
+   *               on a TreeNode T, skips processing T and its subtree; 
otherwise, processes
+   *               T and its subtree recursively.
+   * @param ruleId is a unique Id for `rule` to prune unnecessary tree 
traversals. When it is
+   *               UnknownRuleId, no pruning happens. Otherwise, if `rule` 
(with id `ruleId`)
+   *               has been marked as in effective on a TreeNode T, skips 
processing T and its
+   *               subtree. Do not pass it if the rule is not purely 
functional and reads a
+   *               varying initial state for different invocations.
+   * @return       the stream of alternatives
+   */
+  def multiTransformDownWithPruning(
+      cond: TreePatternBits => Boolean,
+      ruleId: RuleId = UnknownRuleId
+    )(rule: PartialFunction[BaseType, Stream[BaseType]]): Stream[BaseType] = {
+    multiTransformDownWithContinuationAndPruning(cond, ruleId)(rule.andThen(_ 
-> false))
+  }
+
+  /**
+   * Returns alternative copies of this node where `rule` has been recursively 
applied to it and all
+   * of its children (pre-order).
+   *
+   * As it is very easy to generate enormous number of alternatives when the 
input tree is huge or
+   * when the rule returns large number of alternatives, this function returns 
the alternatives as a
+   * lazy `Stream` to be able to limit the number of alternatives generated at 
the caller side as
+   * needed.
+   *
+   * The rule should not apply to indicate that the original node without any 
transformation is a
+   * valid alternative.
+   *
+   * The rule can return `Stream.empty` to indicate that the original node 
should be pruned. In this
+   * case `multiTransform` returns an empty `Stream`.
+   *
+   * Please consider the following examples of `input.multiTransform(rule)`:
+   *
+   * We have an input expression:
+   *    `Add(a, b)`
+   *
+   * 1.
+   * We have a simple rule:
+   *   `a` => `Stream(1, 2)`
+   *   `b` => `Stream(10, 20)`
+   *   `Add(a, b)` => `Stream(11, 12, 21, 22)`
+   *
+   * The output is:
+   *   `Stream(11, 12, 21, 22)`
+   *
+   * 2.
+   * In the previous example if we want to generate alternatives of `a` and 
`b` too then we need to
+   * explicitly add the original `Add(a, b)` expression to the rule:
+   *   `a` => `Stream(1, 2)`
+   *   `b` => `Stream(10, 20)`
+   *   `Add(a, b)` => `Stream(11, 12, 21, 22, Add(a, b))`
+   *
+   * The output is:
+   *   `Stream(11, 12, 21, 22, Add(1, 10), Add(2, 10), Add(1, 20), Add(2, 20))`
+   *
+   * 3.
+   * It is not always easy to determine if we will do any child expression 
mapping but we can enable
+   * the `autoContinue` flag to get the same result:
+   *   `a` => `(Stream(1, 2), false)`
+   *   `b` => `(Stream(10, 20), false)`
+   *   `Add(a, b)` => `(Stream(11, 12, 21, 22), true)` (Note the `true` flag 
and the missing
+   *                                                    `Add(a, b)`)
+   * The output is the same as in 2.:
+   *   `Stream(11, 12, 21, 22, Add(1, 10), Add(2, 10), Add(1, 20), Add(2, 20))`
+   *
+   * This feature makes the usage of `multiTransform` easier as a non-leaf 
transforming rule doesn't
+   * need to take into account that it can transform a descendant node of the 
non-leaf node as well
+   * and so it doesn't need return the non-leaf node itself in the list of 
alternatives to not stop
+   * generating alternatives.
+   *
+   * @param rule   a function used to generate transformed alternatives for a 
node and the
+   *               `autoContinue` flag
+   * @param cond   a Lambda expression to prune tree traversals. If 
`cond.apply` returns false
+   *               on a TreeNode T, skips processing T and its subtree; 
otherwise, processes
+   *               T and its subtree recursively.
+   * @param ruleId is a unique Id for `rule` to prune unnecessary tree 
traversals. When it is
+   *               UnknownRuleId, no pruning happens. Otherwise, if `rule` 
(with id `ruleId`)
+   *               has been marked as in effective on a TreeNode T, skips 
processing T and its
+   *               subtree. Do not pass it if the rule is not purely 
functional and reads a
+   *               varying initial state for different invocations.
+   * @return       the stream of alternatives
+   */
+  def multiTransformDownWithContinuationAndPruning(
+      cond: TreePatternBits => Boolean,
+      ruleId: RuleId = UnknownRuleId
+    )(rule: PartialFunction[BaseType, (Stream[BaseType], Boolean)]): 
Stream[BaseType] = {
+    multiTransformDownHelper(cond, ruleId)(rule).map(_._1)
+  }
+
+  private def multiTransformDownHelper(
+      cond: TreePatternBits => Boolean,
+      ruleId: RuleId = UnknownRuleId
+    )(rule: PartialFunction[BaseType, (Stream[BaseType], Boolean)]): 
Stream[(BaseType, Boolean)] = {
+    if (!cond.apply(this) || isRuleIneffective(ruleId)) {
+      return Stream(this -> false)
+    }
+
+    var ruleApplied = true
+    val (afterRules, autoContinue) = CurrentOrigin.withOrigin(origin) {
+      rule.applyOrElse(this, (_: BaseType) => {
+        ruleApplied = false
+        Stream.empty -> false
+      })
+    }
+    // A stream of a tuple that contains:
+    // - a node that is either the transformed alternative of the current node 
or the current node,
+    // - a boolean flag if the node was actually transformed,
+    // - a boolean flag if a node's children needs to be transformed to add 
the node to the valid
+    // alternatives
+    val afterRulesStream = if (afterRules.isEmpty) {
+      if (ruleApplied) {
+        // If the rule returned with empty alternatives then prune
+        Stream.empty
+      } else {
+        // If the rule was not applied then keep the original node
+        Stream((this, false, false))
+      }
+    } else {
+        // If the rule was applied then use the returned alternatives
+        // The alternatives can include the current node and we need to keep 
track of that
+        var foundEqual = false
+        afterRules.map { afterRule =>
+          (if (this fastEquals afterRule) {
+            foundEqual = true
+            this
+          } else {
+            afterRule.copyTagsFrom(this)
+            afterRule
+          }, true, false)
+        }.append(
+          // If autoContinue is enabled and the current node is not a leaf 
node and the alternatives
+          // returned by the rule doesn't contain the current node then we 
need to add the current
+          // node to the stream, but require any of its child nodes to be 
transformed to keep it as
+          // a valid alternative
+          if (autoContinue && containsChild.nonEmpty && !foundEqual) {
+            Stream((this, false, true))
+          } else {
+            Stream.empty
+          }
+        )
+    }
+
+    def generateChildrenSeq(children: Seq[BaseType]): Stream[(Seq[BaseType], 
Boolean)] = {
+      children.foldRight(Stream((Seq.empty[BaseType], false)))((child, 
childrenSeqStream) =>
+        for {
+          (childrenSeq, childrenSeqChanged) <- childrenSeqStream
+          (newChild, childChanged) <- child.multiTransformDownHelper(cond, 
ruleId)(rule)
+        } yield (newChild +: childrenSeq) -> (childChanged || 
childrenSeqChanged)
+      )
+    }
+
+    afterRulesStream.flatMap { case (afterRule, transformed, 
childrenTransformRequired) =>
+      if (afterRule.containsChild.nonEmpty) {
+        generateChildrenSeq(afterRule.children).collect {
+          case (newChildren, childrenTransformed)
+            if !childrenTransformRequired || childrenTransformed =>
+            afterRule.withNewChildren(newChildren) -> (transformed || 
childrenTransformed)
+        }
+      } else {
+        Stream(afterRule -> transformed)
+      }.map { rewritten_plan =>
+        if (this eq rewritten_plan) {
+          markRuleAsIneffective(ruleId)

Review Comment:
   if my understanding is correct, the only way to mark this rule as 
ineffective is that it returns stream with one alternative which is eq to this 
? then why we use `map` here ?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to