This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c953610deaf [SPARK-40599][SQL] Relax multiTransform rule type to allow 
alternatives to be any kinds of Seq
c953610deaf is described below

commit c953610deafda769feb85fbb936591ffc4448f8e
Author: Peter Toth <[email protected]>
AuthorDate: Thu Jan 19 23:54:06 2023 +0800

    [SPARK-40599][SQL] Relax multiTransform rule type to allow alternatives to 
be any kinds of Seq
    
    ### What changes were proposed in this pull request?
    This is a follow-up PR to https://github.com/apache/spark/pull/38034. It 
relaxes `multiTransformDown()`'s `rule` parameter type to accept any kinds of 
`Seq` and make `MultiTransform.generateCartesianProduct()` helper public.
    
    ### Why are the changes needed?
    API mprovement.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Existing UTs.
    
    Closes #39652 from peter-toth/SPARK-40599-multitransform-follow-up.
    
    Authored-by: Peter Toth <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../apache/spark/sql/catalyst/trees/TreeNode.scala | 70 +++++++++++++---------
 .../spark/sql/catalyst/trees/TreeNodeSuite.scala   | 31 +++++-----
 2 files changed, 57 insertions(+), 44 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index dc64e5e2560..c8df2086a72 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -626,7 +626,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] 
extends Product with Tre
    * @return     the stream of alternatives
    */
   def multiTransformDown(
-      rule: PartialFunction[BaseType, Stream[BaseType]]): Stream[BaseType] = {
+      rule: PartialFunction[BaseType, Seq[BaseType]]): Stream[BaseType] = {
     multiTransformDownWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule)
   }
 
@@ -639,10 +639,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] 
extends Product with Tre
    * as a lazy `Stream` to be able to limit the number of alternatives 
generated at the caller side
    * as needed.
    *
-   * The rule should not apply or can return a one element stream of original 
node to indicate that
+   * The purpose of this function to access the returned alternatives by the 
rule only if they are
+   * needed so the rule can return a `Stream` whose elements are also lazily 
calculated.
+   * E.g. `multiTransform*` calls can be nested with the help of
+   * `MultiTransform.generateCartesianProduct()`.
+   *
+   * The rule should not apply or can return a one element `Seq` of original 
node to indicate that
    * the original node without any transformation is a valid alternative.
    *
-   * The rule can return `Stream.empty` to indicate that the original node 
should be pruned. In this
+   * The rule can return `Seq.empty` to indicate that the original node should 
be pruned. In this
    * case `multiTransform()` returns an empty `Stream`.
    *
    * Please consider the following examples of 
`input.multiTransformDown(rule)`:
@@ -652,9 +657,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] 
extends Product with Tre
    *
    * 1.
    * We have a simple rule:
-   *   `a` => `Stream(1, 2)`
-   *   `b` => `Stream(10, 20)`
-   *   `Add(a, b)` => `Stream(11, 12, 21, 22)`
+   *   `a` => `Seq(1, 2)`
+   *   `b` => `Seq(10, 20)`
+   *   `Add(a, b)` => `Seq(11, 12, 21, 22)`
    *
    * The output is:
    *   `Stream(11, 12, 21, 22)`
@@ -662,9 +667,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] 
extends Product with Tre
    * 2.
    * In the previous example if we want to generate alternatives of `a` and 
`b` too then we need to
    * explicitly add the original `Add(a, b)` expression to the rule:
-   *   `a` => `Stream(1, 2)`
-   *   `b` => `Stream(10, 20)`
-   *   `Add(a, b)` => `Stream(11, 12, 21, 22, Add(a, b))`
+   *   `a` => `Seq(1, 2)`
+   *   `b` => `Seq(10, 20)`
+   *   `Add(a, b)` => `Seq(11, 12, 21, 22, Add(a, b))`
    *
    * The output is:
    *   `Stream(11, 12, 21, 22, Add(1, 10), Add(2, 10), Add(1, 20), Add(2, 20))`
