This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 481f9866f5f5 [SPARK-55303][PYTHON][TESTS] Extract GoldenFileTestMixin 
for type coercion golden file tests
481f9866f5f5 is described below

commit 481f9866f5f5a41ad08e72f2a3820b0620927580
Author: Yicong-Huang <[email protected]>
AuthorDate: Thu Feb 5 11:06:44 2026 +0800

    [SPARK-55303][PYTHON][TESTS] Extract GoldenFileTestMixin for type coercion 
golden file tests
    
    ### What changes were proposed in this pull request?
    
    Extract common golden file testing utilities into `GoldenFileTestMixin` in 
`python/pyspark/testing/goldenutils.py`, and simplify the four type coercion 
test files to use this mixin.
    
    ### Why are the changes needed?
    
    Reduce duplicated code across four test files and provide a reusable 
framework for future golden file tests.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Regenerated all golden files with `SPARK_GENERATE_GOLDEN_FILES=1` and 
verified tests pass.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #54084 from 
Yicong-Huang/SPARK-55303/refactor/extract-golden-file-test-util.
    
    Authored-by: Yicong-Huang <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 .../tests/coercion/test_pandas_udf_input_type.py   |  62 +----
 .../tests/coercion/test_pandas_udf_return_type.py  |  75 +-----
 .../tests/coercion/test_python_udf_input_type.py   |  62 +----
 .../tests/coercion/test_python_udf_return_type.py  |  64 +-----
 python/pyspark/testing/goldenutils.py              | 254 +++++++++++++++++++++
 5 files changed, 297 insertions(+), 220 deletions(-)

diff --git a/python/pyspark/sql/tests/coercion/test_pandas_udf_input_type.py 
b/python/pyspark/sql/tests/coercion/test_pandas_udf_input_type.py
index cd3880e6c9dd..64377f2df698 100644
--- a/python/pyspark/sql/tests/coercion/test_pandas_udf_input_type.py
+++ b/python/pyspark/sql/tests/coercion/test_pandas_udf_input_type.py
@@ -18,7 +18,6 @@
 from decimal import Decimal
 import datetime
 import os
-import time
 import unittest
 
 from pyspark.sql.functions import pandas_udf
@@ -51,6 +50,7 @@ from pyspark.testing.utils import (
     numpy_requirement_message,
 )
 from pyspark.testing.sqlutils import ReusedSQLTestCase
+from pyspark.testing.goldenutils import GoldenFileTestMixin
 
 if have_numpy:
     import numpy as np
@@ -73,29 +73,7 @@ if have_pandas:
     or LooseVersion(np.__version__) < LooseVersion("2.0.0"),
     pandas_requirement_message or pyarrow_requirement_message or 
numpy_requirement_message,
 )
-class PandasUDFInputTypeTests(ReusedSQLTestCase):
-    @classmethod
-    def setUpClass(cls):
-        super().setUpClass()
-
-        # Synchronize default timezone between Python and Java
-        cls.tz_prev = os.environ.get("TZ", None)  # save current tz if set
-        tz = "America/Los_Angeles"
-        os.environ["TZ"] = tz
-        time.tzset()
-
-        cls.sc.environment["TZ"] = tz
-        cls.spark.conf.set("spark.sql.session.timeZone", tz)
-
-    @classmethod
-    def tearDownClass(cls):
-        del os.environ["TZ"]
-        if cls.tz_prev is not None:
-            os.environ["TZ"] = cls.tz_prev
-        time.tzset()
-
-        super().tearDownClass()
-
+class PandasUDFInputTypeTests(GoldenFileTestMixin, ReusedSQLTestCase):
     @property
     def prefix(self):
         return "golden_pandas_udf_input_type_coercion"
@@ -265,27 +243,20 @@ class PandasUDFInputTypeTests(ReusedSQLTestCase):
         self._compare_or_generate_golden(golden_file, test_name)
 
     def _compare_or_generate_golden(self, golden_file, test_name):
-        testing = os.environ.get("SPARK_GENERATE_GOLDEN_FILES", "?") != "1"
+        generating = self.is_generating_golden()
 
         golden_csv = os.path.join(os.path.dirname(__file__), 
f"{golden_file}.csv")
         golden_md = os.path.join(os.path.dirname(__file__), 
f"{golden_file}.md")
 
         golden = None
