This is an automated email from the ASF dual-hosted git repository.

ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 87298db43d9 [SPARK-44503][SQL] Project any PARTITION BY expressions 
not already returned from Python UDTF TABLE arguments
87298db43d9 is described below

commit 87298db43d9a33fa3a3986f274442a17aad74dc3
Author: Daniel Tenedorio <daniel.tenedo...@databricks.com>
AuthorDate: Wed Aug 9 10:27:07 2023 -0700

    [SPARK-44503][SQL] Project any PARTITION BY expressions not already 
returned from Python UDTF TABLE arguments
    
    ### What changes were proposed in this pull request?
    
    This PR adds a projection when any Python UDTF TABLE argument contains 
PARTITION BY expressions that are not simple attributes that are already 
present in the output of the relation.
    
    For example:
    
    ```
    CREATE TABLE t(d DATE, y INT) USING PARQUET;
    INSERT INTO t VALUES ...
    SELECT * FROM UDTF(TABLE(t) PARTITION BY EXTRACT(YEAR FROM d) ORDER BY y 
ASC);
    ```
    
    This will generate a plan like:
    
    ```
    +- Sort (y ASC)
      +- RepartitionByExpressions (partition_by_0)
        +- Project (t.d, t.y, EXTRACT(YEAR FROM t.d) AS partition_by_0)
          +- LogicalRelation "t"
    ```
    
    ### Why are the changes needed?
    
    We project the PARTITION BY expressions so that their resulting values 
appear in attributes that the Python UDTF interpreter can simply inspect in 
order to know when the partition boundaries have changed.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    This PR adds unit test coverage.
    
    Closes #42351 from dtenedor/partition-by-execution.
    
    Authored-by: Daniel Tenedorio <daniel.tenedo...@databricks.com>
    Signed-off-by: Takuya UESHIN <ues...@databricks.com>
---
 .../FunctionTableSubqueryArgumentExpression.scala  |  77 +++++++++++--
 .../sql/execution/python/PythonUDTFSuite.scala     | 127 +++++++++++++++++++--
 2 files changed, 184 insertions(+), 20 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala
index e7a4888125d..daa0751eedf 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala
@@ -104,23 +104,80 @@ case class FunctionTableSubqueryArgumentExpression(
     // the query plan.
     var subquery = plan
     if (partitionByExpressions.nonEmpty) {
-      subquery = RepartitionByExpression(
-        partitionExpressions = partitionByExpressions,
-        child = subquery,
-        optNumPartitions = None)
+      // Add a projection to project each of the partitioning expressions that 
it is not a simple
+      // attribute that is already present in the plan output. Then add a sort 
operation by the
+      // partition keys (plus any explicit ORDER BY items) since after the 
hash-based shuffle
+      // operation, the rows from several partitions may arrive interleaved. 
In this way, the Python
+      // UDTF evaluator is able to inspect the values of the partitioning 
expressions for adjacent
+      // rows in order to determine when each partition ends and the next one 
begins.
+      subquery = Project(
+        projectList = subquery.output ++ extraProjectedPartitioningExpressions,
+        child = subquery)
+      val partitioningAttributes = partitioningExpressionIndexes.map(i => 
subquery.output(i))
+      subquery = Sort(
+        order = partitioningAttributes.map(e => SortOrder(e, Ascending)) ++ 
orderByExpressions,
+        global = false,
+        child = RepartitionByExpression(
+          partitionExpressions = partitioningAttributes,
+          optNumPartitions = None,
+          child = subquery))
     }
     if (withSinglePartition) {
       subquery = Repartition(
         numPartitions = 1,
         shuffle = true,
         child = subquery)
-    }
-    if (orderByExpressions.nonEmpty) {
-      subquery = Sort(
-        order = orderByExpressions,
-        global = false,
-        child = subquery)
+      if (orderByExpressions.nonEmpty) {
+        subquery = Sort(
+          order = orderByExpressions,
+          global = false,
+          child = subquery)
+      }
     }
     Project(Seq(Alias(CreateStruct(subquery.output), "c")()), subquery)
   }
+
+  /**
+   * These are the indexes of the PARTITION BY expressions within the 
concatenation of the child's
+   * output attributes and the [[extraProjectedPartitioningExpressions]]. We 
send these indexes to
+   * the Python UDTF evaluator so it knows which expressions to compare on 
adjacent rows to know
+   * when the partition has changed.
+   */
+  lazy val partitioningExpressionIndexes: Seq[Int] = {
+    val extraPartitionByExpressionsToIndexes: Map[Expression, Int] =
+      extraProjectedPartitioningExpressions.map(_.child).zipWithIndex.toMap
+    partitionByExpressions.map { e =>
+      subqueryOutputs.get(e).getOrElse {
+        extraPartitionByExpressionsToIndexes.get(e).get + plan.output.length
+      }
+    }
+  }
+
+  private lazy val extraProjectedPartitioningExpressions: Seq[Alias] = {
+    partitionByExpressions.filter { e =>
+      !subqueryOutputs.contains(e)
+    }.zipWithIndex.map { case (expr, index) =>
+      Alias(expr, s"partition_by_$index")()
+    }
+  }
+
+  private lazy val subqueryOutputs: Map[Expression, Int] = 
plan.output.zipWithIndex.toMap
 }
