This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 32c6cead0eb [SPARK-41879][CONNECT][PYTHON] Make `DataFrame.collect` 
support nested types
32c6cead0eb is described below

commit 32c6cead0eb460717fd988fb22a4d9c6c35993a3
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Jan 11 11:35:43 2023 +0800

    [SPARK-41879][CONNECT][PYTHON] Make `DataFrame.collect` support nested types
    
    ### What changes were proposed in this pull request?
    Make `DataFrame.collect` support nested types, by introducing a new data 
converter.
    
    Note that the duplicated field names are not supported in this PR, since we 
cannot even read the batches in the client side.
    
    ### Why are the changes needed?
    to be consistent with PySpark
    
    ### Does this PR introduce _any_ user-facing change?
    yes
    
    ### How was this patch tested?
    added UT and enabled doctests
    
    Closes #39462 from zhengruifeng/connect_nested_converter.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 python/pyspark/sql/connect/dataframe.py            |  82 +++++++-----
 python/pyspark/sql/connect/functions.py            |   7 -
 python/pyspark/sql/connect/types.py                | 114 ++++++++++++++++-
 .../sql/tests/connect/test_connect_basic.py        | 142 +++++++++++++++++++++
 4 files changed, 308 insertions(+), 37 deletions(-)

diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index c9cd65b93b7..e9e7c086bd7 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -33,20 +33,34 @@ from typing import (
 import sys
 import random
 import pandas
-import datetime
 import json
 import warnings
 from collections.abc import Iterable
 
 from pyspark import _NoValue
 from pyspark._globals import _NoValueType
-from pyspark.sql.types import StructType, Row
+from pyspark.sql.types import (
+    _create_row,
+    Row,
+    StructType,
+    ArrayType,
+    MapType,
+    TimestampType,
+    TimestampNTZType,
+)
+from pyspark.sql.dataframe import (
+    DataFrame as PySparkDataFrame,
+    DataFrameNaFunctions as PySparkDataFrameNaFunctions,
+    DataFrameStatFunctions as PySparkDataFrameStatFunctions,
+)
+from pyspark.sql.pandas.types import from_arrow_schema
 
 import pyspark.sql.connect.plan as plan
 from pyspark.sql.connect.group import GroupedData
 from pyspark.sql.connect.readwriter import DataFrameWriter
 from pyspark.sql.connect.column import Column
 from pyspark.sql.connect.expressions import UnresolvedRegex
+from pyspark.sql.connect.types import _create_converter
 from pyspark.sql.connect.functions import (
     _to_col,
     _invoke_function,
@@ -54,11 +68,6 @@ from pyspark.sql.connect.functions import (
     lit,
     expr as sql_expression,
 )
-from pyspark.sql.dataframe import (
-    DataFrame as PySparkDataFrame,
-    DataFrameNaFunctions as PySparkDataFrameNaFunctions,
-    DataFrameStatFunctions as PySparkDataFrameStatFunctions,
-)
 
 if TYPE_CHECKING:
     from pyspark.sql.connect._typing import ColumnOrName, LiteralType, 
OptionalPrimitiveType
@@ -1223,29 +1232,44 @@ class DataFrame:
         query = self._plan.to_proto(self._session.client)
         table = self._session.client.to_table(query)
 
+        # We first try the inferred schema from PyArrow Table instead of 
always fetching
+        # the Connect Dataframe schema by 'self.schema', for two reasons:
+        # 1, the schema maybe quietly simple, then we can save an RPC;
+        # 2, if we always invoke 'self.schema' here, all catalog functions 
based on
+        # 'dataframe.collect' will be invoked twice (1, collect data, 2, fetch 
schema),
+        # and then some of them (e.g. "CREATE DATABASE") fail due to the 
second invocation.
+
+        schema: Optional[StructType] = None
+        try:
+            schema = from_arrow_schema(table.schema)
+        except Exception:
+            # may fail due to 'from_arrow_schema' not supporting nested struct
+            schema = None
+
+        if schema is None:
+            schema = self.schema
+        else:
+            if any(
+                isinstance(
+                    f.dataType, (StructType, ArrayType, MapType, 
TimestampType, TimestampNTZType)
+                )
+                for f in schema.fields
+            ):
+                schema = self.schema
+
+        assert schema is not None and isinstance(schema, StructType)
+
+        field_converters = [_create_converter(f.dataType) for f in 
schema.fields]
+
+        # table.to_pylist() automatically remove columns with duplicated names,
+        # to avoid this, use columnar lists here.
+        # TODO: support duplicated field names in the one struct. e.g. 
SF.struct("a", "a")
+        columnar_data = [column.to_pylist() for column in table.columns]
+
         rows: List[Row] = []
-        columns = [column.to_pylist() for column in table.columns]
-        i = 0
-        while i < table.num_rows:
-            values: List[Any] = []
-            j = 0
-            while j < table.num_columns:
-                v = columns[j][i]
-                if isinstance(v, bytes):
-                    values.append(bytearray(v))
-                elif isinstance(v, datetime.datetime) and v.tzinfo is not None:
-                    # TODO: Should be controlled by "spark.sql.timestampType"
-                    # always remove the time zone for now
-                    values.append(v.replace(tzinfo=None))
-                elif isinstance(v, dict):
-                    values.append(Row(**v))
-                else:
-                    values.append(v)
-                j += 1
-            new_row = Row(*values)
-            new_row.__fields__ = table.column_names
-            rows.append(new_row)
-            i += 1
+        for i in range(0, table.num_rows):
+            values = [field_converters[j](columnar_data[j][i]) for j in 
range(0, table.num_columns)]
+            rows.append(_create_row(fields=table.column_names, values=values))
         return rows
 
     collect.__doc__ = PySparkDataFrame.collect.__doc__
diff --git a/python/pyspark/sql/connect/functions.py 
b/python/pyspark/sql/connect/functions.py
index 993fd30e7f0..8f9676dfe47 100644
--- a/python/pyspark/sql/connect/functions.py
+++ b/python/pyspark/sql/connect/functions.py
@@ -2357,13 +2357,6 @@ def _test() -> None:
         # Spark Connect does not support Spark Context but the test depends on 
that.
         del pyspark.sql.connect.functions.monotonically_increasing_id.__doc__
 
-        # TODO(SPARK-41880): Function `from_json` should support non-literal 
expression
-        # TODO(SPARK-41879): `DataFrame.collect` should support nested types
-        del pyspark.sql.connect.functions.struct.__doc__
-        del pyspark.sql.connect.functions.create_map.__doc__
-        del pyspark.sql.connect.functions.from_csv.__doc__
-        del pyspark.sql.connect.functions.from_json.__doc__
-
         # TODO(SPARK-41834): implement Dataframe.conf
         del pyspark.sql.connect.functions.from_unixtime.__doc__
         del pyspark.sql.connect.functions.timestamp_seconds.__doc__
diff --git a/python/pyspark/sql/connect/types.py 
b/python/pyspark/sql/connect/types.py
index 2f4abcec9b3..6f5d5971ef7 100644
--- a/python/pyspark/sql/connect/types.py
+++ b/python/pyspark/sql/connect/types.py
@@ -15,11 +15,13 @@
 # limitations under the License.
 #
 
+import datetime
 import json
 
-from typing import Optional
+from typing import Any, Optional, Callable
 
 from pyspark.sql.types import (
+    Row,
     DataType,
     ByteType,
     ShortType,
@@ -184,3 +186,113 @@ def proto_schema_to_pyspark_data_type(schema: 
pb2.DataType) -> DataType:
         )
     else:
         raise Exception(f"Unsupported data type {schema}")
+
+
+def _need_converter(dataType: DataType) -> bool:
+    if isinstance(dataType, NullType):
+        return True
+    elif isinstance(dataType, StructType):
+        return True
+    elif isinstance(dataType, ArrayType):
+        return _need_converter(dataType.elementType)
+    elif isinstance(dataType, MapType):
+        # Different from PySpark, here always needs conversion,
+        # since the input from Arrow is a list of tuples.
+        return True
+    elif isinstance(dataType, BinaryType):
+        return True
+    elif isinstance(dataType, (TimestampType, TimestampNTZType)):
+        # Always remove the time zone info for now
+        return True
+    else:
+        return False
+
+
+def _create_converter(dataType: DataType) -> Callable:
+    assert dataType is not None and isinstance(dataType, DataType)
+
+    if not _need_converter(dataType):
+        return lambda value: value
+
+    if isinstance(dataType, NullType):
+        return lambda value: None
+
+    elif isinstance(dataType, StructType):
+
+        field_convs = {f.name: _create_converter(f.dataType) for f in 
dataType.fields}
+        need_conv = any(_need_converter(f.dataType) for f in dataType.fields)
+
+        def convert_struct(value: Any) -> Row:
+            if value is None:
+                return Row()
+            else:
+                assert isinstance(value, dict)
+
+                if need_conv:
+                    _dict = {}
+                    for k, v in value.items():
+                        assert isinstance(k, str)
+                        _dict[k] = field_convs[k](v)
+                    return Row(**_dict)
+                else:
+                    return Row(**value)
+
+        return convert_struct
+
+    elif isinstance(dataType, ArrayType):
+
+        element_conv = _create_converter(dataType.elementType)
+
+        def convert_array(value: Any) -> Any:
+            if value is None:
+                return None
+            else:
+                assert isinstance(value, list)
+                return [element_conv(v) for v in value]
+
+        return convert_array
+
+    elif isinstance(dataType, MapType):
+
+        key_conv = _create_converter(dataType.keyType)
+        value_conv = _create_converter(dataType.valueType)
+
+        def convert_map(value: Any) -> Any:
+            if value is None:
+                return None
+            else:
+                assert isinstance(value, list)
+                assert all(isinstance(t, tuple) and len(t) == 2 for t in value)
+                return dict((key_conv(t[0]), value_conv(t[1])) for t in value)
+
+        return convert_map
+
+    elif isinstance(dataType, BinaryType):
+
+        def convert_binary(value: Any) -> Any:
+            if value is None:
+                return None
+            else:
+                assert isinstance(value, bytes)
+                return bytearray(value)
+
+        return convert_binary
+
+    elif isinstance(dataType, (TimestampType, TimestampNTZType)):
+
+        def convert_timestample(value: Any) -> Any:
+            if value is None:
+                return None
+            else:
+                assert isinstance(value, datetime.datetime)
+                if value.tzinfo is not None:
+                    # always remove the time zone for now
+                    return value.replace(tzinfo=None)
+                else:
+                    return value
+
+        return convert_timestample
+
+    else:
+
+        return lambda value: value
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 2b03ce70638..45a9e7b3e17 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -2095,6 +2095,148 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
         ):
             cdf.withMetadata(columnName="name", metadata=["magic"])
 
