This is an automated email from the ASF dual-hosted git repository.

ueshin pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new fe3754afcfe [SPARK-44559][PYTHON][3.5] Improve error messages for 
Python UDTF arrow cast
fe3754afcfe is described below

commit fe3754afcfe50032a4bf9fabf46dcdea47860626
Author: allisonwang-db <[email protected]>
AuthorDate: Wed Aug 2 14:28:30 2023 -0700

    [SPARK-44559][PYTHON][3.5] Improve error messages for Python UDTF arrow cast
    
    ### What changes were proposed in this pull request?
    This PR cherry-picks 
https://github.com/apache/spark/commit/5384f4601a4ba8daba76d67e945eaa6fc2b70b2c.
 It improves error messages when the output of an arrow-optimized Python UDTF 
cannot be casted to the specified return schema of the UDTF.
    
    ### Why are the changes needed?
    To make Python UDTFs more user-friendly.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, before this PR, when the output of a UDTF fails to cast to the desired 
schema, Spark will throw this confusing error message:
    ```python
    udtf(returnType="x: int")
    class TestUDTF:
        def eval(self):
            yield [1, 2],
    
    TestUDTF().collect()
    ```
    
    ```
      File "pyarrow/array.pxi", line 1044, in pyarrow.lib.Array.from_pandas
      File "pyarrow/array.pxi", line 316, in pyarrow.lib.array
      File "pyarrow/array.pxi", line 83, in pyarrow.lib._ndarray_to_array
      File "pyarrow/error.pxi", line 100, in pyarrow.lib.check_status
    pyarrow.lib.ArrowInvalid: Could not convert [1, 2] with type list: tried to 
convert to int32
    ```
    Now, after this PR, the error message will look like this:
    `pyspark.errors.exceptions.base.PySparkRuntimeError: 
[UDTF_ARROW_TYPE_CAST_ERROR] Cannot convert the output value of the column 'x' 
with type 'object' to the specified return type of the column: 'int32'. Please 
check if the data types match and try again.
    `
    
    ### How was this patch tested?
    
    New unit tests
    
    Closes #42290 from allisonwang-db/spark-44559-3.5.
    
    Authored-by: allisonwang-db <[email protected]>
    Signed-off-by: Takuya UESHIN <[email protected]>
---
 python/pyspark/errors/error_classes.py   |   5 +
 python/pyspark/sql/pandas/serializers.py |  69 +++++++-
 python/pyspark/sql/tests/test_udtf.py    | 259 +++++++++++++++++++++++++++++++
 3 files changed, 332 insertions(+), 1 deletion(-)

diff --git a/python/pyspark/errors/error_classes.py 
b/python/pyspark/errors/error_classes.py
index 554a25952b9..db80705e7d2 100644
--- a/python/pyspark/errors/error_classes.py
+++ b/python/pyspark/errors/error_classes.py
@@ -713,6 +713,11 @@ ERROR_CLASSES_JSON = """
       "Expected <expected> values for `<item>`, got <actual>."
     ]
   },
+  "UDTF_ARROW_TYPE_CAST_ERROR" : {
+    "message" : [
+      "Cannot convert the output value of the column '<col_name>' with type 
'<col_type>' to the specified return type of the column: '<arrow_type>'. Please 
check if the data types match and try again."
+    ]
+  },
   "UDTF_EXEC_ERROR" : {
     "message" : [
       "User defined table function encountered an error in the '<method_name>' 
method: <error>"
diff --git a/python/pyspark/sql/pandas/serializers.py 
b/python/pyspark/sql/pandas/serializers.py
index 1d326928e23..993bacbed67 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -19,7 +19,7 @@
 Serializers for PyArrow and pandas conversions. See `pyspark.serializers` for 
more details.
 """
 
-from pyspark.errors import PySparkTypeError, PySparkValueError
+from pyspark.errors import PySparkRuntimeError, PySparkTypeError, 
PySparkValueError
 from pyspark.serializers import Serializer, read_int, write_int, 
UTF8Deserializer, CPickleSerializer
 from pyspark.sql.pandas.types import (
     from_arrow_type,
@@ -538,6 +538,73 @@ class 
ArrowStreamPandasUDTFSerializer(ArrowStreamPandasUDFSerializer):
 
         return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in 
range(len(arrs))])
 
