This is an automated email from the ASF dual-hosted git repository.
ueshin pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push:
new fe3754afcfe [SPARK-44559][PYTHON][3.5] Improve error messages for
Python UDTF arrow cast
fe3754afcfe is described below
commit fe3754afcfe50032a4bf9fabf46dcdea47860626
Author: allisonwang-db <[email protected]>
AuthorDate: Wed Aug 2 14:28:30 2023 -0700
[SPARK-44559][PYTHON][3.5] Improve error messages for Python UDTF arrow cast
### What changes were proposed in this pull request?
This PR cherry-picks
https://github.com/apache/spark/commit/5384f4601a4ba8daba76d67e945eaa6fc2b70b2c.
It improves error messages when the output of an arrow-optimized Python UDTF
cannot be casted to the specified return schema of the UDTF.
### Why are the changes needed?
To make Python UDTFs more user-friendly.
### Does this PR introduce _any_ user-facing change?
Yes, before this PR, when the output of a UDTF fails to cast to the desired
schema, Spark will throw this confusing error message:
```python
udtf(returnType="x: int")
class TestUDTF:
def eval(self):
yield [1, 2],
TestUDTF().collect()
```
```
File "pyarrow/array.pxi", line 1044, in pyarrow.lib.Array.from_pandas
File "pyarrow/array.pxi", line 316, in pyarrow.lib.array
File "pyarrow/array.pxi", line 83, in pyarrow.lib._ndarray_to_array
File "pyarrow/error.pxi", line 100, in pyarrow.lib.check_status
pyarrow.lib.ArrowInvalid: Could not convert [1, 2] with type list: tried to
convert to int32
```
Now, after this PR, the error message will look like this:
`pyspark.errors.exceptions.base.PySparkRuntimeError:
[UDTF_ARROW_TYPE_CAST_ERROR] Cannot convert the output value of the column 'x'
with type 'object' to the specified return type of the column: 'int32'. Please
check if the data types match and try again.
`
### How was this patch tested?
New unit tests
Closes #42290 from allisonwang-db/spark-44559-3.5.
Authored-by: allisonwang-db <[email protected]>
Signed-off-by: Takuya UESHIN <[email protected]>
---
python/pyspark/errors/error_classes.py | 5 +
python/pyspark/sql/pandas/serializers.py | 69 +++++++-
python/pyspark/sql/tests/test_udtf.py | 259 +++++++++++++++++++++++++++++++
3 files changed, 332 insertions(+), 1 deletion(-)
diff --git a/python/pyspark/errors/error_classes.py
b/python/pyspark/errors/error_classes.py
index 554a25952b9..db80705e7d2 100644
--- a/python/pyspark/errors/error_classes.py
+++ b/python/pyspark/errors/error_classes.py
@@ -713,6 +713,11 @@ ERROR_CLASSES_JSON = """
"Expected <expected> values for `<item>`, got <actual>."
]
},
+ "UDTF_ARROW_TYPE_CAST_ERROR" : {
+ "message" : [
+ "Cannot convert the output value of the column '<col_name>' with type
'<col_type>' to the specified return type of the column: '<arrow_type>'. Please
check if the data types match and try again."
+ ]
+ },
"UDTF_EXEC_ERROR" : {
"message" : [
"User defined table function encountered an error in the '<method_name>'
method: <error>"
diff --git a/python/pyspark/sql/pandas/serializers.py
b/python/pyspark/sql/pandas/serializers.py
index 1d326928e23..993bacbed67 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -19,7 +19,7 @@
Serializers for PyArrow and pandas conversions. See `pyspark.serializers` for
more details.
"""
-from pyspark.errors import PySparkTypeError, PySparkValueError
+from pyspark.errors import PySparkRuntimeError, PySparkTypeError,
PySparkValueError
from pyspark.serializers import Serializer, read_int, write_int,
UTF8Deserializer, CPickleSerializer
from pyspark.sql.pandas.types import (
from_arrow_type,
@@ -538,6 +538,73 @@ class
ArrowStreamPandasUDTFSerializer(ArrowStreamPandasUDFSerializer):
return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in
range(len(arrs))])
+ def _create_array(self, series, arrow_type, spark_type=None,
arrow_cast=False):
+ """
+ Override the `_create_array` method in the superclass to create an
Arrow Array
+ from a given pandas.Series and an arrow type. The difference here is
that we always
+ use arrow cast when creating the arrow array. Also, the error messages
are specific
+ to arrow-optimized Python UDTFs.
+
+ Parameters
+ ----------
+ series : pandas.Series
+ A single series
+ arrow_type : pyarrow.DataType, optional
+ If None, pyarrow's inferred type will be used
+ spark_type : DataType, optional
+ If None, spark type converted from arrow_type will be used
+ arrow_cast: bool, optional
+ Whether to apply Arrow casting when the user-specified return type
mismatches the
+ actual return values.
+
+ Returns
+ -------
+ pyarrow.Array
+ """
+ import pyarrow as pa
+ from pandas.api.types import is_categorical_dtype
+
+ if is_categorical_dtype(series.dtype):
+ series = series.astype(series.dtypes.categories.dtype)
+
+ if arrow_type is not None:
+ dt = spark_type or from_arrow_type(arrow_type,
prefer_timestamp_ntz=True)
+ # TODO(SPARK-43579): cache the converter for reuse
+ conv = _create_converter_from_pandas(
+ dt, timezone=self._timezone,
error_on_duplicated_field_names=False
+ )
+ series = conv(series)
+
+ if hasattr(series.array, "__arrow_array__"):
+ mask = None
+ else:
+ mask = series.isnull()
+
+ try:
+ try:
+ return pa.Array.from_pandas(
+ series, mask=mask, type=arrow_type, safe=self._safecheck
+ )
+ except pa.lib.ArrowException:
+ if arrow_cast:
+ return pa.Array.from_pandas(series, mask=mask).cast(
+ target_type=arrow_type, safe=self._safecheck
+ )
+ else:
+ raise
+ except pa.lib.ArrowException:
+ # Display the most user-friendly error messages instead of showing
+ # arrow's error message. This also works better with Spark Connect
+ # where the exception messages are by default truncated.
+ raise PySparkRuntimeError(
+ error_class="UDTF_ARROW_TYPE_CAST_ERROR",
+ message_parameters={
+ "col_name": series.name,
+ "col_type": str(series.dtype),
+ "arrow_type": arrow_type,
+ },
+ ) from None
+
def __repr__(self):
return "ArrowStreamPandasUDTFSerializer"
diff --git a/python/pyspark/sql/tests/test_udtf.py
b/python/pyspark/sql/tests/test_udtf.py
index b3e832b8b97..5c33cb14834 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -518,6 +518,130 @@ class BaseUDTFTestsMixin:
assertDataFrameEqual(TestUDTF(), [Row()])
+ def _check_result_or_exception(self, func_handler, ret_type, expected):
+ func = udtf(func_handler, returnType=ret_type)
+ if not isinstance(expected, str):
+ assertDataFrameEqual(func(), expected)
+ else:
+ with self.assertRaisesRegex(PythonException, expected):
+ func().collect()
+
+ def test_numeric_output_type_casting(self):
+ class TestUDTF:
+ def eval(self):
+ yield 1,
+
+ for i, (ret_type, expected) in enumerate(
+ [
+ ("x: boolean", [Row(x=None)]),
+ ("x: tinyint", [Row(x=1)]),
+ ("x: smallint", [Row(x=1)]),
+ ("x: int", [Row(x=1)]),
+ ("x: bigint", [Row(x=1)]),
+ ("x: string", [Row(x="1")]), # int to string is ok, but
string to int is None
+ (
+ "x: date",
+ "AttributeError",
+ ), # AttributeError: 'int' object has no attribute 'toordinal'
+ (
+ "x: timestamp",
+ "AttributeError",
+ ), # AttributeError: 'int' object has no attribute 'tzinfo'
+ ("x: byte", [Row(x=1)]),
+ ("x: binary", [Row(x=None)]),
+ ("x: float", [Row(x=None)]),
+ ("x: double", [Row(x=None)]),
+ ("x: decimal(10, 0)", [Row(x=None)]),
+ ("x: array<int>", [Row(x=None)]),
+ ("x: map<string,int>", [Row(x=None)]),
+ ("x: struct<a:int>", "UNEXPECTED_TUPLE_WITH_STRUCT"),
+ ]
+ ):
+ with self.subTest(ret_type=ret_type):
+ self._check_result_or_exception(TestUDTF, ret_type, expected)
+
+ def test_numeric_string_output_type_casting(self):
+ class TestUDTF:
+ def eval(self):
+ yield "1",
+
+ for ret_type, expected in [
+ ("x: boolean", [Row(x=None)]),
+ ("x: tinyint", [Row(x=None)]),
+ ("x: smallint", [Row(x=None)]),
+ ("x: int", [Row(x=None)]),
+ ("x: bigint", [Row(x=None)]),
+ ("x: string", [Row(x="1")]),
+ ("x: date", "AttributeError"),
+ ("x: timestamp", "AttributeError"),
+ ("x: byte", [Row(x=None)]),
+ ("x: binary", [Row(x=bytearray(b"1"))]),
+ ("x: float", [Row(x=None)]),
+ ("x: double", [Row(x=None)]),
+ ("x: decimal(10, 0)", [Row(x=None)]),
+ ("x: array<int>", [Row(x=None)]),
+ ("x: map<string,int>", [Row(x=None)]),
+ ("x: struct<a:int>", "UNEXPECTED_TUPLE_WITH_STRUCT"),
+ ]:
+ with self.subTest(ret_type=ret_type):
+ self._check_result_or_exception(TestUDTF, ret_type, expected)
+
+ def test_string_output_type_casting(self):
+ class TestUDTF:
+ def eval(self):
+ yield "hello",
+
+ for ret_type, expected in [
+ ("x: boolean", [Row(x=None)]),
+ ("x: tinyint", [Row(x=None)]),
+ ("x: smallint", [Row(x=None)]),
+ ("x: int", [Row(x=None)]),
+ ("x: bigint", [Row(x=None)]),
+ ("x: string", [Row(x="hello")]),
+ ("x: date", "AttributeError"),
+ ("x: timestamp", "AttributeError"),
+ ("x: byte", [Row(x=None)]),
+ ("x: binary", [Row(x=bytearray(b"hello"))]),
+ ("x: float", [Row(x=None)]),
+ ("x: double", [Row(x=None)]),
+ ("x: decimal(10, 0)", [Row(x=None)]),
+ ("x: array<int>", [Row(x=None)]),
+ ("x: map<string,int>", [Row(x=None)]),
+ ("x: struct<a:int>", "UNEXPECTED_TUPLE_WITH_STRUCT"),
+ ]:
+ with self.subTest(ret_type=ret_type):
+ self._check_result_or_exception(TestUDTF, ret_type, expected)
+
+ def test_array_output_type_casting(self):
+ class TestUDTF:
+ def eval(self):
+ yield [1, 2],
+
+ for ret_type, expected in [
+ ("x: int", [Row(x=None)]),
+ ("x: array<int>", [Row(x=[1, 2])]),
+ ("x: array<double>", [Row(x=[None, None])]),
+ ("x: array<string>", [Row(x=["1", "2"])]),
+ ("x: array<boolean>", [Row(x=[None, None])]),
+ ("x: array<array<int>>", [Row(x=[None, None])]),
+ ("x: map<string,int>", [Row(x=None)]),
+ ]:
+ with self.subTest(ret_type=ret_type):
+ self._check_result_or_exception(TestUDTF, ret_type, expected)
+
+ def test_inconsistent_output_types(self):
+ class TestUDTF:
+ def eval(self):
+ yield 1,
+ yield [1, 2],
+
+ for ret_type, expected in [
+ ("x: int", [Row(x=1), Row(x=None)]),
+ ("x: array<int>", [Row(x=None), Row(x=[1, 2])]),
+ ]:
+ with self.subTest(ret_type=ret_type):
+ assertDataFrameEqual(udtf(TestUDTF, returnType=ret_type)(),
expected)
+
@unittest.skipIf(not have_pandas, pandas_requirement_message)
def test_udtf_with_pandas_input_type(self):
import pandas as pd
@@ -857,6 +981,141 @@ class UDTFArrowTestsMixin(BaseUDTFTestsMixin):
func = udtf(TestUDTF, returnType="a: int")
self.assertEqual(func(lit(1)).collect(), [Row(a=1)])
+ def test_numeric_output_type_casting(self):
+ class TestUDTF:
+ def eval(self):
+ yield 1,
+
+ err = "UDTF_ARROW_TYPE_CAST_ERROR"
+
+ for ret_type, expected in [
+ ("x: boolean", [Row(x=True)]),
+ ("x: tinyint", [Row(x=1)]),
+ ("x: smallint", [Row(x=1)]),
+ ("x: int", [Row(x=1)]),
+ ("x: bigint", [Row(x=1)]),
+ ("x: string", [Row(x="1")]), # require arrow.cast
+ ("x: date", err),
+ ("x: byte", [Row(x=1)]),
+ ("x: binary", [Row(x=bytearray(b"\x01"))]),
+ ("x: float", [Row(x=1.0)]),
+ ("x: double", [Row(x=1.0)]),
+ ("x: decimal(10, 0)", err),
+ ("x: array<int>", err),
+ # TODO(SPARK-44561): fix AssertionError in convert_map and
convert_struct
+ # ("x: map<string,int>", None),
+ # ("x: struct<a:int>", None)
+ ]:
+ with self.subTest(ret_type=ret_type):
+ self._check_result_or_exception(TestUDTF, ret_type, expected)
+
+ def test_numeric_string_output_type_casting(self):
+ class TestUDTF:
+ def eval(self):
+ yield "1",
+
+ err = "UDTF_ARROW_TYPE_CAST_ERROR"
+
+ for ret_type, expected in [
+ ("x: boolean", [Row(x=True)]),
+ ("x: tinyint", [Row(x=1)]),
+ ("x: smallint", [Row(x=1)]),
+ ("x: int", [Row(x=1)]),
+ ("x: bigint", [Row(x=1)]),
+ ("x: string", [Row(x="1")]),
+ ("x: date", err),
+ ("x: timestamp", err),
+ ("x: byte", [Row(x=1)]),
+ ("x: binary", [Row(x=bytearray(b"1"))]),
+ ("x: float", [Row(x=1.0)]),
+ ("x: double", [Row(x=1.0)]),
+ ("x: decimal(10, 0)", [Row(x=1)]),
+ ("x: array<string>", [Row(x=["1"])]),
+ ("x: array<int>", err),
+ # TODO(SPARK-44561): fix AssertionError in convert_map and
convert_struct
+ # ("x: map<string,int>", None),
+ # ("x: struct<a:int>", None)
+ ]:
+ with self.subTest(ret_type=ret_type):
+ self._check_result_or_exception(TestUDTF, ret_type, expected)
+
+ def test_string_output_type_casting(self):
+ class TestUDTF:
+ def eval(self):
+ yield "hello",
+
+ err = "UDTF_ARROW_TYPE_CAST_ERROR"
+
+ for ret_type, expected in [
+ ("x: boolean", err),
+ ("x: tinyint", err),
+ ("x: smallint", err),
+ ("x: int", err),
+ ("x: bigint", err),
+ ("x: string", [Row(x="hello")]),
+ ("x: date", err),
+ ("x: timestamp", err),
+ ("x: byte", err),
+ ("x: binary", [Row(x=bytearray(b"hello"))]),
+ ("x: float", err),
+ ("x: double", err),
+ ("x: decimal(10, 0)", err),
+ ("x: array<string>", [Row(x=["h", "e", "l", "l", "o"])]),
+ ("x: array<int>", err),
+ # TODO(SPARK-44561): fix AssertionError in convert_map and
convert_struct
+ # ("x: map<string,int>", None),
+ # ("x: struct<a:int>", None)
+ ]:
+ with self.subTest(ret_type=ret_type):
+ self._check_result_or_exception(TestUDTF, ret_type, expected)
+
+ def test_array_output_type_casting(self):
+ class TestUDTF:
+ def eval(self):
+ yield [0, 1.1, 2],
+
+ err = "UDTF_ARROW_TYPE_CAST_ERROR"
+
+ for ret_type, expected in [
+ ("x: boolean", err),
+ ("x: tinyint", err),
+ ("x: smallint", err),
+ ("x: int", err),
+ ("x: bigint", err),
+ ("x: string", err),
+ ("x: date", err),
+ ("x: timestamp", err),
+ ("x: byte", err),
+ ("x: binary", err),
+ ("x: float", err),
+ ("x: double", err),
+ ("x: decimal(10, 0)", err),
+ ("x: array<string>", [Row(x=["0", "1.1", "2"])]),
+ ("x: array<boolean>", [Row(x=[False, True, True])]),
+ ("x: array<int>", [Row(x=[0, 1, 2])]),
+ ("x: array<float>", [Row(x=[0, 1.1, 2])]),
+ ("x: array<array<int>>", err),
+ # TODO(SPARK-44561): fix AssertionError in convert_map and
convert_struct
+ # ("x: map<string,int>", None),
+ # ("x: struct<a:int>", None)
+ ]:
+ with self.subTest(ret_type=ret_type):
+ self._check_result_or_exception(TestUDTF, ret_type, expected)
+
+ def test_inconsistent_output_types(self):
+ class TestUDTF:
+ def eval(self):
+ yield 1,
+ yield [1, 2],
+
+ for ret_type in [
+ "x: int",
+ "x: array<int>",
+ ]:
+ with self.subTest(ret_type=ret_type):
+ with self.assertRaisesRegex(PythonException,
"UDTF_ARROW_TYPE_CAST_ERROR"):
+ udtf(TestUDTF, returnType=ret_type)().collect()
+
class UDTFArrowTests(UDTFArrowTestsMixin, ReusedSQLTestCase):
@classmethod
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]