This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d8d604bc07b [SPARK-40599][SQL] Add multiTransform methods to TreeNode 
to generate alternatives
d8d604bc07b is described below

commit d8d604bc07bc3b8c98f73c4b10f93cb4eb7113be
Author: Peter Toth <peter.t...@gmail.com>
AuthorDate: Tue Jan 17 20:58:37 2023 +0800

    [SPARK-40599][SQL] Add multiTransform methods to TreeNode to generate 
alternatives
    
    ### What changes were proposed in this pull request?
    This PR introduce `TreeNode.multiTransform()` methods to be able to 
recursively transform a `TreeNode` (and so a tree) into multiple alternatives. 
These functions are particularly useful if we want to transform an expression 
with a projection in which subexpressions can be aliased with multiple 
different attributes.
    
    E.g. if we have a partitioning expression `HashPartitioning(a + b)` and we 
have a `Project` node that aliases `a` as `a1` and `a2` and `b` as `b1` and 
`b2` we can easily generate a stream of alternative transformations of the 
original partitioning:
    ```
    // This is a simplified test, some arguments are missing to make it conciese
    val partitioning = HashPartitioning(Add(a, b))
    val aliases: Map[Expression, Seq[Attribute]] = ... // collect the alias map 
from project
    val s = partitioning.multiTransform {
      case e: Expression if aliases.contains(e.canonicalized) => 
aliases(e.canonicalized)
    }
    s // Stream(HashPartitioning(Add(a1, b1)), HashPartitioning(Add(a1, b2)), 
HashPartitioning(Add(a2, b2)), HashPartitioning(Add(a2, b2)))
    ```
    
    The result of `multiTransform` is a lazy stream to be able to limit the 
number of alternatives generated at the caller side as needed.
    
    ### Why are the changes needed?
    `TreeNode.multiTransform()` is a useful helper method.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    New UTs are added.
    
    Closes #38034 from peter-toth/SPARK-40599-multitransform.
    
    Authored-by: Peter Toth <peter.t...@gmail.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../apache/spark/sql/catalyst/trees/TreeNode.scala | 128 +++++++++++++++++++++
 .../spark/sql/catalyst/trees/TreeNodeSuite.scala   | 104 +++++++++++++++++
 2 files changed, 232 insertions(+)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index 9510aa4d9e7..dc64e5e2560 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -618,6 +618,134 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] 
extends Product with Tre
     }
   }
 
