This is an automated email from the ASF dual-hosted git repository.

ueshin pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new ea6b41cb398 [SPARK-44561][PYTHON] Fix AssertionError when converting 
UDTF output to a complex type
ea6b41cb398 is described below

commit ea6b41cb3989996e45102b1930b1498324761093
Author: Takuya UESHIN <ues...@databricks.com>
AuthorDate: Mon Aug 7 11:48:24 2023 -0700

    [SPARK-44561][PYTHON] Fix AssertionError when converting UDTF output to a 
complex type
    
    ### What changes were proposed in this pull request?
    
    Fixes AssertionError when converting UDTF output to a complex type by 
ignore assertions in `_create_converter_from_pandas` to make Arrow raise an 
error.
    
    ### Why are the changes needed?
    
    There is an assertion in `_create_converter_from_pandas`, but it should not 
be applied for Python UDTF case.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Added/modified the related tests.
    
    Closes #42310 from ueshin/issues/SPARK-44561/udtf_complex_types.
    
    Authored-by: Takuya UESHIN <ues...@databricks.com>
    Signed-off-by: Takuya UESHIN <ues...@databricks.com>
    (cherry picked from commit f1a161cb39504bd625ea7fa50d2cc72a1a2a59e9)
    Signed-off-by: Takuya UESHIN <ues...@databricks.com>
---
 python/pyspark/sql/pandas/serializers.py           |   5 +-
 python/pyspark/sql/pandas/types.py                 | 108 ++++++---
 .../pyspark/sql/tests/connect/test_parity_udtf.py  |   3 +
 python/pyspark/sql/tests/test_udtf.py              | 247 +++++++++++++++++++--
 4 files changed, 314 insertions(+), 49 deletions(-)

diff --git a/python/pyspark/sql/pandas/serializers.py 
b/python/pyspark/sql/pandas/serializers.py
index f3037c8b39c..d1a3babb1fd 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -571,7 +571,10 @@ class 
ArrowStreamPandasUDTFSerializer(ArrowStreamPandasUDFSerializer):
             dt = spark_type or from_arrow_type(arrow_type, 
prefer_timestamp_ntz=True)
             # TODO(SPARK-43579): cache the converter for reuse
             conv = _create_converter_from_pandas(
-                dt, timezone=self._timezone, 
error_on_duplicated_field_names=False
+                dt,
+                timezone=self._timezone,
+                error_on_duplicated_field_names=False,
+                ignore_unexpected_complex_type_values=True,
             )
             series = conv(series)
 
diff --git a/python/pyspark/sql/pandas/types.py 
b/python/pyspark/sql/pandas/types.py
index 53362047604..b02a003e632 100644
--- a/python/pyspark/sql/pandas/types.py
+++ b/python/pyspark/sql/pandas/types.py
@@ -21,7 +21,7 @@ pandas instances during the type conversion.
 """
 import datetime
 import itertools
-from typing import Any, Callable, List, Optional, Union, TYPE_CHECKING
+from typing import Any, Callable, Iterable, List, Optional, Union, 
TYPE_CHECKING
 
 from pyspark.sql.types import (
     cast,
@@ -750,6 +750,7 @@ def _create_converter_from_pandas(
     *,
     timezone: Optional[str],
     error_on_duplicated_field_names: bool = True,
+    ignore_unexpected_complex_type_values: bool = False,
 ) -> Callable[["pd.Series"], "pd.Series"]:
     """
     Create a converter of pandas Series to create Spark DataFrame with Arrow 
