This is an automated email from the ASF dual-hosted git repository.

zhengruifeng pushed a commit to branch branch-4.x
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.x by this push:
     new 0653f98848d4 [SPARK-57065][PYTHON][TEST] Add with_sql_conf decorator 
in SQLTestUtils
0653f98848d4 is described below

commit 0653f98848d43921987996761c41c68751183cf2
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed May 27 16:06:06 2026 +0800

    [SPARK-57065][PYTHON][TEST] Add with_sql_conf decorator in SQLTestUtils
    
    ### What changes were proposed in this pull request?
    
    Adds `with_sql_conf({K: V, ...})` as a class decorator in
    `pyspark.testing.sqlutils` that wraps `setUpClass` / `tearDownClass`
    to apply and revert Spark confs around the inherited fixtures.
    Migrates three parity files as proof of usage:
    
    - `test_parity_column.py` (2 classes)
    - `test_parity_udf_combinations.py` (1 class — also makes the
      previously asymmetric set-without-unset pair symmetric)
    - `test_parity_arrow_python_udf.py` (5 classes)
    
    Usage:
    
    ```python
    from pyspark.testing.sqlutils import with_sql_conf
    
    with_sql_conf({"spark.sql.execution.pythonUDF.arrow.enabled": "true"})
    class MyParityTests(SomeMixin, ReusedConnectTestCase):
        pass
    ```
    
    ### Why are the changes needed?
    
    Many parity tests override `setUpClass` / `tearDownClass` solely to
    set one or two Spark confs and unset them on teardown. Each such
    override is ~10 lines of identical-shape boilerplate, and the
    hand-written `set` / `unset` pairs are easy to drift apart (one such
    asymmetric pair is fixed by this PR). Consolidating to a single
    declarative decorator removes the duplication and provides a common
    seam for migrating the remaining callsites in a follow-up.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No — test infrastructure only.
    
    ### How was this patch tested?
    
    - `WithSqlConfTests` in `python/pyspark/sql/tests/test_utils.py`
      exercises the decorator end-to-end against `ReusedSQLTestCase`,
      asserting both decorator-set keys land via
      `self.spark.conf.get(...)`.
    - `WithSqlConfParityTests` in
      `python/pyspark/sql/tests/connect/test_parity_utils.py` is the
      Connect counterpart, exercising the decorator against
      `ReusedConnectTestCase`.
    - All three migrated parity files were run locally against a real
      session; in particular `ColumnParityTests.test_df_col_resolution_mode`
      and its `WithNonStrictDFColResolution` subclass read the
      decorator-set conf back through `self.spark.conf.get(...)` and pass,
      confirming the decorator composes through layered subclassing.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Generated-by: Claude Code (model: claude-opus-4-7)
    
    Closes #56106 from zhengruifeng/with-class-conf-decorator.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
    (cherry picked from commit 0f13b94bb8f4f1cdff6ad64d3d27c5995561cab9)
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 .../connect/arrow/test_parity_arrow_python_udf.py  | 68 +++-------------------
 .../sql/tests/connect/test_parity_column.py        | 27 +--------
 .../tests/connect/test_parity_udf_combinations.py  |  7 +--
 .../pyspark/sql/tests/connect/test_parity_utils.py | 13 +++++
 python/pyspark/sql/tests/test_utils.py             | 14 ++++-
 python/pyspark/testing/sqlutils.py                 | 46 +++++++++++++++
 6 files changed, 87 insertions(+), 88 deletions(-)

diff --git 
a/python/pyspark/sql/tests/connect/arrow/test_parity_arrow_python_udf.py 
b/python/pyspark/sql/tests/connect/arrow/test_parity_arrow_python_udf.py
index 0a7a54521686..eef455ff357f 100644
--- a/python/pyspark/sql/tests/connect/arrow/test_parity_arrow_python_udf.py
+++ b/python/pyspark/sql/tests/connect/arrow/test_parity_arrow_python_udf.py
@@ -19,35 +19,16 @@ import unittest
 
 from pyspark.sql.tests.connect.test_parity_udf import UDFParityTests
 from pyspark.sql.tests.arrow.test_arrow_python_udf import 