@@ -683,25 +688,25 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] 
extends Product with Tre
   def multiTransformDownWithPruning(
       cond: TreePatternBits => Boolean,
       ruleId: RuleId = UnknownRuleId
-    )(rule: PartialFunction[BaseType, Stream[BaseType]]): Stream[BaseType] = {
+    )(rule: PartialFunction[BaseType, Seq[BaseType]]): Stream[BaseType] = {
     if (!cond.apply(this) || isRuleIneffective(ruleId)) {
       return Stream(this)
     }
 
-    // We could return `Stream(this)` if the `rule` doesn't apply and handle 
both
+    // We could return `Seq(this)` if the `rule` doesn't apply and handle both
     // - the doesn't apply
-    // - and the rule returns a one element `Stream(originalNode)`
-    // cases together. But, unfortunately it doesn't seem like there is a way 
to match on a one
-    // element stream without eagerly computing the tail head. So this 
contradicts with the purpose
-    // of only taking the necessary elements from the alternatives. I.e. the
-    // "multiTransformDown is lazy" test case in `TreeNodeSuite` would fail.
+    // - and the rule returns a one element `Seq(originalNode)`
+    // cases together. The returned `Seq` can be a `Stream` and unfortunately 
it doesn't seem like
+    // there is a way to match on a one element stream without eagerly 
computing the tail's head.
+    // This contradicts with the purpose of only taking the necessary elements 
from the
+    // alternatives. I.e. the "multiTransformDown is lazy" test case in 
`TreeNodeSuite` would fail.
     // Please note that this behaviour has a downside as well that we can only 
mark the rule on the
     // original node ineffective if the rule didn't match.
     var ruleApplied = true
     val afterRules = CurrentOrigin.withOrigin(origin) {
       rule.applyOrElse(this, (_: BaseType) => {
         ruleApplied = false
-        Stream.empty
+        Seq.empty
       })
     }
 
@@ -716,7 +721,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] 
extends Product with Tre
       }
     } else {
       // If the rule was applied then use the returned alternatives
-      afterRules.map { afterRule =>
+      afterRules.toStream.map { afterRule =>
         if (this fastEquals afterRule) {
           this
         } else {
@@ -728,7 +733,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] 
extends Product with Tre
 
     afterRulesStream.flatMap { afterRule =>
       if (afterRule.containsChild.nonEmpty) {
-        generateChildrenSeq(
+        MultiTransform.generateCartesianProduct(
             afterRule.children.map(_.multiTransformDownWithPruning(cond, 
ruleId)(rule)))
           .map(afterRule.withNewChildren)
       } else {
@@ -737,15 +742,6 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] 
extends Product with Tre
     }
   }
 
-  private def generateChildrenSeq[T](childrenStreams: Seq[Stream[T]]): 
Stream[Seq[T]] = {
-    childrenStreams.foldRight(Stream(Seq.empty[T]))((childrenStream, 
childrenSeqStream) =>
-      for {
-        childrenSeq <- childrenSeqStream
-        child <- childrenStream
-      } yield child +: childrenSeq
-    )
-  }
-
   /**
    * Returns a copy of this node where `f` has been applied to all the nodes 
in `children`.
    */
@@ -1368,3 +1364,21 @@ trait QuaternaryLike[T <: TreeNode[T]] { self: 
TreeNode[T] =>
 
   protected def withNewChildrenInternal(newFirst: T, newSecond: T, newThird: 
T, newFourth: T): T
 }
+
+object MultiTransform {
+
+  /**
+   * Returns the stream of `Seq` elements by generating the cartesian product 
of sequences.
+   *
+   * @param elementSeqs a list of sequences to build the cartesian product from
+   * @return            the stream of generated `Seq` elements
+   */
+  def generateCartesianProduct[T](elementSeqs: Seq[Seq[T]]): Stream[Seq[T]] = {
+    elementSeqs.foldRight(Stream(Seq.empty[T]))((elements, elementTails) =>
+      for {
+        elementTail <- elementTails
+        element <- elements
+      } yield element +: elementTail
+    )
+  }
+}
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
index ac28917675e..e4adf59b392 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
@@ -987,10 +987,10 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
   test("multiTransformDown generates all alternatives") {
     val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), 
Literal("d")))
     val transformed = e.multiTransformDown {
-      case StringLiteral("a") => Stream(Literal(1), Literal(2), Literal(3))
-      case StringLiteral("b") => Stream(Literal(10), Literal(20), Literal(30))
+      case StringLiteral("a") => Seq(Literal(1), Literal(2), Literal(3))
+      case StringLiteral("b") => Seq(Literal(10), Literal(20), Literal(30))
       case Add(StringLiteral("c"), StringLiteral("d"), _) =>
-        Stream(Literal(100), Literal(200), Literal(300))
+        Seq(Literal(100), Literal(200), Literal(300))
     }
     val expected = for {
       cd <- Seq(Literal(100), Literal(200), Literal(300))
@@ -1003,7 +1003,7 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
   test("multiTransformDown is lazy") {
     val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), 
Literal("d")))
     val transformed = e.multiTransformDown {
-      case StringLiteral("a") => Stream(Literal(1), Literal(2), Literal(3))
+      case StringLiteral("a") => Seq(Literal(1), Literal(2), Literal(3))
       case StringLiteral("b") => newErrorAfterStream(Literal(10))
       case Add(StringLiteral("c"), StringLiteral("d"), _) => 
newErrorAfterStream(Literal(100))
     }