-        if testing:
-            golden = pd.read_csv(
-                golden_csv,
-                sep="\t",
-                index_col=0,
-                dtype="str",
-                na_filter=False,
-                engine="python",
-            )
+        if not generating:
+            golden = self.load_golden_csv(golden_csv)
 
         results = []
-        for idx, (test_name, spark_type, data_func) in 
enumerate(self.test_cases):
+        for idx, (case_name, spark_type, data_func) in 
enumerate(self.test_cases):
             input_df = data_func(spark_type).repartition(1)
             input_data = [row["value"] for row in input_df.collect()]
-            result = [test_name, spark_type.simpleString(), str(input_data)]
+            result = [case_name, self.repr_type(spark_type), str(input_data)]
 
             try:
 
@@ -319,15 +290,15 @@ class PandasUDFInputTypeTests(ReusedSQLTestCase):
                 result.append(f"✗ {str(e)}")
 
             # Clean up exception message to remove newlines and extra 
whitespace
-            result = [r.replace("\n", " ").replace("\r", " ").replace("\t", " 
") for r in result]
+            result = [self.clean_result(r) for r in result]
 
             error_msg = None
-            if testing and result != list(golden.iloc[idx]):
+            if not generating and result != list(golden.iloc[idx]):
                 error_msg = f"line mismatch: expects {list(golden.iloc[idx])} 
but got {result}"
 
             results.append((result, error_msg))
 
-        if testing:
+        if not generating:
             errs = []
             for _, err in results:
                 if err is not None:
@@ -340,18 +311,7 @@ class PandasUDFInputTypeTests(ReusedSQLTestCase):
                 columns=["Test Case", "Spark Type", "Spark Value", "Python 
Type", "Python Value"],
             )
 
-            # generating the CSV file as the golden file
-            new_golden.to_csv(golden_csv, sep="\t", header=True, index=True)
-
-            try:
-                # generating the GitHub flavored Markdown file
-                # package tabulate is required
-                new_golden.to_markdown(golden_md, index=True, 
tablefmt="github")
-            except Exception as e:
-                print(
-                    f"{test_name} return type coercion: "
-                    f"fail to write the markdown file due to {e}!"
-                )
+            self.save_golden(new_golden, golden_csv, golden_md)
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/sql/tests/coercion/test_pandas_udf_return_type.py 
b/python/pyspark/sql/tests/coercion/test_pandas_udf_return_type.py
index f0c81adb1010..71fd7b14daa4 100644
--- a/python/pyspark/sql/tests/coercion/test_pandas_udf_return_type.py
+++ b/python/pyspark/sql/tests/coercion/test_pandas_udf_return_type.py
@@ -19,7 +19,6 @@ import concurrent.futures
 from decimal import Decimal
 import itertools
 import os
-import time
 import unittest
 
 from pyspark.sql.functions import pandas_udf
@@ -51,6 +50,7 @@ from pyspark.testing.utils import (
     numpy_requirement_message,
 )
 from pyspark.testing.sqlutils import ReusedSQLTestCase
+from pyspark.testing.goldenutils import GoldenFileTestMixin
 
 if have_numpy:
     import numpy as np
@@ -73,36 +73,14 @@ if have_pandas:
     or LooseVersion(np.__version__) < LooseVersion("2.0.0"),
     pandas_requirement_message or pyarrow_requirement_message or 
numpy_requirement_message,
 )
-class PandasUDFReturnTypeTests(ReusedSQLTestCase):
-    @classmethod
-    def setUpClass(cls):
-        super().setUpClass()
-
-        # Synchronize default timezone between Python and Java
-        cls.tz_prev = os.environ.get("TZ", None)  # save current tz if set
-        tz = "America/Los_Angeles"
-        os.environ["TZ"] = tz
-        time.tzset()
-
-        cls.sc.environment["TZ"] = tz
-        cls.spark.conf.set("spark.sql.session.timeZone", tz)
-
-    @classmethod
-    def tearDownClass(cls):
-        del os.environ["TZ"]
-        if cls.tz_prev is not None:
-            os.environ["TZ"] = cls.tz_prev
-        time.tzset()
-
-        super().tearDownClass()
-
+class PandasUDFReturnTypeTests(GoldenFileTestMixin, ReusedSQLTestCase):
     @property
     def prefix(self):
         return "golden_pandas_udf_return_type_coercion"
 
     @property
     def test_data(self):
-        data = [
+        return [
             [None, None],
             [True, False],
             list("ab"),
@@ -131,7 +109,6 @@ class PandasUDFReturnTypeTests(ReusedSQLTestCase):
             pd.Categorical(["A", "B"]),
             pd.DataFrame({"_1": [1, 2]}),
         ]
-        return data
 
     @property
     def test_types(self):
@@ -153,19 +130,9 @@ class PandasUDFReturnTypeTests(ReusedSQLTestCase):
             StructType([StructField("_1", IntegerType())]),
         ]
 
-    def repr_type(self, spark_type):
-        return spark_type.simpleString()
-
     def repr_value(self, value):
-        v_str = value.to_json() if isinstance(value, pd.DataFrame) else 
str(value)
-        v_str = v_str.replace(chr(10), " ")
-        v_str = v_str[:32]
-        if isinstance(value, np.ndarray):
-            return f"{v_str}@ndarray[{value.dtype.name}]"
-        elif isinstance(value, pd.DataFrame):
-            simple_schema = ", ".join([f"{t} {d.name}" for t, d in 
value.dtypes.items()])
-            return f"{v_str}@Dataframe[{simple_schema}]"
-        return f"{v_str}@{type(value).__name__}"
+        # Use extended pandas value representation
+        return self.repr_pandas_value(value)
 
     def test_str_repr(self):
         self.assertEqual(
@@ -189,21 +156,14 @@ class PandasUDFReturnTypeTests(ReusedSQLTestCase):
         self._compare_or_generate_golden(golden_file, test_name)
 
     def _compare_or_generate_golden(self, golden_file, test_name):
-        testing = os.environ.get("SPARK_GENERATE_GOLDEN_FILES", "?") != "1"
+        generating = self.is_generating_golden()
 
         golden_csv = os.path.join(os.path.dirname(__file__), 
f"{golden_file}.csv")
         golden_md = os.path.join(os.path.dirname(__file__), 
f"{golden_file}.md")
 
         golden = None
-        if testing:
-            golden = pd.read_csv(
-                golden_csv,
-                sep="\t",
-                index_col=0,
-                dtype="str",
-                na_filter=False,
-                engine="python",
-            )
+        if not generating:
+            golden = self.load_golden_csv(golden_csv)
 
         def work(arg):
             spark_type, value = arg
@@ -231,10 +191,10 @@ class PandasUDFReturnTypeTests(ReusedSQLTestCase):
                 result = "X"
 
             # Clean up exception message to remove newlines and extra 
whitespace
-            result = result.replace("\n", " ").replace("\r", " 
").replace("\t", " ")
+            result = self.clean_result(result)
 
             err = None
-            if testing:
+            if not generating:
                 expected = golden.loc[str_t, str_v]
                 if expected != result:
                     err = f"{str_v} => {spark_type} expects {expected} but got 
{result}"
@@ -250,7 +210,7 @@ class PandasUDFReturnTypeTests(ReusedSQLTestCase):
                 )
             )
 
-        if testing:
+        if not generating:
             errs = []
             for _, _, _, err in results:
                 if err is not None:
@@ -270,18 +230,7 @@ class PandasUDFReturnTypeTests(ReusedSQLTestCase):
             for str_t, str_v, res, _ in results:
                 new_golden.loc[str_t, str_v] = res
 
-            # generating the CSV file as the golden file
-            new_golden.to_csv(golden_csv, sep="\t", header=True, index=True)
-
-            try:
-                # generating the GitHub flavored Markdown file
-                # package tabulate is required
-                new_golden.to_markdown(golden_md, index=True, 
tablefmt="github")
-            except Exception as e:
-                print(
-                    f"{test_name} return type coercion: "
-                    f"fail to write the markdown file due to {e}!"
-                )
+            self.save_golden(new_golden, golden_csv, golden_md)
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/sql/tests/coercion/test_python_udf_input_type.py 
b/python/pyspark/sql/tests/coercion/test_python_udf_input_type.py
index 6897647d6bea..f0afb84b4361 100644
--- a/python/pyspark/sql/tests/coercion/test_python_udf_input_type.py
+++ b/python/pyspark/sql/tests/coercion/test_python_udf_input_type.py
@@ -18,7 +18,6 @@
 from decimal import Decimal
 import datetime
 import os
-import time
 import unittest
 
 from pyspark.sql.functions import udf
@@ -51,6 +50,7 @@ from pyspark.testing.utils import (
     numpy_requirement_message,
 )
 from pyspark.testing.sqlutils import ReusedSQLTestCase
+from pyspark.testing.goldenutils import GoldenFileTestMixin
 
 if have_numpy:
     import numpy as np
@@ -73,29 +73,7 @@ if have_pandas:
     or LooseVersion(np.__version__) < LooseVersion("2.0.0"),
     pandas_requirement_message or pyarrow_requirement_message or 
numpy_requirement_message,
 )