ArrowPythonUDFTestsMixin
+from pyspark.testing.sqlutils import with_sql_conf
 
 
+@with_sql_conf({"spark.sql.execution.pythonUDF.arrow.enabled": "true"})
 class ArrowPythonUDFParityTests(UDFParityTests, ArrowPythonUDFTestsMixin):
-    @classmethod
-    def setUpClass(cls):
-        super().setUpClass()
-        cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", 
"true")
-
-    @classmethod
-    def tearDownClass(cls):
-        try:
-            cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.enabled")
-        finally:
-            super().tearDownClass()
+    pass
 
 
+@with_sql_conf({"spark.sql.legacy.execution.pythonUDF.pandas.conversion.enabled":
 "true"})
 class ArrowPythonUDFParityLegacyTestsMixin(ArrowPythonUDFTestsMixin):
-    @classmethod
-    def setUpClass(cls):
-        super().setUpClass()
-        
cls.spark.conf.set("spark.sql.legacy.execution.pythonUDF.pandas.conversion.enabled",
 "true")
-
-    @classmethod
-    def tearDownClass(cls):
-        try:
-            
cls.spark.conf.unset("spark.sql.legacy.execution.pythonUDF.pandas.conversion.enabled")
-        finally:
-            super().tearDownClass()
-
     @unittest.skip("Duplicate test as it is already tested in 
ArrowPythonUDFLegacyTests.")
     def test_udf_binary_type(self):
         super().test_udf_binary_type(self)
@@ -57,21 +38,8 @@ class 
ArrowPythonUDFParityLegacyTestsMixin(ArrowPythonUDFTestsMixin):
         super().test_udf_binary_type_in_nested_structures(self)
 
 
+@with_sql_conf({"spark.sql.legacy.execution.pythonUDF.pandas.conversion.enabled":
 "false"})
 class ArrowPythonUDFParityNonLegacyTestsMixin(ArrowPythonUDFTestsMixin):
-    @classmethod
-    def setUpClass(cls):
-        super().setUpClass()
-        cls.spark.conf.set(
-            "spark.sql.legacy.execution.pythonUDF.pandas.conversion.enabled", 
"false"
-        )
-
-    @classmethod
-    def tearDownClass(cls):
-        try:
-            
cls.spark.conf.unset("spark.sql.legacy.execution.pythonUDF.pandas.conversion.enabled")
-        finally:
-            super().tearDownClass()
-
     @unittest.skip("Duplicate test as it is already tested in 
ArrowPythonUDFNonLegacyTests.")
     def test_udf_binary_type(self):
         super().test_udf_binary_type(self)
@@ -81,32 +49,14 @@ class 
ArrowPythonUDFParityNonLegacyTestsMixin(ArrowPythonUDFTestsMixin):
         super().test_udf_binary_type_in_nested_structures(self)
 
 
+@with_sql_conf({"spark.sql.execution.pythonUDF.arrow.enabled": "true"})
 class ArrowPythonUDFParityLegacyTests(UDFParityTests, 
ArrowPythonUDFParityLegacyTestsMixin):
-    @classmethod
-    def setUpClass(cls):
-        super().setUpClass()
-        cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", 
"true")
-
-    @classmethod
-    def tearDownClass(cls):
-        try:
-            cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.enabled")
-        finally:
-            super().tearDownClass()
+    pass
 
 
+@with_sql_conf({"spark.sql.execution.pythonUDF.arrow.enabled": "true"})
 class ArrowPythonUDFParityNonLegacyTests(UDFParityTests, 
ArrowPythonUDFParityNonLegacyTestsMixin):
-    @classmethod
-    def setUpClass(cls):
-        super().setUpClass()
-        cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", 
"true")
-
-    @classmethod
-    def tearDownClass(cls):
-        try:
-            cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.enabled")
-        finally:
-            super().tearDownClass()
+    pass
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/sql/tests/connect/test_parity_column.py 
b/python/pyspark/sql/tests/connect/test_parity_column.py
index 3903bb57a375..a3f922127c10 100644
--- a/python/pyspark/sql/tests/connect/test_parity_column.py
+++ b/python/pyspark/sql/tests/connect/test_parity_column.py
@@ -19,21 +19,11 @@ import unittest
 
 from pyspark.sql.tests.test_column import ColumnTestsMixin
 from pyspark.testing.connectutils import ReusedConnectTestCase
