allisonwang-db commented on code in PR #52140:
URL: https://github.com/apache/spark/pull/52140#discussion_r2311318194


##########
python/pyspark/sql/tests/arrow/test_arrow_udtf.py:
##########
@@ -178,20 +178,26 @@ def eval(self) -> Iterator["pa.Table"]:
             result_df.collect()
 
     def test_arrow_udtf_error_mismatched_schema(self):
-        @arrow_udtf(returnType="x int, y string")
+
+        @arrow_udtf(returnType="x int, y int")
         class MismatchedSchemaUDTF:
             def eval(self) -> Iterator["pa.Table"]:
                 result_table = pa.table(
                     {
-                        "wrong_col": pa.array([1], type=pa.int32()),
-                        "another_wrong_col": pa.array([2.5], 
type=pa.float64()),
+                        "col_with_arrow_cast": pa.array([1], type=pa.int32()),
+                        "wrong_col": pa.array(["wrong_col"], type=pa.string()),
                     }
                 )
                 yield result_table
 
-        with self.assertRaisesRegex(PythonException, "Schema at index 0 was 
different"):
-            result_df = MismatchedSchemaUDTF()
-            result_df.collect()
+        if 
self.spark.conf.get("spark.sql.execution.pythonUDTF.typeCoercion.enabled").lower()
 == "false":
+            with self.assertRaisesRegex(PythonException, "Arrow UDTFs require 
the return type to match the expected Arrow type. Expected: int32, but got: 
string."):
+                result_df = MismatchedSchemaUDTF()
+                result_df.collect()
+        else:
+            with self.assertRaisesRegex(PythonException, "Failed to parse 
string: 'wrong_col' as a scalar of type int32"):

Review Comment:
   Hmm looks like without arrow cast, the error message looks better.



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala:
##########
@@ -4003,6 +4003,15 @@ object SQLConf {
       .booleanConf
       .createWithDefault(false)
 
+  val PYTHON_TABLE_UDF_TYPE_CORERION_ENABLED =
+    buildConf("spark.sql.execution.pythonUDTF.typeCoercion.enabled")

Review Comment:
   Let's enable Arrow cast for Arrow Python UDTFs by default so we don't need 
this config :) 



##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -201,9 +201,26 @@ class 
ArrowStreamArrowUDTFSerializer(ArrowStreamUDTFSerializer):
     Serializer for PyArrow-native UDTFs that work directly with PyArrow 
RecordBatches and Arrays.
     """
 
-    def __init__(self, table_arg_offsets=None):
+    def __init__(self, table_arg_offsets=None, arrow_cast=True):
         super().__init__()
         self.table_arg_offsets = table_arg_offsets if table_arg_offsets else []
+        self._arrow_cast = arrow_cast
+
+    def _create_array(self, arr, arrow_type):
+        import pyarrow as pa
+
+        assert isinstance(arr, pa.Array)
+        assert isinstance(arrow_type, pa.DataType)
+
+        if arr.type == arrow_type:
+            return arr
+        elif self._arrow_cast:
+            return arr.cast(target_type=arrow_type, safe=True)

Review Comment:
   cc @zhengruifeng is this the same behavior as Arrow UDFs? 



##########
python/pyspark/sql/tests/arrow/test_arrow_udtf.py:
##########
@@ -178,20 +178,26 @@ def eval(self) -> Iterator["pa.Table"]:
             result_df.collect()
 
     def test_arrow_udtf_error_mismatched_schema(self):
-        @arrow_udtf(returnType="x int, y string")
+
+        @arrow_udtf(returnType="x int, y int")
         class MismatchedSchemaUDTF:
             def eval(self) -> Iterator["pa.Table"]:
                 result_table = pa.table(
                     {
-                        "wrong_col": pa.array([1], type=pa.int32()),
-                        "another_wrong_col": pa.array([2.5], 
type=pa.float64()),
+                        "col_with_arrow_cast": pa.array([1], type=pa.int32()),

Review Comment:
   What if we have input to be `int64` and output to be `int32`? Does arrow 
cast throw exception in this case?



##########
python/pyspark/sql/tests/arrow/test_arrow_udtf.py:
##########
@@ -178,20 +178,26 @@ def eval(self) -> Iterator["pa.Table"]:
             result_df.collect()
 
     def test_arrow_udtf_error_mismatched_schema(self):
-        @arrow_udtf(returnType="x int, y string")
+
+        @arrow_udtf(returnType="x int, y int")
         class MismatchedSchemaUDTF:
             def eval(self) -> Iterator["pa.Table"]:
                 result_table = pa.table(
                     {
-                        "wrong_col": pa.array([1], type=pa.int32()),
-                        "another_wrong_col": pa.array([2.5], 
type=pa.float64()),
+                        "col_with_arrow_cast": pa.array([1], type=pa.int32()),
+                        "wrong_col": pa.array(["wrong_col"], type=pa.string()),
                     }
                 )
                 yield result_table
 
-        with self.assertRaisesRegex(PythonException, "Schema at index 0 was 
different"):
-            result_df = MismatchedSchemaUDTF()
-            result_df.collect()
+        if 
self.spark.conf.get("spark.sql.execution.pythonUDTF.typeCoercion.enabled").lower()
 == "false":

Review Comment:
   you can use `with self.sql_conf("...")`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to