allisonwang-db commented on code in PR #42174:
URL: https://github.com/apache/spark/pull/42174#discussion_r1276818608


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala:
##########
@@ -32,11 +32,22 @@ import org.apache.spark.sql.types.DataType
  * table in the catalog. In the latter case, the relation argument comprises
  * a table subquery that may itself refer to one or more tables in its own
  * FROM clause.
+ *
+ * Each TABLE argument may also optionally include a PARTITION BY clause. If 
present, these indicate
+ * how to logically split up the input relation such that the table-valued 
function evaluates
+ * exactly once for each partition, and returns the union of all results. If 
no partitioning list is
+ * present, this splitting of the input relation is undefined. Furthermore, if 
the PARTITION BY
+ * clause includes a following ORDER BY clause, Catalyst will sort the rows in 
each partition such
+ * that the table-valued function receives them one-by-one in the requested 
order. Otherwise, if no
+ * such ordering is specified, the ordering of rows within each partition is 
undefined.

Review Comment:
   Maybe we can use the javadoc style here to explain each parameter? `@param`



##########
python/pyspark/sql/tests/test_udtf.py:
##########
@@ -1444,6 +1447,83 @@ def terminate(self):
                     assertSchemaEqual(df.schema, StructType().add("col1", 
IntegerType()))
                     assertDataFrameEqual(df, [Row(col1=10), Row(col1=100)])
 
+    def test_udtf_call_with_partition_by(self):
+        class TestUDTF:
+            def __init__(self):
+                self._sum = 0
+
+            def eval(self, row: Row):
+                self._sum += row["x"]
+
+            def terminate(self):
+                yield self._sum,
+
+        func = udtf(TestUDTF, returnType="a: int")
+        self.spark.udtf.register("test_udtf_pb", func)
+
+        def actual(query: str) -> str:
+            df = self.spark.sql(query)
+            value = df.collect()[0][0]
+            stripExprIds = re.sub(r'#[\d]+', r'#xx', value)
+            stripPlanIds = re.sub(
+                r'plan_id=[\d]+', r'plan_id=xx', stripExprIds)
+            stripEvalType = re.sub(
+                r'\+- .....EvalPythonUDTF test_udtf_pb.*', r'+- EvalPythonUDTF 
test_udtf_pb',
+                stripPlanIds)
+            print('Query plan: ' + stripEvalType)
+            return stripEvalType.strip('\n')
+        def expected(input: str) -> str:
+            return textwrap.dedent(input).strip('\n')
+
+        self.assertEqual(

Review Comment:
   This is a great test! But I think it would be better if we add this test to 
the scala side: `PythonUDTFSuite.scala`



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala:
##########
@@ -61,5 +75,36 @@ case class FunctionTableSubqueryArgumentExpression(
   final override def nodePatternsInternal: Seq[TreePattern] =
     Seq(FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION)
 
-  lazy val evaluable: LogicalPlan = 
Project(Seq(Alias(CreateStruct(plan.output), "c")()), plan)
+  def hasRepartitioning: Boolean = withSinglePartition || 
partitionByExpressions.nonEmpty
+
+  lazy val evaluable: LogicalPlan = {
+    val subquery = if (hasRepartitioning) {
+      // If the TABLE argument includes the WITH SINGLE PARTITION or PARTITION 
BY or ORDER BY
+      // clause(s), add a corresponding logical operator to represent the 
repartitioning operation
+      // in the query plan.
+      RepartitionForTableFunctionCall(

Review Comment:
   Just curious, have we considered reusing `Sort` + 
`RepartitionByExpression`/`Repartition` nodes, instead of having a dedicated 
logical plan?



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala:
##########
@@ -2073,6 +2073,13 @@ class Analyzer(override val catalogManager: 
CatalogManager) extends RuleExecutor
               _.containsPattern(FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION), 
ruleId)  {
               case t: FunctionTableSubqueryArgumentExpression =>
                 val alias = 
SubqueryAlias.generateSubqueryName(s"_${tableArgs.size}")
+                resolvedFunc match {
+                  case Generate(_: PythonUDTF, _, _, _, _, _) =>
+                  case _ if t.hasRepartitioning =>
+                    throw 
QueryCompilationErrors.tableValuedFunctionPartitionByClauseNotSupported(

Review Comment:
   Can we add a test for this case?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to