This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 916b0d3de97 [SPARK-43817][SPARK-43702][PYTHON] Support UserDefinedType
in createDataFrame from pandas DataFrame and toPandas
916b0d3de97 is described below
commit 916b0d3de973b8b30a8ede3d56b9f8a711110512
Author: Takuya UESHIN <[email protected]>
AuthorDate: Sun May 28 08:47:35 2023 +0800
[SPARK-43817][SPARK-43702][PYTHON] Support UserDefinedType in
createDataFrame from pandas DataFrame and toPandas
### What changes were proposed in this pull request?
Support `UserDefinedType` in `createDataFrame` from pandas DataFrame and
`toPandas`.
For the following schema and pandas DataFrame:
```py
schema = (
StructType()
.add("point", ExamplePointUDT())
.add("struct", StructType().add("point", ExamplePointUDT()))
.add("array", ArrayType(ExamplePointUDT()))
.add("map", MapType(StringType(), ExamplePointUDT()))
)
data = [
Row(
ExamplePoint(1.0, 2.0),
Row(ExamplePoint(3.0, 4.0)),
[ExamplePoint(5.0, 6.0)],
dict(point=ExamplePoint(7.0, 8.0)),
)
]
df = spark.createDataFrame(data, schema)
pdf = pd.DataFrame.from_records(data, columns=schema.names)
```
##### `spark.createDataFrame()`
For all, return the same results:
```py
>>> spark.createDataFrame(pdf, schema).show(truncate=False)
+----------+------------+------------+---------------------+
|point |struct |array |map |
+----------+------------+------------+---------------------+
|(1.0, 2.0)|{(3.0, 4.0)}|[(5.0, 6.0)]|{point -> (7.0, 8.0)}|
+----------+------------+------------+---------------------+
```
##### `df.toPandas()`
```py
>>> spark.conf.set('spark.sql.execution.pandas.structHandlingMode', 'row')
>>> df.toPandas()
point struct array map
0 (1.0,2.0) ((3.0,4.0),) [(5.0,6.0)] {'point': (7.0,8.0)}
```
### Why are the changes needed?
Currently `UserDefinedType` in `spark.createDataFrame()` with pandas
DataFrame and `df.toPandas()` is not supported with Arrow enabled or in Spark
Connect.
##### `spark.createDataFrame()`
Works without Arrow:
```py
>>> spark.createDataFrame(pdf, schema).show(truncate=False)
+----------+------------+------------+---------------------+
|point |struct |array |map |
+----------+------------+------------+---------------------+
|(1.0, 2.0)|{(3.0, 4.0)}|[(5.0, 6.0)]|{point -> (7.0, 8.0)}|
+----------+------------+------------+---------------------+
```
, whereas:
- With Arrow:
Works with fallback:
```py
>>> spark.createDataFrame(pdf, schema).show(truncate=False)
/.../python/pyspark/sql/pandas/conversion.py:351: UserWarning:
createDataFrame attempted Arrow optimization because
'spark.sql.execution.arrow.pyspark.enabled' is set to true; however, failed by
the reason below:
[UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION] ExamplePointUDT() is not
supported in conversion to Arrow.
Attempting non-optimization as
'spark.sql.execution.arrow.pyspark.fallback.enabled' is set to true.
warn(msg)
+----------+------------+------------+---------------------+
|point |struct |array |map |
+----------+------------+------------+---------------------+
|(1.0, 2.0)|{(3.0, 4.0)}|[(5.0, 6.0)]|{point -> (7.0, 8.0)}|
+----------+------------+------------+---------------------+
```
- Spark Connect
```py
>>> spark.createDataFrame(pdf, schema).show(truncate=False)
Traceback (most recent call last):
...
pyspark.errors.exceptions.base.PySparkTypeError:
[UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION] ExamplePointUDT() is not supported
in conversion to Arrow.
```
##### `df.toPandas()`
Works without Arrow:
```py
>>> spark.conf.set('spark.sql.execution.pandas.structHandlingMode', 'row')
>>> df.toPandas()
point struct array map
0 (1.0,2.0) ((3.0,4.0),) [(5.0,6.0)] {'point': (7.0,8.0)}
```
, whereas:
- With Arrow
Works with fallback:
```py
>>> spark.conf.set('spark.sql.execution.pandas.structHandlingMode', 'row')
>>> df.toPandas()
/.../python/pyspark/sql/pandas/conversion.py:111: UserWarning: toPandas
attempted Arrow optimization because
'spark.sql.execution.arrow.pyspark.enabled' is set to true; however, failed by
the reason below:
[UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION] ExamplePointUDT() is not
supported in conversion to Arrow.
Attempting non-optimization as
'spark.sql.execution.arrow.pyspark.fallback.enabled' is set to true.
warn(msg)
point struct array map
0 (1.0,2.0) ((3.0,4.0),) [(5.0,6.0)] {'point': (7.0,8.0)}
```
- Spark Connect
Results with the internal type:
```py
>>> spark.conf.set('spark.sql.execution.pandas.structHandlingMode', 'row')
>>> df.toPandas()
point struct array map
0 [1.0, 2.0] ([3.0, 4.0],) [[5.0, 6.0]] {'point': [7.0, 8.0]}
```
### Does this PR introduce _any_ user-facing change?
Users will be able to use `UserDefinedType`.
### How was this patch tested?
Added the related tests.
Closes #41333 from ueshin/issues/SPARK-43817/udt.
Authored-by: Takuya UESHIN <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../connect/data_type_ops/test_parity_udt_ops.py | 48 +------
python/pyspark/sql/connect/client/core.py | 4 +-
python/pyspark/sql/connect/conversion.py | 7 +-
python/pyspark/sql/connect/dataframe.py | 2 +-
python/pyspark/sql/connect/session.py | 22 +++-
python/pyspark/sql/connect/types.py | 146 ---------------------
python/pyspark/sql/pandas/conversion.py | 17 ++-
python/pyspark/sql/pandas/serializers.py | 28 ++--
python/pyspark/sql/pandas/types.py | 58 ++++++--
.../pyspark/sql/tests/connect/test_parity_arrow.py | 6 +
python/pyspark/sql/tests/test_arrow.py | 67 +++++++++-
11 files changed, 171 insertions(+), 234 deletions(-)
diff --git
a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_udt_ops.py
b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_udt_ops.py
index 81511829c06..70a79e4cd3f 100644
--- a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_udt_ops.py
+++ b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_udt_ops.py
@@ -25,53 +25,7 @@ from pyspark.testing.connectutils import
ReusedConnectTestCase
class UDTOpsParityTests(
UDTOpsTestsMixin, PandasOnSparkTestUtils, OpsTestBase,
ReusedConnectTestCase
):
- @unittest.skip(
- "TODO(SPARK-43702): Fix pyspark.sql.pandas.types.to_arrow_type to work
with Spark Connect."
- )
- def test_eq(self):
- super().test_eq()
-
- @unittest.skip(
- "TODO(SPARK-43702): Fix pyspark.sql.pandas.types.to_arrow_type to work
with Spark Connect."
- )
- def test_from_to_pandas(self):
- super().test_from_to_pandas()
-
- @unittest.skip(
- "TODO(SPARK-43702): Fix pyspark.sql.pandas.types.to_arrow_type to work
with Spark Connect."
- )
- def test_ge(self):
- super().test_ge()
-
- @unittest.skip(
- "TODO(SPARK-43702): Fix pyspark.sql.pandas.types.to_arrow_type to work
with Spark Connect."
- )
- def test_gt(self):
- super().test_gt()
-
- @unittest.skip(
- "TODO(SPARK-43702): Fix pyspark.sql.pandas.types.to_arrow_type to work
with Spark Connect."
- )
- def test_isnull(self):
- super().test_isnull()
-
- @unittest.skip(
- "TODO(SPARK-43702): Fix pyspark.sql.pandas.types.to_arrow_type to work
with Spark Connect."
- )
- def test_le(self):
- super().test_le()
-
- @unittest.skip(
- "TODO(SPARK-43702): Fix pyspark.sql.pandas.types.to_arrow_type to work
with Spark Connect."
- )
- def test_lt(self):
- super().test_lt()
-
- @unittest.skip(
- "TODO(SPARK-43702): Fix pyspark.sql.pandas.types.to_arrow_type to work
with Spark Connect."
- )
- def test_ne(self):
- super().test_ne()
+ pass
if __name__ == "__main__":
diff --git a/python/pyspark/sql/connect/client/core.py
b/python/pyspark/sql/connect/client/core.py
index 544ed5d4183..a0f790b2992 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -75,7 +75,7 @@ from pyspark.sql.connect.expressions import (
CommonInlineUserDefinedFunction,
JavaUDF,
)
-from pyspark.sql.pandas.types import _create_converter_to_pandas
+from pyspark.sql.pandas.types import _create_converter_to_pandas,
from_arrow_schema
from pyspark.sql.types import DataType, StructType, TimestampType, _has_type
from pyspark.rdd import PythonEvalType
from pyspark.storagelevel import StorageLevel
@@ -717,7 +717,7 @@ class SparkConnectClient(object):
table, schema, metrics, observed_metrics, _ =
self._execute_and_fetch(req)
assert table is not None
- schema = schema or types.from_arrow_schema(table.schema,
prefer_timestamp_ntz=True)
+ schema = schema or from_arrow_schema(table.schema,
prefer_timestamp_ntz=True)
assert schema is not None and isinstance(schema, StructType)
# Rename columns to avoid duplicated column names.
diff --git a/python/pyspark/sql/connect/conversion.py
b/python/pyspark/sql/connect/conversion.py
index 3cc301c38ea..cdbc3a1e39c 100644
--- a/python/pyspark/sql/connect/conversion.py
+++ b/python/pyspark/sql/connect/conversion.py
@@ -42,9 +42,8 @@ from pyspark.sql.types import (
)
from pyspark.storagelevel import StorageLevel
-from pyspark.sql.connect.types import to_arrow_schema
import pyspark.sql.connect.proto as pb2
-from pyspark.sql.pandas.types import _dedup_names, _deduplicate_field_names
+from pyspark.sql.pandas.types import to_arrow_schema, _dedup_names,
_deduplicate_field_names
from typing import (
Any,
@@ -246,7 +245,7 @@ class LocalDataToArrowConversion:
elif isinstance(dataType, UserDefinedType):
udt: UserDefinedType = dataType
- conv =
LocalDataToArrowConversion._create_converter(dataType.sqlType())
+ conv = LocalDataToArrowConversion._create_converter(udt.sqlType())
def convert_udt(value: Any) -> Any:
if value is None:
@@ -428,7 +427,7 @@ class ArrowTableToRowsConversion:
elif isinstance(dataType, UserDefinedType):
udt: UserDefinedType = dataType
- conv =
ArrowTableToRowsConversion._create_converter(dataType.sqlType())
+ conv = ArrowTableToRowsConversion._create_converter(udt.sqlType())
def convert_udt(value: Any) -> Any:
if value is None:
diff --git a/python/pyspark/sql/connect/dataframe.py
b/python/pyspark/sql/connect/dataframe.py
index 70aa53ed73e..46218bb4dc0 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -76,7 +76,7 @@ from pyspark.sql.connect.functions import (
lit,
expr as sql_expression,
)
-from pyspark.sql.connect.types import from_arrow_schema
+from pyspark.sql.pandas.types import from_arrow_schema
if TYPE_CHECKING:
diff --git a/python/pyspark/sql/connect/session.py
b/python/pyspark/sql/connect/session.py
index 7932ab54081..2d58ce1daf0 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -331,8 +331,12 @@ class SparkSession:
# Determine arrow types to coerce data when creating batches
arrow_schema: Optional[pa.Schema] = None
+ spark_types: List[Optional[DataType]]
+ arrow_types: List[Optional[pa.DataType]]
if isinstance(schema, StructType):
- arrow_schema = to_arrow_schema(cast(StructType,
_deduplicate_field_names(schema)))
+ deduped_schema = cast(StructType,
_deduplicate_field_names(schema))
+ spark_types = [field.dataType for field in
deduped_schema.fields]
+ arrow_schema = to_arrow_schema(deduped_schema)
arrow_types = [field.type for field in arrow_schema]
_cols = [str(x) if not isinstance(x, str) else x for x in
schema.fieldNames()]
elif isinstance(schema, DataType):
@@ -342,14 +346,15 @@ class SparkSession:
)
else:
# Any timestamps must be coerced to be compatible with Spark
- arrow_types = [
- to_arrow_type(TimestampType())
+ spark_types = [
+ TimestampType()
if is_datetime64_dtype(t) or is_datetime64tz_dtype(t)
- else to_arrow_type(DayTimeIntervalType())
+ else DayTimeIntervalType()
if is_timedelta64_dtype(t)
else None
for t in data.dtypes
]
+ arrow_types = [to_arrow_type(dt) if dt is not None else None
for dt in spark_types]
timezone, safecheck = self._client.get_configs(
"spark.sql.session.timeZone",
"spark.sql.execution.pandas.convertToArrowArraySafely"
@@ -358,7 +363,14 @@ class SparkSession:
ser = ArrowStreamPandasSerializer(cast(str, timezone), safecheck
== "true")
_table = pa.Table.from_batches(
- [ser._create_batch([(c, t) for (_, c), t in zip(data.items(),
arrow_types)])]
+ [
+ ser._create_batch(
+ [
+ (c, at, st)
+ for (_, c), at, st in zip(data.items(),
arrow_types, spark_types)
+ ]
+ )
+ ]
)
if isinstance(schema, StructType):
diff --git a/python/pyspark/sql/connect/types.py
b/python/pyspark/sql/connect/types.py
index fa8f9f5f8ff..2a21cdf0675 100644
--- a/python/pyspark/sql/connect/types.py
+++ b/python/pyspark/sql/connect/types.py
@@ -20,8 +20,6 @@ check_dependencies(__name__)
import json
-import pyarrow as pa
-
from typing import Any, Dict, Optional
from pyspark.sql.types import (
@@ -299,147 +297,3 @@ def proto_schema_to_pyspark_data_type(schema:
pb2.DataType) -> DataType:
return UserDefinedType.fromJson(json_value)
else:
raise Exception(f"Unsupported data type {schema}")
-
-
-def to_arrow_type(dt: DataType) -> "pa.DataType":
- """
- Convert Spark data type to pyarrow type.
-
- This function refers to 'pyspark.sql.pandas.types.to_arrow_type' but relax
the restriction,
- e.g. it supports nested StructType.
- """
- if type(dt) == BooleanType:
- arrow_type = pa.bool_()
- elif type(dt) == ByteType:
- arrow_type = pa.int8()
- elif type(dt) == ShortType:
- arrow_type = pa.int16()
- elif type(dt) == IntegerType:
- arrow_type = pa.int32()
- elif type(dt) == LongType:
- arrow_type = pa.int64()
- elif type(dt) == FloatType:
- arrow_type = pa.float32()
- elif type(dt) == DoubleType:
- arrow_type = pa.float64()
- elif type(dt) == DecimalType:
- arrow_type = pa.decimal128(dt.precision, dt.scale)
- elif type(dt) == StringType:
- arrow_type = pa.string()
- elif type(dt) == BinaryType:
- arrow_type = pa.binary()
- elif type(dt) == DateType:
- arrow_type = pa.date32()
- elif type(dt) == TimestampType:
- # Timestamps should be in UTC, JVM Arrow timestamps require a timezone
to be read
- arrow_type = pa.timestamp("us", tz="UTC")
- elif type(dt) == TimestampNTZType:
- arrow_type = pa.timestamp("us", tz=None)
- elif type(dt) == DayTimeIntervalType:
- arrow_type = pa.duration("us")
- elif type(dt) == ArrayType:
- field = pa.field("element", to_arrow_type(dt.elementType),
nullable=dt.containsNull)
- arrow_type = pa.list_(field)
- elif type(dt) == MapType:
- key_field = pa.field("key", to_arrow_type(dt.keyType), nullable=False)
- value_field = pa.field("value", to_arrow_type(dt.valueType),
nullable=dt.valueContainsNull)
- arrow_type = pa.map_(key_field, value_field)
- elif type(dt) == StructType:
- fields = [
- pa.field(field.name, to_arrow_type(field.dataType),
nullable=field.nullable)
- for field in dt
- ]
- arrow_type = pa.struct(fields)
- elif type(dt) == NullType:
- arrow_type = pa.null()
- elif isinstance(dt, UserDefinedType):
- arrow_type = to_arrow_type(dt.sqlType())
- else:
- raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
- return arrow_type
-
-
-def to_arrow_schema(schema: StructType) -> "pa.Schema":
- """Convert a schema from Spark to Arrow"""
- fields = [
- pa.field(field.name, to_arrow_type(field.dataType),
nullable=field.nullable)
- for field in schema
- ]
- return pa.schema(fields)
-
-
-def from_arrow_type(at: "pa.DataType", prefer_timestamp_ntz: bool = False) ->
DataType:
- """Convert pyarrow type to Spark data type.
-
- This function refers to 'pyspark.sql.pandas.types.from_arrow_type' but
relax the restriction,
- e.g. it supports nested StructType, Array of TimestampType. However, Arrow
DictionaryType is
- not allowed.
- """
- import pyarrow.types as types
-
- spark_type: DataType
- if types.is_boolean(at):
- spark_type = BooleanType()
- elif types.is_int8(at):
- spark_type = ByteType()
- elif types.is_int16(at):
- spark_type = ShortType()
- elif types.is_int32(at):
- spark_type = IntegerType()
- elif types.is_int64(at):
- spark_type = LongType()
- elif types.is_float32(at):
- spark_type = FloatType()
- elif types.is_float64(at):
- spark_type = DoubleType()
- elif types.is_decimal(at):
- spark_type = DecimalType(precision=at.precision, scale=at.scale)
- elif types.is_string(at):
- spark_type = StringType()
- elif types.is_binary(at):
- spark_type = BinaryType()
- elif types.is_date32(at):
- spark_type = DateType()
- elif types.is_timestamp(at) and prefer_timestamp_ntz and at.tz is None:
- spark_type = TimestampNTZType()
- elif types.is_timestamp(at):
- spark_type = TimestampType()
- elif types.is_duration(at):
- spark_type = DayTimeIntervalType()
- elif types.is_list(at):
- spark_type = ArrayType(from_arrow_type(at.value_type,
prefer_timestamp_ntz))
- elif types.is_map(at):
- spark_type = MapType(
- from_arrow_type(at.key_type, prefer_timestamp_ntz),
- from_arrow_type(at.item_type, prefer_timestamp_ntz),
- )
- elif types.is_struct(at):
- return StructType(
- [
- StructField(
- field.name,
- from_arrow_type(field.type, prefer_timestamp_ntz),
- nullable=field.nullable,
- )
- for field in at
- ]
- )
- elif types.is_null(at):
- spark_type = NullType()
- else:
- raise TypeError("Unsupported type in conversion from Arrow: " +
str(at))
- return spark_type
-
-
-def from_arrow_schema(arrow_schema: "pa.Schema", prefer_timestamp_ntz: bool =
False) -> StructType:
- """Convert schema from Arrow to Spark."""
- return StructType(
- [
- StructField(
- field.name,
- from_arrow_type(field.type, prefer_timestamp_ntz),
- nullable=field.nullable,
- )
- for field in arrow_schema
- ]
- )
diff --git a/python/pyspark/sql/pandas/conversion.py
b/python/pyspark/sql/pandas/conversion.py
index 5e147aff48d..8664c4df73e 100644
--- a/python/pyspark/sql/pandas/conversion.py
+++ b/python/pyspark/sql/pandas/conversion.py
@@ -598,9 +598,7 @@ class SparkConversionMixin:
# Determine arrow types to coerce data when creating batches
if isinstance(schema, StructType):
- arrow_types = [
- to_arrow_type(_deduplicate_field_names(f.dataType)) for f in
schema.fields
- ]
+ spark_types = [_deduplicate_field_names(f.dataType) for f in
schema.fields]
elif isinstance(schema, DataType):
raise PySparkTypeError(
error_class="UNSUPPORTED_DATA_TYPE_FOR_ARROW",
@@ -608,10 +606,8 @@ class SparkConversionMixin:
)
else:
# Any timestamps must be coerced to be compatible with Spark
- arrow_types = [
- to_arrow_type(TimestampType())
- if is_datetime64_dtype(t) or is_datetime64tz_dtype(t)
- else None
+ spark_types = [
+ TimestampType() if is_datetime64_dtype(t) or
is_datetime64tz_dtype(t) else None
for t in pdf.dtypes
]
@@ -619,9 +615,12 @@ class SparkConversionMixin:
step = self._jconf.arrowMaxRecordsPerBatch()
pdf_slices = (pdf.iloc[start : start + step] for start in range(0,
len(pdf), step))
- # Create list of Arrow (columns, type) for serializer dump_stream
+ # Create list of Arrow (columns, arrow_type, spark_type) for
serializer dump_stream
arrow_data = [
- [(c, t) for (_, c), t in zip(pdf_slice.items(), arrow_types)]
+ [
+ (c, to_arrow_type(t) if t is not None else None, t)
+ for (_, c), t in zip(pdf_slice.items(), spark_types)
+ ]
for pdf_slice in pdf_slices
]
diff --git a/python/pyspark/sql/pandas/serializers.py
b/python/pyspark/sql/pandas/serializers.py
index e81d90fc23e..84471143367 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -27,7 +27,7 @@ from pyspark.sql.pandas.types import (
_create_converter_from_pandas,
_create_converter_to_pandas,
)
-from pyspark.sql.types import StringType, StructType, BinaryType, StructField,
LongType
+from pyspark.sql.types import DataType, StringType, StructType, BinaryType,
StructField, LongType
class SpecialLengths:
@@ -189,7 +189,7 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
)
return converter(s)
- def _create_array(self, series, arrow_type):
+ def _create_array(self, series, arrow_type, spark_type=None):
"""
Create an Arrow Array from the given pandas.Series and optional type.
@@ -199,6 +199,8 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
A single series
arrow_type : pyarrow.DataType, optional
If None, pyarrow's inferred type will be used
+ spark_type : DataType, optional
+ If None, spark type converted from arrow_type will be used
Returns
-------
@@ -211,10 +213,10 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
series = series.astype(series.dtypes.categories.dtype)
if arrow_type is not None:
- spark_type = from_arrow_type(arrow_type, prefer_timestamp_ntz=True)
+ dt = spark_type or from_arrow_type(arrow_type,
prefer_timestamp_ntz=True)
# TODO(SPARK-43579): cache the converter for reuse
conv = _create_converter_from_pandas(
- spark_type, timezone=self._timezone,
error_on_duplicated_field_names=False
+ dt, timezone=self._timezone,
error_on_duplicated_field_names=False
)
series = conv(series)
@@ -261,14 +263,24 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
"""
import pyarrow as pa
- # Make input conform to [(series1, type1), (series2, type2), ...]
- if not isinstance(series, (list, tuple)) or (
- len(series) == 2 and isinstance(series[1], pa.DataType)
+ # Make input conform to
+ # [(series1, arrow_type1, spark_type1), (series2, arrow_type2,
spark_type2), ...]
+ if (
+ not isinstance(series, (list, tuple))
+ or (len(series) == 2 and isinstance(series[1], pa.DataType))
+ or (
+ len(series) == 3
+ and isinstance(series[1], pa.DataType)
+ and isinstance(series[2], DataType)
+ )
):
series = [series]
series = ((s, None) if not isinstance(s, (list, tuple)) else s for s
in series)
+ series = ((s[0], s[1], None) if len(s) == 2 else s for s in series)
- arrs = [self._create_array(s, t) for s, t in series]
+ arrs = [
+ self._create_array(s, arrow_type, spark_type) for s, arrow_type,
spark_type in series
+ ]
return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in
range(len(arrs))])
def dump_stream(self, iterator, stream):
diff --git a/python/pyspark/sql/pandas/types.py
b/python/pyspark/sql/pandas/types.py
index adf497bbd73..ae7c25e0828 100644
--- a/python/pyspark/sql/pandas/types.py
+++ b/python/pyspark/sql/pandas/types.py
@@ -46,6 +46,7 @@ from pyspark.sql.types import (
StructField,
NullType,
DataType,
+ UserDefinedType,
Row,
_create_row,
)
@@ -119,6 +120,8 @@ def to_arrow_type(dt: DataType) -> "pa.DataType":
arrow_type = pa.struct(fields)
elif type(dt) == NullType:
arrow_type = pa.null()
+ elif isinstance(dt, UserDefinedType):
+ arrow_type = to_arrow_type(dt.sqlType())
else:
raise PySparkTypeError(
error_class="UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION",
@@ -561,10 +564,12 @@ def _create_converter_to_pandas(
return correct_dtype
- def _converter(dt: DataType) -> Optional[Callable[[Any], Any]]:
+ def _converter(
+ dt: DataType, _struct_in_pandas: Optional[str]
+ ) -> Optional[Callable[[Any], Any]]:
if isinstance(dt, ArrayType):
- _element_conv = _converter(dt.elementType)
+ _element_conv = _converter(dt.elementType, _struct_in_pandas)
if _element_conv is None:
return None
@@ -582,8 +587,8 @@ def _create_converter_to_pandas(
return convert_array
elif isinstance(dt, MapType):
- _key_conv = _converter(dt.keyType) or (lambda x: x)
- _value_conv = _converter(dt.valueType) or (lambda x: x)
+ _key_conv = _converter(dt.keyType, _struct_in_pandas) or (lambda
x: x)
+ _value_conv = _converter(dt.valueType, _struct_in_pandas) or
(lambda x: x)
def convert_map(value: Any) -> Any:
if value is None:
@@ -599,7 +604,7 @@ def _create_converter_to_pandas(
return convert_map
elif isinstance(dt, StructType):
- assert struct_in_pandas is not None
+ assert _struct_in_pandas is not None
field_names = dt.names
@@ -611,9 +616,11 @@ def _create_converter_to_pandas(
dedup_field_names = _dedup_names(field_names)
- field_convs = [_converter(f.dataType) or (lambda x: x) for f in
dt.fields]
+ field_convs = [
+ _converter(f.dataType, _struct_in_pandas) or (lambda x: x) for
f in dt.fields
+ ]
- if struct_in_pandas == "row":
+ if _struct_in_pandas == "row":
def convert_struct_as_row(value: Any) -> Any:
if value is None:
@@ -633,7 +640,7 @@ def _create_converter_to_pandas(
return convert_struct_as_row
- elif struct_in_pandas == "dict":
+ elif _struct_in_pandas == "dict":
def convert_struct_as_dict(value: Any) -> Any:
if value is None:
@@ -654,7 +661,7 @@ def _create_converter_to_pandas(
return convert_struct_as_dict
else:
- raise ValueError(f"Unknown value for `struct_in_pandas`:
{struct_in_pandas}")
+ raise ValueError(f"Unknown value for `struct_in_pandas`:
{_struct_in_pandas}")
elif isinstance(dt, TimestampType):
assert timezone is not None
@@ -685,10 +692,26 @@ def _create_converter_to_pandas(
return convert_timestamp_ntz
+ elif isinstance(dt, UserDefinedType):
+ udt: UserDefinedType = dt
+
+ conv = _converter(udt.sqlType(), _struct_in_pandas="row") or
(lambda x: x)
+
+ def convert_udt(value: Any) -> Any:
+ if value is None:
+ return None
+ elif hasattr(value, "__UDT__"):
+ assert isinstance(value.__UDT__, type(udt))
+ return value
+ else:
+ return udt.deserialize(conv(value))
+
+ return convert_udt
+
else:
return None
- conv = _converter(data_type)
+ conv = _converter(data_type, struct_in_pandas)
if conv is not None:
return lambda pser: pser.apply(conv) # type: ignore[return-value]
else:
@@ -779,7 +802,7 @@ def _create_converter_from_pandas(
for i, key in enumerate(field_names)
}
else:
- assert isinstance(value, Row)
+ assert isinstance(value, tuple)
return {dedup_field_names[i]: field_convs[i](v) for i, v
in enumerate(value)}
return convert_struct
@@ -799,6 +822,19 @@ def _create_converter_from_pandas(
return convert_timestamp
+ elif isinstance(dt, UserDefinedType):
+ udt: UserDefinedType = dt
+
+ conv = _converter(udt.sqlType()) or (lambda x: x)
+
+ def convert_udt(value: Any) -> Any:
+ if value is None:
+ return None
+ else:
+ return conv(udt.serialize(value))
+
+ return convert_udt
+
return None
conv = _converter(data_type)
diff --git a/python/pyspark/sql/tests/connect/test_parity_arrow.py
b/python/pyspark/sql/tests/connect/test_parity_arrow.py
index d1c8a1a55a0..60f1ef257c5 100644
--- a/python/pyspark/sql/tests/connect/test_parity_arrow.py
+++ b/python/pyspark/sql/tests/connect/test_parity_arrow.py
@@ -121,6 +121,12 @@ class ArrowParityTests(ArrowTestsMixin,
ReusedConnectTestCase):
def test_toPandas_nested_timestamp(self):
self.check_toPandas_nested_timestamp(True)
+ def test_createDataFrame_udt(self):
+ self.check_createDataFrame_udt(True)
+
+ def test_toPandas_udt(self):
+ self.check_toPandas_udt(True)
+
if __name__ == "__main__":
from pyspark.sql.tests.connect.test_parity_arrow import * # noqa: F401
diff --git a/python/pyspark/sql/tests/test_arrow.py
b/python/pyspark/sql/tests/test_arrow.py
index dfde747c265..e26aabbea27 100644
--- a/python/pyspark/sql/tests/test_arrow.py
+++ b/python/pyspark/sql/tests/test_arrow.py
@@ -53,6 +53,8 @@ from pyspark.testing.sqlutils import (
have_pyarrow,
pandas_requirement_message,
pyarrow_requirement_message,
+ ExamplePoint,
+ ExamplePointUDT,
)
from pyspark.testing.utils import QuietTest
from pyspark.errors import ArithmeticException, PySparkTypeError,
UnsupportedOperationException
@@ -1022,7 +1024,7 @@ class ArrowTestsMixin:
df = self.spark.range(2).select([])
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled":
arrow_enabled}):
- assert_frame_equal(df.toPandas(), pd.DataFrame(index=range(2)))
+ assert_frame_equal(df.toPandas(), pd.DataFrame(columns=[],
index=range(2)))
def test_createDataFrame_nested_timestamp(self):
for arrow_enabled in [True, False]:
@@ -1143,6 +1145,69 @@ class ArrowTestsMixin:
assert_frame_equal(pdf, expected)
+ def test_createDataFrame_udt(self):
+ for arrow_enabled in [True, False]:
+ with self.subTest(arrow_enabled=arrow_enabled):
+ self.check_createDataFrame_udt(arrow_enabled)
+
+ def check_createDataFrame_udt(self, arrow_enabled):
+ schema = (
+ StructType()
+ .add("point", ExamplePointUDT())
+ .add("struct", StructType().add("point", ExamplePointUDT()))
+ .add("array", ArrayType(ExamplePointUDT()))
+ .add("map", MapType(StringType(), ExamplePointUDT()))
+ )
+ data = [
+ Row(
+ ExamplePoint(1.0, 2.0),
+ Row(ExamplePoint(3.0, 4.0)),
+ [ExamplePoint(5.0, 6.0)],
+ dict(point=ExamplePoint(7.0, 8.0)),
+ )
+ ]
+ pdf = pd.DataFrame.from_records(data, columns=schema.names)
+
+ with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled":
arrow_enabled}):
+ df = self.spark.createDataFrame(pdf, schema)
+
+ self.assertEqual(df.collect(), data)
+
+ def test_toPandas_udt(self):
+ for arrow_enabled in [True, False]:
+ with self.subTest(arrow_enabled=arrow_enabled):
+ self.check_toPandas_udt(arrow_enabled)
+
+ def check_toPandas_udt(self, arrow_enabled):
+ schema = (
+ StructType()
+ .add("point", ExamplePointUDT())
+ .add("struct", StructType().add("point", ExamplePointUDT()))
+ .add("array", ArrayType(ExamplePointUDT()))
+ .add("map", MapType(StringType(), ExamplePointUDT()))
+ )
+ data = [
+ Row(
+ ExamplePoint(1.0, 2.0),
+ Row(ExamplePoint(3.0, 4.0)),
+ [ExamplePoint(5.0, 6.0)],
+ dict(point=ExamplePoint(7.0, 8.0)),
+ )
+ ]
+ df = self.spark.createDataFrame(data, schema)
+
+ with self.sql_conf(
+ {
+ "spark.sql.execution.arrow.pyspark.enabled": arrow_enabled,
+ "spark.sql.execution.pandas.structHandlingMode": "row",
+ }
+ ):
+ pdf = df.toPandas()
+
+ expected = pd.DataFrame.from_records(data, columns=schema.names)
+
+ assert_frame_equal(pdf, expected)
+
@unittest.skipIf(
not have_pandas or not have_pyarrow,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]