+  /**
+   * Returns alternative copies of this node where `rule` has been recursively 
applied to it and all
+   * of its children (pre-order).
+   *
+   * @param rule a function used to generate alternatives for a node
+   * @return     the stream of alternatives
+   */
+  def multiTransformDown(
+      rule: PartialFunction[BaseType, Stream[BaseType]]): Stream[BaseType] = {
+    multiTransformDownWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule)
+  }
+
+  /**
+   * Returns alternative copies of this node where `rule` has been recursively 
applied to it and all
+   * of its children (pre-order).
+   *
+   * As it is very easy to generate enormous number of alternatives when the 
input tree is huge or
+   * when the rule returns many alternatives for many nodes, this function 
returns the alternatives
+   * as a lazy `Stream` to be able to limit the number of alternatives 
generated at the caller side
+   * as needed.
+   *
+   * The rule should not apply or can return a one element stream of original 
node to indicate that
+   * the original node without any transformation is a valid alternative.
+   *
+   * The rule can return `Stream.empty` to indicate that the original node 
should be pruned. In this
+   * case `multiTransform()` returns an empty `Stream`.
+   *
+   * Please consider the following examples of 
`input.multiTransformDown(rule)`:
+   *
+   * We have an input expression:
+   *    `Add(a, b)`
+   *
+   * 1.
+   * We have a simple rule:
+   *   `a` => `Stream(1, 2)`
+   *   `b` => `Stream(10, 20)`
+   *   `Add(a, b)` => `Stream(11, 12, 21, 22)`
+   *
+   * The output is:
+   *   `Stream(11, 12, 21, 22)`
+   *
+   * 2.
+   * In the previous example if we want to generate alternatives of `a` and 
`b` too then we need to
+   * explicitly add the original `Add(a, b)` expression to the rule:
+   *   `a` => `Stream(1, 2)`
+   *   `b` => `Stream(10, 20)`
+   *   `Add(a, b)` => `Stream(11, 12, 21, 22, Add(a, b))`
+   *
+   * The output is:
+   *   `Stream(11, 12, 21, 22, Add(1, 10), Add(2, 10), Add(1, 20), Add(2, 20))`
+   *
+   * @param rule   a function used to generate alternatives for a node
+   * @param cond   a Lambda expression to prune tree traversals. If 
`cond.apply` returns false
+   *               on a TreeNode T, skips processing T and its subtree; 
otherwise, processes
+   *               T and its subtree recursively.
+   * @param ruleId is a unique Id for `rule` to prune unnecessary tree 
traversals. When it is
+   *               UnknownRuleId, no pruning happens. Otherwise, if `rule` 
(with id `ruleId`)
+   *               has been marked as in effective on a TreeNode T, skips 
processing T and its
+   *               subtree. Do not pass it if the rule is not purely 
functional and reads a
+   *               varying initial state for different invocations.
+   * @return       the stream of alternatives
+   */
+  def multiTransformDownWithPruning(
+      cond: TreePatternBits => Boolean,
+      ruleId: RuleId = UnknownRuleId
+    )(rule: PartialFunction[BaseType, Stream[BaseType]]): Stream[BaseType] = {
+    if (!cond.apply(this) || isRuleIneffective(ruleId)) {
+      return Stream(this)
+    }
+
+    // We could return `Stream(this)` if the `rule` doesn't apply and handle 
both
+    // - the doesn't apply
+    // - and the rule returns a one element `Stream(originalNode)`
+    // cases together. But, unfortunately it doesn't seem like there is a way 
to match on a one
+    // element stream without eagerly computing the tail head. So this 
contradicts with the purpose
+    // of only taking the necessary elements from the alternatives. I.e. the
+    // "multiTransformDown is lazy" test case in `TreeNodeSuite` would fail.
+    // Please note that this behaviour has a downside as well that we can only 
mark the rule on the
+    // original node ineffective if the rule didn't match.
+    var ruleApplied = true
+    val afterRules = CurrentOrigin.withOrigin(origin) {
+      rule.applyOrElse(this, (_: BaseType) => {
+        ruleApplied = false
+        Stream.empty
+      })
+    }
+
+    val afterRulesStream = if (afterRules.isEmpty) {
+      if (ruleApplied) {
+        // If the rule returned with empty alternatives then prune
+        Stream.empty
+      } else {
+        // If the rule was not applied then keep the original node
+        this.markRuleAsIneffective(ruleId)
+        Stream(this)
+      }
+    } else {
+      // If the rule was applied then use the returned alternatives
+      afterRules.map { afterRule =>
+        if (this fastEquals afterRule) {
+          this
+        } else {
+          afterRule.copyTagsFrom(this)
+          afterRule
+        }
+      }
+    }
+
+    afterRulesStream.flatMap { afterRule =>
+      if (afterRule.containsChild.nonEmpty) {
+        generateChildrenSeq(
+            afterRule.children.map(_.multiTransformDownWithPruning(cond, 
ruleId)(rule)))
+          .map(afterRule.withNewChildren)
+      } else {
+        Stream(afterRule)
+      }
+    }
+  }
+
+  private def generateChildrenSeq[T](childrenStreams: Seq[Stream[T]]): 
Stream[Seq[T]] = {
+    childrenStreams.foldRight(Stream(Seq.empty[T]))((childrenStream, 
childrenSeqStream) =>
+      for {
+        childrenSeq <- childrenSeqStream
+        child <- childrenStream
+      } yield child +: childrenSeq
+    )
+  }
+
   /**
    * Returns a copy of this node where `f` has been applied to all the nodes 
in `children`.
    */
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
index 286d3dddae6..ac28917675e 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
@@ -977,4 +977,108 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
       assert(origin.context.summary.isEmpty)
     }
   }