+from pyspark.testing.sqlutils import with_sql_conf
 
 
+@with_sql_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": "true"})
 class ColumnParityTests(ColumnTestsMixin, ReusedConnectTestCase):
-    @classmethod
-    def setUpClass(cls):
-        super().setUpClass()
-        
cls.spark.conf.set("spark.sql.analyzer.strictDataFrameColumnResolution", "true")
-
-    @classmethod
-    def tearDownClass(cls):
-        try:
-            
cls.spark.conf.unset("spark.sql.analyzer.strictDataFrameColumnResolution")
-        finally:
-            super().tearDownClass()
-
     @unittest.skip("Requires JVM access.")
     def test_validate_column_types(self):
         super().test_validate_column_types()
@@ -45,23 +35,12 @@ class ColumnParityTests(ColumnTestsMixin, 
ReusedConnectTestCase):
         )
 
 
+@with_sql_conf({"spark.sql.analyzer.strictDataFrameColumnResolution": "false"})
 class ColumnParityTestsWithNonStrictDFColResolution(ColumnParityTests):
     """Re-run the Column parity tests with
     `spark.sql.analyzer.strictDataFrameColumnResolution=false` to exercise the
     name-based fallback path for tagged UnresolvedAttributes."""
 
-    @classmethod
-    def setUpClass(cls):
-        super().setUpClass()
-        
cls.spark.conf.set("spark.sql.analyzer.strictDataFrameColumnResolution", 
"false")
-
-    @classmethod
-    def tearDownClass(cls):
-        try:
-            
cls.spark.conf.unset("spark.sql.analyzer.strictDataFrameColumnResolution")
-        finally:
-            super().tearDownClass()
-
     def test_df_col_resolution_mode(self):
         self.assertEqual(
             
self.spark.conf.get("spark.sql.analyzer.strictDataFrameColumnResolution"),
diff --git a/python/pyspark/sql/tests/connect/test_parity_udf_combinations.py 
b/python/pyspark/sql/tests/connect/test_parity_udf_combinations.py
index b1e8d473201d..ae8ebb6368f9 100644
--- a/python/pyspark/sql/tests/connect/test_parity_udf_combinations.py
+++ b/python/pyspark/sql/tests/connect/test_parity_udf_combinations.py
@@ -18,13 +18,12 @@
 
 from pyspark.sql.tests.test_udf_combinations import UDFCombinationsTestsMixin
 from pyspark.testing.connectutils import ReusedConnectTestCase
+from pyspark.testing.sqlutils import with_sql_conf
 
 
+@with_sql_conf({"spark.sql.execution.pythonUDF.arrow.enabled": "false"})
 class UDFCombinationsParityTests(UDFCombinationsTestsMixin, 
ReusedConnectTestCase):
-    @classmethod
-    def setUpClass(cls):
-        ReusedConnectTestCase.setUpClass()
-        cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", 
"false")
+    pass
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/sql/tests/connect/test_parity_utils.py 
b/python/pyspark/sql/tests/connect/test_parity_utils.py
index 521b6082cf22..ed03243b8119 100644
--- a/python/pyspark/sql/tests/connect/test_parity_utils.py
+++ b/python/pyspark/sql/tests/connect/test_parity_utils.py
@@ -16,6 +16,7 @@
 #
 
 from pyspark.testing.connectutils import ReusedConnectTestCase
+from pyspark.testing.sqlutils import with_sql_conf
 from pyspark.sql.tests.test_utils import UtilsTestsMixin
 
 
@@ -23,6 +24,18 @@ class UtilsParityTests(UtilsTestsMixin, 
ReusedConnectTestCase):
     pass
 
 
