asl3 commented on code in PR #52105: URL: https://github.com/apache/spark/pull/52105#discussion_r2298190428
########## python/pyspark/sql/tests/udf_type_tests/test_udf_input_types.py: ########## @@ -0,0 +1,403 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import array +import datetime +import os +import tempfile +import unittest +from decimal import Decimal +import numpy as np +import pandas as pd + +from pyspark.sql import Row, SparkSession +from pyspark.sql.functions import udf, pandas_udf +from pyspark.sql.types import ( + ArrayType, + BinaryType, + BooleanType, + ByteType, + DateType, + DecimalType, + DoubleType, + FloatType, + IntegerType, + LongType, + MapType, + ShortType, + StringType, + StructField, + StructType, + TimestampType, +) +from pyspark.testing.sqlutils import ReusedSQLTestCase +from .type_table_utils import generate_table_diff, format_type_table + + +class UDFInputTypeTests(ReusedSQLTestCase): + @classmethod + def setUpClass(cls): + super(UDFInputTypeTests, cls).setUpClass() + + def setUp(self): + super(UDFInputTypeTests, self).setUp() + + def test_udf_input_types_arrow_disabled(self): + golden_file = os.path.join( + os.path.dirname(__file__), "golden_udf_input_types_arrow_disabled.txt" + ) + self._run_udf_input_type_coercion_test( + config={}, + use_arrow=False, + golden_file=golden_file, + test_name="UDF input types - Arrow disabled", + ) + + def test_udf_input_types_arrow_legacy_pandas(self): + golden_file = os.path.join( + os.path.dirname(__file__), "golden_udf_input_types_arrow_legacy_pandas.txt" + ) + self._run_udf_input_type_coercion_test( + config={"spark.sql.legacy.execution.pythonUDF.pandas.conversion.enabled": "true"}, + use_arrow=True, + golden_file=golden_file, + test_name="UDF input types - Arrow with legacy pandas", + ) + + def test_udf_input_types_arrow_enabled(self): + golden_file = os.path.join( + os.path.dirname(__file__), "golden_udf_input_types_arrow_enabled.txt" + ) + self._run_udf_input_type_coercion_test( + config={"spark.sql.legacy.execution.pythonUDF.pandas.conversion.enabled": "false"}, + use_arrow=True, + golden_file=golden_file, + test_name="UDF input types - Arrow enabled", + ) + + + def _run_udf_input_type_coercion_test(self, config, use_arrow, golden_file, test_name): + with self.sql_conf(config): + results = self._generate_udf_input_type_coercion_results(use_arrow) + actual_output = format_type_table( + results, ["Test Case", "Spark Type", "Spark Value", "Python Type", "Python Value"], column_width=85 + ) + self._compare_or_create_golden_file(actual_output, golden_file, test_name) + + + def _generate_udf_input_type_coercion_results(self, use_arrow): + results = [] + test_cases = self._get_input_type_test_cases() + + for test_name, spark_type, data_func in test_cases: + input_df = data_func(spark_type).repartition(1) + input_data = [row["value"] for row in input_df.collect()] + result_row = [test_name, spark_type.simpleString(), str(input_data)] + + try: + + def type_udf(x): + if x is None: + return "NoneType" + else: + return type(x).__name__ + + def value_udf(x): + return x + + def value_str(x): + return str(x) + + type_test_udf = udf(type_udf, returnType=StringType(), useArrow=use_arrow) + value_test_udf = udf(value_udf, returnType=spark_type, useArrow=use_arrow) + value_str_udf = udf(value_str, returnType=StringType(), useArrow=use_arrow) + + result_df = input_df.select( + value_test_udf("value").alias("python_value"), + type_test_udf("value").alias("python_type"), + value_str_udf("value").alias("python_value_str"), + ) + results_data = result_df.collect() + values = [row["python_value"] for row in results_data] + types = [row["python_type"] for row in results_data] + values_str = [row["python_value_str"] for row in results_data] + + # Assert that the UDF output values match the input values + assert values == input_data, f"Input {values} != output {input_data}" + + result_row.append(str(types)) + result_row.append(str(values_str).replace("\n", " ")) + + except Exception as e: + print("error_msg", e) + # Clean up exception message to remove newlines and extra whitespace + error_msg = str(e).replace("\n", " ").replace("\r", " ") + result_row.append(f"✗ {error_msg}") + + results.append(result_row) + + return results + + + def test_pandas_udf_input(self): + golden_file = os.path.join( + os.path.dirname(__file__), "golden_pandas_udf_input_types.txt" + ) + results = self._generate_pandas_udf_input_type_coercion_results() + actual_output = format_type_table( + results, ["Test Case", "Spark Type", "Spark Value", "Python Type", "Python Value"], column_width=85 + ) + self._compare_or_create_golden_file(actual_output, golden_file, "Pandas UDF input types") + + def _generate_pandas_udf_input_type_coercion_results(self): + results = [] + test_cases = self._get_input_type_test_cases() + + for test_name, spark_type, data_func in test_cases: + input_df = data_func(spark_type).repartition(1) + input_data = [row["value"] for row in input_df.collect()] + result_row = [test_name, spark_type.simpleString(), str(input_data)] + + try: + def type_pandas_udf(data): + if hasattr(data, 'dtype'): + # Series case + return pd.Series([str(data.dtype)] * len(data)) + else: + # DataFrame case (for struct types) + return pd.Series([str(type(data).__name__)] * len(data)) + + def value_pandas_udf(series): + return series + + type_test_pandas_udf = pandas_udf(type_pandas_udf, returnType=StringType()) + value_test_pandas_udf = pandas_udf(value_pandas_udf, returnType=spark_type) + + result_df = input_df.select( + value_test_pandas_udf("value").alias("python_value"), + type_test_pandas_udf("value").alias("python_type"), + ) + results_data = result_df.collect() + values = [row["python_value"] for row in results_data] + types = [row["python_type"] for row in results_data] + + result_row.append(str(types)) + result_row.append(str(values).replace("\n", " ")) + + except Exception as e: + print("error_msg", e) + error_msg = str(e).replace("\n", " ").replace("\r", " ") + result_row.append(f"✗ {error_msg}") + + results.append(result_row) + + return results + + def _compare_or_create_golden_file(self, actual_output, golden_file, test_name): + """Compare actual output with golden file or create golden file if it doesn't exist. + + Args: + actual_output: The actual output to compare + golden_file: Path to the golden file + test_name: Name of the test for error messages + """ + if os.path.exists(golden_file): + with open(golden_file, "r") as f: + expected_output = f.read() + + if actual_output != expected_output: + diff_output = generate_table_diff(actual_output, expected_output, cell_width=85) + self.fail( + f""" + Results don't match golden file for \:{test_name}\".\n + Diff:\n{diff_output} + """ + ) + else: + with open(golden_file, "w") as f: + f.write(actual_output) + self.fail(f"Golden file created for {test_name}. Please review and re-run the test.") + + def _create_value_schema(self, data_type): + """Helper to create a StructType schema with a single 'value' column of the given type.""" + return StructType([StructField("value", data_type, True)]) + + def _get_input_type_test_cases(self): + from pyspark.sql.types import StructType, StructField + import datetime + from decimal import Decimal + + def df(args): + def create_df(data_type): + # For StructType where the data contains Row objects (not wrapped in tuples) + if (isinstance(data_type, StructType) and + len(args) > 0 and + args[0][0] is not None and + hasattr(args[0][0], '_fields')): + schema = data_type + else: + # For all other types, wrap in a "value" column + schema = StructType([StructField("value", data_type, True)]) + return self.spark.createDataFrame(args, schema) + + return create_df + + return [ + ("byte_values", ByteType(), df([(-128,), (127,), (0,)])), + ("byte_null", ByteType(), df([(None,), (42,)])), Review Comment: optional nit, but i meant we could clarify how the test cases for the return_types and input_types were decided -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org