This is an automated email from the ASF dual-hosted git repository.
zhengruifeng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 0f13b94bb8f4 [SPARK-57065][PYTHON][TEST] Add with_sql_conf decorator
in SQLTestUtils
0f13b94bb8f4 is described below
commit 0f13b94bb8f4f1cdff6ad64d3d27c5995561cab9
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed May 27 16:06:06 2026 +0800
[SPARK-57065][PYTHON][TEST] Add with_sql_conf decorator in SQLTestUtils
### What changes were proposed in this pull request?
Adds `with_sql_conf({K: V, ...})` as a class decorator in
`pyspark.testing.sqlutils` that wraps `setUpClass` / `tearDownClass`
to apply and revert Spark confs around the inherited fixtures.
Migrates three parity files as proof of usage:
- `test_parity_column.py` (2 classes)
- `test_parity_udf_combinations.py` (1 class — also makes the
previously asymmetric set-without-unset pair symmetric)
- `test_parity_arrow_python_udf.py` (5 classes)
Usage:
```python
from pyspark.testing.sqlutils import with_sql_conf
with_sql_conf({"spark.sql.execution.pythonUDF.arrow.enabled": "true"})
class MyParityTests(SomeMixin, ReusedConnectTestCase):
pass
```
### Why are the changes needed?
Many parity tests override `setUpClass` / `tearDownClass` solely to
set one or two Spark confs and unset them on teardown. Each such
override is ~10 lines of identical-shape boilerplate, and the
hand-written `set` / `unset` pairs are easy to drift apart (one such
asymmetric pair is fixed by this PR). Consolidating to a single
declarative decorator removes the duplication and provides a common
seam for migrating the remaining callsites in a follow-up.
### Does this PR introduce _any_ user-facing change?
No — test infrastructure only.
### How was this patch tested?
- `WithSqlConfTests` in `python/pyspark/sql/tests/test_utils.py`
exercises the decorator end-to-end against `ReusedSQLTestCase`,
asserting both decorator-set keys land via
`self.spark.conf.get(...)`.
- `WithSqlConfParityTests` in
`python/pyspark/sql/tests/connect/test_parity_utils.py` is the
Connect counterpart, exercising the decorator against
`ReusedConnectTestCase`.
- All three migrated parity files were run locally against a real
session; in particular `ColumnParityTests.test_df_col_resolution_mode`
and its `WithNonStrictDFColResolution` subclass read the
decorator-set conf back through `self.spark.conf.get(...)` and pass,
confirming the decorator composes through layered subclassing.
### Was this patch authored or co-authored using generative AI tooling?
Generated-by: Claude Code (model: claude-opus-4-7)
Closes #56106 from zhengruifeng/with-class-conf-decorator.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../connect/arrow/test_parity_arrow_python_udf.py | 68 +++-------------------
.../sql/tests/connect/test_parity_column.py | 27 +--------
.../tests/connect/test_parity_udf_combinations.py | 7 +--
.../pyspark/sql/tests/connect/test_parity_utils.py | 13 +++++
python/pyspark/sql/tests/test_utils.py | 14 ++++-
python/pyspark/testing/sqlutils.py | 46 +++++++++++++++
6 files changed, 87 insertions(+), 88 deletions(-)
diff --git
a/python/pyspark/sql/tests/connect/arrow/test_parity_arrow_python_udf.py
b/python/pyspark/sql/tests/connect/arrow/test_parity_arrow_python_udf.py
index 0a7a54521686..eef455ff357f 100644
--- a/python/pyspark/sql/tests/connect/arrow/test_parity_arrow_python_udf.py
+++ b/python/pyspark/sql/tests/connect/arrow/test_parity_arrow_python_udf.py
@@ -19,35 +19,16 @@ import unittest
from pyspark.sql.tests.connect.test_parity_udf import UDFParityTests
from pyspark.sql.tests.arrow.test_arrow_python_udf import
ArrowPythonUDFTestsMixin
+from pyspark.testing.sqlutils import with_sql_conf
+@with_sql_conf({"spark.sql.execution.pythonUDF.arrow.enabled": "true"})
class ArrowPythonUDFParityTests(UDFParityTests, ArrowPythonUDFTestsMixin):
- @classmethod
- def setUpClass(cls):
- super().setUpClass()
- cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled",
"true")
-
- @classmethod
- def tearDownClass(cls):
- try:
- cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.enabled")
- finally:
- super().tearDownClass()
+ pass
+@with_sql_conf({"spark.sql.legacy.execution.pythonUDF.pandas.conversion.enabled":
"true"})
class ArrowPythonUDFParityLegacyTestsMixin(ArrowPythonUDFTestsMixin):
- @classmethod
- def setUpClass(cls):
- super().setUpClass()
-
cls.spark.conf.set("spark.sql.legacy.execution.pythonUDF.pandas.conversion.enabled",
"true")
-
- @classmethod
- def tearDownClass(cls):
- try:
-
cls.spark.conf.unset("spark.sql.legacy.execution.pythonUDF.pandas.conversion.enabled")
- finally:
- super().tearDownClass()
-
@unittest.skip("Duplicate test as it is already tested in
ArrowPythonUDFLegacyTests.")
def test_udf_binary_type(self):
super().test_udf_binary_type(self)
@@ -57,21 +38,8 @@ class
ArrowPythonUDFParityLegacyTestsMixin(ArrowPythonUDFTestsMixin):
super().test_udf_binary_type_in_nested_structures(self)
+@with_sql_conf({"spark.sql.legacy.execution.pythonUDF.pandas.conversion.enabled":
"false"})
class ArrowPythonUDFParityNonLegacyTestsMixin(ArrowPythonUDFTestsMixin):
- @classmethod
- def setUpClass(cls):
- super().setUpClass()
- cls.spark.conf.set(
- "spark.sql.legacy.execution.pythonUDF.pandas.conversion.enabled",
"false"
- )
-
- @classmethod
- def tearDownClass(cls):
- try:
-
cls.spark.conf.unset("spark.sql.legacy.execution.pythonUDF.pandas.conversion.enabled")
- finally:
- super().tearDownClass()
-
@unittest.skip("Duplicate test as it is already tested in
ArrowPythonUDFNonLegacyTests.")
def test_udf_binary_type(self):
super().test_udf_binary_type(self)
@@ -81,32 +49,14 @@ class
ArrowPythonUDFParityNonLegacyTestsMixin(ArrowPythonUDFTestsMixin):
super().test_udf_binary_type_in_nested_structures(self)
+@with_sql_conf({"spark.sql.execution.pythonUDF.arrow.enabled": "true"})
class ArrowPythonUDFParityLegacyTests(UDFParityTests,
ArrowPythonUDFParityLegacyTestsMixin):
- @classmethod
- def setUpClass(cls):
- super().setUpClass()
- cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled",
"true")
-
- @classmethod
- def tearDownClass(cls):
- try:
- cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.enabled")
- finally:
- super().tearDownClass()
+ pass
+@with_sql_conf({"spark.sql.execution.pythonUDF.arrow.enabled": "true"})
class ArrowPythonUDFParityNonLegacyTests(UDFParityTests,
ArrowPythonUDFParityNonLegacyTestsMixin):
- @classmethod
- def setUpClass(cls):
- super().setUpClass()
- cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled",
"true")
-
- @classmethod
- def tearDownClass(cls):
- try:
- cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.enabled")
- finally:
- super().tearDownClass()
+ pass
if __name__ == "__main__":
diff --git a/python/pyspark/sql/tests/connect/test_parity_column.py
b/python/pyspark/sql/tests/connect/test_parity_column.py
index 3903bb57a375..a3f922127c10 100644
--- a/python/pyspark/sql/tests/connect/test_parity_column.py
+++ b/python/pyspark/sql/tests/connect/test_parity_column.py
@@ -19,21 +19,11 @@ import unittest
from pyspark.sql.tests.test_column import ColumnTestsMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
+from pyspark.testing.sqlutils import with_sql_conf
+@with_sql_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": "true"})
class ColumnParityTests(ColumnTestsMixin, ReusedConnectTestCase):
- @classmethod
- def setUpClass(cls):
- super().setUpClass()
-
cls.spark.conf.set("spark.sql.analyzer.strictDataFrameColumnResolution", "true")
-
- @classmethod
- def tearDownClass(cls):
- try:
-
cls.spark.conf.unset("spark.sql.analyzer.strictDataFrameColumnResolution")
- finally:
- super().tearDownClass()
-
@unittest.skip("Requires JVM access.")
def test_validate_column_types(self):
super().test_validate_column_types()
@@ -45,23 +35,12 @@ class ColumnParityTests(ColumnTestsMixin,
ReusedConnectTestCase):
)
+@with_sql_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": "false"})
class ColumnParityTestsWithNonStrictDFColResolution(ColumnParityTests):
"""Re-run the Column parity tests with
`spark.sql.analyzer.strictDataFrameColumnResolution=false` to exercise the
name-based fallback path for tagged UnresolvedAttributes."""
- @classmethod
- def setUpClass(cls):
- super().setUpClass()
-
cls.spark.conf.set("spark.sql.analyzer.strictDataFrameColumnResolution",
"false")
-
- @classmethod
- def tearDownClass(cls):
- try:
-
cls.spark.conf.unset("spark.sql.analyzer.strictDataFrameColumnResolution")
- finally:
- super().tearDownClass()
-
def test_df_col_resolution_mode(self):
self.assertEqual(
self.spark.conf.get("spark.sql.analyzer.strictDataFrameColumnResolution"),
diff --git a/python/pyspark/sql/tests/connect/test_parity_udf_combinations.py
b/python/pyspark/sql/tests/connect/test_parity_udf_combinations.py
index b1e8d473201d..ae8ebb6368f9 100644
--- a/python/pyspark/sql/tests/connect/test_parity_udf_combinations.py
+++ b/python/pyspark/sql/tests/connect/test_parity_udf_combinations.py
@@ -18,13 +18,12 @@
from pyspark.sql.tests.test_udf_combinations import UDFCombinationsTestsMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
+from pyspark.testing.sqlutils import with_sql_conf
+@with_sql_conf({"spark.sql.execution.pythonUDF.arrow.enabled": "false"})
class UDFCombinationsParityTests(UDFCombinationsTestsMixin,
ReusedConnectTestCase):
- @classmethod
- def setUpClass(cls):
- ReusedConnectTestCase.setUpClass()
- cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled",
"false")
+ pass
if __name__ == "__main__":
diff --git a/python/pyspark/sql/tests/connect/test_parity_utils.py
b/python/pyspark/sql/tests/connect/test_parity_utils.py
index 521b6082cf22..ed03243b8119 100644
--- a/python/pyspark/sql/tests/connect/test_parity_utils.py
+++ b/python/pyspark/sql/tests/connect/test_parity_utils.py
@@ -16,6 +16,7 @@
#
from pyspark.testing.connectutils import ReusedConnectTestCase
+from pyspark.testing.sqlutils import with_sql_conf
from pyspark.sql.tests.test_utils import UtilsTestsMixin
@@ -23,6 +24,18 @@ class UtilsParityTests(UtilsTestsMixin,
ReusedConnectTestCase):
pass
+@with_sql_conf(
+ {
+ "spark.sql.test.with_sql_conf.key1": "v1",
+ "spark.sql.test.with_sql_conf.key2": "v2",
+ }
+)
+class WithSqlConfParityTests(ReusedConnectTestCase):
+ def test_confs_applied(self):
+
self.assertEqual(self.spark.conf.get("spark.sql.test.with_sql_conf.key1"), "v1")
+
self.assertEqual(self.spark.conf.get("spark.sql.test.with_sql_conf.key2"), "v2")
+
+
if __name__ == "__main__":
from pyspark.testing import main
diff --git a/python/pyspark/sql/tests/test_utils.py
b/python/pyspark/sql/tests/test_utils.py
index 3454a5f8b66c..6e5a0ee69287 100644
--- a/python/pyspark/sql/tests/test_utils.py
+++ b/python/pyspark/sql/tests/test_utils.py
@@ -37,7 +37,7 @@ from pyspark.testing.utils import (
have_pandas,
have_pyarrow,
)
-from pyspark.testing.sqlutils import ReusedSQLTestCase
+from pyspark.testing.sqlutils import ReusedSQLTestCase, with_sql_conf
from pyspark.sql import Row
import pyspark.sql.functions as F
from pyspark.sql.functions import to_date, unix_timestamp, from_unixtime
@@ -1876,6 +1876,18 @@ class UtilsTests(UtilsTestsMixin, ReusedSQLTestCase):
pass
+@with_sql_conf(
+ {
+ "spark.sql.test.with_sql_conf.key1": "v1",
+ "spark.sql.test.with_sql_conf.key2": "v2",
+ }
+)
+class WithSqlConfTests(ReusedSQLTestCase):
+ def test_confs_applied(self):
+
self.assertEqual(self.spark.conf.get("spark.sql.test.with_sql_conf.key1"), "v1")
+
self.assertEqual(self.spark.conf.get("spark.sql.test.with_sql_conf.key2"), "v2")
+
+
if __name__ == "__main__":
from pyspark.testing import main
diff --git a/python/pyspark/testing/sqlutils.py
b/python/pyspark/testing/sqlutils.py
index e9d99dd3d2d6..f91cb9f85e0e 100644
--- a/python/pyspark/testing/sqlutils.py
+++ b/python/pyspark/testing/sqlutils.py
@@ -170,6 +170,52 @@ except Exception as e:
test_compiled = not test_not_compiled_message
+def with_sql_conf(pairs):
+ """
+ Class decorator that sets the given Spark confs in ``setUpClass`` and
unsets them in
+ ``tearDownClass``, around the calls to the inherited
``setUpClass``/``tearDownClass``.
+
+ The decorated class is expected to expose ``cls.spark`` (a
``SparkSession``) after the
+ inherited ``setUpClass`` runs — typically by extending
``ReusedSQLTestCase`` or
+ ``ReusedConnectTestCase``.
+
+ If the decorated class defines its own ``setUpClass``/``tearDownClass``,
those are
+ called instead of the inherited ones; the conf set/unset happens after
setUpClass
+ and before tearDownClass respectively.
+ """
+ assert isinstance(pairs, dict), "pairs should be a dictionary."
+
+ def decorator(cls):
+ own_setup = cls.__dict__.get("setUpClass")
+ own_teardown = cls.__dict__.get("tearDownClass")
+
+ @classmethod
+ def setUpClass(klass):
+ if own_setup is not None:
+ own_setup.__func__(klass)
+ else:
+ super(cls, klass).setUpClass()
+ for key, value in pairs.items():
+ klass.spark.conf.set(key, value)
+
+ @classmethod
+ def tearDownClass(klass):
+ try:
+ for key in pairs:
+ klass.spark.conf.unset(key)
+ finally:
+ if own_teardown is not None:
+ own_teardown.__func__(klass)
+ else:
+ super(cls, klass).tearDownClass()
+
+ cls.setUpClass = setUpClass
+ cls.tearDownClass = tearDownClass
+ return cls
+
+ return decorator
+
+
class SQLTestUtils:
"""
This util assumes the instance of this to have 'spark' attribute, having a
spark session.
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]