+    def _create_array(self, series, arrow_type, spark_type=None, 
arrow_cast=False):
+        """
+        Override the `_create_array` method in the superclass to create an 
Arrow Array
+        from a given pandas.Series and an arrow type. The difference here is 
that we always
+        use arrow cast when creating the arrow array. Also, the error messages 
are specific
+        to arrow-optimized Python UDTFs.
+
+        Parameters
+        ----------
+        series : pandas.Series
+            A single series
+        arrow_type : pyarrow.DataType, optional
+            If None, pyarrow's inferred type will be used
+        spark_type : DataType, optional
+            If None, spark type converted from arrow_type will be used
+        arrow_cast: bool, optional
+            Whether to apply Arrow casting when the user-specified return type 
mismatches the
+            actual return values.
+
+        Returns
+        -------
+        pyarrow.Array
+        """
+        import pyarrow as pa
+        from pandas.api.types import is_categorical_dtype
+
+        if is_categorical_dtype(series.dtype):
+            series = series.astype(series.dtypes.categories.dtype)
+
+        if arrow_type is not None:
+            dt = spark_type or from_arrow_type(arrow_type, 
prefer_timestamp_ntz=True)
+            # TODO(SPARK-43579): cache the converter for reuse
+            conv = _create_converter_from_pandas(
+                dt, timezone=self._timezone, 
error_on_duplicated_field_names=False
+            )
+            series = conv(series)
+
+        if hasattr(series.array, "__arrow_array__"):
+            mask = None
+        else:
+            mask = series.isnull()
+
+        try:
+            try:
+                return pa.Array.from_pandas(
+                    series, mask=mask, type=arrow_type, safe=self._safecheck
+                )
+            except pa.lib.ArrowException:
+                if arrow_cast:
+                    return pa.Array.from_pandas(series, mask=mask).cast(
+                        target_type=arrow_type, safe=self._safecheck
+                    )
+                else:
+                    raise
+        except pa.lib.ArrowException:
+            # Display the most user-friendly error messages instead of showing
+            # arrow's error message. This also works better with Spark Connect
+            # where the exception messages are by default truncated.
+            raise PySparkRuntimeError(
+                error_class="UDTF_ARROW_TYPE_CAST_ERROR",
+                message_parameters={
+                    "col_name": series.name,
+                    "col_type": str(series.dtype),
+                    "arrow_type": arrow_type,
+                },
+            ) from None
+
     def __repr__(self):
         return "ArrowStreamPandasUDTFSerializer"
 
diff --git a/python/pyspark/sql/tests/test_udtf.py 
b/python/pyspark/sql/tests/test_udtf.py
index b3e832b8b97..5c33cb14834 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -518,6 +518,130 @@ class BaseUDTFTestsMixin:
 
         assertDataFrameEqual(TestUDTF(), [Row()])
 