-class UDFInputTypeTests(ReusedSQLTestCase):
-    @classmethod
-    def setUpClass(cls):
-        super().setUpClass()
-
-        # Synchronize default timezone between Python and Java
-        cls.tz_prev = os.environ.get("TZ", None)  # save current tz if set
-        tz = "America/Los_Angeles"
-        os.environ["TZ"] = tz
-        time.tzset()
-
-        cls.sc.environment["TZ"] = tz
-        cls.spark.conf.set("spark.sql.session.timeZone", tz)
-
-    @classmethod
-    def tearDownClass(cls):
-        del os.environ["TZ"]
-        if cls.tz_prev is not None:
-            os.environ["TZ"] = cls.tz_prev
-        time.tzset()
-
-        super().tearDownClass()
-
+class UDFInputTypeTests(GoldenFileTestMixin, ReusedSQLTestCase):
     @property
     def prefix(self):
         return "golden_python_udf_input_type_coercion"
@@ -289,27 +267,20 @@ class UDFInputTypeTests(ReusedSQLTestCase):
             self._compare_or_generate_golden(golden_file, test_name)
 
     def _compare_or_generate_golden(self, golden_file, test_name):
-        testing = os.environ.get("SPARK_GENERATE_GOLDEN_FILES", "?") != "1"
+        generating = self.is_generating_golden()
 
         golden_csv = os.path.join(os.path.dirname(__file__), 
f"{golden_file}.csv")
         golden_md = os.path.join(os.path.dirname(__file__), 
f"{golden_file}.md")
 
         golden = None