optimization.
@@ -763,6 +764,17 @@ def _create_converter_from_pandas(
     error_on_duplicated_field_names : bool, optional
         Whether raise an exception when there are duplicated field names.
         (default ``True``)
+    ignore_unexpected_complex_type_values : bool, optional
+        Whether ignore the case where unexpected values are given for complex 
types.
+        If ``False``, each complex type expects:
+
+        * array type: :class:`Iterable`
+        * map type: :class:`dict`
+        * struct type: :class:`dict` or :class:`tuple`
+
+        and raise an AssertionError when the given value is not the expected 
type.
+        If ``True``, just ignore and return the give value.
+        (default ``False``)
 
     Returns
     -------
@@ -781,15 +793,26 @@ def _create_converter_from_pandas(
     def _converter(dt: DataType) -> Optional[Callable[[Any], Any]]:
 
         if isinstance(dt, ArrayType):
-            _element_conv = _converter(dt.elementType)
-            if _element_conv is None:
-                return None
+            _element_conv = _converter(dt.elementType) or (lambda x: x)
 
-            def convert_array(value: Any) -> Any:
-                if value is None:
-                    return None
-                else:
-                    return [_element_conv(v) for v in value]  # type: 
ignore[misc]
+            if ignore_unexpected_complex_type_values:
+
+                def convert_array(value: Any) -> Any:
+                    if value is None:
+                        return None
+                    elif isinstance(value, Iterable):
+                        return [_element_conv(v) for v in value]
+                    else:
+                        return value
+
+            else:
+
+                def convert_array(value: Any) -> Any:
+                    if value is None:
+                        return None
+                    else:
+                        assert isinstance(value, Iterable)
+                        return [_element_conv(v) for v in value]
 
             return convert_array
 
@@ -797,12 +820,24 @@ def _create_converter_from_pandas(
             _key_conv = _converter(dt.keyType) or (lambda x: x)
             _value_conv = _converter(dt.valueType) or (lambda x: x)
 
-            def convert_map(value: Any) -> Any:
-                if value is None:
-                    return None
-                else:
-                    assert isinstance(value, dict)
-                    return [(_key_conv(k), _value_conv(v)) for k, v in 
value.items()]
+            if ignore_unexpected_complex_type_values:
+
+                def convert_map(value: Any) -> Any:
+                    if value is None:
+                        return None
+                    elif isinstance(value, dict):
+                        return [(_key_conv(k), _value_conv(v)) for k, v in 
value.items()]
+                    else:
+                        return value
+
+            else:
+
+                def convert_map(value: Any) -> Any:
+                    if value is None:
+                        return None
+                    else:
+                        assert isinstance(value, dict)
+                        return [(_key_conv(k), _value_conv(v)) for k, v in 
value.items()]
 
             return convert_map
 
@@ -820,17 +855,38 @@ def _create_converter_from_pandas(
 
             field_convs = [_converter(f.dataType) or (lambda x: x) for f in 
dt.fields]
 
-            def convert_struct(value: Any) -> Any:
-                if value is None:
-                    return None
-                elif isinstance(value, dict):
-                    return {
-                        dedup_field_names[i]: field_convs[i](value.get(key, 
None))
-                        for i, key in enumerate(field_names)
-                    }
-                else:
-                    assert isinstance(value, tuple)
-                    return {dedup_field_names[i]: field_convs[i](v) for i, v 
in enumerate(value)}
+            if ignore_unexpected_complex_type_values:
+
+                def convert_struct(value: Any) -> Any:
+                    if value is None:
+                        return None
+                    elif isinstance(value, dict):
+                        return {
+                            dedup_field_names[i]: 
field_convs[i](value.get(key, None))
+                            for i, key in enumerate(field_names)
+                        }
+                    elif isinstance(value, tuple):
+                        return {
+                            dedup_field_names[i]: field_convs[i](v) for i, v 
in enumerate(value)
+                        }
+                    else:
+                        return value
+
+            else:
+
+                def convert_struct(value: Any) -> Any:
+                    if value is None:
+                        return None
+                    elif isinstance(value, dict):
+                        return {
+                            dedup_field_names[i]: 
field_convs[i](value.get(key, None))
+                            for i, key in enumerate(field_names)
+                        }
+                    else:
+                        assert isinstance(value, tuple)
+                        return {
+                            dedup_field_names[i]: field_convs[i](v) for i, v 
in enumerate(value)
+                        }
 
             return convert_struct
 
diff --git a/python/pyspark/sql/tests/connect/test_parity_udtf.py 
b/python/pyspark/sql/tests/connect/test_parity_udtf.py
index 355f5288d2c..1222b1bb5b4 100644
--- a/python/pyspark/sql/tests/connect/test_parity_udtf.py
+++ b/python/pyspark/sql/tests/connect/test_parity_udtf.py
@@ -43,6 +43,9 @@ class UDTFParityTests(BaseUDTFTestsMixin, 
ReusedConnectTestCase):
 
     # TODO: use PySpark error classes instead of SparkConnectGrpcException
 
+    def test_struct_output_type_casting_row(self):
+        self.check_struct_output_type_casting_row(SparkConnectGrpcException)
+
     def test_udtf_with_invalid_return_type(self):
         @udtf(returnType="int")
         class TestUDTF:
diff --git a/python/pyspark/sql/tests/test_udtf.py 
b/python/pyspark/sql/tests/test_udtf.py
index 9384a6bc011..0540ecddde7 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -17,9 +17,10 @@
 import os
 import tempfile
 import unittest
-
 from typing import Iterator
 
+from py4j.protocol import Py4JJavaError
+
 from pyspark.errors import (
     PySparkAttributeError,
     PythonException,
@@ -558,12 +559,14 @@ class BaseUDTFTestsMixin:
 
         assertDataFrameEqual(TestUDTF(), [Row()])
 
-    def _check_result_or_exception(self, func_handler, ret_type, expected):
+    def _check_result_or_exception(
+        self, func_handler, ret_type, expected, *, err_type=PythonException
+    ):
         func = udtf(func_handler, returnType=ret_type)
         if not isinstance(expected, str):
             assertDataFrameEqual(func(), expected)
         else:
-            with self.assertRaisesRegex(PythonException, expected):
+            with self.assertRaisesRegex(err_type, expected):
                 func().collect()
 
     def test_numeric_output_type_casting(self):
@@ -655,20 +658,129 @@ class BaseUDTFTestsMixin:
     def test_array_output_type_casting(self):
         class TestUDTF:
             def eval(self):
-                yield [1, 2],
+                yield [0, 1.1, 2],
 
         for ret_type, expected in [
+            ("x: boolean", [Row(x=None)]),
+            ("x: tinyint", [Row(x=None)]),
+            ("x: smallint", [Row(x=None)]),
             ("x: int", [Row(x=None)]),
-            ("x: array<int>", [Row(x=[1, 2])]),
-            ("x: array<double>", [Row(x=[None, None])]),
-            ("x: array<string>", [Row(x=["1", "2"])]),
-            ("x: array<boolean>", [Row(x=[None, None])]),
-            ("x: array<array<int>>", [Row(x=[None, None])]),
+            ("x: bigint", [Row(x=None)]),
+            ("x: string", [Row(x="[0, 1.1, 2]")]),
+            ("x: date", "AttributeError"),
+            ("x: timestamp", "AttributeError"),
+            ("x: byte", [Row(x=None)]),
+            ("x: binary", [Row(x=None)]),
+            ("x: float", [Row(x=None)]),
+            ("x: double", [Row(x=None)]),
+            ("x: decimal(10, 0)", [Row(x=None)]),
+            ("x: array<int>", [Row(x=[0, None, 2])]),
+            ("x: array<double>", [Row(x=[None, 1.1, None])]),
+            ("x: array<string>", [Row(x=["0", "1.1", "2"])]),
+            ("x: array<boolean>", [Row(x=[None, None, None])]),
+            ("x: array<array<int>>", [Row(x=[None, None, None])]),
             ("x: map<string,int>", [Row(x=None)]),
+            ("x: struct<a:int,b:int,c:int>", [Row(x=Row(a=0, b=None, c=2))]),
+        ]:
+            with self.subTest(ret_type=ret_type):
+                self._check_result_or_exception(TestUDTF, ret_type, expected)
+
+    def test_map_output_type_casting(self):
+        class TestUDTF:
+            def eval(self):
+                yield {"a": 0, "b": 1.1, "c": 2},
+
+        for ret_type, expected in [
+            ("x: boolean", [Row(x=None)]),
+            ("x: tinyint", [Row(x=None)]),
+            ("x: smallint", [Row(x=None)]),
+            ("x: int", [Row(x=None)]),
+            ("x: bigint", [Row(x=None)]),
+            ("x: string", [Row(x="{a=0, b=1.1, c=2}")]),
+            ("x: date", "AttributeError"),
+            ("x: timestamp", "AttributeError"),
+            ("x: byte", [Row(x=None)]),
+            ("x: binary", [Row(x=None)]),
+            ("x: float", [Row(x=None)]),
+            ("x: double", [Row(x=None)]),
+            ("x: decimal(10, 0)", [Row(x=None)]),
+            ("x: array<string>", [Row(x=None)]),
+            ("x: map<string,string>", [Row(x={"a": "0", "b": "1.1", "c": 
"2"})]),
+            ("x: map<string,boolean>", [Row(x={"a": None, "b": None, "c": 
None})]),
+            ("x: map<string,int>", [Row(x={"a": 0, "b": None, "c": 2})]),
+            ("x: map<string,float>", [Row(x={"a": None, "b": 1.1, "c": 
None})]),
+            ("x: map<string,map<string,int>>", [Row(x={"a": None, "b": None, 
"c": None})]),
+            ("x: struct<a:int>", [Row(x=Row(a=0))]),
+        ]:
+            with self.subTest(ret_type=ret_type):
+                self._check_result_or_exception(TestUDTF, ret_type, expected)
+
+    def test_struct_output_type_casting_dict(self):
+        class TestUDTF:
+            def eval(self):
+                yield {"a": 0, "b": 1.1, "c": 2},
+
+        for ret_type, expected in [
+            ("x: boolean", [Row(x=None)]),
+            ("x: tinyint", [Row(x=None)]),
+            ("x: smallint", [Row(x=None)]),
+            ("x: int", [Row(x=None)]),
+            ("x: bigint", [Row(x=None)]),
+            ("x: string", [Row(x="{a=0, b=1.1, c=2}")]),
+            ("x: date", "AttributeError"),
+            ("x: timestamp", "AttributeError"),
+            ("x: byte", [Row(x=None)]),
+            ("x: binary", [Row(x=None)]),
+            ("x: float", [Row(x=None)]),
+            ("x: double", [Row(x=None)]),
+            ("x: decimal(10, 0)", [Row(x=None)]),
+            ("x: array<string>", [Row(x=None)]),
+            ("x: map<string,string>", [Row(x={"a": "0", "b": "1.1", "c": 
"2"})]),
+            ("x: struct<a:string,b:string,c:string>", [Row(Row(a="0", b="1.1", 
c="2"))]),
+            ("x: struct<a:int,b:int,c:int>", [Row(Row(a=0, b=None, c=2))]),
+            ("x: struct<a:float,b:float,c:float>", [Row(Row(a=None, b=1.1, 
c=None))]),
         ]:
             with self.subTest(ret_type=ret_type):
                 self._check_result_or_exception(TestUDTF, ret_type, expected)
 
+    def test_struct_output_type_casting_row(self):
+        self.check_struct_output_type_casting_row(Py4JJavaError)
+
+    def check_struct_output_type_casting_row(self, error_type):
+        class TestUDTF:
+            def eval(self):
+                yield Row(a=0, b=1.1, c=2),
+
+        err = ("PickleException", error_type)
+
+        for ret_type, expected in [
+            ("x: boolean", err),
+            ("x: tinyint", err),
+            ("x: smallint", err),
+            ("x: int", err),
+            ("x: bigint", err),
+            ("x: string", err),
+            ("x: date", "ValueError"),
+            ("x: timestamp", "ValueError"),
+            ("x: byte", err),
+            ("x: binary", err),
+            ("x: float", err),
+            ("x: double", err),
+            ("x: decimal(10, 0)", err),
+            ("x: array<string>", err),
+            ("x: map<string,string>", err),
+            ("x: struct<a:string,b:string,c:string>", [Row(Row(a="0", b="1.1", 
c="2"))]),
+            ("x: struct<a:int,b:int,c:int>", [Row(Row(a=0, b=None, c=2))]),
+            ("x: struct<a:float,b:float,c:float>", [Row(Row(a=None, b=1.1, 
c=None))]),
+        ]:
+            with self.subTest(ret_type=ret_type):
+                if isinstance(expected, tuple):
+                    self._check_result_or_exception(
+                        TestUDTF, ret_type, expected[0], err_type=expected[1]
+                    )
+                else:
+                    self._check_result_or_exception(TestUDTF, ret_type, 
expected)
+
     def test_inconsistent_output_types(self):
         class TestUDTF:
             def eval(self):
@@ -1084,9 +1196,8 @@ class UDTFArrowTestsMixin(BaseUDTFTestsMixin):
             ("x: double", [Row(x=1.0)]),
             ("x: decimal(10, 0)", err),
             ("x: array<int>", err),
-            # TODO(SPARK-44561): fix AssertionError in convert_map and 
convert_struct
-            # ("x: map<string,int>", None),
-            # ("x: struct<a:int>", None)
+            ("x: map<string,int>", err),
+            ("x: struct<a:int>", err),
         ]:
             with self.subTest(ret_type=ret_type):
                 self._check_result_or_exception(TestUDTF, ret_type, expected)
@@ -1113,10 +1224,9 @@ class UDTFArrowTestsMixin(BaseUDTFTestsMixin):
             ("x: double", [Row(x=1.0)]),
             ("x: decimal(10, 0)", [Row(x=1)]),
             ("x: array<string>", [Row(x=["1"])]),
-            ("x: array<int>", err),
-            # TODO(SPARK-44561): fix AssertionError in convert_map and 
convert_struct
-            # ("x: map<string,int>", None),
-            # ("x: struct<a:int>", None)
+            ("x: array<int>", [Row(x=[1])]),
+            ("x: map<string,int>", err),
+            ("x: struct<a:int>", err),
         ]:
             with self.subTest(ret_type=ret_type):
                 self._check_result_or_exception(TestUDTF, ret_type, expected)
@@ -1144,9 +1254,8 @@ class UDTFArrowTestsMixin(BaseUDTFTestsMixin):
             ("x: decimal(10, 0)", err),
             ("x: array<string>", [Row(x=["h", "e", "l", "l", "o"])]),
             ("x: array<int>", err),
-            # TODO(SPARK-44561): fix AssertionError in convert_map and 
convert_struct
-            # ("x: map<string,int>", None),
-            # ("x: struct<a:int>", None)
+            ("x: map<string,int>", err),
+            ("x: struct<a:int>", err),
         ]:
             with self.subTest(ret_type=ret_type):
                 self._check_result_or_exception(TestUDTF, ret_type, expected)
@@ -1177,9 +1286,103 @@ class UDTFArrowTestsMixin(BaseUDTFTestsMixin):
             ("x: array<int>", [Row(x=[0, 1, 2])]),
             ("x: array<float>", [Row(x=[0, 1.1, 2])]),
             ("x: array<array<int>>", err),
-            # TODO(SPARK-44561): fix AssertionError in convert_map and 
convert_struct
-            # ("x: map<string,int>", None),
-            # ("x: struct<a:int>", None)
+            ("x: map<string,int>", err),
+            ("x: struct<a:int>", err),
+            ("x: struct<a:int,b:int,c:int>", err),
+        ]:
+            with self.subTest(ret_type=ret_type):
+                self._check_result_or_exception(TestUDTF, ret_type, expected)
+
+    def test_map_output_type_casting(self):
+        class TestUDTF:
+            def eval(self):
+                yield {"a": 0, "b": 1.1, "c": 2},
+
+        err = "UDTF_ARROW_TYPE_CAST_ERROR"
+
+        for ret_type, expected in [
+            ("x: boolean", err),
+            ("x: tinyint", err),
+            ("x: smallint", err),
+            ("x: int", err),
+            ("x: bigint", err),
+            ("x: string", err),
+            ("x: date", err),
+            ("x: timestamp", err),
+            ("x: byte", err),
+            ("x: binary", err),
+            ("x: float", err),
+            ("x: double", err),
+            ("x: decimal(10, 0)", err),
+            ("x: array<string>", [Row(x=["a", "b", "c"])]),
+            ("x: map<string,string>", err),
+            ("x: map<string,boolean>", err),
+            ("x: map<string,int>", [Row(x={"a": 0, "b": 1, "c": 2})]),
+            ("x: map<string,float>", [Row(x={"a": 0, "b": 1.1, "c": 2})]),
+            ("x: map<string,map<string,int>>", err),
+            ("x: struct<a:int>", [Row(x=Row(a=0))]),
+        ]:
+            with self.subTest(ret_type=ret_type):
+                self._check_result_or_exception(TestUDTF, ret_type, expected)
+
+    def test_struct_output_type_casting_dict(self):
+        class TestUDTF:
+            def eval(self):
+                yield {"a": 0, "b": 1.1, "c": 2},
+
+        err = "UDTF_ARROW_TYPE_CAST_ERROR"
+
+        for ret_type, expected in [
+            ("x: boolean", err),
+            ("x: tinyint", err),
+            ("x: smallint", err),
+            ("x: int", err),
+            ("x: bigint", err),
+            ("x: string", err),
+            ("x: date", err),
+            ("x: timestamp", err),
+            ("x: byte", err),
+            ("x: binary", err),
+            ("x: float", err),
+            ("x: double", err),
+            ("x: decimal(10, 0)", err),
+            ("x: array<string>", [Row(x=["a", "b", "c"])]),
+            ("x: map<string,string>", err),
+            ("x: struct<a:string,b:string,c:string>", [Row(Row(a="0", b="1.1", 
c="2"))]),
+            ("x: struct<a:int,b:int,c:int>", [Row(Row(a=0, b=1, c=2))]),
+            ("x: struct<a:float,b:float,c:float>", [Row(Row(a=0, b=1.1, 
c=2))]),
+            ("x: struct<a:struct<>,b:struct<>,c:struct<>>", err),
+        ]:
+            with self.subTest(ret_type=ret_type):
+                self._check_result_or_exception(TestUDTF, ret_type, expected)
+
+    def test_struct_output_type_casting_row(self):
+        class TestUDTF:
+            def eval(self):
+                yield Row(a=0, b=1.1, c=2),
+
+        err = "UDTF_ARROW_TYPE_CAST_ERROR"
+
+        for ret_type, expected in [
+            ("x: boolean", err),
+            ("x: tinyint", err),
+            ("x: smallint", err),
+            ("x: int", err),
+            ("x: bigint", err),
+            ("x: string", err),
+            ("x: date", err),
+            ("x: timestamp", err),
+            ("x: byte", err),
+            ("x: binary", err),
+            ("x: float", err),
+            ("x: double", err),
+            ("x: decimal(10, 0)", err),
+            ("x: array<string>", [Row(x=["0", "1.1", "2"])]),
+            ("x: map<string,string>", err),
+            ("x: struct<a:string,b:string,c:string>", [Row(Row(a="0", b="1.1", 
c="2"))]),
+            ("x: struct<a:int,b:int,c:int>", [Row(Row(a=0, b=1, c=2))]),
+            ("x: struct<a:float,b:float,c:float>", [Row(Row(a=0, b=1.1, 
c=2))]),
+            ("x: struct<a:struct<>,b:struct<>,c:struct<>>", err),
         ]:
             with self.subTest(ret_type=ret_type):
                 self._check_result_or_exception(TestUDTF, ret_type, expected)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to