This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 1c3eb7e9e0c7 [SPARK-51578][PYTHON][TESTS] Add a helper function to
fail time outed tests
1c3eb7e9e0c7 is described below
commit 1c3eb7e9e0c723ecd3ca66a676a9ea9cd95b0605
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Mon Mar 24 11:08:36 2025 +0900
[SPARK-51578][PYTHON][TESTS] Add a helper function to fail time outed tests
### What changes were proposed in this pull request?
Add a helper function to fail time outed tests
### Why are the changes needed?
to detect which test cause timeout
### Does this PR introduce _any_ user-facing change?
no, test-only
### How was this patch tested?
added tests
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #50337 from zhengruifeng/py_test_timeout.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/tests/test_udf.py | 17 ++++++++++++++++-
python/pyspark/testing/utils.py | 21 +++++++++++++++++++++
python/pyspark/tests/test_util.py | 25 ++++++++++++++++++++++++-
3 files changed, 61 insertions(+), 2 deletions(-)
diff --git a/python/pyspark/sql/tests/test_udf.py
b/python/pyspark/sql/tests/test_udf.py
index efb6ff159ee5..9eda458b8508 100644
--- a/python/pyspark/sql/tests/test_udf.py
+++ b/python/pyspark/sql/tests/test_udf.py
@@ -50,7 +50,7 @@ from pyspark.testing.sqlutils import (
test_compiled,
test_not_compiled_message,
)
-from pyspark.testing.utils import assertDataFrameEqual
+from pyspark.testing.utils import assertDataFrameEqual, timeout
class BaseUDFTestsMixin(object):
@@ -1259,6 +1259,21 @@ class BaseUDFTestsMixin(object):
messageParameters={"arg_name": "evalType", "arg_type": "str"},
)
+ def test_timeout_util_with_udf(self):
+ @udf
+ def f(x):
+ time.sleep(10)
+ return str(x)
+
+ @timeout(1)
+ def timeout_func():
+ self.spark.range(1).select(f("id")).show()
+
+ # causing a py4j.protocol.Py4JNetworkError in pyspark classic
+ # causing a TimeoutError in pyspark connect
+ with self.assertRaises(Exception):
+ timeout_func()
+
class UDFTests(BaseUDFTestsMixin, ReusedSQLTestCase):
@classmethod
diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py
index 780f0f90ac62..f1639d5fdb38 100644
--- a/python/pyspark/testing/utils.py
+++ b/python/pyspark/testing/utils.py
@@ -23,6 +23,7 @@ import difflib
import functools
from decimal import Decimal
from time import time, sleep
+import signal
from typing import (
Any,
Optional,
@@ -122,6 +123,26 @@ def write_int(i):
return struct.pack("!i", i)
+def timeout(seconds):
+ def decorator(func):
+ def handler(signum, frame):
+ raise TimeoutError(f"Function {func.__name__} timed out after
{seconds} seconds")
+
+ def wrapper(*args, **kwargs):
+ signal.alarm(0)
+ signal.signal(signal.SIGALRM, handler)
+ signal.alarm(seconds)
+ try:
+ result = func(*args, **kwargs)
+ finally:
+ signal.alarm(0)
+ return result
+
+ return wrapper
+
+ return decorator
+
+
def eventually(
timeout=30.0,
catch_assertions=False,
diff --git a/python/pyspark/tests/test_util.py
b/python/pyspark/tests/test_util.py
index e1079ca7b4e8..d9bda1e56993 100644
--- a/python/pyspark/tests/test_util.py
+++ b/python/pyspark/tests/test_util.py
@@ -15,6 +15,7 @@
# limitations under the License.
#
import os
+import time
import unittest
from unittest.mock import patch
@@ -23,7 +24,7 @@ from py4j.protocol import Py4JJavaError
from pyspark import keyword_only
from pyspark.util import _parse_memory
from pyspark.loose_version import LooseVersion
-from pyspark.testing.utils import PySparkTestCase, eventually
+from pyspark.testing.utils import PySparkTestCase, eventually, timeout
from pyspark.find_spark_home import _find_spark_home
@@ -87,6 +88,28 @@ class UtilTests(PySparkTestCase):
finally:
os.environ["SPARK_HOME"] = origin
+ def test_timeout_decorator(self):
+ @timeout(1)
+ def timeout_func():
+ time.sleep(10)
+
+ with self.assertRaises(TimeoutError) as e:
+ timeout_func()
+ self.assertTrue("Function timeout_func timed out after 1 seconds" in
str(e.exception))
+
+ def test_timeout_function(self):
+ def timeout_func():
+ time.sleep(10)
+
+ with self.assertRaises(TimeoutError) as e:
+ timeout(1)(timeout_func)()
+ self.assertTrue("Function timeout_func timed out after 1 seconds" in
str(e.exception))
+
+ def test_timeout_lambda(self):
+ with self.assertRaises(TimeoutError) as e:
+ timeout(1)(lambda: time.sleep(10))()
+ self.assertTrue("Function <lambda> timed out after 1 seconds" in
str(e.exception))
+
@eventually(timeout=180, catch_assertions=True)
def test_eventually_decorator(self):
import random
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]