+    def test_collect_nested_type(self):
+        query = """
+            SELECT * FROM VALUES
+            (1, 4, 0, 8, true, true, ARRAY(1, NULL, 3), MAP(1, 2, 3, 4)),
+            (2, 5, -1, NULL, false, NULL, ARRAY(1, 3), MAP(1, NULL, 3, 4)),
+            (3, 6, NULL, 0, false, NULL, ARRAY(NULL), NULL)
+            AS tab(a, b, c, d, e, f, g, h)
+            """
+
+        # +---+---+----+----+-----+----+------------+-------------------+
+        # |  a|  b|   c|   d|    e|   f|           g|                  h|
+        # +---+---+----+----+-----+----+------------+-------------------+
+        # |  1|  4|   0|   8| true|true|[1, null, 3]|   {1 -> 2, 3 -> 4}|
+        # |  2|  5|  -1|null|false|null|      [1, 3]|{1 -> null, 3 -> 4}|
+        # |  3|  6|null|   0|false|null|      [null]|               null|
+        # +---+---+----+----+-----+----+------------+-------------------+
+
+        cdf = self.connect.sql(query)
+        sdf = self.spark.sql(query)
+
+        # test collect array
+        # +--------------+-------------+------------+
+        # |array(a, b, c)|  array(e, f)|           g|
+        # +--------------+-------------+------------+
+        # |     [1, 4, 0]| [true, true]|[1, null, 3]|
+        # |    [2, 5, -1]|[false, null]|      [1, 3]|
+        # |  [3, 6, null]|[false, null]|      [null]|
+        # +--------------+-------------+------------+
+        self.assertEqual(
+            cdf.select(CF.array("a", "b", "c"), CF.array("e", "f"), 
CF.col("g")).collect(),
+            sdf.select(SF.array("a", "b", "c"), SF.array("e", "f"), 
SF.col("g")).collect(),
+        )
+
+        # test collect nested array
+        # +-----------------------------------+-------------------------+
+        # |array(array(a), array(b), array(c))|array(array(e), array(f))|
+        # +-----------------------------------+-------------------------+
+        # |                    [[1], [4], [0]]|         [[true], [true]]|
+        # |                   [[2], [5], [-1]]|        [[false], [null]]|
+        # |                 [[3], [6], [null]]|        [[false], [null]]|
+        # +-----------------------------------+-------------------------+
+        self.assertEqual(
+            cdf.select(
+                CF.array(CF.array("a"), CF.array("b"), CF.array("c")),
+                CF.array(CF.array("e"), CF.array("f")),
+            ).collect(),
+            sdf.select(
+                SF.array(SF.array("a"), SF.array("b"), SF.array("c")),
+                SF.array(SF.array("e"), SF.array("f")),
+            ).collect(),
+        )
+
+        # test collect array of struct, map
+        # +----------------+---------------------+
+        # |array(struct(a))|             array(h)|
+        # +----------------+---------------------+
+        # |           [{1}]|   [{1 -> 2, 3 -> 4}]|
+        # |           [{2}]|[{1 -> null, 3 -> 4}]|
+        # |           [{3}]|               [null]|
+        # +----------------+---------------------+
+        self.assertEqual(
+            cdf.select(CF.array(CF.struct("a")), CF.array("h")).collect(),
+            sdf.select(SF.array(SF.struct("a")), SF.array("h")).collect(),
+        )
+
+        # test collect map
+        # +-------------------+-------------------+
+        # |                  h|    map(a, b, b, c)|
+        # +-------------------+-------------------+
+        # |   {1 -> 2, 3 -> 4}|   {1 -> 4, 4 -> 0}|
+        # |{1 -> null, 3 -> 4}|  {2 -> 5, 5 -> -1}|
+        # |               null|{3 -> 6, 6 -> null}|
+        # +-------------------+-------------------+
+        self.assertEqual(
+            cdf.select(CF.col("h"), CF.create_map("a", "b", "b", 
"c")).collect(),
+            sdf.select(SF.col("h"), SF.create_map("a", "b", "b", 
"c")).collect(),
+        )
+
+        # test collect map of struct, array
+        # +-------------------+------------------------+
+        # |          map(a, g)|    map(a, struct(b, g))|
+        # +-------------------+------------------------+
+        # |{1 -> [1, null, 3]}|{1 -> {4, [1, null, 3]}}|
+        # |      {2 -> [1, 3]}|      {2 -> {5, [1, 3]}}|
+        # |      {3 -> [null]}|      {3 -> {6, [null]}}|
+        # +-------------------+------------------------+
+        self.assertEqual(
+            cdf.select(CF.create_map("a", "g"), CF.create_map("a", 
CF.struct("b", "g"))).collect(),
+            sdf.select(SF.create_map("a", "g"), SF.create_map("a", 
SF.struct("b", "g"))).collect(),
+        )
+
+        # test collect struct
+        # +------------------+--------------------------+
+        # |struct(a, b, c, d)|           struct(e, f, g)|
+        # +------------------+--------------------------+
+        # |      {1, 4, 0, 8}|{true, true, [1, null, 3]}|
+        # |  {2, 5, -1, null}|     {false, null, [1, 3]}|
+        # |   {3, 6, null, 0}|     {false, null, [null]}|
+        # +------------------+--------------------------+
+        self.assertEqual(
+            cdf.select(CF.struct("a", "b", "c", "d"), CF.struct("e", "f", 
"g")).collect(),
+            sdf.select(SF.struct("a", "b", "c", "d"), SF.struct("e", "f", 
"g")).collect(),
+        )
+
+        # test collect nested struct
+        # 
+------------------------------------------+--------------------------+----------------------------+
 # noqa