+@with_sql_conf(
+    {
+        "spark.sql.test.with_sql_conf.key1": "v1",
+        "spark.sql.test.with_sql_conf.key2": "v2",
+    }
+)
+class WithSqlConfParityTests(ReusedConnectTestCase):
+    def test_confs_applied(self):
+        
self.assertEqual(self.spark.conf.get("spark.sql.test.with_sql_conf.key1"), "v1")
+        
self.assertEqual(self.spark.conf.get("spark.sql.test.with_sql_conf.key2"), "v2")
+
+
 if __name__ == "__main__":
     from pyspark.testing import main
 
diff --git a/python/pyspark/sql/tests/test_utils.py 
b/python/pyspark/sql/tests/test_utils.py
index 3454a5f8b66c..6e5a0ee69287 100644
--- a/python/pyspark/sql/tests/test_utils.py
+++ b/python/pyspark/sql/tests/test_utils.py
@@ -37,7 +37,7 @@ from pyspark.testing.utils import (
     have_pandas,
     have_pyarrow,
 )
-from pyspark.testing.sqlutils import ReusedSQLTestCase
+from pyspark.testing.sqlutils import ReusedSQLTestCase, with_sql_conf
 from pyspark.sql import Row
 import pyspark.sql.functions as F
 from pyspark.sql.functions import to_date, unix_timestamp, from_unixtime
@@ -1876,6 +1876,18 @@ class UtilsTests(UtilsTestsMixin, ReusedSQLTestCase):
     pass
 
 
+@with_sql_conf(
+    {
+        "spark.sql.test.with_sql_conf.key1": "v1",
+        "spark.sql.test.with_sql_conf.key2": "v2",
+    }
+)
+class WithSqlConfTests(ReusedSQLTestCase):
+    def test_confs_applied(self):
+        
self.assertEqual(self.spark.conf.get("spark.sql.test.with_sql_conf.key1"), "v1")
+        
self.assertEqual(self.spark.conf.get("spark.sql.test.with_sql_conf.key2"), "v2")
+
+
 if __name__ == "__main__":
     from pyspark.testing import main
 
diff --git a/python/pyspark/testing/sqlutils.py 
b/python/pyspark/testing/sqlutils.py
index e9d99dd3d2d6..f91cb9f85e0e 100644
--- a/python/pyspark/testing/sqlutils.py
+++ b/python/pyspark/testing/sqlutils.py
@@ -170,6 +170,52 @@ except Exception as e:
 test_compiled = not test_not_compiled_message
 
 
+def with_sql_conf(pairs):
+    """
+    Class decorator that sets the given Spark confs in ``setUpClass`` and 
unsets them in
+    ``tearDownClass``, around the calls to the inherited 
``setUpClass``/``tearDownClass``.
+
+    The decorated class is expected to expose ``cls.spark`` (a 
``SparkSession``) after the
+    inherited ``setUpClass`` runs — typically by extending 
``ReusedSQLTestCase`` or
+    ``ReusedConnectTestCase``.
+
+    If the decorated class defines its own ``setUpClass``/``tearDownClass``, 
those are
+    called instead of the inherited ones; the conf set/unset happens after 
setUpClass
+    and before tearDownClass respectively.
+    """
+    assert isinstance(pairs, dict), "pairs should be a dictionary."
+
+    def decorator(cls):
+        own_setup = cls.__dict__.get("setUpClass")
+        own_teardown = cls.__dict__.get("tearDownClass")
+
+        @classmethod
+        def setUpClass(klass):
+            if own_setup is not None:
+                own_setup.__func__(klass)
+            else:
+                super(cls, klass).setUpClass()
+            for key, value in pairs.items():
+                klass.spark.conf.set(key, value)
+
+        @classmethod
+        def tearDownClass(klass):
+            try:
+                for key in pairs:
+                    klass.spark.conf.unset(key)
+            finally:
+                if own_teardown is not None:
+                    own_teardown.__func__(klass)
+                else:
+                    super(cls, klass).tearDownClass()
+
+        cls.setUpClass = setUpClass
+        cls.tearDownClass = tearDownClass
+        return cls
+
+    return decorator
+
+
 class SQLTestUtils:
     """
     This util assumes the instance of this to have 'spark' attribute, having a 
spark session.


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to