+    def _check_result_or_exception(self, func_handler, ret_type, expected):
+        func = udtf(func_handler, returnType=ret_type)
+        if not isinstance(expected, str):
+            assertDataFrameEqual(func(), expected)
+        else:
+            with self.assertRaisesRegex(PythonException, expected):
+                func().collect()
+
+    def test_numeric_output_type_casting(self):
+        class TestUDTF:
+            def eval(self):
+                yield 1,
+
+        for i, (ret_type, expected) in enumerate(
+            [
+                ("x: boolean", [Row(x=None)]),
+                ("x: tinyint", [Row(x=1)]),
+                ("x: smallint", [Row(x=1)]),
+                ("x: int", [Row(x=1)]),
+                ("x: bigint", [Row(x=1)]),
+                ("x: string", [Row(x="1")]),  # int to string is ok, but 
string to int is None
+                (
+                    "x: date",
+                    "AttributeError",
+                ),  # AttributeError: 'int' object has no attribute 'toordinal'
+                (
+                    "x: timestamp",
+                    "AttributeError",
+                ),  # AttributeError: 'int' object has no attribute 'tzinfo'
+                ("x: byte", [Row(x=1)]),
+                ("x: binary", [Row(x=None)]),
+                ("x: float", [Row(x=None)]),
+                ("x: double", [Row(x=None)]),
+                ("x: decimal(10, 0)", [Row(x=None)]),
+                ("x: array<int>", [Row(x=None)]),
+                ("x: map<string,int>", [Row(x=None)]),
+                ("x: struct<a:int>", "UNEXPECTED_TUPLE_WITH_STRUCT"),
+            ]
+        ):
+            with self.subTest(ret_type=ret_type):
+                self._check_result_or_exception(TestUDTF, ret_type, expected)
+
+    def test_numeric_string_output_type_casting(self):
+        class TestUDTF:
+            def eval(self):
+                yield "1",
+
+        for ret_type, expected in [
+            ("x: boolean", [Row(x=None)]),
+            ("x: tinyint", [Row(x=None)]),
+            ("x: smallint", [Row(x=None)]),
+            ("x: int", [Row(x=None)]),
+            ("x: bigint", [Row(x=None)]),
+            ("x: string", [Row(x="1")]),
+            ("x: date", "AttributeError"),
+            ("x: timestamp", "AttributeError"),
+            ("x: byte", [Row(x=None)]),
+            ("x: binary", [Row(x=bytearray(b"1"))]),
+            ("x: float", [Row(x=None)]),
+            ("x: double", [Row(x=None)]),
+            ("x: decimal(10, 0)", [Row(x=None)]),
+            ("x: array<int>", [Row(x=None)]),
+            ("x: map<string,int>", [Row(x=None)]),
+            ("x: struct<a:int>", "UNEXPECTED_TUPLE_WITH_STRUCT"),
+        ]:
+            with self.subTest(ret_type=ret_type):
+                self._check_result_or_exception(TestUDTF, ret_type, expected)
+
+    def test_string_output_type_casting(self):
+        class TestUDTF:
+            def eval(self):
+                yield "hello",
+
+        for ret_type, expected in [
+            ("x: boolean", [Row(x=None)]),
+            ("x: tinyint", [Row(x=None)]),
+            ("x: smallint", [Row(x=None)]),
+            ("x: int", [Row(x=None)]),
+            ("x: bigint", [Row(x=None)]),
+            ("x: string", [Row(x="hello")]),
+            ("x: date", "AttributeError"),
+            ("x: timestamp", "AttributeError"),
+            ("x: byte", [Row(x=None)]),
+            ("x: binary", [Row(x=bytearray(b"hello"))]),
+            ("x: float", [Row(x=None)]),
+            ("x: double", [Row(x=None)]),
+            ("x: decimal(10, 0)", [Row(x=None)]),
+            ("x: array<int>", [Row(x=None)]),
+            ("x: map<string,int>", [Row(x=None)]),
+            ("x: struct<a:int>", "UNEXPECTED_TUPLE_WITH_STRUCT"),
+        ]:
+            with self.subTest(ret_type=ret_type):
+                self._check_result_or_exception(TestUDTF, ret_type, expected)
+
+    def test_array_output_type_casting(self):
+        class TestUDTF:
+            def eval(self):
+                yield [1, 2],
+
+        for ret_type, expected in [
+            ("x: int", [Row(x=None)]),
+            ("x: array<int>", [Row(x=[1, 2])]),
+            ("x: array<double>", [Row(x=[None, None])]),
+            ("x: array<string>", [Row(x=["1", "2"])]),
+            ("x: array<boolean>", [Row(x=[None, None])]),
+            ("x: array<array<int>>", [Row(x=[None, None])]),
+            ("x: map<string,int>", [Row(x=None)]),
+        ]:
+            with self.subTest(ret_type=ret_type):
+                self._check_result_or_exception(TestUDTF, ret_type, expected)
+
+    def test_inconsistent_output_types(self):
+        class TestUDTF:
+            def eval(self):
+                yield 1,
+                yield [1, 2],
+
+        for ret_type, expected in [
+            ("x: int", [Row(x=1), Row(x=None)]),
+            ("x: array<int>", [Row(x=None), Row(x=[1, 2])]),
+        ]:
+            with self.subTest(ret_type=ret_type):
+                assertDataFrameEqual(udtf(TestUDTF, returnType=ret_type)(), 
expected)
+
     @unittest.skipIf(not have_pandas, pandas_requirement_message)
     def test_udtf_with_pandas_input_type(self):
         import pandas as pd
@@ -857,6 +981,141 @@ class UDTFArrowTestsMixin(BaseUDTFTestsMixin):
         func = udtf(TestUDTF, returnType="a: int")
         self.assertEqual(func(lit(1)).collect(), [Row(a=1)])
 
