This is an automated email from the ASF dual-hosted git repository. ueshin pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new ea6b41cb398 [SPARK-44561][PYTHON] Fix AssertionError when converting UDTF output to a complex type ea6b41cb398 is described below commit ea6b41cb3989996e45102b1930b1498324761093 Author: Takuya UESHIN <ues...@databricks.com> AuthorDate: Mon Aug 7 11:48:24 2023 -0700 [SPARK-44561][PYTHON] Fix AssertionError when converting UDTF output to a complex type ### What changes were proposed in this pull request? Fixes AssertionError when converting UDTF output to a complex type by ignore assertions in `_create_converter_from_pandas` to make Arrow raise an error. ### Why are the changes needed? There is an assertion in `_create_converter_from_pandas`, but it should not be applied for Python UDTF case. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added/modified the related tests. Closes #42310 from ueshin/issues/SPARK-44561/udtf_complex_types. Authored-by: Takuya UESHIN <ues...@databricks.com> Signed-off-by: Takuya UESHIN <ues...@databricks.com> (cherry picked from commit f1a161cb39504bd625ea7fa50d2cc72a1a2a59e9) Signed-off-by: Takuya UESHIN <ues...@databricks.com> --- python/pyspark/sql/pandas/serializers.py | 5 +- python/pyspark/sql/pandas/types.py | 108 ++++++--- .../pyspark/sql/tests/connect/test_parity_udtf.py | 3 + python/pyspark/sql/tests/test_udtf.py | 247 +++++++++++++++++++-- 4 files changed, 314 insertions(+), 49 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index f3037c8b39c..d1a3babb1fd 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -571,7 +571,10 @@ class ArrowStreamPandasUDTFSerializer(ArrowStreamPandasUDFSerializer): dt = spark_type or from_arrow_type(arrow_type, prefer_timestamp_ntz=True) # TODO(SPARK-43579): cache the converter for reuse conv = _create_converter_from_pandas( - dt, timezone=self._timezone, error_on_duplicated_field_names=False + dt, + timezone=self._timezone, + error_on_duplicated_field_names=False, + ignore_unexpected_complex_type_values=True, ) series = conv(series) diff --git a/python/pyspark/sql/pandas/types.py b/python/pyspark/sql/pandas/types.py index 53362047604..b02a003e632 100644 --- a/python/pyspark/sql/pandas/types.py +++ b/python/pyspark/sql/pandas/types.py @@ -21,7 +21,7 @@ pandas instances during the type conversion. """ import datetime import itertools -from typing import Any, Callable, List, Optional, Union, TYPE_CHECKING +from typing import Any, Callable, Iterable, List, Optional, Union, TYPE_CHECKING from pyspark.sql.types import ( cast, @@ -750,6 +750,7 @@ def _create_converter_from_pandas( *, timezone: Optional[str], error_on_duplicated_field_names: bool = True, + ignore_unexpected_complex_type_values: bool = False, ) -> Callable[["pd.Series"], "pd.Series"]: """ Create a converter of pandas Series to create Spark DataFrame with Arrow optimization. @@ -763,6 +764,17 @@ def _create_converter_from_pandas( error_on_duplicated_field_names : bool, optional Whether raise an exception when there are duplicated field names. (default ``True``) + ignore_unexpected_complex_type_values : bool, optional + Whether ignore the case where unexpected values are given for complex types. + If ``False``, each complex type expects: + + * array type: :class:`Iterable` + * map type: :class:`dict` + * struct type: :class:`dict` or :class:`tuple` + + and raise an AssertionError when the given value is not the expected type. + If ``True``, just ignore and return the give value. + (default ``False``) Returns ------- @@ -781,15 +793,26 @@ def _create_converter_from_pandas( def _converter(dt: DataType) -> Optional[Callable[[Any], Any]]: if isinstance(dt, ArrayType): - _element_conv = _converter(dt.elementType) - if _element_conv is None: - return None + _element_conv = _converter(dt.elementType) or (lambda x: x) - def convert_array(value: Any) -> Any: - if value is None: - return None - else: - return [_element_conv(v) for v in value] # type: ignore[misc] + if ignore_unexpected_complex_type_values: + + def convert_array(value: Any) -> Any: + if value is None: + return None + elif isinstance(value, Iterable): + return [_element_conv(v) for v in value] + else: + return value + + else: + + def convert_array(value: Any) -> Any: + if value is None: + return None + else: + assert isinstance(value, Iterable) + return [_element_conv(v) for v in value] return convert_array @@ -797,12 +820,24 @@ def _create_converter_from_pandas( _key_conv = _converter(dt.keyType) or (lambda x: x) _value_conv = _converter(dt.valueType) or (lambda x: x) - def convert_map(value: Any) -> Any: - if value is None: - return None - else: - assert isinstance(value, dict) - return [(_key_conv(k), _value_conv(v)) for k, v in value.items()] + if ignore_unexpected_complex_type_values: + + def convert_map(value: Any) -> Any: + if value is None: + return None + elif isinstance(value, dict): + return [(_key_conv(k), _value_conv(v)) for k, v in value.items()] + else: + return value + + else: + + def convert_map(value: Any) -> Any: + if value is None: + return None + else: + assert isinstance(value, dict) + return [(_key_conv(k), _value_conv(v)) for k, v in value.items()] return convert_map @@ -820,17 +855,38 @@ def _create_converter_from_pandas( field_convs = [_converter(f.dataType) or (lambda x: x) for f in dt.fields] - def convert_struct(value: Any) -> Any: - if value is None: - return None - elif isinstance(value, dict): - return { - dedup_field_names[i]: field_convs[i](value.get(key, None)) - for i, key in enumerate(field_names) - } - else: - assert isinstance(value, tuple) - return {dedup_field_names[i]: field_convs[i](v) for i, v in enumerate(value)} + if ignore_unexpected_complex_type_values: + + def convert_struct(value: Any) -> Any: + if value is None: + return None + elif isinstance(value, dict): + return { + dedup_field_names[i]: field_convs[i](value.get(key, None)) + for i, key in enumerate(field_names) + } + elif isinstance(value, tuple): + return { + dedup_field_names[i]: field_convs[i](v) for i, v in enumerate(value) + } + else: + return value + + else: + + def convert_struct(value: Any) -> Any: + if value is None: + return None + elif isinstance(value, dict): + return { + dedup_field_names[i]: field_convs[i](value.get(key, None)) + for i, key in enumerate(field_names) + } + else: + assert isinstance(value, tuple) + return { + dedup_field_names[i]: field_convs[i](v) for i, v in enumerate(value) + } return convert_struct diff --git a/python/pyspark/sql/tests/connect/test_parity_udtf.py b/python/pyspark/sql/tests/connect/test_parity_udtf.py index 355f5288d2c..1222b1bb5b4 100644 --- a/python/pyspark/sql/tests/connect/test_parity_udtf.py +++ b/python/pyspark/sql/tests/connect/test_parity_udtf.py @@ -43,6 +43,9 @@ class UDTFParityTests(BaseUDTFTestsMixin, ReusedConnectTestCase): # TODO: use PySpark error classes instead of SparkConnectGrpcException + def test_struct_output_type_casting_row(self): + self.check_struct_output_type_casting_row(SparkConnectGrpcException) + def test_udtf_with_invalid_return_type(self): @udtf(returnType="int") class TestUDTF: diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index 9384a6bc011..0540ecddde7 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -17,9 +17,10 @@ import os import tempfile import unittest - from typing import Iterator +from py4j.protocol import Py4JJavaError + from pyspark.errors import ( PySparkAttributeError, PythonException, @@ -558,12 +559,14 @@ class BaseUDTFTestsMixin: assertDataFrameEqual(TestUDTF(), [Row()]) - def _check_result_or_exception(self, func_handler, ret_type, expected): + def _check_result_or_exception( + self, func_handler, ret_type, expected, *, err_type=PythonException + ): func = udtf(func_handler, returnType=ret_type) if not isinstance(expected, str): assertDataFrameEqual(func(), expected) else: - with self.assertRaisesRegex(PythonException, expected): + with self.assertRaisesRegex(err_type, expected): func().collect() def test_numeric_output_type_casting(self): @@ -655,20 +658,129 @@ class BaseUDTFTestsMixin: def test_array_output_type_casting(self): class TestUDTF: def eval(self): - yield [1, 2], + yield [0, 1.1, 2], for ret_type, expected in [ + ("x: boolean", [Row(x=None)]), + ("x: tinyint", [Row(x=None)]), + ("x: smallint", [Row(x=None)]), ("x: int", [Row(x=None)]), - ("x: array<int>", [Row(x=[1, 2])]), - ("x: array<double>", [Row(x=[None, None])]), - ("x: array<string>", [Row(x=["1", "2"])]), - ("x: array<boolean>", [Row(x=[None, None])]), - ("x: array<array<int>>", [Row(x=[None, None])]), + ("x: bigint", [Row(x=None)]), + ("x: string", [Row(x="[0, 1.1, 2]")]), + ("x: date", "AttributeError"), + ("x: timestamp", "AttributeError"), + ("x: byte", [Row(x=None)]), + ("x: binary", [Row(x=None)]), + ("x: float", [Row(x=None)]), + ("x: double", [Row(x=None)]), + ("x: decimal(10, 0)", [Row(x=None)]), + ("x: array<int>", [Row(x=[0, None, 2])]), + ("x: array<double>", [Row(x=[None, 1.1, None])]), + ("x: array<string>", [Row(x=["0", "1.1", "2"])]), + ("x: array<boolean>", [Row(x=[None, None, None])]), + ("x: array<array<int>>", [Row(x=[None, None, None])]), ("x: map<string,int>", [Row(x=None)]), + ("x: struct<a:int,b:int,c:int>", [Row(x=Row(a=0, b=None, c=2))]), + ]: + with self.subTest(ret_type=ret_type): + self._check_result_or_exception(TestUDTF, ret_type, expected) + + def test_map_output_type_casting(self): + class TestUDTF: + def eval(self): + yield {"a": 0, "b": 1.1, "c": 2}, + + for ret_type, expected in [ + ("x: boolean", [Row(x=None)]), + ("x: tinyint", [Row(x=None)]), + ("x: smallint", [Row(x=None)]), + ("x: int", [Row(x=None)]), + ("x: bigint", [Row(x=None)]), + ("x: string", [Row(x="{a=0, b=1.1, c=2}")]), + ("x: date", "AttributeError"), + ("x: timestamp", "AttributeError"), + ("x: byte", [Row(x=None)]), + ("x: binary", [Row(x=None)]), + ("x: float", [Row(x=None)]), + ("x: double", [Row(x=None)]), + ("x: decimal(10, 0)", [Row(x=None)]), + ("x: array<string>", [Row(x=None)]), + ("x: map<string,string>", [Row(x={"a": "0", "b": "1.1", "c": "2"})]), + ("x: map<string,boolean>", [Row(x={"a": None, "b": None, "c": None})]), + ("x: map<string,int>", [Row(x={"a": 0, "b": None, "c": 2})]), + ("x: map<string,float>", [Row(x={"a": None, "b": 1.1, "c": None})]), + ("x: map<string,map<string,int>>", [Row(x={"a": None, "b": None, "c": None})]), + ("x: struct<a:int>", [Row(x=Row(a=0))]), + ]: + with self.subTest(ret_type=ret_type): + self._check_result_or_exception(TestUDTF, ret_type, expected) + + def test_struct_output_type_casting_dict(self): + class TestUDTF: + def eval(self): + yield {"a": 0, "b": 1.1, "c": 2}, + + for ret_type, expected in [ + ("x: boolean", [Row(x=None)]), + ("x: tinyint", [Row(x=None)]), + ("x: smallint", [Row(x=None)]), + ("x: int", [Row(x=None)]), + ("x: bigint", [Row(x=None)]), + ("x: string", [Row(x="{a=0, b=1.1, c=2}")]), + ("x: date", "AttributeError"), + ("x: timestamp", "AttributeError"), + ("x: byte", [Row(x=None)]), + ("x: binary", [Row(x=None)]), + ("x: float", [Row(x=None)]), + ("x: double", [Row(x=None)]), + ("x: decimal(10, 0)", [Row(x=None)]), + ("x: array<string>", [Row(x=None)]), + ("x: map<string,string>", [Row(x={"a": "0", "b": "1.1", "c": "2"})]), + ("x: struct<a:string,b:string,c:string>", [Row(Row(a="0", b="1.1", c="2"))]), + ("x: struct<a:int,b:int,c:int>", [Row(Row(a=0, b=None, c=2))]), + ("x: struct<a:float,b:float,c:float>", [Row(Row(a=None, b=1.1, c=None))]), ]: with self.subTest(ret_type=ret_type): self._check_result_or_exception(TestUDTF, ret_type, expected) + def test_struct_output_type_casting_row(self): + self.check_struct_output_type_casting_row(Py4JJavaError) + + def check_struct_output_type_casting_row(self, error_type): + class TestUDTF: + def eval(self): + yield Row(a=0, b=1.1, c=2), + + err = ("PickleException", error_type) + + for ret_type, expected in [ + ("x: boolean", err), + ("x: tinyint", err), + ("x: smallint", err), + ("x: int", err), + ("x: bigint", err), + ("x: string", err), + ("x: date", "ValueError"), + ("x: timestamp", "ValueError"), + ("x: byte", err), + ("x: binary", err), + ("x: float", err), + ("x: double", err), + ("x: decimal(10, 0)", err), + ("x: array<string>", err), + ("x: map<string,string>", err), + ("x: struct<a:string,b:string,c:string>", [Row(Row(a="0", b="1.1", c="2"))]), + ("x: struct<a:int,b:int,c:int>", [Row(Row(a=0, b=None, c=2))]), + ("x: struct<a:float,b:float,c:float>", [Row(Row(a=None, b=1.1, c=None))]), + ]: + with self.subTest(ret_type=ret_type): + if isinstance(expected, tuple): + self._check_result_or_exception( + TestUDTF, ret_type, expected[0], err_type=expected[1] + ) + else: + self._check_result_or_exception(TestUDTF, ret_type, expected) + def test_inconsistent_output_types(self): class TestUDTF: def eval(self): @@ -1084,9 +1196,8 @@ class UDTFArrowTestsMixin(BaseUDTFTestsMixin): ("x: double", [Row(x=1.0)]), ("x: decimal(10, 0)", err), ("x: array<int>", err), - # TODO(SPARK-44561): fix AssertionError in convert_map and convert_struct - # ("x: map<string,int>", None), - # ("x: struct<a:int>", None) + ("x: map<string,int>", err), + ("x: struct<a:int>", err), ]: with self.subTest(ret_type=ret_type): self._check_result_or_exception(TestUDTF, ret_type, expected) @@ -1113,10 +1224,9 @@ class UDTFArrowTestsMixin(BaseUDTFTestsMixin): ("x: double", [Row(x=1.0)]), ("x: decimal(10, 0)", [Row(x=1)]), ("x: array<string>", [Row(x=["1"])]), - ("x: array<int>", err), - # TODO(SPARK-44561): fix AssertionError in convert_map and convert_struct - # ("x: map<string,int>", None), - # ("x: struct<a:int>", None) + ("x: array<int>", [Row(x=[1])]), + ("x: map<string,int>", err), + ("x: struct<a:int>", err), ]: with self.subTest(ret_type=ret_type): self._check_result_or_exception(TestUDTF, ret_type, expected) @@ -1144,9 +1254,8 @@ class UDTFArrowTestsMixin(BaseUDTFTestsMixin): ("x: decimal(10, 0)", err), ("x: array<string>", [Row(x=["h", "e", "l", "l", "o"])]), ("x: array<int>", err), - # TODO(SPARK-44561): fix AssertionError in convert_map and convert_struct - # ("x: map<string,int>", None), - # ("x: struct<a:int>", None) + ("x: map<string,int>", err), + ("x: struct<a:int>", err), ]: with self.subTest(ret_type=ret_type): self._check_result_or_exception(TestUDTF, ret_type, expected) @@ -1177,9 +1286,103 @@ class UDTFArrowTestsMixin(BaseUDTFTestsMixin): ("x: array<int>", [Row(x=[0, 1, 2])]), ("x: array<float>", [Row(x=[0, 1.1, 2])]), ("x: array<array<int>>", err), - # TODO(SPARK-44561): fix AssertionError in convert_map and convert_struct - # ("x: map<string,int>", None), - # ("x: struct<a:int>", None) + ("x: map<string,int>", err), + ("x: struct<a:int>", err), + ("x: struct<a:int,b:int,c:int>", err), + ]: + with self.subTest(ret_type=ret_type): + self._check_result_or_exception(TestUDTF, ret_type, expected) + + def test_map_output_type_casting(self): + class TestUDTF: + def eval(self): + yield {"a": 0, "b": 1.1, "c": 2}, + + err = "UDTF_ARROW_TYPE_CAST_ERROR" + + for ret_type, expected in [ + ("x: boolean", err), + ("x: tinyint", err), + ("x: smallint", err), + ("x: int", err), + ("x: bigint", err), + ("x: string", err), + ("x: date", err), + ("x: timestamp", err), + ("x: byte", err), + ("x: binary", err), + ("x: float", err), + ("x: double", err), + ("x: decimal(10, 0)", err), + ("x: array<string>", [Row(x=["a", "b", "c"])]), + ("x: map<string,string>", err), + ("x: map<string,boolean>", err), + ("x: map<string,int>", [Row(x={"a": 0, "b": 1, "c": 2})]), + ("x: map<string,float>", [Row(x={"a": 0, "b": 1.1, "c": 2})]), + ("x: map<string,map<string,int>>", err), + ("x: struct<a:int>", [Row(x=Row(a=0))]), + ]: + with self.subTest(ret_type=ret_type): + self._check_result_or_exception(TestUDTF, ret_type, expected) + + def test_struct_output_type_casting_dict(self): + class TestUDTF: + def eval(self): + yield {"a": 0, "b": 1.1, "c": 2}, + + err = "UDTF_ARROW_TYPE_CAST_ERROR" + + for ret_type, expected in [ + ("x: boolean", err), + ("x: tinyint", err), + ("x: smallint", err), + ("x: int", err), + ("x: bigint", err), + ("x: string", err), + ("x: date", err), + ("x: timestamp", err), + ("x: byte", err), + ("x: binary", err), + ("x: float", err), + ("x: double", err), + ("x: decimal(10, 0)", err), + ("x: array<string>", [Row(x=["a", "b", "c"])]), + ("x: map<string,string>", err), + ("x: struct<a:string,b:string,c:string>", [Row(Row(a="0", b="1.1", c="2"))]), + ("x: struct<a:int,b:int,c:int>", [Row(Row(a=0, b=1, c=2))]), + ("x: struct<a:float,b:float,c:float>", [Row(Row(a=0, b=1.1, c=2))]), + ("x: struct<a:struct<>,b:struct<>,c:struct<>>", err), + ]: + with self.subTest(ret_type=ret_type): + self._check_result_or_exception(TestUDTF, ret_type, expected) + + def test_struct_output_type_casting_row(self): + class TestUDTF: + def eval(self): + yield Row(a=0, b=1.1, c=2), + + err = "UDTF_ARROW_TYPE_CAST_ERROR" + + for ret_type, expected in [ + ("x: boolean", err), + ("x: tinyint", err), + ("x: smallint", err), + ("x: int", err), + ("x: bigint", err), + ("x: string", err), + ("x: date", err), + ("x: timestamp", err), + ("x: byte", err), + ("x: binary", err), + ("x: float", err), + ("x: double", err), + ("x: decimal(10, 0)", err), + ("x: array<string>", [Row(x=["0", "1.1", "2"])]), + ("x: map<string,string>", err), + ("x: struct<a:string,b:string,c:string>", [Row(Row(a="0", b="1.1", c="2"))]), + ("x: struct<a:int,b:int,c:int>", [Row(Row(a=0, b=1, c=2))]), + ("x: struct<a:float,b:float,c:float>", [Row(Row(a=0, b=1.1, c=2))]), + ("x: struct<a:struct<>,b:struct<>,c:struct<>>", err), ]: with self.subTest(ret_type=ret_type): self._check_result_or_exception(TestUDTF, ret_type, expected) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org