-        if testing:
-            golden = pd.read_csv(
-                golden_csv,
-                sep="\t",
-                index_col=0,
-                dtype="str",
-                na_filter=False,
-                engine="python",
-            )
+        if not generating:
+            golden = self.load_golden_csv(golden_csv)
 
         results = []
-        for idx, (test_name, spark_type, data_func) in 
enumerate(self.test_cases):
+        for idx, (case_name, spark_type, data_func) in 
enumerate(self.test_cases):
             input_df = data_func(spark_type).repartition(1)
             input_data = [row["value"] for row in input_df.collect()]
-            result = [test_name, spark_type.simpleString(), str(input_data)]
+            result = [case_name, self.repr_type(spark_type), str(input_data)]
 
             try:
 
@@ -350,15 +321,15 @@ class UDFInputTypeTests(ReusedSQLTestCase):
                 result.append(f"✗ {str(e)}")
 
             # Clean up exception message to remove newlines and extra 
whitespace
-            result = [r.replace("\n", " ").replace("\r", " ").replace("\t", " 
") for r in result]
+            result = [self.clean_result(r) for r in result]
 
             error_msg = None
-            if testing and result != list(golden.iloc[idx]):
+            if not generating and result != list(golden.iloc[idx]):
                 error_msg = f"line mismatch: expects {list(golden.iloc[idx])} 
but got {result}"
 
             results.append((result, error_msg))
 
-        if testing:
+        if not generating:
             errs = []
             for _, err in results:
                 if err is not None:
@@ -371,18 +342,7 @@ class UDFInputTypeTests(ReusedSQLTestCase):
                 columns=["Test Case", "Spark Type", "Spark Value", "Python 
Type", "Python Value"],
             )
 
-            # generating the CSV file as the golden file
-            new_golden.to_csv(golden_csv, sep="\t", header=True, index=True)
-
-            try:
-                # generating the GitHub flavored Markdown file
-                # package tabulate is required
-                new_golden.to_markdown(golden_md, index=True, 
tablefmt="github")
-            except Exception as e:
-                print(
-                    f"{test_name} return type coercion: "
-                    f"fail to write the markdown file due to {e}!"
-                )
+            self.save_golden(new_golden, golden_csv, golden_md)
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/sql/tests/coercion/test_python_udf_return_type.py 
b/python/pyspark/sql/tests/coercion/test_python_udf_return_type.py
index 08ba2f809f50..e3b9939fa51f 100644
--- a/python/pyspark/sql/tests/coercion/test_python_udf_return_type.py
+++ b/python/pyspark/sql/tests/coercion/test_python_udf_return_type.py
@@ -22,7 +22,6 @@ from decimal import Decimal
 import itertools
 import os
 import re
-import time
 import unittest
 
 from pyspark.sql import Row
@@ -55,6 +54,7 @@ from pyspark.testing.utils import (
     numpy_requirement_message,
 )
 from pyspark.testing.sqlutils import ReusedSQLTestCase
+from pyspark.testing.goldenutils import GoldenFileTestMixin
 
 if have_numpy:
     import numpy as np
@@ -77,29 +77,7 @@ if have_pandas:
     or LooseVersion(np.__version__) < LooseVersion("2.0.0"),
     pandas_requirement_message or pyarrow_requirement_message or 
numpy_requirement_message,
 )
-class UDFReturnTypeTests(ReusedSQLTestCase):
-    @classmethod
-    def setUpClass(cls):
-        super().setUpClass()
-
-        # Synchronize default timezone between Python and Java
-        cls.tz_prev = os.environ.get("TZ", None)  # save current tz if set
-        tz = "America/Los_Angeles"
-        os.environ["TZ"] = tz
-        time.tzset()
-
-        cls.sc.environment["TZ"] = tz
-        cls.spark.conf.set("spark.sql.session.timeZone", tz)
-
-    @classmethod
-    def tearDownClass(cls):
-        del os.environ["TZ"]
-        if cls.tz_prev is not None:
-            os.environ["TZ"] = cls.tz_prev
-        time.tzset()
-
-        super().tearDownClass()
-
+class UDFReturnTypeTests(GoldenFileTestMixin, ReusedSQLTestCase):
     @property
     def prefix(self):
         return "golden_python_udf_return_type_coercion"
@@ -144,12 +122,6 @@ class UDFReturnTypeTests(ReusedSQLTestCase):
             StructType([StructField("_1", IntegerType())]),
         ]
 
