allisonwang-db commented on code in PR #52140: URL: https://github.com/apache/spark/pull/52140#discussion_r2311318194
########## python/pyspark/sql/tests/arrow/test_arrow_udtf.py: ########## @@ -178,20 +178,26 @@ def eval(self) -> Iterator["pa.Table"]: result_df.collect() def test_arrow_udtf_error_mismatched_schema(self): - @arrow_udtf(returnType="x int, y string") + + @arrow_udtf(returnType="x int, y int") class MismatchedSchemaUDTF: def eval(self) -> Iterator["pa.Table"]: result_table = pa.table( { - "wrong_col": pa.array([1], type=pa.int32()), - "another_wrong_col": pa.array([2.5], type=pa.float64()), + "col_with_arrow_cast": pa.array([1], type=pa.int32()), + "wrong_col": pa.array(["wrong_col"], type=pa.string()), } ) yield result_table - with self.assertRaisesRegex(PythonException, "Schema at index 0 was different"): - result_df = MismatchedSchemaUDTF() - result_df.collect() + if self.spark.conf.get("spark.sql.execution.pythonUDTF.typeCoercion.enabled").lower() == "false": + with self.assertRaisesRegex(PythonException, "Arrow UDTFs require the return type to match the expected Arrow type. Expected: int32, but got: string."): + result_df = MismatchedSchemaUDTF() + result_df.collect() + else: + with self.assertRaisesRegex(PythonException, "Failed to parse string: 'wrong_col' as a scalar of type int32"): Review Comment: Hmm looks like without arrow cast, the error message looks better. ########## sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala: ########## @@ -4003,6 +4003,15 @@ object SQLConf { .booleanConf .createWithDefault(false) + val PYTHON_TABLE_UDF_TYPE_CORERION_ENABLED = + buildConf("spark.sql.execution.pythonUDTF.typeCoercion.enabled") Review Comment: Let's enable Arrow cast for Arrow Python UDTFs by default so we don't need this config :) ########## python/pyspark/sql/pandas/serializers.py: ########## @@ -201,9 +201,26 @@ class ArrowStreamArrowUDTFSerializer(ArrowStreamUDTFSerializer): Serializer for PyArrow-native UDTFs that work directly with PyArrow RecordBatches and Arrays. """ - def __init__(self, table_arg_offsets=None): + def __init__(self, table_arg_offsets=None, arrow_cast=True): super().__init__() self.table_arg_offsets = table_arg_offsets if table_arg_offsets else [] + self._arrow_cast = arrow_cast + + def _create_array(self, arr, arrow_type): + import pyarrow as pa + + assert isinstance(arr, pa.Array) + assert isinstance(arrow_type, pa.DataType) + + if arr.type == arrow_type: + return arr + elif self._arrow_cast: + return arr.cast(target_type=arrow_type, safe=True) Review Comment: cc @zhengruifeng is this the same behavior as Arrow UDFs? ########## python/pyspark/sql/tests/arrow/test_arrow_udtf.py: ########## @@ -178,20 +178,26 @@ def eval(self) -> Iterator["pa.Table"]: result_df.collect() def test_arrow_udtf_error_mismatched_schema(self): - @arrow_udtf(returnType="x int, y string") + + @arrow_udtf(returnType="x int, y int") class MismatchedSchemaUDTF: def eval(self) -> Iterator["pa.Table"]: result_table = pa.table( { - "wrong_col": pa.array([1], type=pa.int32()), - "another_wrong_col": pa.array([2.5], type=pa.float64()), + "col_with_arrow_cast": pa.array([1], type=pa.int32()), Review Comment: What if we have input to be `int64` and output to be `int32`? Does arrow cast throw exception in this case? ########## python/pyspark/sql/tests/arrow/test_arrow_udtf.py: ########## @@ -178,20 +178,26 @@ def eval(self) -> Iterator["pa.Table"]: result_df.collect() def test_arrow_udtf_error_mismatched_schema(self): - @arrow_udtf(returnType="x int, y string") + + @arrow_udtf(returnType="x int, y int") class MismatchedSchemaUDTF: def eval(self) -> Iterator["pa.Table"]: result_table = pa.table( { - "wrong_col": pa.array([1], type=pa.int32()), - "another_wrong_col": pa.array([2.5], type=pa.float64()), + "col_with_arrow_cast": pa.array([1], type=pa.int32()), + "wrong_col": pa.array(["wrong_col"], type=pa.string()), } ) yield result_table - with self.assertRaisesRegex(PythonException, "Schema at index 0 was different"): - result_df = MismatchedSchemaUDTF() - result_df.collect() + if self.spark.conf.get("spark.sql.execution.pythonUDTF.typeCoercion.enabled").lower() == "false": Review Comment: you can use `with self.sql_conf("...")` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org