ueshin commented on code in PR #42420:
URL: https://github.com/apache/spark/pull/42420#discussion_r1291828740


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala:
##########
@@ -2117,11 +2120,25 @@ class Analyzer(override val catalogManager: 
CatalogManager) extends RuleExecutor
                   tableArgs.size)
               }
               val alias = 
SubqueryAlias.generateSubqueryName(s"_${tableArgs.size}")
+              // Propagate the column indexes for TABLE arguments to the 
PythonUDTF instance.
+              val tvfWithTableColumnIndexes: LogicalPlan = tvf match {
+                case g @ Generate(p: PythonUDTF, _, _, _, _, _) =>
+                  functionTableSubqueryArgs.headOption.map { tableArg =>
+                    val indexes = PythonUDTFPartitionColumnIndexes(
+                      numPartitionChildIndexes = 
tableArg.partitioningExpressionIndexes.length,

Review Comment:
   Do we need this? It seems pretty easy to get by 
`partitionChildIndexes.length`.



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala:
##########
@@ -2117,11 +2120,25 @@ class Analyzer(override val catalogManager: 
CatalogManager) extends RuleExecutor
                   tableArgs.size)
               }
               val alias = 
SubqueryAlias.generateSubqueryName(s"_${tableArgs.size}")
+              // Propagate the column indexes for TABLE arguments to the 
PythonUDTF instance.
+              val tvfWithTableColumnIndexes: LogicalPlan = tvf match {
+                case g @ Generate(p: PythonUDTF, _, _, _, _, _) =>

Review Comment:
   We also need to take care of `UnresolvedPolymorphicPythonUDTF`?



##########
python/pyspark/worker.py:
##########
@@ -564,23 +566,66 @@ def read_udtf(pickleSer, infile, eval_type):
             f"The return type of a UDTF must be a struct type, but got 
{type(return_type)}."
         )
 
+    # This class holds the UDTF class instance and associated state as we 
evaluate rows.
+    # We keep this state in a class to simplify updating it as needed within 
nested
+    # function calls, such as when we destroy the UDTF class instance when 
partition
+    # boundaries change and then create a new one.
+    class UdtfState:
+        def __init__(self):
+            self.udtf = None
+            self.prev_arguments = None
+            self.eval = None
+            self.terminate = None
+
+        def set_terminate(self, wrap_udtf, return_type):
+            if hasattr(self.udtf, "terminate"):
+                self.terminate = wrap_udtf(getattr(self.udtf, "terminate"), 
return_type)
+            else:
+                self.terminate = None
+
+    udtf_state = UdtfState()
+
     # Instantiate the UDTF class.
-    try:
-        udtf = handler()
-    except Exception as e:
-        raise PySparkRuntimeError(
-            error_class="UDTF_EXEC_ERROR",
-            message_parameters={"method_name": "__init__", "error": str(e)},
-        )
+    def create_udtf_class_instance(return_type):
+        try:
+            if udtf_state.udtf is not None:
+                del udtf_state.udtf
+            udtf_state.udtf = handler()
+            udtf_state.prev_arguments = None
+        except Exception as e:
+            raise PySparkRuntimeError(
+                error_class="UDTF_EXEC_ERROR",
+                message_parameters={"method_name": "__init__", "error": 
str(e)},
+            )
+
+    create_udtf_class_instance(return_type)
 
     # Validate the UDTF
-    if not hasattr(udtf, "eval"):
+    if not hasattr(udtf_state.udtf, "eval"):
         raise PySparkRuntimeError(
             "Failed to execute the user defined table function because it has 
not "
             "implemented the 'eval' method. Please add the 'eval' method and 
try "
             "the query again."
         )
 
+    # Inspects the values of the projected PARTITION BY expressions, if any. 
Returns true when
+    # these values change, in which case the caller should invoke the 
'terminate' method on the
+    # UDTF class instance and then destroy it and creates a new one to 
implement the desired
+    # partitioning semantics.
+    def check_partition_boundaries(arguments):
+        if num_partition_child_indexes == 0 or udtf_state.prev_arguments is 
None:
+            udtf_state.prev_arguments = arguments
+            return False
+        cur_table_arg = [arg for arg in arguments if type(arg) is Row][0]
+        prev_table_arg = [arg for arg in udtf_state.prev_arguments if 
type(arg) is Row][0]
+        cur_partition_values = [cur_table_arg[i] for i in 
partition_child_indexes]
+        prev_partition_values = [prev_table_arg[i] for i in 
partition_child_indexes]
+        udtf_state.prev_arguments = arguments

Review Comment:
   This should be set before `eval` is called?
   
   
https://github.com/apache/spark/blob/58783339e51e7c6198586b32ecd4af3a971b7097/python/pyspark/worker.py#L689
   
https://github.com/apache/spark/blob/58783339e51e7c6198586b32ecd4af3a971b7097/python/pyspark/worker.py#L767



##########
python/pyspark/worker.py:
##########
@@ -564,23 +566,66 @@ def read_udtf(pickleSer, infile, eval_type):
             f"The return type of a UDTF must be a struct type, but got 
{type(return_type)}."
         )
 
+    # This class holds the UDTF class instance and associated state as we 
evaluate rows.
+    # We keep this state in a class to simplify updating it as needed within 
nested
+    # function calls, such as when we destroy the UDTF class instance when 
partition
+    # boundaries change and then create a new one.
+    class UdtfState:
+        def __init__(self):
+            self.udtf = None
+            self.prev_arguments = None
+            self.eval = None
+            self.terminate = None
+
+        def set_terminate(self, wrap_udtf, return_type):
+            if hasattr(self.udtf, "terminate"):
+                self.terminate = wrap_udtf(getattr(self.udtf, "terminate"), 
return_type)
+            else:
+                self.terminate = None
+
+    udtf_state = UdtfState()
+
     # Instantiate the UDTF class.
-    try:
-        udtf = handler()
-    except Exception as e:
-        raise PySparkRuntimeError(
-            error_class="UDTF_EXEC_ERROR",
-            message_parameters={"method_name": "__init__", "error": str(e)},
-        )
+    def create_udtf_class_instance(return_type):

Review Comment:
   Seems like `return_type` is not be used in this function?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to