This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 32c6cead0eb [SPARK-41879][CONNECT][PYTHON] Make `DataFrame.collect`
support nested types
32c6cead0eb is described below
commit 32c6cead0eb460717fd988fb22a4d9c6c35993a3
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Jan 11 11:35:43 2023 +0800
[SPARK-41879][CONNECT][PYTHON] Make `DataFrame.collect` support nested types
### What changes were proposed in this pull request?
Make `DataFrame.collect` support nested types, by introducing a new data
converter.
Note that the duplicated field names are not supported in this PR, since we
cannot even read the batches in the client side.
### Why are the changes needed?
to be consistent with PySpark
### Does this PR introduce _any_ user-facing change?
yes
### How was this patch tested?
added UT and enabled doctests
Closes #39462 from zhengruifeng/connect_nested_converter.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/sql/connect/dataframe.py | 82 +++++++-----
python/pyspark/sql/connect/functions.py | 7 -
python/pyspark/sql/connect/types.py | 114 ++++++++++++++++-
.../sql/tests/connect/test_connect_basic.py | 142 +++++++++++++++++++++
4 files changed, 308 insertions(+), 37 deletions(-)
diff --git a/python/pyspark/sql/connect/dataframe.py
b/python/pyspark/sql/connect/dataframe.py
index c9cd65b93b7..e9e7c086bd7 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -33,20 +33,34 @@ from typing import (
import sys
import random
import pandas
-import datetime
import json
import warnings
from collections.abc import Iterable
from pyspark import _NoValue
from pyspark._globals import _NoValueType
-from pyspark.sql.types import StructType, Row
+from pyspark.sql.types import (
+ _create_row,
+ Row,
+ StructType,
+ ArrayType,
+ MapType,
+ TimestampType,
+ TimestampNTZType,
+)
+from pyspark.sql.dataframe import (
+ DataFrame as PySparkDataFrame,
+ DataFrameNaFunctions as PySparkDataFrameNaFunctions,
+ DataFrameStatFunctions as PySparkDataFrameStatFunctions,
+)
+from pyspark.sql.pandas.types import from_arrow_schema
import pyspark.sql.connect.plan as plan
from pyspark.sql.connect.group import GroupedData
from pyspark.sql.connect.readwriter import DataFrameWriter
from pyspark.sql.connect.column import Column
from pyspark.sql.connect.expressions import UnresolvedRegex
+from pyspark.sql.connect.types import _create_converter
from pyspark.sql.connect.functions import (
_to_col,
_invoke_function,
@@ -54,11 +68,6 @@ from pyspark.sql.connect.functions import (
lit,
expr as sql_expression,
)
-from pyspark.sql.dataframe import (
- DataFrame as PySparkDataFrame,
- DataFrameNaFunctions as PySparkDataFrameNaFunctions,
- DataFrameStatFunctions as PySparkDataFrameStatFunctions,
-)
if TYPE_CHECKING:
from pyspark.sql.connect._typing import ColumnOrName, LiteralType,
OptionalPrimitiveType
@@ -1223,29 +1232,44 @@ class DataFrame:
query = self._plan.to_proto(self._session.client)
table = self._session.client.to_table(query)
+ # We first try the inferred schema from PyArrow Table instead of
always fetching
+ # the Connect Dataframe schema by 'self.schema', for two reasons:
+ # 1, the schema maybe quietly simple, then we can save an RPC;
+ # 2, if we always invoke 'self.schema' here, all catalog functions
based on
+ # 'dataframe.collect' will be invoked twice (1, collect data, 2, fetch
schema),
+ # and then some of them (e.g. "CREATE DATABASE") fail due to the
second invocation.
+
+ schema: Optional[StructType] = None
+ try:
+ schema = from_arrow_schema(table.schema)
+ except Exception:
+ # may fail due to 'from_arrow_schema' not supporting nested struct
+ schema = None
+
+ if schema is None:
+ schema = self.schema
+ else:
+ if any(
+ isinstance(
+ f.dataType, (StructType, ArrayType, MapType,
TimestampType, TimestampNTZType)
+ )
+ for f in schema.fields
+ ):
+ schema = self.schema
+
+ assert schema is not None and isinstance(schema, StructType)
+
+ field_converters = [_create_converter(f.dataType) for f in
schema.fields]
+
+ # table.to_pylist() automatically remove columns with duplicated names,
+ # to avoid this, use columnar lists here.
+ # TODO: support duplicated field names in the one struct. e.g.
SF.struct("a", "a")
+ columnar_data = [column.to_pylist() for column in table.columns]
+
rows: List[Row] = []
- columns = [column.to_pylist() for column in table.columns]
- i = 0
- while i < table.num_rows:
- values: List[Any] = []
- j = 0
- while j < table.num_columns:
- v = columns[j][i]
- if isinstance(v, bytes):
- values.append(bytearray(v))
- elif isinstance(v, datetime.datetime) and v.tzinfo is not None:
- # TODO: Should be controlled by "spark.sql.timestampType"
- # always remove the time zone for now
- values.append(v.replace(tzinfo=None))
- elif isinstance(v, dict):
- values.append(Row(**v))
- else:
- values.append(v)
- j += 1
- new_row = Row(*values)
- new_row.__fields__ = table.column_names
- rows.append(new_row)
- i += 1
+ for i in range(0, table.num_rows):
+ values = [field_converters[j](columnar_data[j][i]) for j in
range(0, table.num_columns)]
+ rows.append(_create_row(fields=table.column_names, values=values))
return rows
collect.__doc__ = PySparkDataFrame.collect.__doc__
diff --git a/python/pyspark/sql/connect/functions.py
b/python/pyspark/sql/connect/functions.py
index 993fd30e7f0..8f9676dfe47 100644
--- a/python/pyspark/sql/connect/functions.py
+++ b/python/pyspark/sql/connect/functions.py
@@ -2357,13 +2357,6 @@ def _test() -> None:
# Spark Connect does not support Spark Context but the test depends on
that.
del pyspark.sql.connect.functions.monotonically_increasing_id.__doc__
- # TODO(SPARK-41880): Function `from_json` should support non-literal
expression
- # TODO(SPARK-41879): `DataFrame.collect` should support nested types
- del pyspark.sql.connect.functions.struct.__doc__
- del pyspark.sql.connect.functions.create_map.__doc__
- del pyspark.sql.connect.functions.from_csv.__doc__
- del pyspark.sql.connect.functions.from_json.__doc__
-
# TODO(SPARK-41834): implement Dataframe.conf
del pyspark.sql.connect.functions.from_unixtime.__doc__
del pyspark.sql.connect.functions.timestamp_seconds.__doc__
diff --git a/python/pyspark/sql/connect/types.py
b/python/pyspark/sql/connect/types.py
index 2f4abcec9b3..6f5d5971ef7 100644
--- a/python/pyspark/sql/connect/types.py
+++ b/python/pyspark/sql/connect/types.py
@@ -15,11 +15,13 @@
# limitations under the License.
#
+import datetime
import json
-from typing import Optional
+from typing import Any, Optional, Callable
from pyspark.sql.types import (
+ Row,
DataType,
ByteType,
ShortType,
@@ -184,3 +186,113 @@ def proto_schema_to_pyspark_data_type(schema:
pb2.DataType) -> DataType:
)
else:
raise Exception(f"Unsupported data type {schema}")
+
+
+def _need_converter(dataType: DataType) -> bool:
+ if isinstance(dataType, NullType):
+ return True
+ elif isinstance(dataType, StructType):
+ return True
+ elif isinstance(dataType, ArrayType):
+ return _need_converter(dataType.elementType)
+ elif isinstance(dataType, MapType):
+ # Different from PySpark, here always needs conversion,
+ # since the input from Arrow is a list of tuples.
+ return True
+ elif isinstance(dataType, BinaryType):
+ return True
+ elif isinstance(dataType, (TimestampType, TimestampNTZType)):
+ # Always remove the time zone info for now
+ return True
+ else:
+ return False
+
+
+def _create_converter(dataType: DataType) -> Callable:
+ assert dataType is not None and isinstance(dataType, DataType)
+
+ if not _need_converter(dataType):
+ return lambda value: value
+
+ if isinstance(dataType, NullType):
+ return lambda value: None
+
+ elif isinstance(dataType, StructType):
+
+ field_convs = {f.name: _create_converter(f.dataType) for f in
dataType.fields}
+ need_conv = any(_need_converter(f.dataType) for f in dataType.fields)
+
+ def convert_struct(value: Any) -> Row:
+ if value is None:
+ return Row()
+ else:
+ assert isinstance(value, dict)
+
+ if need_conv:
+ _dict = {}
+ for k, v in value.items():
+ assert isinstance(k, str)
+ _dict[k] = field_convs[k](v)
+ return Row(**_dict)
+ else:
+ return Row(**value)
+
+ return convert_struct
+
+ elif isinstance(dataType, ArrayType):
+
+ element_conv = _create_converter(dataType.elementType)
+
+ def convert_array(value: Any) -> Any:
+ if value is None:
+ return None
+ else:
+ assert isinstance(value, list)
+ return [element_conv(v) for v in value]
+
+ return convert_array
+
+ elif isinstance(dataType, MapType):
+
+ key_conv = _create_converter(dataType.keyType)
+ value_conv = _create_converter(dataType.valueType)
+
+ def convert_map(value: Any) -> Any:
+ if value is None:
+ return None
+ else:
+ assert isinstance(value, list)
+ assert all(isinstance(t, tuple) and len(t) == 2 for t in value)
+ return dict((key_conv(t[0]), value_conv(t[1])) for t in value)
+
+ return convert_map
+
+ elif isinstance(dataType, BinaryType):
+
+ def convert_binary(value: Any) -> Any:
+ if value is None:
+ return None
+ else:
+ assert isinstance(value, bytes)
+ return bytearray(value)
+
+ return convert_binary
+
+ elif isinstance(dataType, (TimestampType, TimestampNTZType)):
+
+ def convert_timestample(value: Any) -> Any:
+ if value is None:
+ return None
+ else:
+ assert isinstance(value, datetime.datetime)
+ if value.tzinfo is not None:
+ # always remove the time zone for now
+ return value.replace(tzinfo=None)
+ else:
+ return value
+
+ return convert_timestample
+
+ else:
+
+ return lambda value: value
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 2b03ce70638..45a9e7b3e17 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -2095,6 +2095,148 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
):
cdf.withMetadata(columnName="name", metadata=["magic"])
+ def test_collect_nested_type(self):
+ query = """
+ SELECT * FROM VALUES
+ (1, 4, 0, 8, true, true, ARRAY(1, NULL, 3), MAP(1, 2, 3, 4)),
+ (2, 5, -1, NULL, false, NULL, ARRAY(1, 3), MAP(1, NULL, 3, 4)),
+ (3, 6, NULL, 0, false, NULL, ARRAY(NULL), NULL)
+ AS tab(a, b, c, d, e, f, g, h)
+ """
+
+ # +---+---+----+----+-----+----+------------+-------------------+
+ # | a| b| c| d| e| f| g| h|
+ # +---+---+----+----+-----+----+------------+-------------------+
+ # | 1| 4| 0| 8| true|true|[1, null, 3]| {1 -> 2, 3 -> 4}|
+ # | 2| 5| -1|null|false|null| [1, 3]|{1 -> null, 3 -> 4}|
+ # | 3| 6|null| 0|false|null| [null]| null|
+ # +---+---+----+----+-----+----+------------+-------------------+
+
+ cdf = self.connect.sql(query)
+ sdf = self.spark.sql(query)
+
+ # test collect array
+ # +--------------+-------------+------------+
+ # |array(a, b, c)| array(e, f)| g|
+ # +--------------+-------------+------------+
+ # | [1, 4, 0]| [true, true]|[1, null, 3]|
+ # | [2, 5, -1]|[false, null]| [1, 3]|
+ # | [3, 6, null]|[false, null]| [null]|
+ # +--------------+-------------+------------+
+ self.assertEqual(
+ cdf.select(CF.array("a", "b", "c"), CF.array("e", "f"),
CF.col("g")).collect(),
+ sdf.select(SF.array("a", "b", "c"), SF.array("e", "f"),
SF.col("g")).collect(),
+ )
+
+ # test collect nested array
+ # +-----------------------------------+-------------------------+
+ # |array(array(a), array(b), array(c))|array(array(e), array(f))|
+ # +-----------------------------------+-------------------------+
+ # | [[1], [4], [0]]| [[true], [true]]|
+ # | [[2], [5], [-1]]| [[false], [null]]|
+ # | [[3], [6], [null]]| [[false], [null]]|
+ # +-----------------------------------+-------------------------+
+ self.assertEqual(
+ cdf.select(
+ CF.array(CF.array("a"), CF.array("b"), CF.array("c")),
+ CF.array(CF.array("e"), CF.array("f")),
+ ).collect(),
+ sdf.select(
+ SF.array(SF.array("a"), SF.array("b"), SF.array("c")),
+ SF.array(SF.array("e"), SF.array("f")),
+ ).collect(),
+ )
+
+ # test collect array of struct, map
+ # +----------------+---------------------+
+ # |array(struct(a))| array(h)|
+ # +----------------+---------------------+
+ # | [{1}]| [{1 -> 2, 3 -> 4}]|
+ # | [{2}]|[{1 -> null, 3 -> 4}]|
+ # | [{3}]| [null]|
+ # +----------------+---------------------+
+ self.assertEqual(
+ cdf.select(CF.array(CF.struct("a")), CF.array("h")).collect(),
+ sdf.select(SF.array(SF.struct("a")), SF.array("h")).collect(),
+ )
+
+ # test collect map
+ # +-------------------+-------------------+
+ # | h| map(a, b, b, c)|
+ # +-------------------+-------------------+
+ # | {1 -> 2, 3 -> 4}| {1 -> 4, 4 -> 0}|
+ # |{1 -> null, 3 -> 4}| {2 -> 5, 5 -> -1}|
+ # | null|{3 -> 6, 6 -> null}|
+ # +-------------------+-------------------+
+ self.assertEqual(
+ cdf.select(CF.col("h"), CF.create_map("a", "b", "b",
"c")).collect(),
+ sdf.select(SF.col("h"), SF.create_map("a", "b", "b",
"c")).collect(),
+ )
+
+ # test collect map of struct, array
+ # +-------------------+------------------------+
+ # | map(a, g)| map(a, struct(b, g))|
+ # +-------------------+------------------------+
+ # |{1 -> [1, null, 3]}|{1 -> {4, [1, null, 3]}}|
+ # | {2 -> [1, 3]}| {2 -> {5, [1, 3]}}|
+ # | {3 -> [null]}| {3 -> {6, [null]}}|
+ # +-------------------+------------------------+
+ self.assertEqual(
+ cdf.select(CF.create_map("a", "g"), CF.create_map("a",
CF.struct("b", "g"))).collect(),
+ sdf.select(SF.create_map("a", "g"), SF.create_map("a",
SF.struct("b", "g"))).collect(),
+ )
+
+ # test collect struct
+ # +------------------+--------------------------+
+ # |struct(a, b, c, d)| struct(e, f, g)|
+ # +------------------+--------------------------+
+ # | {1, 4, 0, 8}|{true, true, [1, null, 3]}|
+ # | {2, 5, -1, null}| {false, null, [1, 3]}|
+ # | {3, 6, null, 0}| {false, null, [null]}|
+ # +------------------+--------------------------+
+ self.assertEqual(
+ cdf.select(CF.struct("a", "b", "c", "d"), CF.struct("e", "f",
"g")).collect(),
+ sdf.select(SF.struct("a", "b", "c", "d"), SF.struct("e", "f",
"g")).collect(),
+ )
+
+ # test collect nested struct
+ #
+------------------------------------------+--------------------------+----------------------------+
# noqa
+ # |struct(a, struct(a, struct(c, struct(d))))|struct(a, b, struct(c,
d))| struct(e, f, struct(g))| # noqa
+ #
+------------------------------------------+--------------------------+----------------------------+
# noqa
+ # | {1, {1, {0, {8}}}}| {1, 4, {0,
8}}|{true, true, {[1, null, 3]}}| # noqa
+ # | {2, {2, {-1, {null}}}}| {2, 5, {-1,
null}}| {false, null, {[1, 3]}}| # noqa
+ # | {3, {3, {null, {0}}}}| {3, 6, {null,
0}}| {false, null, {[null]}}| # noqa
+ #
+------------------------------------------+--------------------------+----------------------------+
# noqa
+ self.assertEqual(
+ cdf.select(
+ CF.struct("a", CF.struct("a", CF.struct("c", CF.struct("d")))),
+ CF.struct("a", "b", CF.struct("c", "d")),
+ CF.struct("e", "f", CF.struct("g")),
+ ).collect(),
+ sdf.select(
+ SF.struct("a", SF.struct("a", SF.struct("c", SF.struct("d")))),
+ SF.struct("a", "b", SF.struct("c", "d")),
+ SF.struct("e", "f", SF.struct("g")),
+ ).collect(),
+ )
+
+ # test collect struct containing array, map
+ # +--------------------------------------------+
+ # | struct(a, struct(a, struct(g, struct(h))))|
+ # +--------------------------------------------+
+ # |{1, {1, {[1, null, 3], {{1 -> 2, 3 -> 4}}}}}|
+ # | {2, {2, {[1, 3], {{1 -> null, 3 -> 4}}}}}|
+ # | {3, {3, {[null], {null}}}}|
+ # +--------------------------------------------+
+ self.assertEqual(
+ cdf.select(
+ CF.struct("a", CF.struct("a", CF.struct("g", CF.struct("h")))),
+ ).collect(),
+ sdf.select(
+ SF.struct("a", SF.struct("a", SF.struct("g", SF.struct("h")))),
+ ).collect(),
+ )
+
def test_unsupported_functions(self):
# SPARK-41225: Disable unsupported functions.
df = self.connect.read.table(self.tbl_name)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]