peter-toth commented on code in PR #38034:
URL: https://github.com/apache/spark/pull/38034#discussion_r1070234296
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala:
##########
@@ -618,6 +618,165 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]]
extends Product with Tre
}
}
+ /**
+ * Returns alternative copies of this node where `rule` has been recursively
applied to the tree.
+ *
+ * Users should not expect a specific directionality. If a specific
directionality is needed,
+ * multiTransformDown or multiTransformUp should be used.
+ *
+ * @param rule a function used to generate transformed alternatives for a
node
+ * @return the stream of alternatives
+ */
+ def multiTransform(rule: PartialFunction[BaseType, Seq[BaseType]]):
Stream[BaseType] = {
+ multiTransformDown(rule)
+ }
+
+ /**
+ * Returns alternative copies of this node where `rule` has been recursively
applied to the tree.
+ *
+ * Users should not expect a specific directionality. If a specific
directionality is needed,
+ * multiTransformDownWithPruning or multiTransformUpWithPruning should be
used.
+ *
+ * @param rule a function used to generate transformed alternatives for a
node
+ * @param cond a Lambda expression to prune tree traversals. If
`cond.apply` returns false
+ * on a TreeNode T, skips processing T and its subtree;
otherwise, processes
+ * T and its subtree recursively.
+ * @param ruleId is a unique Id for `rule` to prune unnecessary tree
traversals. When it is
+ * UnknownRuleId, no pruning happens. Otherwise, if `rule`
(with id `ruleId`)
+ * has been marked as in effective on a TreeNode T, skips
processing T and its
+ * subtree. Do not pass it if the rule is not purely
functional and reads a
+ * varying initial state for different invocations.
+ * @return the stream of alternatives
+ */
+ def multiTransformWithPruning(
+ cond: TreePatternBits => Boolean,
+ ruleId: RuleId = UnknownRuleId
+ )(rule: PartialFunction[BaseType, Seq[BaseType]]): Stream[BaseType] = {
+ multiTransformDownWithPruning(cond, ruleId)(rule).map(_._1)
+ }
+
+ /**
+ * Returns alternative copies of this node where `rule` has been recursively
applied to it and all
+ * of its children (pre-order).
+ *
+ * @param rule the function used to generate transformed alternatives for a
node
+ * @return the stream of alternatives
+ */
+ def multiTransformDown(rule: PartialFunction[BaseType, Seq[BaseType]]):
Stream[BaseType] = {
+ multiTransformDownWithPruning(AlwaysProcess.fn,
UnknownRuleId)(rule).map(_._1)
+ }
+
+ /**
+ * Returns alternative copies of this node where `rule` has been recursively
applied to it and all
+ * of its children (pre-order).
+ *
+ * As it is very easy to generate enormous number of alternatives when the
input tree is huge or
+ * when the rule returns large number of alternatives, this function returns
the alternatives as a
+ * lazy `Stream` to be able to limit the number of alternatives generated at
the caller side as
+ * needed.
+ *
+ * To indicate that the original node without any transformation is a valid
alternative the rule
+ * can either:
+ * - not apply or
+ * - return an empty `Seq` or
+ * - a `Seq` that contains a node that is equal to the original node.
+ *
+ * Please note that this function always consider the original node as a
valid alternative (even
+ * if the original node is not included in the returned `Seq`) if the rule
can transform any of
+ * the descendants of the node. E.g. consider a simple expression:
+ * `Add(a, b)`
+ * and a rule that returns:
+ * `Seq(1, 2)` for `a` and
+ * `Seq(10, 20)` for `b` and
+ * `Seq(11, 12, 21, 22)` for `Add(a, b)` (note that the original `Add(a,
b)` is not returned)
+ * then the result of `multiTransform` is:
+ * `Seq(11, 12, 21, 22, Add(1, 10), Add(2, 10), Add(1, 20), Add(2, 20))`.
+ * This feature makes the usage of `multiTransform` easier as a non-leaf
transforming rule doesn't
+ * need to take into account that it can transform a descendant node of the
non-leaf node as well
+ * and so it doesn't need return the non-leaf node itself in the list of
alternatives to not stop
+ * generating alternatives.
+ *
+ * @param rule a function used to generate transformed alternatives for a
node
+ * @param cond a Lambda expression to prune tree traversals. If
`cond.apply` returns false
+ * on a TreeNode T, skips processing T and its subtree;
otherwise, processes
+ * T and its subtree recursively.
+ * @param ruleId is a unique Id for `rule` to prune unnecessary tree
traversals. When it is
+ * UnknownRuleId, no pruning happens. Otherwise, if `rule`
(with id `ruleId`)
+ * has been marked as in effective on a TreeNode T, skips
processing T and its
+ * subtree. Do not pass it if the rule is not purely
functional and reads a
+ * varying initial state for different invocations.
+ * @return the stream of alternatives with a flag if any
transformation was done
Review Comment:
This flag is an internal flag and `multiTransform` and
`multiTransformWithPruning` doesn't even return it. Probably
`multiTransformDownWithPruning` should hide it as well using a private helper.
It is used internally after a non-leaf node is transformed to some
alternatives but those alternatives doesn't contain the original node itself.
E.g. please see a bit above that `Add(a, b) -> Seq(11, 12, 21, 22)`, but the
`Seq` doesn't contain `Add(a, b)` itself.
Now in this case we still need to consider the original `Add(a, b)` and
traverse down and consider alternatives that might transform `a` or `b`.
This is done by adding `Add(a, b)` with a `childrenTransformRequired=true`
flag to the `afterRulesStream ` stream.
The returned flag you are asking about (we can call it
`childrenTransformed`), is kind of the pair ot the previous flag.
Once the transformation of children are done that flag is "propagated up"
and based on both flags we can decide if we need to add the `Add(a, b)` (with
its children transformed) to the valid alternatives
(`!childrenTransformRequired || childrenTransformed`). We do this in this
collect:
https://github.com/apache/spark/blob/cc73d00f3e608076feebc65737e53e6454845bd4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala#L762-L768
Please note that there is another flag `transformed` in the
`afterRulesStream`. If is pretty similar to negated `childrenTransformRequired`
but they are not always negate each other to handle nodes where the rule didn't
apply.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]