+    def test_numeric_output_type_casting(self):
+        class TestUDTF:
+            def eval(self):
+                yield 1,
+
+        err = "UDTF_ARROW_TYPE_CAST_ERROR"
+
+        for ret_type, expected in [
+            ("x: boolean", [Row(x=True)]),
+            ("x: tinyint", [Row(x=1)]),
+            ("x: smallint", [Row(x=1)]),
+            ("x: int", [Row(x=1)]),
+            ("x: bigint", [Row(x=1)]),
+            ("x: string", [Row(x="1")]),  # require arrow.cast
+            ("x: date", err),
+            ("x: byte", [Row(x=1)]),
+            ("x: binary", [Row(x=bytearray(b"\x01"))]),
+            ("x: float", [Row(x=1.0)]),
+            ("x: double", [Row(x=1.0)]),
+            ("x: decimal(10, 0)", err),
+            ("x: array<int>", err),
+            # TODO(SPARK-44561): fix AssertionError in convert_map and 
convert_struct
+            # ("x: map<string,int>", None),
+            # ("x: struct<a:int>", None)
+        ]:
+            with self.subTest(ret_type=ret_type):
+                self._check_result_or_exception(TestUDTF, ret_type, expected)
+
+    def test_numeric_string_output_type_casting(self):
+        class TestUDTF:
+            def eval(self):
+                yield "1",
+
+        err = "UDTF_ARROW_TYPE_CAST_ERROR"
+
+        for ret_type, expected in [
+            ("x: boolean", [Row(x=True)]),
+            ("x: tinyint", [Row(x=1)]),
+            ("x: smallint", [Row(x=1)]),
+            ("x: int", [Row(x=1)]),
+            ("x: bigint", [Row(x=1)]),
+            ("x: string", [Row(x="1")]),
+            ("x: date", err),
+            ("x: timestamp", err),
+            ("x: byte", [Row(x=1)]),
+            ("x: binary", [Row(x=bytearray(b"1"))]),
+            ("x: float", [Row(x=1.0)]),
+            ("x: double", [Row(x=1.0)]),
+            ("x: decimal(10, 0)", [Row(x=1)]),
+            ("x: array<string>", [Row(x=["1"])]),
+            ("x: array<int>", err),
+            # TODO(SPARK-44561): fix AssertionError in convert_map and 
convert_struct
+            # ("x: map<string,int>", None),
+            # ("x: struct<a:int>", None)
+        ]:
+            with self.subTest(ret_type=ret_type):
+                self._check_result_or_exception(TestUDTF, ret_type, expected)
+
+    def test_string_output_type_casting(self):
+        class TestUDTF:
+            def eval(self):
+                yield "hello",
+
+        err = "UDTF_ARROW_TYPE_CAST_ERROR"
+
+        for ret_type, expected in [
+            ("x: boolean", err),
+            ("x: tinyint", err),
+            ("x: smallint", err),
+            ("x: int", err),
+            ("x: bigint", err),
+            ("x: string", [Row(x="hello")]),
+            ("x: date", err),
+            ("x: timestamp", err),
+            ("x: byte", err),
+            ("x: binary", [Row(x=bytearray(b"hello"))]),
+            ("x: float", err),
+            ("x: double", err),
+            ("x: decimal(10, 0)", err),
+            ("x: array<string>", [Row(x=["h", "e", "l", "l", "o"])]),
+            ("x: array<int>", err),
+            # TODO(SPARK-44561): fix AssertionError in convert_map and 
convert_struct
+            # ("x: map<string,int>", None),
+            # ("x: struct<a:int>", None)
+        ]:
+            with self.subTest(ret_type=ret_type):
+                self._check_result_or_exception(TestUDTF, ret_type, expected)
+
+    def test_array_output_type_casting(self):
+        class TestUDTF:
+            def eval(self):
+                yield [0, 1.1, 2],
+
+        err = "UDTF_ARROW_TYPE_CAST_ERROR"
+
+        for ret_type, expected in [
+            ("x: boolean", err),
+            ("x: tinyint", err),
+            ("x: smallint", err),
+            ("x: int", err),
+            ("x: bigint", err),
+            ("x: string", err),
+            ("x: date", err),
+            ("x: timestamp", err),
+            ("x: byte", err),
+            ("x: binary", err),
+            ("x: float", err),
+            ("x: double", err),
+            ("x: decimal(10, 0)", err),
+            ("x: array<string>", [Row(x=["0", "1.1", "2"])]),
+            ("x: array<boolean>", [Row(x=[False, True, True])]),
+            ("x: array<int>", [Row(x=[0, 1, 2])]),
+            ("x: array<float>", [Row(x=[0, 1.1, 2])]),
+            ("x: array<array<int>>", err),
+            # TODO(SPARK-44561): fix AssertionError in convert_map and 
convert_struct
+            # ("x: map<string,int>", None),
+            # ("x: struct<a:int>", None)
+        ]:
+            with self.subTest(ret_type=ret_type):
+                self._check_result_or_exception(TestUDTF, ret_type, expected)
+
+    def test_inconsistent_output_types(self):
+        class TestUDTF:
+            def eval(self):
+                yield 1,
+                yield [1, 2],
+
+        for ret_type in [
+            "x: int",
+            "x: array<int>",
+        ]:
+            with self.subTest(ret_type=ret_type):
+                with self.assertRaisesRegex(PythonException, 
"UDTF_ARROW_TYPE_CAST_ERROR"):
+                    udtf(TestUDTF, returnType=ret_type)().collect()
+
 
 class UDTFArrowTests(UDTFArrowTestsMixin, ReusedSQLTestCase):
     @classmethod


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to