This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d6b1fd4a7cd [SPARK-45401][PYTHON] Add a new method `cleanup` in the 
UDTF interface
d6b1fd4a7cd is described below

commit d6b1fd4a7cdaf1c6bfd0fff9cb7cece5111cd025
Author: allisonwang-db <[email protected]>
AuthorDate: Sun Oct 8 13:53:10 2023 -0700

    [SPARK-45401][PYTHON] Add a new method `cleanup` in the UDTF interface
    
    ### What changes were proposed in this pull request?
    
    This PR adds a new API `cleanup` to the current Python UDTF interface.
    
    ### Why are the changes needed?
    
    Currently, the `terminate` method of a UDTF is always executed, regardless 
of whether the `eval` method calls are successful. This is problematic as users 
might have certain logic in `terminate` that can only be executed once all 
`eval` calls succeed.
    
    But what if users wish to perform cleanup actions during UDTF execution, 
such as closing connections? One option is for users to embed a try...except 
logic within the eval call:
    
    ```
    def eval(self, row: Any):
      try:
        run_code()
      except Exception:
        clean_up()
    ```
    
    However, running this try-except block for every eval call can be 
expensive, potentially affecting the performance of UDTFs.
    
    To tackle this, we can introduce a new method in the UDTF interface that 
will be called regardless of the outcome. Now the logic would look like:
    
    ```
    try:
      for row in rows:
        eval(row)
      terminate()
    finally:
      cleanup()
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes.
    
    ### How was this patch tested?
    
    Unit test
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #43225 from allisonwang-db/spark-45401-udtf-cleanup.
    
    Authored-by: allisonwang-db <[email protected]>
    Signed-off-by: Dongjoon Hyun <[email protected]>
---
 python/docs/source/user_guide/sql/python_udtf.rst | 22 ++++++++-
 python/pyspark/sql/tests/test_udtf.py             | 56 +++++++++++++++++++++++
 python/pyspark/worker.py                          | 16 ++++++-
 3 files changed, 90 insertions(+), 4 deletions(-)

diff --git a/python/docs/source/user_guide/sql/python_udtf.rst 
b/python/docs/source/user_guide/sql/python_udtf.rst
index 0e583915c58..74d8eb88986 100644
--- a/python/docs/source/user_guide/sql/python_udtf.rst
+++ b/python/docs/source/user_guide/sql/python_udtf.rst
@@ -108,16 +108,19 @@ To implement a Python UDTF, you first need to define a 
class implementing the me
 
         def terminate(self) -> Iterator[Any]:
             """
-            Called when the UDTF has processed all input rows.
+            Called when the UDTF has successfully processed all input rows.
 
             This method is optional to implement and is useful for performing 
any
-            cleanup or finalization operations after the UDTF has finished 
processing
+            finalization operations after the UDTF has finished processing
             all rows. It can also be used to yield additional rows if needed.
             Table functions that consume all rows in the entire input partition
             and then compute and return the entire output table can do so from
             this method as well (please be mindful of memory usage when doing
             this).
 
+            If any exceptions occur during input row processing, this method
+            won't be called.
+
             Yields
             ------
             tuple
@@ -131,6 +134,21 @@ To implement a Python UDTF, you first need to define a 
class implementing the me
             """
             ...
 
+        def cleanup(self) -> None:
+            """
+            Invoked after the UDTF completes processing input rows.
+
+            This method is optional to implement and is useful for final 
cleanup
+            regardless of whether the UDTF processed all input rows 
successfully
+            or was aborted due to exceptions.
+
+            Examples
+            --------
+            >>> def cleanup(self) -> None:
+            >>>     self.conn.close()
+            """
+            ...
+
 
 The return type of the UDTF defines the schema of the table it outputs. 
 It must be either a ``StructType``, for example ``StructType().add("c1", 
StringType())``
diff --git a/python/pyspark/sql/tests/test_udtf.py 
b/python/pyspark/sql/tests/test_udtf.py
index a1d82056c50..9c821f4bde9 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -406,6 +406,62 @@ class BaseUDTFTestsMixin:
             ],
         )
 
+    def test_udtf_cleanup_with_exception_in_eval(self):
+        with tempfile.TemporaryDirectory() as d:
+            path = os.path.join(d, "file.txt")
+
+            @udtf(returnType="x: int")
+            class TestUDTF:
+                def __init__(self):
+                    self.path = path
+
+                def eval(self, x: int):
+                    raise Exception("eval error")
+
+                def terminate(self):
+                    with open(self.path, "a") as f:
+                        f.write("terminate")
+
+                def cleanup(self):
+                    with open(self.path, "a") as f:
+                        f.write("cleanup")
+
+            with self.assertRaisesRegex(PythonException, "eval error"):
+                TestUDTF(lit(1)).show()
+
+            with open(path, "r") as f:
+                data = f.read()
+
+            # Only cleanup method should be called.
+            self.assertEqual(data, "cleanup")
+
+    def test_udtf_cleanup_with_exception_in_terminate(self):
+        with tempfile.TemporaryDirectory() as d:
+            path = os.path.join(d, "file.txt")
+
+            @udtf(returnType="x: int")
+            class TestUDTF:
+                def __init__(self):
+                    self.path = path
+
+                def eval(self, x: int):
+                    yield (x,)
+
+                def terminate(self):
+                    raise Exception("terminate error")
+
+                def cleanup(self):
+                    with open(self.path, "a") as f:
+                        f.write("cleanup")
+
+            with self.assertRaisesRegex(PythonException, "terminate error"):
+                TestUDTF(lit(1)).show()
+
+            with open(path, "r") as f:
+                data = f.read()
+
+            self.assertEqual(data, "cleanup")
+
     def test_init_with_exception(self):
         @udtf(returnType="x: int")
         class TestUDTF:
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 4cffb02a64a..a073942adb6 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -757,6 +757,10 @@ def read_udtf(pickleSer, infile, eval_type):
                 return self._udtf.terminate()
             return iter(())
 
+        def cleanup(self) -> None:
+            if hasattr(self._udtf, "cleanup"):
+                self._udtf.cleanup()
+
         def _check_partition_boundaries(self, arguments: list) -> bool:
             result = False
             if len(self._prev_arguments) > 0:
@@ -894,15 +898,19 @@ def read_udtf(pickleSer, infile, eval_type):
         else:
             terminate = None
 
+        cleanup = getattr(udtf, "cleanup") if hasattr(udtf, "cleanup") else 
None
+
         def mapper(_, it):
             try:
                 for a in it:
                     # The eval function yields an iterator. Each element 
produced by this
                     # iterator is a tuple in the form of (pandas.DataFrame, 
arrow_return_type).
                     yield from eval(*[a[o] for o in args_kwargs_offsets])
-            finally:
                 if terminate is not None:
                     yield from terminate()
+            finally:
+                if cleanup is not None:
+                    cleanup()
 
         return mapper, None, ser, ser
 
@@ -977,14 +985,18 @@ def read_udtf(pickleSer, infile, eval_type):
         else:
             terminate = None
 
+        cleanup = getattr(udtf, "cleanup") if hasattr(udtf, "cleanup") else 
None
+
         # Return an iterator of iterators.
         def mapper(_, it):
             try:
                 for a in it:
                     yield eval(*[a[o] for o in args_kwargs_offsets])
-            finally:
                 if terminate is not None:
                     yield terminate()
+            finally:
+                if cleanup is not None:
+                    cleanup()
 
         return mapper, None, ser, ser
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to