This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 5e0a82e590d
[SPARK-41746][SPARK-41838][SPARK-41837][SPARK-41835][SPARK-41836][SPARK-41847][CONNECT][PYTHON]
Make `createDataFrame(rows/lists/tuples/dicts)` support nested types
5e0a82e590d is described below
commit 5e0a82e590d1c3c3c5fa7f347dddf450fabbf772
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Fri Jan 13 12:51:42 2023 +0800
[SPARK-41746][SPARK-41838][SPARK-41837][SPARK-41835][SPARK-41836][SPARK-41847][CONNECT][PYTHON]
Make `createDataFrame(rows/lists/tuples/dicts)` support nested types
### What changes were proposed in this pull request?
Make `createDataFrame` support nested types when the input data are rows,
lists, tuples, dicts
### Why are the changes needed?
to be consistent with PySpark
### Does this PR introduce _any_ user-facing change?
yes
### How was this patch tested?
1, enabled doctests (4 of them still fail due to different order in `show`,
see [SPARK-42032](https://issues.apache.org/jira/browse/SPARK-42032))
2, added UT
Closes #39535 from zhengruifeng/connect_cdf_nested.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/sql/connect/conversion.py | 208 +++++++++++++++++++++
python/pyspark/sql/connect/functions.py | 21 +--
python/pyspark/sql/connect/session.py | 83 ++++----
python/pyspark/sql/connect/types.py | 64 +++++++
.../sql/tests/connect/test_connect_basic.py | 134 ++++++++++++-
5 files changed, 441 insertions(+), 69 deletions(-)
diff --git a/python/pyspark/sql/connect/conversion.py
b/python/pyspark/sql/connect/conversion.py
new file mode 100644
index 00000000000..8df45be7eb3
--- /dev/null
+++ b/python/pyspark/sql/connect/conversion.py
@@ -0,0 +1,208 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import datetime
+
+import pyarrow as pa
+
+from pyspark.sql.types import (
+ Row,
+ DataType,
+ TimestampType,
+ TimestampNTZType,
+ MapType,
+ StructType,
+ ArrayType,
+ BinaryType,
+ NullType,
+)
+
+from pyspark.sql.connect.types import to_arrow_schema
+
+from typing import (
+ Any,
+ Callable,
+ Sequence,
+)
+
+
+class LocalDataToArrowConversion:
+ """
+ Conversion from local data (except pandas DataFrame and numpy ndarray) to
Arrow.
+ Currently, only :class:`SparkSession` in Spark Connect can use this class.
+ """
+
+ @staticmethod
+ def _need_converter(dataType: DataType) -> bool:
+ if isinstance(dataType, NullType):
+ return True
+ elif isinstance(dataType, StructType):
+ # Struct maybe rows, should convert to dict.
+ return True
+ elif isinstance(dataType, ArrayType):
+ return
LocalDataToArrowConversion._need_converter(dataType.elementType)
+ elif isinstance(dataType, MapType):
+ # Different from PySpark, here always needs conversion,
+ # since an Arrow Map requires a list of tuples.
+ return True
+ elif isinstance(dataType, BinaryType):
+ return True
+ elif isinstance(dataType, (TimestampType, TimestampNTZType)):
+ # Always truncate
+ return True
+ else:
+ return False
+
+ @staticmethod
+ def _create_converter(dataType: DataType) -> Callable:
+ assert dataType is not None and isinstance(dataType, DataType)
+
+ if not LocalDataToArrowConversion._need_converter(dataType):
+ return lambda value: value
+
+ if isinstance(dataType, NullType):
+ return lambda value: None
+
+ elif isinstance(dataType, StructType):
+
+ field_names = dataType.fieldNames()
+
+ field_convs = {
+ field.name:
LocalDataToArrowConversion._create_converter(field.dataType)
+ for field in dataType.fields
+ }
+
+ def convert_struct(value: Any) -> Any:
+ if value is None:
+ return None
+ else:
+ assert isinstance(value, (Row, dict)), f"{type(value)}
{value}"
+
+ _dict = {}
+ if isinstance(value, dict):
+ for k, v in value.items():
+ assert isinstance(k, str)
+ _dict[k] = field_convs[k](v)
+ elif isinstance(value, Row) and hasattr(value,
"__fields__"):
+ for k, v in value.asDict(recursive=False).items():
+ assert isinstance(k, str)
+ _dict[k] = field_convs[k](v)
+ else:
+ i = 0
+ for v in value:
+ field_name = field_names[i]
+ _dict[field_name] = field_convs[field_name](v)
+ i += 1
+
+ return _dict
+
+ return convert_struct
+
+ elif isinstance(dataType, ArrayType):
+
+ element_conv =
LocalDataToArrowConversion._create_converter(dataType.elementType)
+
+ def convert_array(value: Any) -> Any:
+ if value is None:
+ return None
+ else:
+ assert isinstance(value, list)
+ return [element_conv(v) for v in value]
+
+ return convert_array
+
+ elif isinstance(dataType, MapType):
+
+ key_conv =
LocalDataToArrowConversion._create_converter(dataType.keyType)
+ value_conv =
LocalDataToArrowConversion._create_converter(dataType.valueType)
+
+ def convert_map(value: Any) -> Any:
+ if value is None:
+ return None
+ else:
+ assert isinstance(value, dict)
+
+ _tuples = []
+ for k, v in value.items():
+ _tuples.append((key_conv(k), value_conv(v)))
+
+ return _tuples
+
+ return convert_map
+
+ elif isinstance(dataType, BinaryType):
+
+ def convert_binary(value: Any) -> Any:
+ if value is None:
+ return None
+ else:
+ assert isinstance(value, (bytes, bytearray))
+ return bytes(value)
+
+ return convert_binary
+
+ elif isinstance(dataType, (TimestampType, TimestampNTZType)):
+
+ def convert_timestample(value: Any) -> Any:
+ if value is None:
+ return None
+ else:
+ assert isinstance(value, datetime.datetime)
+ return value.astimezone(datetime.timezone.utc)
+
+ return convert_timestample
+
+ else:
+
+ return lambda value: value
+
+ @staticmethod
+ def convert(data: Sequence[Any], schema: StructType) -> "pa.Table":
+ assert isinstance(data, list) and len(data) > 0
+
+ assert schema is not None and isinstance(schema, StructType)
+
+ pa_schema = to_arrow_schema(schema)
+
+ column_names = schema.fieldNames()
+
+ column_convs = {
+ field.name:
LocalDataToArrowConversion._create_converter(field.dataType)
+ for field in schema.fields
+ }
+
+ pylist = []
+
+ for item in data:
+ _dict = {}
+
+ if isinstance(item, dict):
+ for col, value in item.items():
+ _dict[col] = column_convs[col](value)
+ elif isinstance(item, Row) and hasattr(item, "__fields__"):
+ for col, value in item.asDict(recursive=False).items():
+ _dict[col] = column_convs[col](value)
+ else:
+ i = 0
+ for value in item:
+ col = column_names[i]
+ _dict[col] = column_convs[col](value)
+ i += 1
+
+ pylist.append(_dict)
+
+ return pa.Table.from_pylist(pylist, schema=pa_schema)
diff --git a/python/pyspark/sql/connect/functions.py
b/python/pyspark/sql/connect/functions.py
index 8f9676dfe47..e1286f7d66e 100644
--- a/python/pyspark/sql/connect/functions.py
+++ b/python/pyspark/sql/connect/functions.py
@@ -2365,19 +2365,11 @@ def _test() -> None:
# TODO(SPARK-41757): Fix String representation for Column class
del pyspark.sql.connect.functions.col.__doc__
- # TODO(SPARK-41838): fix dataset.show
- del pyspark.sql.connect.functions.posexplode_outer.__doc__
- del pyspark.sql.connect.functions.explode_outer.__doc__
-
- # TODO(SPARK-41837): createDataFrame datatype conversion error
- del pyspark.sql.connect.functions.to_csv.__doc__
- del pyspark.sql.connect.functions.to_json.__doc__
-
- # TODO(SPARK-41835): Fix `transform_keys` function
+ # TODO(SPARK-42032): different key order in DF.show
del pyspark.sql.connect.functions.transform_keys.__doc__
-
- # TODO(SPARK-41836): Implement `transform_values` function
del pyspark.sql.connect.functions.transform_values.__doc__
+ del pyspark.sql.connect.functions.map_filter.__doc__
+ del pyspark.sql.connect.functions.map_zip_with.__doc__
# TODO(SPARK-41812): Proper column names after join
del pyspark.sql.connect.functions.count_distinct.__doc__
@@ -2388,13 +2380,6 @@ def _test() -> None:
# TODO(SPARK-41845): Fix count bug
del pyspark.sql.connect.functions.count.__doc__
- # TODO(SPARK-41847): mapfield,structlist invalid type
- del pyspark.sql.connect.functions.element_at.__doc__
- del pyspark.sql.connect.functions.explode.__doc__
- del pyspark.sql.connect.functions.map_filter.__doc__
- del pyspark.sql.connect.functions.map_zip_with.__doc__
- del pyspark.sql.connect.functions.posexplode.__doc__
-
globs["spark"] = (
PySparkSession.builder.appName("sql.connect.functions tests")
.remote("local[4]")
diff --git a/python/pyspark/sql/connect/session.py
b/python/pyspark/sql/connect/session.py
index 9ed9a7bd2da..76073fc2717 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -192,9 +192,14 @@ class SparkSession:
_schema: Optional[Union[AtomicType, StructType]] = None
_schema_str: Optional[str] = None
_cols: Optional[List[str]] = None
+ _num_cols: Optional[int] = None
if isinstance(schema, (AtomicType, StructType)):
_schema = schema
+ if isinstance(schema, StructType):
+ _num_cols = len(schema.fields)
+ else:
+ _num_cols = 1
elif isinstance(schema, str):
_schema_str = schema
@@ -202,6 +207,7 @@ class SparkSession:
elif isinstance(schema, (list, tuple)):
# Must re-encode any unicode strings to be consistent with
StructField names
_cols = [x.encode("utf-8") if not isinstance(x, str) else x for x
in schema]
+ _num_cols = len(_cols)
if isinstance(data, Sized) and len(data) == 0:
if _schema is not None:
@@ -289,65 +295,48 @@ class SparkSession:
else:
_data = list(data)
- if _schema is None and isinstance(_data[0], (Row, dict)):
- if isinstance(_data[0], dict):
- # Sort the data to respect inferred schema.
- # For dictionaries, we sort the schema in alphabetical
order.
- _data = [dict(sorted(d.items())) for d in _data]
+ if isinstance(_data[0], dict):
+ # Sort the data to respect inferred schema.
+ # For dictionaries, we sort the schema in alphabetical order.
+ _data = [dict(sorted(d.items())) for d in _data]
+
+ elif not isinstance(_data[0], (Row, tuple, list, dict)):
+ # input data can be [1, 2, 3]
+ # we need to convert it to [[1], [2], [3]] to be able to infer
schema.
+ _data = [[d] for d in _data]
+ try:
_inferred_schema = self._inferSchemaFromList(_data, _cols)
- if _cols is not None:
- for i, name in enumerate(_cols):
- _inferred_schema.fields[i].name = name
- _inferred_schema.names[i] = name
+ except Exception:
+ # For cases like createDataFrame([("Alice", None, 80.1)],
schema)
+ # we can not infer the schema from the data itself.
+ warnings.warn("failed to infer the schema from data")
+ if _schema is None or not isinstance(_schema, StructType):
+ raise ValueError(
+ "Some of types cannot be determined after inferring, "
+ "a StructType Schema is required in this case"
+ )
+ _inferred_schema = _schema
- if _cols is None:
- if _schema is None and _inferred_schema is None:
- if isinstance(_data[0], (list, tuple)):
- _cols = ["_%s" % i for i in range(1, len(_data[0]) +
1)]
- else:
- _cols = ["_1"]
- elif _schema is not None and isinstance(_schema, StructType):
- _cols = _schema.names
- elif _inferred_schema is not None:
- _cols = _inferred_schema.names
- else:
- _cols = ["value"]
+ from pyspark.sql.connect.conversion import
LocalDataToArrowConversion
- if isinstance(_data[0], Row):
- _table = pa.Table.from_pylist([row.asDict(recursive=True) for
row in _data])
- elif isinstance(_data[0], dict):
- _table = pa.Table.from_pylist(_data)
- elif isinstance(_data[0], (list, tuple)):
- _table = pa.Table.from_pylist([dict(zip(_cols, list(item)))
for item in _data])
- else:
- # input data can be [1, 2, 3]
- _table = pa.Table.from_pylist([dict(zip(_cols, [item])) for
item in _data])
-
- # Validate number of columns
- num_cols = _table.shape[1]
- if (
- _schema is not None
- and isinstance(_schema, StructType)
- and len(_schema.fields) != num_cols
- ):
- raise ValueError(
- f"Length mismatch: Expected axis has {num_cols} elements, "
- f"new values have {len(_schema.fields)} elements"
- )
+ # Spark Connect will try its best to build the Arrow table with the
+ # inferred schema in the client side, and then rename the columns
and
+ # cast the datatypes in the server side.
+ _table = LocalDataToArrowConversion.convert(_data,
_inferred_schema)
- if _cols is not None and len(_cols) != num_cols:
+ # TODO: Beside the validation on number of columns, we should also
check
+ # whether the Arrow Schema is compatible with the user provided Schema.
+ if _num_cols is not None and _num_cols != _table.shape[1]:
raise ValueError(
- f"Length mismatch: Expected axis has {num_cols} elements, "
- f"new values have {len(_cols)} elements"
+ f"Length mismatch: Expected axis has {_num_cols} elements, "
+ f"new values have {_table.shape[1]} elements"
)
if _schema is not None:
return DataFrame.withPlan(LocalRelation(_table,
schema=_schema.json()), self)
elif _schema_str is not None:
return DataFrame.withPlan(LocalRelation(_table,
schema=_schema_str), self)
- elif _inferred_schema is not None:
- return DataFrame.withPlan(LocalRelation(_table,
schema=_inferred_schema.json()), self)
elif _cols is not None and len(_cols) > 0:
return DataFrame.withPlan(LocalRelation(_table), self).toDF(*_cols)
else:
diff --git a/python/pyspark/sql/connect/types.py
b/python/pyspark/sql/connect/types.py
index 6f5d5971ef7..a1d864c7429 100644
--- a/python/pyspark/sql/connect/types.py
+++ b/python/pyspark/sql/connect/types.py
@@ -18,6 +18,8 @@
import datetime
import json
+import pyarrow as pa
+
from typing import Any, Optional, Callable
from pyspark.sql.types import (
@@ -188,6 +190,68 @@ def proto_schema_to_pyspark_data_type(schema:
pb2.DataType) -> DataType:
raise Exception(f"Unsupported data type {schema}")
+def to_arrow_type(dt: DataType) -> "pa.DataType":
+ """
+ Convert Spark data type to pyarrow type.
+
+ This function refers to 'pyspark.sql.pandas.types.to_arrow_type' but relax
the restriction,
+ e.g. it supports nested StructType.
+ """
+ if type(dt) == BooleanType:
+ arrow_type = pa.bool_()
+ elif type(dt) == ByteType:
+ arrow_type = pa.int8()
+ elif type(dt) == ShortType:
+ arrow_type = pa.int16()
+ elif type(dt) == IntegerType:
+ arrow_type = pa.int32()
+ elif type(dt) == LongType:
+ arrow_type = pa.int64()
+ elif type(dt) == FloatType:
+ arrow_type = pa.float32()
+ elif type(dt) == DoubleType:
+ arrow_type = pa.float64()
+ elif type(dt) == DecimalType:
+ arrow_type = pa.decimal128(dt.precision, dt.scale)
+ elif type(dt) == StringType:
+ arrow_type = pa.string()
+ elif type(dt) == BinaryType:
+ arrow_type = pa.binary()
+ elif type(dt) == DateType:
+ arrow_type = pa.date32()
+ elif type(dt) == TimestampType:
+ # Timestamps should be in UTC, JVM Arrow timestamps require a timezone
to be read
+ arrow_type = pa.timestamp("us", tz="UTC")
+ elif type(dt) == TimestampNTZType:
+ arrow_type = pa.timestamp("us", tz=None)
+ elif type(dt) == DayTimeIntervalType:
+ arrow_type = pa.duration("us")
+ elif type(dt) == ArrayType:
+ arrow_type = pa.list_(to_arrow_type(dt.elementType))
+ elif type(dt) == MapType:
+ arrow_type = pa.map_(to_arrow_type(dt.keyType),
to_arrow_type(dt.valueType))
+ elif type(dt) == StructType:
+ fields = [
+ pa.field(field.name, to_arrow_type(field.dataType),
nullable=field.nullable)
+ for field in dt
+ ]
+ arrow_type = pa.struct(fields)
+ elif type(dt) == NullType:
+ arrow_type = pa.null()
+ else:
+ raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
+ return arrow_type
+
+
+def to_arrow_schema(schema: StructType) -> "pa.Schema":
+ """Convert a schema from Spark to Arrow"""
+ fields = [
+ pa.field(field.name, to_arrow_type(field.dataType),
nullable=field.nullable)
+ for field in schema
+ ]
+ return pa.schema(fields)
+
+
def _need_converter(dataType: DataType) -> bool:
if isinstance(dataType, NullType):
return True
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 0aea2c81c58..8d9becf259a 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -439,7 +439,7 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
with self.assertRaisesRegex(
ValueError,
- "Length mismatch: Expected axis has 4 elements, new values have 5
elements",
+ "Length mismatch: Expected axis has 5 elements, new values have 4
elements",
):
self.connect.createDataFrame(data, ["a", "b", "c", "d", "e"])
@@ -600,9 +600,6 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
]
),
]:
- print(schema)
- print(schema)
- print(schema)
cdf = self.connect.createDataFrame(data=[], schema=schema)
sdf = self.spark.createDataFrame(data=[], schema=schema)
@@ -615,6 +612,135 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
):
self.connect.createDataFrame(data=[])
+ def test_timestampe_create_from_rows(self):
+ data = [(datetime.datetime(2016, 3, 11, 9, 0, 7), 1)]
+
+ cdf = self.connect.createDataFrame(data, ["date", "val"])
+ sdf = self.spark.createDataFrame(data, ["date", "val"])
+
+ self.assertEqual(cdf.schema, sdf.schema)
+ self.assertEqual(cdf.collect(), sdf.collect())
+
+ def test_nested_type_create_from_rows(self):
+ data1 = [Row(a=1, b=Row(c=2, d=Row(e=3, f=Row(g=4, h=Row(i=5)))))]
+ # root
+ # |-- a: long (nullable = true)
+ # |-- b: struct (nullable = true)
+ # | |-- c: long (nullable = true)
+ # | |-- d: struct (nullable = true)
+ # | | |-- e: long (nullable = true)
+ # | | |-- f: struct (nullable = true)
+ # | | | |-- g: long (nullable = true)
+ # | | | |-- h: struct (nullable = true)
+ # | | | | |-- i: long (nullable = true)
+
+ data2 = [
+ (
+ 1,
+ "a",
+ Row(
+ a=1,
+ b=[1, 2, 3],
+ c={"a": "b"},
+ d=Row(x=1, y="y", z=Row(o=1, p=2, q=Row(g=1.5))),
+ ),
+ )
+ ]
+ # root
+ # |-- _1: long (nullable = true)
+ # |-- _2: string (nullable = true)
+ # |-- _3: struct (nullable = true)
+ # | |-- a: long (nullable = true)
+ # | |-- b: array (nullable = true)
+ # | | |-- element: long (containsNull = true)
+ # | |-- c: map (nullable = true)
+ # | | |-- key: string
+ # | | |-- value: string (valueContainsNull = true)
+ # | |-- d: struct (nullable = true)
+ # | | |-- x: long (nullable = true)
+ # | | |-- y: string (nullable = true)
+ # | | |-- z: struct (nullable = true)
+ # | | | |-- o: long (nullable = true)
+ # | | | |-- p: long (nullable = true)
+ # | | | |-- q: struct (nullable = true)
+ # | | | | |-- g: double (nullable = true)
+
+ data3 = [
+ Row(
+ a=1,
+ b=[1, 2, 3],
+ c={"a": "b"},
+ d=Row(x=1, y="y", z=Row(1, 2, 3)),
+ e=list("hello connect"),
+ )
+ ]
+ # root
+ # |-- a: long (nullable = true)
+ # |-- b: array (nullable = true)
+ # | |-- element: long (containsNull = true)
+ # |-- c: map (nullable = true)
+ # | |-- key: string
+ # | |-- value: string (valueContainsNull = true)
+ # |-- d: struct (nullable = true)
+ # | |-- x: long (nullable = true)
+ # | |-- y: string (nullable = true)
+ # | |-- z: struct (nullable = true)
+ # | | |-- _1: long (nullable = true)
+ # | | |-- _2: long (nullable = true)
+ # | | |-- _3: long (nullable = true)
+ # |-- e: array (nullable = true)
+ # | |-- element: string (containsNull = true)
+
+ data4 = [
+ {
+ "a": 1,
+ "b": Row(x=1, y=Row(z=2)),
+ "c": {"x": -1, "y": 2},
+ "d": [1, 2, 3, 4, 5],
+ }
+ ]
+ # root
+ # |-- a: long (nullable = true)
+ # |-- b: struct (nullable = true)
+ # | |-- x: long (nullable = true)
+ # | |-- y: struct (nullable = true)
+ # | | |-- z: long (nullable = true)
+ # |-- c: map (nullable = true)
+ # | |-- key: string
+ # | |-- value: long (valueContainsNull = true)
+ # |-- d: array (nullable = true)
+ # | |-- element: long (containsNull = true)
+
+ data5 = [
+ {
+ "a": [Row(x=1, y="2"), Row(x=-1, y="-2")],
+ "b": [[1, 2, 3], [4, 5], [6]],
+ "c": {3: {4: {5: 6}}, 7: {8: {9: 0}}},
+ }
+ ]
+ # root
+ # |-- a: array (nullable = true)
+ # | |-- element: struct (containsNull = true)
+ # | | |-- x: long (nullable = true)
+ # | | |-- y: string (nullable = true)
+ # |-- b: array (nullable = true)
+ # | |-- element: array (containsNull = true)
+ # | | |-- element: long (containsNull = true)
+ # |-- c: map (nullable = true)
+ # | |-- key: long
+ # | |-- value: map (valueContainsNull = true)
+ # | | |-- key: long
+ # | | |-- value: map (valueContainsNull = true)
+ # | | | |-- key: long
+ # | | | |-- value: long (valueContainsNull = true)
+
+ for data in [data1, data2, data3, data4, data5]:
+ cdf = self.connect.createDataFrame(data)
+ sdf = self.spark.createDataFrame(data)
+
+ self.assertEqual(cdf.schema, sdf.schema)
+ self.assertEqual(cdf.collect(), sdf.collect())
+
def test_simple_explain_string(self):
df = self.connect.read.table(self.tbl_name).limit(10)
result = df._explain_string()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]