@@ -1017,8 +1017,8 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
     }
 
     val transformed2 = e.multiTransformDown {
-      case StringLiteral("a") => Stream(Literal(1), Literal(2), Literal(3))
-      case StringLiteral("b") => Stream(Literal(10), Literal(20), Literal(30))
+      case StringLiteral("a") => Seq(Literal(1), Literal(2), Literal(3))
+      case StringLiteral("b") => Seq(Literal(10), Literal(20), Literal(30))
       case Add(StringLiteral("c"), StringLiteral("d"), _) => 
newErrorAfterStream(Literal(100))
     }
     val expected2 = for {
@@ -1035,10 +1035,9 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper 
{
   test("multiTransformDown rule return this") {
     val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), 
Literal("d")))
     val transformed = e.multiTransformDown {
-      case s @ StringLiteral("a") => Stream(Literal(1), Literal(2), s)
-      case s @ StringLiteral("b") => Stream(Literal(10), Literal(20), s)
-      case a @ Add(StringLiteral("c"), StringLiteral("d"), _) =>
-        Stream(Literal(100), Literal(200), a)
+      case s @ StringLiteral("a") => Seq(Literal(1), Literal(2), s)
+      case s @ StringLiteral("b") => Seq(Literal(10), Literal(20), s)
+      case a @ Add(StringLiteral("c"), StringLiteral("d"), _) => 
Seq(Literal(100), Literal(200), a)
     }
     val expected = for {
       cd <- Seq(Literal(100), Literal(200), Add(Literal("c"), Literal("d")))
@@ -1053,10 +1052,10 @@ class TreeNodeSuite extends SparkFunSuite with 
SQLHelper {
     val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), 
Literal("d")))
     val transformed = e.multiTransformDown {
       case a @ Add(StringLiteral("a"), StringLiteral("b"), _) =>
-        Stream(Literal(11), Literal(12), Literal(21), Literal(22), a)
-      case StringLiteral("a") => Stream(Literal(1), Literal(2))
-      case StringLiteral("b") => Stream(Literal(10), Literal(20))
-      case Add(StringLiteral("c"), StringLiteral("d"), _) => 
Stream(Literal(100), Literal(200))
+        Seq(Literal(11), Literal(12), Literal(21), Literal(22), a)
+      case StringLiteral("a") => Seq(Literal(1), Literal(2))
+      case StringLiteral("b") => Seq(Literal(10), Literal(20))
+      case Add(StringLiteral("c"), StringLiteral("d"), _) => Seq(Literal(100), 
Literal(200))
     }
     val expected = for {
       cd <- Seq(Literal(100), Literal(200))
@@ -1072,12 +1071,12 @@ class TreeNodeSuite extends SparkFunSuite with 
SQLHelper {
   test("multiTransformDown can prune") {
     val e = Add(Add(Literal("a"), Literal("b")), Add(Literal("c"), 
Literal("d")))
     val transformed = e.multiTransformDown {
-      case StringLiteral("a") => Stream.empty
+      case StringLiteral("a") => Seq.empty
     }
     assert(transformed.isEmpty)
 
     val transformed2 = e.multiTransformDown {
-      case Add(StringLiteral("c"), StringLiteral("d"), _) => Stream.empty
+      case Add(StringLiteral("c"), StringLiteral("d"), _) => Seq.empty
     }
     assert(transformed2.isEmpty)
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to