+        # |struct(a, struct(a, struct(c, struct(d))))|struct(a, b, struct(c, 
d))|     struct(e, f, struct(g))| # noqa
+        # 
+------------------------------------------+--------------------------+----------------------------+
 # noqa
+        # |                        {1, {1, {0, {8}}}}|            {1, 4, {0, 
8}}|{true, true, {[1, null, 3]}}| # noqa
+        # |                    {2, {2, {-1, {null}}}}|        {2, 5, {-1, 
null}}|     {false, null, {[1, 3]}}| # noqa
+        # |                     {3, {3, {null, {0}}}}|         {3, 6, {null, 
0}}|     {false, null, {[null]}}| # noqa
+        # 
+------------------------------------------+--------------------------+----------------------------+
 # noqa
+        self.assertEqual(
+            cdf.select(
+                CF.struct("a", CF.struct("a", CF.struct("c", CF.struct("d")))),
+                CF.struct("a", "b", CF.struct("c", "d")),
+                CF.struct("e", "f", CF.struct("g")),
+            ).collect(),
+            sdf.select(
+                SF.struct("a", SF.struct("a", SF.struct("c", SF.struct("d")))),
+                SF.struct("a", "b", SF.struct("c", "d")),
+                SF.struct("e", "f", SF.struct("g")),
+            ).collect(),
+        )
+
+        # test collect struct containing array, map
+        # +--------------------------------------------+
+        # |  struct(a, struct(a, struct(g, struct(h))))|
+        # +--------------------------------------------+
+        # |{1, {1, {[1, null, 3], {{1 -> 2, 3 -> 4}}}}}|
+        # |   {2, {2, {[1, 3], {{1 -> null, 3 -> 4}}}}}|
+        # |                  {3, {3, {[null], {null}}}}|
+        # +--------------------------------------------+
+        self.assertEqual(
+            cdf.select(
+                CF.struct("a", CF.struct("a", CF.struct("g", CF.struct("h")))),
+            ).collect(),
+            sdf.select(
+                SF.struct("a", SF.struct("a", SF.struct("g", SF.struct("h")))),
+            ).collect(),
+        )
+
     def test_unsupported_functions(self):
         # SPARK-41225: Disable unsupported functions.
         df = self.connect.read.table(self.tbl_name)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to