This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 9ce067c7d532 [SPARK-55235][PYTHON][TESTS] Refactor tests for pandas
udf input type coercion
9ce067c7d532 is described below
commit 9ce067c7d532a6766d622e55164baa942a065250
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Jan 28 14:10:35 2026 +0900
[SPARK-55235][PYTHON][TESTS] Refactor tests for pandas udf input type
coercion
### What changes were proposed in this pull request?
Refactor tests for pandas udf input type coercion
### Why are the changes needed?
to use the new pandas-based framework and output csv/markdown files
### Does this PR introduce _any_ user-facing change?
NO, test-only
### How was this patch tested?
CI
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #53999 from zhengruifeng/refactor_pd_udf_input.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
dev/sparktestsupport/modules.py | 2 +-
.../golden_pandas_udf_input_type_coercion_base.csv | 40 +++
.../golden_pandas_udf_input_type_coercion_base.md | 41 +++
.../test_pandas_udf_input_type.py} | 231 ++++++++------
python/pyspark/sql/tests/udf_type_tests/README.md | 13 -
.../pyspark/sql/tests/udf_type_tests/__init__.py | 16 -
.../golden_pandas_udf_input_types.txt | 43 ---
.../sql/tests/udf_type_tests/type_table_utils.py | 332 ---------------------
8 files changed, 215 insertions(+), 503 deletions(-)
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index f14d94251365..79ec8da96b93 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -594,7 +594,7 @@ pyspark_sql = Module(
"pyspark.sql.tests.plot.test_frame_plot",
"pyspark.sql.tests.plot.test_frame_plot_plotly",
"pyspark.sql.tests.test_connect_compatibility",
- "pyspark.sql.tests.udf_type_tests.test_udf_input_types",
+ "pyspark.sql.tests.coercion.test_pandas_udf_input_type",
"pyspark.sql.tests.coercion.test_python_udf_input_type",
"pyspark.sql.tests.coercion.test_pandas_udf_return_type",
"pyspark.sql.tests.coercion.test_python_udf_return_type",
diff --git
a/python/pyspark/sql/tests/coercion/golden_pandas_udf_input_type_coercion_base.csv
b/python/pyspark/sql/tests/coercion/golden_pandas_udf_input_type_coercion_base.csv
new file mode 100644
index 000000000000..965213ba4820
--- /dev/null
+++
b/python/pyspark/sql/tests/coercion/golden_pandas_udf_input_type_coercion_base.csv
@@ -0,0 +1,40 @@
+ Test Case Spark Type Spark Value Python Type Python
Value
+0 byte_values tinyint [-128, 127, 0] ['int8', 'int8', 'int8']
[-128, 127, 0]
+1 byte_null tinyint [None, 42] ['Int8', 'Int8'] [None,
42]
+2 short_values smallint [-32768, 32767, 0] ['int16',
'int16', 'int16'] [-32768, 32767, 0]
+3 short_null smallint [None, 123] ['Int16', 'Int16']
[None, 123]
+4 int_values int [-2147483648, 2147483647, 0] ['int32',
'int32', 'int32'] [-2147483648, 2147483647, 0]
+5 int_null int [None, 456] ['Int32', 'Int32'] [None,
456]
+6 long_values bigint [-9223372036854775808, 9223372036854775807, 0]
['int64', 'int64', 'int64'] [-9223372036854775808, 9223372036854775807, 0]
+7 long_null bigint [None, 789] ['Int64', 'Int64'] [None,
789]
+8 float_values float [0.0, 1.0, 3.140000104904175] ['float32',
'float32', 'float32'] [0.0, 1.0, 3.140000104904175]
+9 float_null float [None, 3.140000104904175] ['float32',
'float32'] [None, 3.140000104904175]
+10 double_values double [0.0, 1.0, 0.3333333333333333] ['float64',
'float64', 'float64'] [0.0, 1.0, 0.3333333333333333]
+11 double_null double [None, 2.71] ['float64', 'float64'] [None,
2.71]
+12 decimal_values decimal(3,2) [Decimal('5.35'), Decimal('1.23')]
['object', 'object'] [Decimal('5.35'), Decimal('1.23')]
+13 decimal_null decimal(3,2) [None, Decimal('9.99')] ['object',
'object'] [None, Decimal('9.99')]
+14 string_values string ['abc', '', 'hello'] ['object', 'object',
'object'] ['abc', '', 'hello']
+15 string_null string [None, 'test'] ['object', 'object'] [None,
'test']
+16 binary_values binary [b'abc', b'', b'ABC'] ['object', 'object',
'object'] [b'abc', b'', b'ABC']
+17 binary_null binary [None, b'test'] ['object', 'object'] [None,
b'test']
+18 boolean_values boolean [True, False] ['bool', 'bool'] [True,
False]
+19 boolean_null boolean [None, True] ['object', 'object'] [None,
True]
+20 date_values date [datetime.date(2020, 2, 2), datetime.date(1970,
1, 1)] ['object', 'object'] [datetime.date(2020, 2, 2), datetime.date(1970,
1, 1)]
+21 date_null date [None, datetime.date(2023, 1, 1)]
['object', 'object'] [None, datetime.date(2023, 1, 1)]
+22 timestamp_values timestamp [datetime.datetime(2020, 2, 2,
12, 15, 16, 123000)] ['datetime64[ns]'] [datetime.datetime(2020, 2, 2,
12, 15, 16, 123000)]
+23 timestamp_null timestamp [None, datetime.datetime(2023, 1, 1,
12, 0)] ['datetime64[ns]', 'datetime64[ns]'] [None,
datetime.datetime(2023, 1, 1, 12, 0)]
+24 array_int_values array<int> [[1, 2, 3], [], [1, None, 3]]
['object', 'object', 'object'] [[1, 2, 3], [], [1, None, 3]]
+25 array_int_null array<int> [None, [4, 5, 6]] ['object',
'object'] [None, [4, 5, 6]]
+26 map_str_int_values map<string,int> [{'world': 2, 'hello': 1}, {}]
['object', 'object'] [{'world': 2, 'hello': 1}, {}]
+27 map_str_int_null map<string,int> [None, {'test': 123}]
['object', 'object'] [None, {'test': 123}]
+28 struct_int_str_values struct<a1:int,a2:string> [Row(a1=1,
a2='hello'), Row(a1=2, a2='world')] ['DataFrame', 'DataFrame'] [Row(a1=1,
a2='hello'), Row(a1=2, a2='world')]
+29 struct_int_str_null struct<a1:int,a2:string> [None,
Row(a1=99, a2='test')] ['DataFrame', 'DataFrame'] [Row(a1=None,
a2=None), Row(a1=99, a2='test')]
+30 array_array_int array<array<int>> [[[1, 2, 3]], [[1], [2, 3]]]
['object', 'object'] [[[1, 2, 3]], [[1], [2, 3]]]
+31 array_map_str_int array<map<string,int>> [[{'world': 2, 'hello':
1}], [{'a': 1}, {'b': 2}]] ['object', 'object'] [[{'world': 2, 'hello':
1}], [{'a': 1}, {'b': 2}]]
+32 array_struct_int_str array<struct<a1:int,a2:string>> [[Row(a1=1,
a2='hello')], [Row(a1=1, a2='hello'), Row(a1=2, a2='world')]] ['object',
'object'] [[Row(a1=1, a2='hello')], [Row(a1=1, a2='hello'), Row(a1=2,
a2='world')]]
+33 map_int_array_int map<int,array<int>> [{1: [1, 2, 3]}, {1:
[1], 2: [2, 3]}] ['object', 'object'] [{1: [1, 2, 3]}, {1: [1], 2: [2, 3]}]
+34 map_int_map_str_int map<int,map<string,int>> [{1: {'world':
2, 'hello': 1}}] ['object'] [{1: {'world': 2, 'hello': 1}}]
+35 map_int_struct_int_str map<int,struct<a1:int,a2:string>> [{1:
Row(a1=1, a2='hello')}] ['object'] [{1: Row(a1=1, a2='hello')}]
+36 struct_int_array_int struct<a:int,b:array<int>> [Row(a=1, b=[1,
2, 3])] ['DataFrame'] [Row(a=1, b=[1, 2, 3])]
+37 struct_int_map_str_int struct<a:int,b:map<string,int>> [Row(a=1,
b={'world': 2, 'hello': 1})] ['DataFrame'] [Row(a=1, b={'world': 2, 'hello':
1})]
+38 struct_int_struct_int_str
struct<a:int,b:struct<a1:int,a2:string>> [Row(a=1, b=Row(a1=1,
a2='hello'))] ['DataFrame'] [Row(a=1, b=Row(a1=1, a2='hello'))]
diff --git
a/python/pyspark/sql/tests/coercion/golden_pandas_udf_input_type_coercion_base.md
b/python/pyspark/sql/tests/coercion/golden_pandas_udf_input_type_coercion_base.md
new file mode 100644
index 000000000000..5240057fbfd8
--- /dev/null
+++
b/python/pyspark/sql/tests/coercion/golden_pandas_udf_input_type_coercion_base.md
@@ -0,0 +1,41 @@
+| | Test Case | Spark Type |
Spark Value |
Python Type | Python Value
|
+|----|---------------------------|------------------------------------------|---------------------------------------------------------------------------|--------------------------------------|---------------------------------------------------------------------------|
+| 0 | byte_values | tinyint |
[-128, 127, 0] |
['int8', 'int8', 'int8'] | [-128, 127, 0]
|
+| 1 | byte_null | tinyint |
[None, 42] |
['Int8', 'Int8'] | [None, 42]
|
+| 2 | short_values | smallint |
[-32768, 32767, 0] |
['int16', 'int16', 'int16'] | [-32768, 32767, 0]
|
+| 3 | short_null | smallint |
[None, 123] |
['Int16', 'Int16'] | [None, 123]
|
+| 4 | int_values | int |
[-2147483648, 2147483647, 0] |
['int32', 'int32', 'int32'] | [-2147483648, 2147483647, 0]
|
+| 5 | int_null | int |
[None, 456] |
['Int32', 'Int32'] | [None, 456]
|
+| 6 | long_values | bigint |
[-9223372036854775808, 9223372036854775807, 0] |
['int64', 'int64', 'int64'] | [-9223372036854775808,
9223372036854775807, 0] |
+| 7 | long_null | bigint |
[None, 789] |
['Int64', 'Int64'] | [None, 789]
|
+| 8 | float_values | float |
[0.0, 1.0, 3.140000104904175] |
['float32', 'float32', 'float32'] | [0.0, 1.0, 3.140000104904175]
|
+| 9 | float_null | float |
[None, 3.140000104904175] |
['float32', 'float32'] | [None, 3.140000104904175]
|
+| 10 | double_values | double |
[0.0, 1.0, 0.3333333333333333] |
['float64', 'float64', 'float64'] | [0.0, 1.0, 0.3333333333333333]
|
+| 11 | double_null | double |
[None, 2.71] |
['float64', 'float64'] | [None, 2.71]
|
+| 12 | decimal_values | decimal(3,2) |
[Decimal('5.35'), Decimal('1.23')] |
['object', 'object'] | [Decimal('5.35'), Decimal('1.23')]
|
+| 13 | decimal_null | decimal(3,2) |
[None, Decimal('9.99')] |
['object', 'object'] | [None, Decimal('9.99')]
|
+| 14 | string_values | string |
['abc', '', 'hello'] |
['object', 'object', 'object'] | ['abc', '', 'hello']
|
+| 15 | string_null | string |
[None, 'test'] |
['object', 'object'] | [None, 'test']
|
+| 16 | binary_values | binary |
[b'abc', b'', b'ABC'] |
['object', 'object', 'object'] | [b'abc', b'', b'ABC']
|
+| 17 | binary_null | binary |
[None, b'test'] |
['object', 'object'] | [None, b'test']
|
+| 18 | boolean_values | boolean |
[True, False] |
['bool', 'bool'] | [True, False]
|
+| 19 | boolean_null | boolean |
[None, True] |
['object', 'object'] | [None, True]
|
+| 20 | date_values | date |
[datetime.date(2020, 2, 2), datetime.date(1970, 1, 1)] |
['object', 'object'] | [datetime.date(2020, 2, 2),
datetime.date(1970, 1, 1)] |
+| 21 | date_null | date |
[None, datetime.date(2023, 1, 1)] |
['object', 'object'] | [None, datetime.date(2023, 1, 1)]
|
+| 22 | timestamp_values | timestamp |
[datetime.datetime(2020, 2, 2, 12, 15, 16, 123000)] |
['datetime64[ns]'] | [datetime.datetime(2020, 2, 2, 12, 15,
16, 123000)] |
+| 23 | timestamp_null | timestamp |
[None, datetime.datetime(2023, 1, 1, 12, 0)] |
['datetime64[ns]', 'datetime64[ns]'] | [None, datetime.datetime(2023, 1, 1, 12,
0)] |
+| 24 | array_int_values | array<int> |
[[1, 2, 3], [], [1, None, 3]] |
['object', 'object', 'object'] | [[1, 2, 3], [], [1, None, 3]]
|
+| 25 | array_int_null | array<int> |
[None, [4, 5, 6]] |
['object', 'object'] | [None, [4, 5, 6]]
|
+| 26 | map_str_int_values | map<string,int> |
[{'world': 2, 'hello': 1}, {}] |
['object', 'object'] | [{'world': 2, 'hello': 1}, {}]
|
+| 27 | map_str_int_null | map<string,int> |
[None, {'test': 123}] |
['object', 'object'] | [None, {'test': 123}]
|
+| 28 | struct_int_str_values | struct<a1:int,a2:string> |
[Row(a1=1, a2='hello'), Row(a1=2, a2='world')] |
['DataFrame', 'DataFrame'] | [Row(a1=1, a2='hello'), Row(a1=2,
a2='world')] |
+| 29 | struct_int_str_null | struct<a1:int,a2:string> |
[None, Row(a1=99, a2='test')] |
['DataFrame', 'DataFrame'] | [Row(a1=None, a2=None), Row(a1=99,
a2='test')] |
+| 30 | array_array_int | array<array<int>> |
[[[1, 2, 3]], [[1], [2, 3]]] |
['object', 'object'] | [[[1, 2, 3]], [[1], [2, 3]]]
|
+| 31 | array_map_str_int | array<map<string,int>> |
[[{'world': 2, 'hello': 1}], [{'a': 1}, {'b': 2}]] |
['object', 'object'] | [[{'world': 2, 'hello': 1}], [{'a': 1},
{'b': 2}]] |
+| 32 | array_struct_int_str | array<struct<a1:int,a2:string>> |
[[Row(a1=1, a2='hello')], [Row(a1=1, a2='hello'), Row(a1=2, a2='world')]] |
['object', 'object'] | [[Row(a1=1, a2='hello')], [Row(a1=1,
a2='hello'), Row(a1=2, a2='world')]] |
+| 33 | map_int_array_int | map<int,array<int>> |
[{1: [1, 2, 3]}, {1: [1], 2: [2, 3]}] |
['object', 'object'] | [{1: [1, 2, 3]}, {1: [1], 2: [2, 3]}]
|
+| 34 | map_int_map_str_int | map<int,map<string,int>> |
[{1: {'world': 2, 'hello': 1}}] |
['object'] | [{1: {'world': 2, 'hello': 1}}]
|
+| 35 | map_int_struct_int_str | map<int,struct<a1:int,a2:string>> |
[{1: Row(a1=1, a2='hello')}] |
['object'] | [{1: Row(a1=1, a2='hello')}]
|
+| 36 | struct_int_array_int | struct<a:int,b:array<int>> |
[Row(a=1, b=[1, 2, 3])] |
['DataFrame'] | [Row(a=1, b=[1, 2, 3])]
|
+| 37 | struct_int_map_str_int | struct<a:int,b:map<string,int>> |
[Row(a=1, b={'world': 2, 'hello': 1})] |
['DataFrame'] | [Row(a=1, b={'world': 2, 'hello': 1})]
|
+| 38 | struct_int_struct_int_str | struct<a:int,b:struct<a1:int,a2:string>> |
[Row(a=1, b=Row(a1=1, a2='hello'))] |
['DataFrame'] | [Row(a=1, b=Row(a1=1, a2='hello'))]
|
\ No newline at end of file
diff --git a/python/pyspark/sql/tests/udf_type_tests/test_udf_input_types.py
b/python/pyspark/sql/tests/coercion/test_pandas_udf_input_type.py
similarity index 70%
rename from python/pyspark/sql/tests/udf_type_tests/test_udf_input_types.py
rename to python/pyspark/sql/tests/coercion/test_pandas_udf_input_type.py
index e0d144892128..005a8f62945a 100644
--- a/python/pyspark/sql/tests/udf_type_tests/test_udf_input_types.py
+++ b/python/pyspark/sql/tests/coercion/test_pandas_udf_input_type.py
@@ -15,14 +15,15 @@
# limitations under the License.
#
+from decimal import Decimal
+import datetime
import os
-import platform
+import time
import unittest
-import pandas as pd
-from pyspark.sql import Row
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import (
+ Row,
ArrayType,
BinaryType,
BooleanType,
@@ -50,120 +51,57 @@ from pyspark.testing.utils import (
numpy_requirement_message,
)
from pyspark.testing.sqlutils import ReusedSQLTestCase
-from .type_table_utils import generate_table_diff, format_type_table
if have_numpy:
import numpy as np
+if have_pandas:
+ import pandas as pd
+
+# If you need to re-generate the golden files, you need to set the
+# SPARK_GENERATE_GOLDEN_FILES=1 environment variable before running this test,
+# e.g.:
+# SPARK_GENERATE_GOLDEN_FILES=1 python/run-tests -k
+# --testnames 'pyspark.sql.tests.coercion.test_pandas_udf_input_type'
+# If package tabulate https://pypi.org/project/tabulate/ is installed,
+# it will also re-generate the Markdown files.
@unittest.skipIf(
not have_pandas
or not have_pyarrow
or not have_numpy
- or LooseVersion(np.__version__) < LooseVersion("2.0.0")
- or platform.system() == "Darwin",
- pandas_requirement_message
- or pyarrow_requirement_message
- or numpy_requirement_message
- or "float128 not supported on macos",
+ or LooseVersion(np.__version__) < LooseVersion("2.0.0"),
+ pandas_requirement_message or pyarrow_requirement_message or
numpy_requirement_message,
)
-class UDFInputTypeTests(ReusedSQLTestCase):
+class PandasUDFInputTypeTests(ReusedSQLTestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
- def setUp(self):
- super().setUp()
-
- def test_pandas_udf_input(self):
- golden_file = os.path.join(os.path.dirname(__file__),
"golden_pandas_udf_input_types.txt")
- results = self._generate_pandas_udf_input_type_coercion_results()
- actual_output = format_type_table(
- results,
- ["Test Case", "Spark Type", "Spark Value", "Python Type", "Python
Value"],
- column_width=85,
- )
- self._compare_or_create_golden_file(actual_output, golden_file,
"Pandas UDF input types")
-
- def _generate_pandas_udf_input_type_coercion_results(self):
- results = []
- test_cases = self._get_input_type_test_cases()
-
- for test_name, spark_type, data_func in test_cases:
- input_df = data_func(spark_type).repartition(1)
- input_data = [row["value"] for row in input_df.collect()]
- result_row = [test_name, spark_type.simpleString(),
str(input_data)]
-
- try:
-
- def type_pandas_udf(data):
- if hasattr(data, "dtype"):
- # Series case
- return pd.Series([str(data.dtype)] * len(data))
- else:
- # DataFrame case (for struct types)
- return pd.Series([str(type(data).__name__)] *
len(data))
+ # Synchronize default timezone between Python and Java
+ cls.tz_prev = os.environ.get("TZ", None) # save current tz if set
+ tz = "America/Los_Angeles"
+ os.environ["TZ"] = tz
+ time.tzset()
- def value_pandas_udf(series):
- return series
+ cls.sc.environment["TZ"] = tz
+ cls.spark.conf.set("spark.sql.session.timeZone", tz)
- type_test_pandas_udf = pandas_udf(type_pandas_udf,
returnType=StringType())
- value_test_pandas_udf = pandas_udf(value_pandas_udf,
returnType=spark_type)
-
- result_df = input_df.select(
- value_test_pandas_udf("value").alias("python_value"),
- type_test_pandas_udf("value").alias("python_type"),
- )
- results_data = result_df.collect()
- values = [row["python_value"] for row in results_data]
- types = [row["python_type"] for row in results_data]
-
- result_row.append(str(types))
- result_row.append(str(values).replace("\n", " "))
-
- except Exception as e:
- print("error_msg", e)
- error_msg = str(e).replace("\n", " ").replace("\r", " ")
- result_row.append(f"✗ {error_msg}")
-
- results.append(result_row)
-
- return results
-
- def _compare_or_create_golden_file(self, actual_output, golden_file,
test_name):
- """Compare actual output with golden file or create golden file if it
doesn't exist.
-
- Args:
- actual_output: The actual output to compare
- golden_file: Path to the golden file
- test_name: Name of the test for error messages
- """
- if os.path.exists(golden_file):
- with open(golden_file, "r") as f:
- expected_output = f.read()
-
- if actual_output != expected_output:
- diff_output = generate_table_diff(actual_output,
expected_output, cell_width=85)
- self.fail(
- f"""
- Results don't match golden file for :{test_name}.\n
- Diff:\n{diff_output}
- """
- )
- else:
- with open(golden_file, "w") as f:
- f.write(actual_output)
- self.fail(f"Golden file created for {test_name}. Please review and
re-run the test.")
+ @classmethod
+ def tearDownClass(cls):
+ del os.environ["TZ"]
+ if cls.tz_prev is not None:
+ os.environ["TZ"] = cls.tz_prev
+ time.tzset()
- def _create_value_schema(self, data_type):
- """Helper to create a StructType schema with a single 'value' column
of the given type."""
- return StructType([StructField("value", data_type, True)])
+ super().tearDownClass()
- def _get_input_type_test_cases(self):
- from pyspark.sql.types import StructType, StructField
- import datetime
- from decimal import Decimal
+ @property
+ def prefix(self):
+ return "golden_pandas_udf_input_type_coercion"
+ @property
+ def test_cases(self):
def df(args):
def create_df(data_type):
# For StructType where the data contains Row objects (not
wrapped in tuples)
@@ -317,6 +255,103 @@ class UDFInputTypeTests(ReusedSQLTestCase):
),
]
+ def test_pandas_input_type_coercion_vanilla(self):
+ self._run_pandas_udf_input_type_coercion(
+ golden_file=f"{self.prefix}_base",
+ test_name="Pandas UDF",
+ )
+
+ def _run_pandas_udf_input_type_coercion(self, golden_file, test_name):
+ self._compare_or_generate_golden(golden_file, test_name)
+
+ def _compare_or_generate_golden(self, golden_file, test_name):
+ testing = os.environ.get("SPARK_GENERATE_GOLDEN_FILES", "?") != "1"
+
+ golden_csv = os.path.join(os.path.dirname(__file__),
f"{golden_file}.csv")
+ golden_md = os.path.join(os.path.dirname(__file__),
f"{golden_file}.md")
+
+ golden = None
+ if testing:
+ golden = pd.read_csv(
+ golden_csv,
+ sep="\t",
+ index_col=0,
+ dtype="str",
+ na_filter=False,
+ engine="python",
+ )
+
+ results = []
+ for idx, (test_name, spark_type, data_func) in
enumerate(self.test_cases):
+ input_df = data_func(spark_type).repartition(1)
+ input_data = [row["value"] for row in input_df.collect()]
+ result = [test_name, spark_type.simpleString(), str(input_data)]
+
+ try:
+
+ def type_pandas_udf(data):
+ if hasattr(data, "dtype"):
+ # Series case
+ return pd.Series([str(data.dtype)] * len(data))
+ else:
+ # DataFrame case (for struct types)
+ return pd.Series([str(type(data).__name__)] *
len(data))
+
+ def value_pandas_udf(series):
+ return series
+
+ type_test_pandas_udf = pandas_udf(type_pandas_udf,
returnType=StringType())
+ value_test_pandas_udf = pandas_udf(value_pandas_udf,
returnType=spark_type)
+
+ result_df = input_df.select(
+ value_test_pandas_udf("value").alias("python_value"),
+ type_test_pandas_udf("value").alias("python_type"),
+ )
+ results_data = result_df.collect()
+ values = [row["python_value"] for row in results_data]
+ types = [row["python_type"] for row in results_data]
+
+ result.append(str(types))
+ result.append(str(values).replace("\n", " "))
+
+ except Exception as e:
+ print("error_msg", e)
+ # Clean up exception message to remove newlines and extra
whitespace
+ e = str(e).replace("\n", " ").replace("\r", " ").replace("\t",
" ")
+ result.append(f"✗ {e}")
+
+ error_msg = None
+ if testing and result != list(golden.iloc[idx]):
+ error_msg = f"line mismatch: expects {list(golden.iloc[idx])}
but got {result}"
+
+ results.append((result, error_msg))
+
+ if testing:
+ errs = []
+ for _, err in results:
+ if err is not None:
+ errs.append(err)
+ self.assertTrue(len(errs) == 0, "\n" + "\n".join(errs) + "\n")
+
+ else:
+ new_golden = pd.DataFrame(
+ [res for res, _ in results],
+ columns=["Test Case", "Spark Type", "Spark Value", "Python
Type", "Python Value"],
+ )
+
+ # generating the CSV file as the golden file
+ new_golden.to_csv(golden_csv, sep="\t", header=True, index=True)
+
+ try:
+ # generating the GitHub flavored Markdown file
+ # package tabulate is required
+ new_golden.to_markdown(golden_md, index=True,
tablefmt="github")
+ except Exception as e:
+ print(
+ f"{test_name} return type coercion: "
+ f"fail to write the markdown file due to {e}!"
+ )
+
if __name__ == "__main__":
from pyspark.testing import main
diff --git a/python/pyspark/sql/tests/udf_type_tests/README.md
b/python/pyspark/sql/tests/udf_type_tests/README.md
deleted file mode 100644
index 74d1933c4951..000000000000
--- a/python/pyspark/sql/tests/udf_type_tests/README.md
+++ /dev/null
@@ -1,13 +0,0 @@
-These tests capture input/output type interfaces between python udfs and the
engine. This internal documentation, not user-facing documentation. Please
consider the type behavior "experimental", unless we specify otherwise. Parts
of the type handling might change in the future.
-
-# Return type tests
-These generate tables with the returned 'Python Value' and the 'SQL Type'
output type of the UDF. The 'SQL Type' fields are DDL formatted strings, which
can be used as `returnType`s.
-- Note: The values inside the table are generated by `repr`. X' means it
throws an exception during the conversion.
-- Note: Python 3.11.9, Pandas 2.2.3 and PyArrow 17.0.0 are used.
-
-# Input type tests
-These generate tables with 'Spark Type' and 'Spark Value', representing the
engine-side input data. The UDF input data is captured in the 'Python type' and
'Python value' columns.
-
-# When this test fails:
-- Look at the diff in the test output
-- To regenerate golden files, simply delete the existing golden file and
re-run the test.
\ No newline at end of file
diff --git a/python/pyspark/sql/tests/udf_type_tests/__init__.py
b/python/pyspark/sql/tests/udf_type_tests/__init__.py
deleted file mode 100644
index cce3acad34a4..000000000000
--- a/python/pyspark/sql/tests/udf_type_tests/__init__.py
+++ /dev/null
@@ -1,16 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
diff --git
a/python/pyspark/sql/tests/udf_type_tests/golden_pandas_udf_input_types.txt
b/python/pyspark/sql/tests/udf_type_tests/golden_pandas_udf_input_types.txt
deleted file mode 100644
index d3092c47bbda..000000000000
--- a/python/pyspark/sql/tests/udf_type_tests/golden_pandas_udf_input_types.txt
+++ /dev/null
@@ -1,43 +0,0 @@
-+--------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------+
-|Test Case
|Spark Type
|Spark Value
|Python Type
|Python Value
|
-+--------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------+
-|byte_values
|tinyint
|[-128, 127, 0]
|['int8', 'int8', 'int8']
|[-128, 127, 0]
|
-|byte_null
|tinyint
|[None, 42]
|['Int8', 'Int8']
|[None, 42]
|
-|short_values
|smallint
|[-32768, 32767, 0]
|['int16', 'int16', 'int16']
|[-32768, 32767, 0]
|
-|short_null
|smallint
|[None, 123]
|['Int16', 'Int16']
|[None, 123]
|
-|int_values
|int
|[-2147483648, 2147483647, 0]
|['int32', 'int32', 'int32']
|[-2147483648, 2147483647, 0]
|
-|int_null
|int
|[None, 456]
|['Int32', 'Int32']
|[None, 456]
|
-|long_values
|bigint
|[-9223372036854775808, 9223372036854775807, 0]
|['int64', 'int64', 'int64']
|[-9223372036854775808, 9223372036854775807, 0]
|
-|long_null
|bigint
|[None, 789]
|['Int64', 'Int64']
|[None, 789]
|
-|float_values
|float
|[0.0, 1.0, 3.140000104904175]
|['float32', 'float32', 'float32']
|[0.0, 1.0, 3.140000104904175]
|
-|float_null
|float
|[None, 3.140000104904175]
|['float32', 'float32']
|[None, 3.140000104904175]
|
-|double_values
|double
|[0.0, 1.0, 0.3333333333333333]
|['float64', 'float64', 'float64']
|[0.0, 1.0, 0.3333333333333333]
|
-|double_null
|double
|[None, 2.71]
|['float64', 'float64']
|[None, 2.71]
|
-|decimal_values
|decimal(3,2)
|[Decimal('5.35'), Decimal('1.23')]
|['object', 'object']
|[Decimal('5.35'), Decimal('1.23')]
|
-|decimal_null
|decimal(3,2)
|[None, Decimal('9.99')]
|['object', 'object']
|[None, Decimal('9.99')]
|
-|string_values
|string
|['abc', '', 'hello']
|['object', 'object', 'object']
|['abc', '', 'hello']
|
-|string_null
|string
|[None, 'test']
|['object', 'object']
|[None, 'test']
|
-|binary_values
|binary
|[b'abc', b'', b'ABC']
|['object', 'object', 'object']
|[b'abc', b'', b'ABC']
|
-|binary_null
|binary
|[None, b'test']
|['object', 'object']
|[None, b'test']
|
-|boolean_values
|boolean
|[True, False]
|['bool', 'bool']
|[True, False]
|
-|boolean_null
|boolean
|[None, True]
|['object', 'object']
|[None, True]
|
-|date_values
|date
|[datetime.date(2020, 2, 2), datetime.date(1970, 1, 1)]
|['object', 'object']
|[datetime.date(2020, 2, 2), datetime.date(1970,
1, 1)] |
-|date_null
|date
|[None, datetime.date(2023, 1, 1)]
|['object', 'object']
|[None, datetime.date(2023, 1, 1)]
|
-|timestamp_values
|timestamp
|[datetime.datetime(2020, 2, 2, 12, 15, 16, 123000)]
|['datetime64[ns]']
|[datetime.datetime(2020, 2, 2, 12, 15, 16,
123000)] |
-|timestamp_null
|timestamp
|[None, datetime.datetime(2023, 1, 1, 12, 0)]
|['datetime64[ns]', 'datetime64[ns]']
|[None, datetime.datetime(2023, 1, 1, 12, 0)]
|
-|array_int_values
|array<int>
|[[1, 2, 3], [], [1, None, 3]]
|['object', 'object', 'object']
|[[1, 2, 3], [], [1, None, 3]]
|
-|array_int_null
|array<int>
|[None, [4, 5, 6]]
|['object', 'object']
|[None, [4, 5, 6]]
|
-|map_str_int_values
|map<string,int>
|[{'world': 2, 'hello': 1}, {}]
|['object', 'object']
|[{'world': 2, 'hello': 1}, {}]
|
-|map_str_int_null
|map<string,int>
|[None, {'test': 123}]
|['object', 'object']
|[None, {'test': 123}]
|
-|struct_int_str_values
|struct<a1:int,a2:string>
|[Row(a1=1, a2='hello'), Row(a1=2, a2='world')]
|['DataFrame', 'DataFrame']
|[Row(a1=1, a2='hello'), Row(a1=2, a2='world')]
|
-|struct_int_str_null
|struct<a1:int,a2:string>
|[None, Row(a1=99, a2='test')]
|['DataFrame', 'DataFrame']
|[Row(a1=None, a2=None), Row(a1=99, a2='test')]
|
-|array_array_int
|array<array<int>>
|[[[1, 2, 3]], [[1], [2, 3]]]
|['object', 'object']
|[[[1, 2, 3]], [[1], [2, 3]]]
|
-|array_map_str_int
|array<map<string,int>>
|[[{'world': 2, 'hello': 1}], [{'a': 1}, {'b': 2}]]
|['object', 'object']
|[[{'world': 2, 'hello': 1}], [{'a': 1}, {'b':
2}]] |
-|array_struct_int_str
|array<struct<a1:int,a2:string>>
|[[Row(a1=1, a2='hello')], [Row(a1=1, a2='hello'), Row(a1=2,
a2='world')]] |['object', 'object']
|[[Row(a1=1, a2='hello')], [Row(a1=1,
a2='hello'), Row(a1=2, a2='world')]] |
-|map_int_array_int
|map<int,array<int>>
|[{1: [1, 2, 3]}, {1: [1], 2: [2, 3]}]
|['object', 'object']
|[{1: [1, 2, 3]}, {1: [1], 2: [2, 3]}]
|
-|map_int_map_str_int
|map<int,map<string,int>>
|[{1: {'world': 2, 'hello': 1}}]
|['object']
|[{1: {'world': 2, 'hello': 1}}]
|
-|map_int_struct_int_str
|map<int,struct<a1:int,a2:string>>
|[{1: Row(a1=1, a2='hello')}]
|['object']
|[{1: Row(a1=1, a2='hello')}]
|
-|struct_int_array_int
|struct<a:int,b:array<int>>
|[Row(a=1, b=[1, 2, 3])]
|['DataFrame']
|[Row(a=1, b=[1, 2, 3])]
|
-|struct_int_map_str_int
|struct<a:int,b:map<string,int>>
|[Row(a=1, b={'world': 2, 'hello': 1})]
|['DataFrame']
|[Row(a=1, b={'world': 2, 'hello': 1})]
|
-|struct_int_struct_int_str
|struct<a:int,b:struct<a1:int,a2:string>>
|[Row(a=1, b=Row(a1=1, a2='hello'))]
|['DataFrame']
|[Row(a=1, b=Row(a1=1, a2='hello'))]
|
-+--------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------+
\ No newline at end of file
diff --git a/python/pyspark/sql/tests/udf_type_tests/type_table_utils.py
b/python/pyspark/sql/tests/udf_type_tests/type_table_utils.py
deleted file mode 100755
index 88752027f670..000000000000
--- a/python/pyspark/sql/tests/udf_type_tests/type_table_utils.py
+++ /dev/null
@@ -1,332 +0,0 @@
-#!/usr/bin/env python3
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-import os
-import sys
-import argparse
-import re
-from typing import List, Tuple, Optional
-
-CELL_WIDTH = 30
-
-
-class Colors:
- """ANSI color codes for terminal output"""
-
- RESET = "\033[0m"
- RED = "\033[91m"
- GREEN = "\033[92m"
- BOLD = "\033[1m"
- BG_RED = "\033[101m"
- BG_GREEN = "\033[102m"
-
-
-def parse_table_line(line: str) -> Optional[List[str]]:
- """Parse a table line and extract cell contents."""
- if not line.strip() or line.strip().startswith("+"):
- return None
-
- cells = [cell.strip() for cell in line.strip("|").split("|")]
- return cells
-
-
-def parse_table_content(content: str) -> Tuple[List[str], List[List[str]]]:
- """Parse table content and return header and rows."""
- lines = content.strip().split("\n")
- header = None
- rows = []
-
- for line in lines:
- cells = parse_table_line(line)
- if cells is not None:
- if header is None:
- header = cells
- else:
- rows.append(cells)
-
- return header, rows
-
-
-def highlight_cell_diff(
- expected_cell: str, actual_cell: str, use_colors: bool = True, cell_width:
int = CELL_WIDTH
-) -> str:
- """Highlight differences within a single cell, showing inline diff with
both parts visible."""
- if expected_cell == actual_cell:
- return expected_cell
-
- if use_colors:
- format_overhead = 1
- else:
- format_overhead = 6
-
- total_content_length = len(expected_cell) + len(actual_cell) +
format_overhead
-
- if total_content_length > cell_width:
- available_space = cell_width - format_overhead
- half_space = available_space // 2
-
- expected_truncated = expected_cell
- actual_truncated = actual_cell
-
- if len(expected_cell) > half_space:
- expected_truncated = expected_cell[: half_space - 3] + "..."
- if len(actual_cell) > half_space:
- actual_truncated = actual_cell[: half_space - 3] + "..."
- else:
- expected_truncated = expected_cell
- actual_truncated = actual_cell
-
- if use_colors:
- return (
- f"{Colors.BG_RED}{expected_truncated}{Colors.RESET}→"
- f"{Colors.BG_GREEN}{actual_truncated}{Colors.RESET}"
- )
- else:
- return f"[-{expected_truncated}-][+{actual_truncated}+]"
-
-
-def format_table_diff(
- header: List[str],
- expected_rows: List[List[str]],
- actual_rows: List[List[str]],
- use_colors: bool = True,
- cell_width: int = CELL_WIDTH,
-) -> str:
- """Format a table diff with cell-level highlighting."""
- output_lines = []
-
- title = "Table Comparison (Expected vs Actual)"
- output_lines.append(
- f"\n{Colors.BOLD if use_colors else ''}{title}{Colors.RESET if
use_colors else ''}"
- )
- output_lines.append("=" * len(title))
-
- col_widths = [cell_width] * len(header)
-
- def format_row(cells: List[str], prefix: str = "", color: str = "") -> str:
- """Format a single row with proper alignment."""
- formatted_cells = []
- for i, (cell, width) in enumerate(zip(cells, col_widths)):
- display_cell = str(cell)
-
- visible_length = len(re.sub(r"\x1b\[[0-9;]*m", "", display_cell))
-
- if visible_length > width:
- truncated = ""
- visible_count = 0
- i = 0
- while i < len(display_cell) and visible_count < width - 3:
- if display_cell[i : i + 1] == "\x1b":
- end = display_cell.find("m", i)
- if end != -1:
- truncated += display_cell[i : end + 1]
- i = end + 1
- else:
- i += 1
- else:
- truncated += display_cell[i]
- visible_count += 1
- i += 1
- display_cell = truncated + "..."
- visible_length = visible_count + 3
-
- padding = width - visible_length
- formatted_cells.append(display_cell + " " * padding)
-
- row_content = f"|{' |'.join(formatted_cells)} |"
- if color and use_colors:
- return f"{prefix}{color}{row_content}{Colors.RESET}"
- return f"{prefix}{row_content}"
-
- def create_border(char: str = "-") -> str:
- """Create a table border."""
- return "+" + "+".join(char * (width + 1) for width in col_widths) + "+"
-
- output_lines.append(create_border())
- output_lines.append(format_row(header))
- output_lines.append(create_border())
-
- max_rows = max(len(expected_rows), len(actual_rows))
- changes_found = False
-
- for row_idx in range(max_rows):
- expected_row = expected_rows[row_idx] if row_idx < len(expected_rows)
else None
- actual_row = actual_rows[row_idx] if row_idx < len(actual_rows) else
None
-
- if expected_row is None:
- display_row = actual_row[:] if actual_row else []
- while len(display_row) < len(header):
- display_row.append("")
- output_lines.append(format_row(display_row, "+ ", Colors.GREEN if
use_colors else ""))
- changes_found = True
- elif actual_row is None:
- display_row = expected_row[:] if expected_row else []
- while len(display_row) < len(header):
- display_row.append("")
- output_lines.append(format_row(display_row, "- ", Colors.RED if
use_colors else ""))
- changes_found = True
- else:
- row_has_changes = False
- diff_row = []
-
- for col_idx in range(len(header)):
- expected_cell = expected_row[col_idx] if col_idx <
len(expected_row) else ""
- actual_cell = actual_row[col_idx] if col_idx < len(actual_row)
else ""
-
- if expected_cell != actual_cell:
- row_has_changes = True
- diff_cell = highlight_cell_diff(
- expected_cell, actual_cell, use_colors, cell_width
- )
- diff_row.append(diff_cell)
- else:
- diff_row.append(expected_cell)
-
- while len(diff_row) < len(header):
- diff_row.append("")
-
- output_lines.append(format_row(diff_row))
- if row_has_changes:
- changes_found = True
-
- output_lines.append(create_border())
-
- if not changes_found:
- green_start = Colors.GREEN if use_colors else ""
- reset_end = Colors.RESET if use_colors else ""
- output_lines.append(f"\n{green_start}✓ Tables are
identical!{reset_end}")
- else:
- legend = "\nLegend:"
- if use_colors:
- legend += f"\n {Colors.BG_RED}Red background{Colors.RESET}:
Expected content (removed)"
- legend += f"\n {Colors.BG_GREEN}Green background{Colors.RESET}:
Actual content (added)"
- else:
- legend += "\n [-text-]: Expected content (removed)"
- legend += "\n [+text+]: Actual content (added)"
- legend += "\n Lines prefixed with '-': Expected only rows"
- legend += "\n Lines prefixed with '+': Actual only rows"
- output_lines.append(legend)
-
- return "\n".join(output_lines)
-
-
-def generate_table_diff(actual, expected, cell_width=CELL_WIDTH):
- """Generate a table-aware diff between actual and expected output."""
- try:
- expected_header, expected_rows = parse_table_content(expected)
- actual_header, actual_rows = parse_table_content(actual)
-
- if expected_header and actual_header:
- return format_table_diff(expected_header, expected_rows,
actual_rows, True, cell_width)
- except Exception:
- pass
-
- return "Unable to parse content as table format."
-
-
-def format_type_table(results, header, column_width=30):
- """Format results into an ASCII table with the given header and column
width.
-
- Args:
- results: List of rows, where each row is a list of values
- header: List of header strings
- column_width: Width of each column (default: 30)
-
- Returns:
- String representation of the formatted table
- """
- column_widths = [column_width] * len(header)
- output_lines = []
-
- top_border = "+" + "+".join("-" * (width + 1) for width in column_widths)
+ "+"
- output_lines.append(top_border)
-
- header_line = (
- "|"
- + "|".join(f"{cell[:width]:<{width}} " for cell, width in zip(header,
column_widths))
- + "|"
- )
- output_lines.append(header_line)
- output_lines.append(top_border)
-
- for row in results:
- data_line = (
- "|"
- + "|".join(f"{str(cell)[:width]:<{width}} " for cell, width in
zip(row, column_widths))
- + "|"
- )
- output_lines.append(data_line)
-
- output_lines.append(top_border)
- return "\n".join(output_lines)
-
-
-def compare_files(file1_path, file2_path, cell_width=CELL_WIDTH):
- """Compare two files and show the differences."""
- if not os.path.exists(file1_path):
- print(f"Error: File '{file1_path}' does not exist")
- return False
-
- if not os.path.exists(file2_path):
- print(f"Error: File '{file2_path}' does not exist")
- return False
-
- try:
- with open(file1_path, "r") as f1:
- content1 = f1.read()
-
- with open(file2_path, "r") as f2:
- content2 = f2.read()
- except Exception as e:
- print(f"Error reading files: {e}")
- return False
-
- print(f"Comparing '{file1_path}' (expected) with '{file2_path}' (actual)")
- print("=" * 80)
-
- if content1 == content2:
- print("Files are identical!")
- return True
- else:
- print("Files differ. Generating word-wise diff...")
- print()
-
- diff_output = generate_table_diff(content2, content1, cell_width)
- print(diff_output)
- return False
-
-
-def main():
- parser = argparse.ArgumentParser(description="Compare two table files
using word-wise diff")
- parser.add_argument("file1", help="First file (expected)")
- parser.add_argument("file2", help="Second file (actual)")
- parser.add_argument(
- "--cell-width",
- type=int,
- default=CELL_WIDTH,
- help=f"Width of each table cell (default: {CELL_WIDTH})",
- )
-
- args = parser.parse_args()
-
- success = compare_files(args.file1, args.file2, args.cell_width)
- sys.exit(0 if success else 1)
-
-
-if __name__ == "__main__":
- main()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]