-    def repr_type(self, spark_type):
-        return spark_type.simpleString()
-
-    def repr_value(self, value):
-        return f"{str(value)}@{type(value).__name__}"
-
     def test_str_repr(self):
         self.assertEqual(
             len(self.test_types),
@@ -196,21 +168,14 @@ class UDFReturnTypeTests(ReusedSQLTestCase):
             self._compare_or_generate_golden(golden_file, test_name)
 
     def _compare_or_generate_golden(self, golden_file, test_name):
-        testing = os.environ.get("SPARK_GENERATE_GOLDEN_FILES", "?") != "1"
+        generating = self.is_generating_golden()
 
         golden_csv = os.path.join(os.path.dirname(__file__), 
f"{golden_file}.csv")
         golden_md = os.path.join(os.path.dirname(__file__), 
f"{golden_file}.md")
 
         golden = None
-        if testing:
-            golden = pd.read_csv(
-                golden_csv,
-                sep="\t",
-                index_col=0,
-                dtype="str",
-                na_filter=False,
-                engine="python",
-            )
+        if not generating:
+            golden = self.load_golden_csv(golden_csv)
 
         def work(arg):
             spark_type, value = arg
@@ -228,10 +193,10 @@ class UDFReturnTypeTests(ReusedSQLTestCase):
                 result = "X"
 
             # Clean up exception message to remove newlines and extra 
whitespace
-            result = result.replace("\n", " ").replace("\r", " 
").replace("\t", " ")
+            result = self.clean_result(result)
 
             err = None
-            if testing:
+            if not generating:
                 expected = golden.loc[str_t, str_v]
                 if expected != result:
                     err = f"{str_v} => {spark_type} expects {expected} but got 
{result}"
@@ -247,7 +212,7 @@ class UDFReturnTypeTests(ReusedSQLTestCase):
                 )
             )
 
-        if testing:
+        if not generating:
             errs = []
             for _, _, _, err in results:
                 if err is not None:
@@ -267,18 +232,7 @@ class UDFReturnTypeTests(ReusedSQLTestCase):
             for str_t, str_v, res, _ in results:
                 new_golden.loc[str_t, str_v] = res
 