+
+object FunctionTableSubqueryArgumentExpression {
+  /**
+   * Returns a sequence of zero-based integer indexes identifying the values 
of a Python UDTF's
+   * 'eval' method's *args list that correspond to partitioning columns of the 
input TABLE argument.
+   */
+  def partitionChildIndexes(udtfArguments: Seq[Expression]): Seq[Int] = {
+    udtfArguments.zipWithIndex.flatMap { case (expr, index) =>
+      expr match {
+        case f: FunctionTableSubqueryArgumentExpression =>
+          f.partitioningExpressionIndexes.map(_ + index)
+        case _ =>
+          Seq()
+      }
+    }
+  }
+}
+
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala
index 8f1bf172bbd..43f61a7c61e 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala
@@ -19,7 +19,8 @@ package org.apache.spark.sql.execution.python
 
 import org.apache.spark.api.python.PythonEvalType
 import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils, 
QueryTest, Row}
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, 
LogicalPlan, Repartition, RepartitionByExpression, Sort, SubqueryAlias}
+import org.apache.spark.sql.catalyst.expressions.{Add, Alias, 
FunctionTableSubqueryArgumentExpression, Literal}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, 
LogicalPlan, OneRowRelation, Project, Repartition, RepartitionByExpression, 
Sort, SubqueryAlias}
 import org.apache.spark.sql.functions.lit
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.types.StructType
@@ -112,7 +113,9 @@ class PythonUDTFSuite extends QueryTest with 
SharedSparkSession {
   test("SPARK-44503: Specify PARTITION BY and ORDER BY for TABLE arguments") {
     // Positive tests
     assume(shouldTestPythonUDFs)
-    def failure(plan: LogicalPlan): Unit = fail(s"Unexpected plan: $plan")
+    def failure(plan: LogicalPlan): Unit = {
+      fail(s"Unexpected plan: $plan")
+    }
     sql(
       """
         |SELECT * FROM testUDTF(
@@ -120,8 +123,12 @@ class PythonUDTFSuite extends QueryTest with 
SharedSparkSession {
         |  PARTITION BY X)
         |""".stripMargin).queryExecution.analyzed
       .collectFirst { case r: RepartitionByExpression => r }.get match {
-      case RepartitionByExpression(_, SubqueryAlias(_, _: LocalRelation), _, 
_) =>
-      case other => failure(other)
+      case RepartitionByExpression(
+        _, Project(
+          _, SubqueryAlias(
+            _, _: LocalRelation)), _, _) =>
+      case other =>
+        failure(other)
     }
     sql(
       """
@@ -130,8 +137,11 @@ class PythonUDTFSuite extends QueryTest with 
SharedSparkSession {
         |  WITH SINGLE PARTITION)
         |""".stripMargin).queryExecution.analyzed
       .collectFirst { case r: Repartition => r }.get match {
-      case Repartition(1, true, SubqueryAlias(_, _: LocalRelation)) =>
-      case other => failure(other)
+      case Repartition(
+        1, true, SubqueryAlias(
+          _, _: LocalRelation)) =>
+      case other =>
+        failure(other)
     }
     sql(
       """
@@ -140,8 +150,13 @@ class PythonUDTFSuite extends QueryTest with 
SharedSparkSession {
         |  PARTITION BY SUBSTR(X, 2) ORDER BY (X, Y))
         |""".stripMargin).queryExecution.analyzed
       .collectFirst { case r: Sort => r }.get match {
-      case Sort(_, false, RepartitionByExpression(_, SubqueryAlias(_, _: 
LocalRelation), _, _)) =>
-      case other => failure(other)
+      case Sort(
+        _, false, RepartitionByExpression(
+          _, Project(
+            _, SubqueryAlias(
+              _, _: LocalRelation)), _, _)) =>
+      case other =>
+        failure(other)
     }
     sql(
       """
@@ -150,8 +165,12 @@ class PythonUDTFSuite extends QueryTest with 
SharedSparkSession {
         |  WITH SINGLE PARTITION ORDER BY (X, Y))
         |""".stripMargin).queryExecution.analyzed
       .collectFirst { case r: Sort => r }.get match {
-      case Sort(_, false, Repartition(1, true, SubqueryAlias(_, _: 
LocalRelation))) =>
-      case other => failure(other)
+      case Sort(
+        _, false, Repartition(
+          1, true, SubqueryAlias(
+            _, _: LocalRelation))) =>
+      case other =>
+        failure(other)
     }
     // Negative tests
     withTable("t") {
@@ -172,4 +191,92 @@ class PythonUDTFSuite extends QueryTest with 
SharedSparkSession {
           stop = 30))
     }
   }
