dtenedor commented on code in PR #42420:
URL: https://github.com/apache/spark/pull/42420#discussion_r1293778050


##########
python/pyspark/worker.py:
##########
@@ -621,22 +663,30 @@ def verify_result(result):
 
             return lambda *a: map(lambda res: (res, arrow_return_type), 
map(verify_result, f(*a)))
 
-        eval = wrap_arrow_udtf(getattr(udtf, "eval"), return_type)
-
-        if hasattr(udtf, "terminate"):
-            terminate = wrap_arrow_udtf(getattr(udtf, "terminate"), 
return_type)
-        else:
-            terminate = None
+        udtf_state.eval = wrap_arrow_udtf(getattr(udtf_state.udtf, "eval"), 
return_type)
+        udtf_state.set_terminate(wrap_arrow_udtf, return_type)
 
         def mapper(_, it):
             try:
                 for a in it:
                     # The eval function yields an iterator. Each element 
produced by this
                     # iterator is a tuple in the form of (pandas.DataFrame, 
arrow_return_type).
-                    yield from eval(*[a[o] for o in arg_offsets])
+                    arguments = [a[o] for o in arg_offsets]
+                    changed_partitions = check_partition_boundaries(arguments)
+                    if changed_partitions:
+                        # Call 'terminate' on the UDTF class instance, if 
applicable.
+                        # Then destroy the UDTF class instance and create a 
new one.
+                        if udtf_state.terminate is not None:
+                            yield from udtf_state.terminate()
+                        create_udtf_classs_instance(return_type)

Review Comment:
   Good catch, fixed!



##########
python/pyspark/worker.py:
##########
@@ -564,23 +566,66 @@ def read_udtf(pickleSer, infile, eval_type):
             f"The return type of a UDTF must be a struct type, but got 
{type(return_type)}."
         )
 
+    # This class holds the UDTF class instance and associated state as we 
evaluate rows.
+    # We keep this state in a class to simplify updating it as needed within 
nested
+    # function calls, such as when we destroy the UDTF class instance when 
partition
+    # boundaries change and then create a new one.
+    class UdtfState:
+        def __init__(self):
+            self.udtf = None
+            self.prev_arguments = None
+            self.eval = None
+            self.terminate = None
+
+        def set_terminate(self, wrap_udtf, return_type):
+            if hasattr(self.udtf, "terminate"):
+                self.terminate = wrap_udtf(getattr(self.udtf, "terminate"), 
return_type)
+            else:
+                self.terminate = None
+
+    udtf_state = UdtfState()
+
     # Instantiate the UDTF class.
-    try:
-        udtf = handler()
-    except Exception as e:
-        raise PySparkRuntimeError(
-            error_class="UDTF_EXEC_ERROR",
-            message_parameters={"method_name": "__init__", "error": str(e)},
-        )
+    def create_udtf_class_instance(return_type):
+        try:
+            if udtf_state.udtf is not None:
+                del udtf_state.udtf
+            udtf_state.udtf = handler()
+            udtf_state.prev_arguments = None
+        except Exception as e:
+            raise PySparkRuntimeError(
+                error_class="UDTF_EXEC_ERROR",
+                message_parameters={"method_name": "__init__", "error": 
str(e)},
+            )
+
+    create_udtf_class_instance(return_type)
 
     # Validate the UDTF
-    if not hasattr(udtf, "eval"):
+    if not hasattr(udtf_state.udtf, "eval"):
         raise PySparkRuntimeError(
             "Failed to execute the user defined table function because it has 
not "
             "implemented the 'eval' method. Please add the 'eval' method and 
try "
             "the query again."
         )
 
+    # Inspects the values of the projected PARTITION BY expressions, if any. 
Returns true when
+    # these values change, in which case the caller should invoke the 
'terminate' method on the
+    # UDTF class instance and then destroy it and creates a new one to 
implement the desired
+    # partitioning semantics.
+    def check_partition_boundaries(arguments):
+        if num_partition_child_indexes == 0 or udtf_state.prev_arguments is 
None:
+            udtf_state.prev_arguments = arguments
+            return False
+        cur_table_arg = [arg for arg in arguments if type(arg) is Row][0]
+        prev_table_arg = [arg for arg in udtf_state.prev_arguments if 
type(arg) is Row][0]
+        cur_partition_values = [cur_table_arg[i] for i in 
partition_child_indexes]
+        prev_partition_values = [prev_table_arg[i] for i in 
partition_child_indexes]
+        udtf_state.prev_arguments = arguments

Review Comment:
   Reviewing the code again, you are right. We should be manually setting 
`udtf_state.prev_arguments = None` at the start just once, and then not doing 
that in `create_udtf_class_instance`. I made this change.
   
   I tried to think of a unit test that would have caught this bug, but 
`test_udtf_with_table_argument_and_partition_by_and_order_by` should have 
already caught it; I am not sure why it didn't before. Leaving that test as-is 
for now.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to