-            # generating the CSV file as the golden file
-            new_golden.to_csv(golden_csv, sep="\t", header=True, index=True)
-
-            try:
-                # generating the GitHub flavored Markdown file
-                # package tabulate is required
-                new_golden.to_markdown(golden_md, index=True, 
tablefmt="github")
-            except Exception as e:
-                print(
-                    f"{test_name} return type coercion: "
-                    f"fail to write the markdown file due to {e}!"
-                )
+            self.save_golden(new_golden, golden_csv, golden_md)
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/testing/goldenutils.py 
b/python/pyspark/testing/goldenutils.py
new file mode 100644
index 000000000000..ecb253689e97
--- /dev/null
+++ b/python/pyspark/testing/goldenutils.py
@@ -0,0 +1,254 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import Any, Optional
+import os
+import time
+
+import pandas as pd
+
+try:
+    import numpy as np
+
+    have_numpy = True
+except ImportError:
+    have_numpy = False
+
+
+class GoldenFileTestMixin:
+    """
+    Mixin class providing utilities for golden file based testing.
+
+    Golden files are CSV files that store expected test results. This mixin 
provides:
+    - Timezone setup/teardown for deterministic results
+    - Golden file read/write with SPARK_GENERATE_GOLDEN_FILES env var support
+    - Result string cleaning utilities
+
+    To regenerate golden files, set SPARK_GENERATE_GOLDEN_FILES=1 before 
running tests.
+
+    Usage:
+        class MyTest(GoldenFileTestMixin, ReusedSQLTestCase):
+            def test_something(self):
+                # Use helper methods from mixin
+                if self.is_generating_golden():
+                    self.save_golden(df, golden_csv, golden_md)
+                else:
+                    golden = self.load_golden_csv(golden_csv)
+                    # compare results with golden
+    """
+
+    _tz_prev: Optional[str] = None
+
+    def __init_subclass__(cls, **kwargs):
+        """Verify correct inheritance order at class definition time."""
+        super().__init_subclass__(**kwargs)
+        # Check that GoldenFileTestMixin comes before any class with 
setUpClass in MRO.
+        # This ensures setup_timezone() will be called after Spark session is 
created.
+        # Correct:   class MyTest(GoldenFileTestMixin, ReusedSQLTestCase)
+        # Incorrect: class MyTest(ReusedSQLTestCase, GoldenFileTestMixin)
+        for base in cls.__mro__:
+            if base is GoldenFileTestMixin:
+                break
+            # If we find a class with setUpClass before GoldenFileTestMixin, 
that's wrong
+            if base is not cls and hasattr(base, "setUpClass") and 
"setUpClass" in base.__dict__:
+                raise TypeError(
+                    f"{cls.__name__} has incorrect inheritance order. "
+                    f"GoldenFileTestMixin must be listed BEFORE 
{base.__name__}. "
+                    f"Use: class {cls.__name__}(GoldenFileTestMixin, 
{base.__name__}, ...)"
+                )
+
+    @classmethod
+    def setUpClass(cls) -> None:
+        """Setup test class with timezone configuration."""
+        super().setUpClass()
+        cls.setup_timezone()
+
+    @classmethod
+    def tearDownClass(cls) -> None:
+        """Teardown test class and restore timezone."""
+        cls.teardown_timezone()
+        super().tearDownClass()
+
+    @classmethod
+    def setup_timezone(cls, tz: str = "America/Los_Angeles") -> None:
+        """
+        Setup timezone for deterministic test results.
+        Synchronizes timezone between Python and Java.
+        """
+        cls._tz_prev = os.environ.get("TZ", None)
+        os.environ["TZ"] = tz
+        time.tzset()
+
+        cls.sc.environment["TZ"] = tz
+        cls.spark.conf.set("spark.sql.session.timeZone", tz)
+
+    @classmethod
+    def teardown_timezone(cls) -> None:
+        """Restore original timezone."""
+        if "TZ" in os.environ:
+            del os.environ["TZ"]
+        if cls._tz_prev is not None:
+            os.environ["TZ"] = cls._tz_prev
+        time.tzset()
+
+    @staticmethod
+    def is_generating_golden() -> bool:
+        """Check if we are generating golden files (vs testing against 
them)."""
+        return os.environ.get("SPARK_GENERATE_GOLDEN_FILES", "0") == "1"
+
+    @staticmethod
+    def load_golden_csv(golden_csv: str, use_index: bool = True) -> 
"pd.DataFrame":
+        """
+        Load golden file from CSV.
+
+        Parameters
+        ----------
+        golden_csv : str
+            Path to the golden CSV file.
+        use_index : bool
+            If True, use first column as index.
+            If False, don't use index.
+
+        Returns
+        -------
+        pd.DataFrame
+            The loaded golden data with string dtype.
+        """
+        return pd.read_csv(
+            golden_csv,
+            sep="\t",
+            index_col=0 if use_index else None,
+            dtype="str",
+            na_filter=False,
+            engine="python",
+        )
+
+    @staticmethod
+    def save_golden(df: "pd.DataFrame", golden_csv: str, golden_md: 
Optional[str] = None) -> None:
+        """
+        Save DataFrame as golden file (CSV and optionally Markdown).
+
+        Parameters
+        ----------
+        df : pd.DataFrame
+            The DataFrame to save.
+        golden_csv : str
+            Path to save the CSV file.
+        golden_md : str, optional
+            Path to save the Markdown file. Requires tabulate package.
+        """
+        df.to_csv(golden_csv, sep="\t", header=True, index=True)
+
+        if golden_md is not None:
+            try:
+                df.to_markdown(golden_md, index=True, tablefmt="github")
+            except Exception as e:
+                import warnings
+
+                warnings.warn(
+                    f"Failed to write markdown file {golden_md}: {e}. "
+                    "Install 'tabulate' package to generate markdown files."
+                )
+
+    @staticmethod
+    def repr_type(t: Any) -> str:
+        """
+        Convert a type to string representation.
+
+        Handles different type representations:
+        - Spark DataType: uses simpleString() (e.g., "int", "string", 
"array<int>")
+        - Python type: uses __name__ (e.g., "int", "str", "list")
+        - Other: uses str()
+
+        Parameters
+        ----------
+        t : Any
+            The type to represent. Can be Spark DataType or Python type.
+
+        Returns
+        -------
+        str
+            String representation of the type.
+        """
+        # Check if it's a Spark DataType (has simpleString method)
+        if hasattr(t, "simpleString"):
+            return t.simpleString()
+        # Check if it's a Python type
+        elif isinstance(t, type):
+            return t.__name__
+        else:
+            return str(t)
+
+    @classmethod
+    def repr_value(cls, value: Any, max_len: int = 32) -> str:
+        """
+        Convert Python value to string representation for golden file.
+
+        Default format: "value_str@type_name"
+        Subclasses can override this method for custom representations.
+
+        Parameters
+        ----------
+        value : Any
+            The Python value to represent.
+        max_len : int, default 32
+            Maximum length for the value string portion.
+
+        Returns
+        -------
+        str
+            String representation in format "value@type".
+        """
+        v_str = str(value)[:max_len]
+        return f"{v_str}@{type(value).__name__}"
+
+    @classmethod
+    def repr_pandas_value(cls, value: Any, max_len: int = 32) -> str:
+        """
+        Convert Python/Pandas value to string representation for golden file.
+
+        Extended version that handles pandas DataFrame and numpy ndarray 
specially.
+
+        Parameters
+        ----------
+        value : Any
+            The Python value to represent.
+        max_len : int, default 32
+            Maximum length for the value string portion.
+
+        Returns
+        -------
+        str
+            String representation in format "value@type[dtype]".
+        """
+        if isinstance(value, pd.DataFrame):
+            v_str = value.to_json()
+        else:
+            v_str = str(value)
+        v_str = v_str.replace("\n", " ")[:max_len]
+
+        if have_numpy and isinstance(value, np.ndarray):
+            return f"{v_str}@ndarray[{value.dtype.name}]"
+        elif isinstance(value, pd.DataFrame):
+            simple_schema = ", ".join([f"{t} {d.name}" for t, d in 
value.dtypes.items()])
+            return f"{v_str}@Dataframe[{simple_schema}]"
+        return f"{v_str}@{type(value).__name__}"
+
+    @staticmethod
+    def clean_result(result: str) -> str:
+        """Clean result string by removing newlines and extra whitespace."""
+        return result.replace("\n", " ").replace("\r", " ").replace("\t", " ")


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to