+
+  test("SPARK-44503: Compute partition child indexes for various UDTF argument 
lists") {
+    // Each of the following tests calls the PythonUDTF.partitionChildIndexes 
with a list of
+    // expressions and then checks the PARTITION BY child expression indexes 
that come out.
+    val projectList = Seq(
+      Alias(Literal(42), "a")(),
+      Alias(Literal(43), "b")())
+    val projectTwoValues = Project(
+      projectList = projectList,
+      child = OneRowRelation())
+    // There are no UDTF TABLE arguments, so there are no PARTITION BY child 
expression indexes.
+    val partitionChildIndexes = 
FunctionTableSubqueryArgumentExpression.partitionChildIndexes(_)
+    assert(partitionChildIndexes(Seq(
+      Literal(41))) ==
+      Seq.empty[Int])
+    assert(partitionChildIndexes(Seq(
+      Literal(41),
+      Literal("abc"))) ==
+      Seq.empty[Int])
+    // The UDTF TABLE argument has no PARTITION BY expressions, so there are 
no PARTITION BY child
+    // expression indexes.
+    assert(partitionChildIndexes(Seq(
+      FunctionTableSubqueryArgumentExpression(
+        plan = projectTwoValues))) ==
+      Seq.empty[Int])
+    // The UDTF TABLE argument has two PARTITION BY expressions which are 
equal to the output
+    // attributes from the provided relation, in order. Therefore the 
PARTITION BY child expression
+    // indexes are 0 and 1.
+    assert(partitionChildIndexes(Seq(
+      FunctionTableSubqueryArgumentExpression(
+        plan = projectTwoValues,
+        partitionByExpressions = projectTwoValues.output))) ==
+      Seq(0, 1))
+    // The UDTF TABLE argument has one PARTITION BY expression which is equal 
to the first output
+    // attribute from the provided relation. Therefore the PARTITION BY child 
expression index is 0.
+    assert(partitionChildIndexes(Seq(
+      FunctionTableSubqueryArgumentExpression(
+        plan = projectTwoValues,
+        partitionByExpressions = Seq(projectList.head.toAttribute)))) ==
+      Seq(0))
+    // The UDTF TABLE argument has one PARTITION BY expression which is equal 
to the second output
+    // attribute from the provided relation. Therefore the PARTITION BY child 
expression index is 1.
+    assert(partitionChildIndexes(Seq(
+      FunctionTableSubqueryArgumentExpression(
+        plan = projectTwoValues,
+        partitionByExpressions = Seq(projectList.last.toAttribute)))) ==
+      Seq(1))
+    // The UDTF has one scalar argument, then one TABLE argument, then another 
scalar argument. The
+    // TABLE argument has two PARTITION BY expressions which are equal to the 
output attributes from
+    // the provided relation, in order. Therefore the PARTITION BY child 
expression indexes are 1
+    // and 2, because they begin at an offset of 1 from the zero-based start 
of the list of values
+    // provided to the UDTF 'eval' method.
+    assert(partitionChildIndexes(Seq(
+      Literal(41),
+      FunctionTableSubqueryArgumentExpression(
+        plan = projectTwoValues,
+        partitionByExpressions = projectTwoValues.output),
+      Literal("abc"))) ==
+      Seq(1, 2))
+    // Same as above, but the PARTITION BY expressions are new expressions 
which must be projected
+    // after all the attributes from the relation provided to the UDTF TABLE 
argument. Therefore the
+    // PARTITION BY child indexes are 3 and 4 because they begin at an offset 
of 3 from the
+    // zero-based start of the list of values provided to the UDTF 'eval' 
method.
+    assert(partitionChildIndexes(Seq(
+      Literal(41),
+      FunctionTableSubqueryArgumentExpression(
+        plan = projectTwoValues,
+        partitionByExpressions = Seq(Literal(42), Literal(43))),
+      Literal("abc"))) ==
+      Seq(3, 4))
+    // Same as above, but the PARTITION BY list comprises just one addition 
expression.
+    assert(partitionChildIndexes(Seq(
+      Literal(41),
+      FunctionTableSubqueryArgumentExpression(
+        plan = projectTwoValues,
+        partitionByExpressions = Seq(Add(projectList.head.toAttribute, 
Literal(1)))),
+      Literal("abc"))) ==
+      Seq(3))
+    // Same as above, but the PARTITION BY list comprises one literal value 
and one addition
+    // expression.
+    assert(partitionChildIndexes(Seq(
+      Literal(41),
+      FunctionTableSubqueryArgumentExpression(
+        plan = projectTwoValues,
+        partitionByExpressions = Seq(Literal(42), 
Add(projectList.head.toAttribute, Literal(1)))),
+      Literal("abc"))) ==
+      Seq(3, 4))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to