+
+  private def newErrorAfterStream(es: Expression*) = {
+    es.toStream.append(
+      throw new NoSuchElementException("Stream should not return more 
elements")
+    )
+  }
+
+  test("multiTransformDown generates all alternatives") {
+    val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), 
Literal("d")))
+    val transformed = e.multiTransformDown {
+      case StringLiteral("a") => Stream(Literal(1), Literal(2), Literal(3))
+      case StringLiteral("b") => Stream(Literal(10), Literal(20), Literal(30))
+      case Add(StringLiteral("c"), StringLiteral("d"), _) =>
+        Stream(Literal(100), Literal(200), Literal(300))
+    }
+    val expected = for {
+      cd <- Seq(Literal(100), Literal(200), Literal(300))
+      b <- Seq(Literal(10), Literal(20), Literal(30))
+      a <- Seq(Literal(1), Literal(2), Literal(3))
+    } yield Add(Add(a, b), cd)
+    assert(transformed === expected)
+  }
+
+  test("multiTransformDown is lazy") {
+    val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), 
Literal("d")))
+    val transformed = e.multiTransformDown {
+      case StringLiteral("a") => Stream(Literal(1), Literal(2), Literal(3))
+      case StringLiteral("b") => newErrorAfterStream(Literal(10))
+      case Add(StringLiteral("c"), StringLiteral("d"), _) => 
newErrorAfterStream(Literal(100))
+    }
+    val expected = for {
+      a <- Seq(Literal(1), Literal(2), Literal(3))
+    } yield Add(Add(a, Literal(10)), Literal(100))
+    // We don't access alternatives for `b` after 10 and for `c` after 100
+    assert(transformed.take(3) == expected)
+    intercept[NoSuchElementException] {
+      transformed.take(3 + 1).toList
+    }
+
+    val transformed2 = e.multiTransformDown {
+      case StringLiteral("a") => Stream(Literal(1), Literal(2), Literal(3))
+      case StringLiteral("b") => Stream(Literal(10), Literal(20), Literal(30))
+      case Add(StringLiteral("c"), StringLiteral("d"), _) => 
newErrorAfterStream(Literal(100))
+    }
+    val expected2 = for {
+      b <- Seq(Literal(10), Literal(20), Literal(30))
+      a <- Seq(Literal(1), Literal(2), Literal(3))
+    } yield Add(Add(a, b), Literal(100))
+    // We don't access alternatives for `c` after 100
+    assert(transformed2.take(3 * 3) === expected2)
+    intercept[NoSuchElementException] {
+      transformed.take(3 * 3 + 1).toList
+    }
+  }
+
+  test("multiTransformDown rule return this") {
+    val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), 
Literal("d")))
+    val transformed = e.multiTransformDown {
+      case s @ StringLiteral("a") => Stream(Literal(1), Literal(2), s)
+      case s @ StringLiteral("b") => Stream(Literal(10), Literal(20), s)
+      case a @ Add(StringLiteral("c"), StringLiteral("d"), _) =>
+        Stream(Literal(100), Literal(200), a)
+    }
+    val expected = for {
+      cd <- Seq(Literal(100), Literal(200), Add(Literal("c"), Literal("d")))
+      b <- Seq(Literal(10), Literal(20), Literal("b"))
+      a <- Seq(Literal(1), Literal(2), Literal("a"))
+    } yield Add(Add(a, b), cd)
+    assert(transformed == expected)
+  }
+
+  test("multiTransformDown doesn't stop generating alternatives of descendants 
when non-leaf is " +
+    "transformed and itself is in the alternatives") {
+    val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), 
Literal("d")))
+    val transformed = e.multiTransformDown {
+      case a @ Add(StringLiteral("a"), StringLiteral("b"), _) =>
+        Stream(Literal(11), Literal(12), Literal(21), Literal(22), a)
+      case StringLiteral("a") => Stream(Literal(1), Literal(2))
+      case StringLiteral("b") => Stream(Literal(10), Literal(20))
+      case Add(StringLiteral("c"), StringLiteral("d"), _) => 
Stream(Literal(100), Literal(200))
+    }
+    val expected = for {
+      cd <- Seq(Literal(100), Literal(200))
+      ab <- Seq(Literal(11), Literal(12), Literal(21), Literal(22)) ++
+        (for {
+          b <- Seq(Literal(10), Literal(20))
+          a <- Seq(Literal(1), Literal(2))
+        } yield Add(a, b))
+    } yield Add(ab, cd)
+    assert(transformed == expected)
+  }
+
+  test("multiTransformDown can prune") {
+    val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), 
Literal("d")))
+    val transformed = e.multiTransformDown {
+      case StringLiteral("a") => Stream.empty
+    }
+    assert(transformed.isEmpty)
+
+    val transformed2 = e.multiTransformDown {
+      case Add(StringLiteral("c"), StringLiteral("d"), _) => Stream.empty
+    }
+    assert(transformed2.isEmpty)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to