This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 1c3eb7e9e0c7 [SPARK-51578][PYTHON][TESTS] Add a helper function to 
fail time outed tests
1c3eb7e9e0c7 is described below

commit 1c3eb7e9e0c723ecd3ca66a676a9ea9cd95b0605
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Mon Mar 24 11:08:36 2025 +0900

    [SPARK-51578][PYTHON][TESTS] Add a helper function to fail time outed tests
    
    ### What changes were proposed in this pull request?
    Add a helper function to fail time outed tests
    
    ### Why are the changes needed?
    to detect which test cause timeout
    
    ### Does this PR introduce _any_ user-facing change?
    no, test-only
    
    ### How was this patch tested?
    added tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #50337 from zhengruifeng/py_test_timeout.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/sql/tests/test_udf.py | 17 ++++++++++++++++-
 python/pyspark/testing/utils.py      | 21 +++++++++++++++++++++
 python/pyspark/tests/test_util.py    | 25 ++++++++++++++++++++++++-
 3 files changed, 61 insertions(+), 2 deletions(-)

diff --git a/python/pyspark/sql/tests/test_udf.py 
b/python/pyspark/sql/tests/test_udf.py
index efb6ff159ee5..9eda458b8508 100644
--- a/python/pyspark/sql/tests/test_udf.py
+++ b/python/pyspark/sql/tests/test_udf.py
@@ -50,7 +50,7 @@ from pyspark.testing.sqlutils import (
     test_compiled,
     test_not_compiled_message,
 )
-from pyspark.testing.utils import assertDataFrameEqual
+from pyspark.testing.utils import assertDataFrameEqual, timeout
 
 
 class BaseUDFTestsMixin(object):
@@ -1259,6 +1259,21 @@ class BaseUDFTestsMixin(object):
             messageParameters={"arg_name": "evalType", "arg_type": "str"},
         )
 
+    def test_timeout_util_with_udf(self):
+        @udf
+        def f(x):
+            time.sleep(10)
+            return str(x)
+
+        @timeout(1)
+        def timeout_func():
+            self.spark.range(1).select(f("id")).show()
+
+        # causing a py4j.protocol.Py4JNetworkError in pyspark classic
+        # causing a TimeoutError in pyspark connect
+        with self.assertRaises(Exception):
+            timeout_func()
+
 
 class UDFTests(BaseUDFTestsMixin, ReusedSQLTestCase):
     @classmethod
diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py
index 780f0f90ac62..f1639d5fdb38 100644
--- a/python/pyspark/testing/utils.py
+++ b/python/pyspark/testing/utils.py
@@ -23,6 +23,7 @@ import difflib
 import functools
 from decimal import Decimal
 from time import time, sleep
+import signal
 from typing import (
     Any,
     Optional,
@@ -122,6 +123,26 @@ def write_int(i):
     return struct.pack("!i", i)
 
 
+def timeout(seconds):
+    def decorator(func):
+        def handler(signum, frame):
+            raise TimeoutError(f"Function {func.__name__} timed out after 
{seconds} seconds")
+
+        def wrapper(*args, **kwargs):
+            signal.alarm(0)
+            signal.signal(signal.SIGALRM, handler)
+            signal.alarm(seconds)
+            try:
+                result = func(*args, **kwargs)
+            finally:
+                signal.alarm(0)
+            return result
+
+        return wrapper
+
+    return decorator
+
+
 def eventually(
     timeout=30.0,
     catch_assertions=False,
diff --git a/python/pyspark/tests/test_util.py 
b/python/pyspark/tests/test_util.py
index e1079ca7b4e8..d9bda1e56993 100644
--- a/python/pyspark/tests/test_util.py
+++ b/python/pyspark/tests/test_util.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 #
 import os
+import time
 import unittest
 from unittest.mock import patch
 
@@ -23,7 +24,7 @@ from py4j.protocol import Py4JJavaError
 from pyspark import keyword_only
 from pyspark.util import _parse_memory
 from pyspark.loose_version import LooseVersion
-from pyspark.testing.utils import PySparkTestCase, eventually
+from pyspark.testing.utils import PySparkTestCase, eventually, timeout
 from pyspark.find_spark_home import _find_spark_home
 
 
@@ -87,6 +88,28 @@ class UtilTests(PySparkTestCase):
         finally:
             os.environ["SPARK_HOME"] = origin
 
+    def test_timeout_decorator(self):
+        @timeout(1)
+        def timeout_func():
+            time.sleep(10)
+
+        with self.assertRaises(TimeoutError) as e:
+            timeout_func()
+        self.assertTrue("Function timeout_func timed out after 1 seconds" in 
str(e.exception))
+
+    def test_timeout_function(self):
+        def timeout_func():
+            time.sleep(10)
+
+        with self.assertRaises(TimeoutError) as e:
+            timeout(1)(timeout_func)()
+        self.assertTrue("Function timeout_func timed out after 1 seconds" in 
str(e.exception))
+
+    def test_timeout_lambda(self):
+        with self.assertRaises(TimeoutError) as e:
+            timeout(1)(lambda: time.sleep(10))()
+        self.assertTrue("Function <lambda> timed out after 1 seconds" in 
str(e.exception))
+
     @eventually(timeout=180, catch_assertions=True)
     def test_eventually_decorator(self):
         import random


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to