This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 50d9f94225b [SPARK-43055][CONNECT][PYTHON][FOLLOWUP] Fix deduplicate
field names and refactor
50d9f94225b is described below
commit 50d9f94225ba0f127ceaebfea465ac450b017f86
Author: Takuya UESHIN <[email protected]>
AuthorDate: Fri Apr 21 14:57:58 2023 +0900
[SPARK-43055][CONNECT][PYTHON][FOLLOWUP] Fix deduplicate field names and
refactor
### What changes were proposed in this pull request?
Fixes deduplicate field names, and refactor to use the same renaming rule
between `ArrowTableToRowsConversion.convert` and
`LocalDataToArrowConversion.convert`.
### Why are the changes needed?
If there is a duplicated field name in a separate position, it fails to
deduplicate and returns a wrong result.
```py
>>> from pyspark.sql.types import *
>>> data = [
... Row(Row("a", 1), Row(2, 3, "b", 4, "c", "d")),
... Row(Row("w", 6), Row(7, 8, "x", 9, "y", "z")),
... ]
>>> schema = (
... StructType()
... .add("struct", StructType().add("x", StringType()).add("x",
IntegerType()))
... .add(
... "struct",
... StructType()
... .add("a", IntegerType())
... .add("x", IntegerType())
... .add("x", StringType())
... .add("y", IntegerType())
... .add("y", StringType())
... .add("x", StringType()),
... )
... )
>>> df = spark.createDataFrame(data, schema=schema)
>>>
>>> df.collect()
[Row(struct=Row(x='a', x=1), struct=Row(a=2, x=None, x=None, y=4, y='c',
x=None)), Row(struct=Row(x='w', x=6), struct=Row(a=7, x=None, x=None, y=9,
y='y', x=None))]
```
It should be:
```py
>>> df.collect()
[Row(struct=Row(x='a', x=1), struct=Row(a=2, x=3, x='b', y=4, y='c',
x='d')), Row(struct=Row(x='w', x=6), struct=Row(a=7, x=8, x='x', y=9, y='y',
x='z'))]
```
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Updated the related test.
Closes #40888 from ueshin/issues/SPARK-43055/fix_dedup.
Authored-by: Takuya UESHIN <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/connect/conversion.py | 102 +++++++++++++++++------------
python/pyspark/sql/tests/test_dataframe.py | 8 ++-
2 files changed, 65 insertions(+), 45 deletions(-)
diff --git a/python/pyspark/sql/connect/conversion.py
b/python/pyspark/sql/connect/conversion.py
index a6fe0c00e09..16679e80205 100644
--- a/python/pyspark/sql/connect/conversion.py
+++ b/python/pyspark/sql/connect/conversion.py
@@ -40,7 +40,6 @@ from pyspark.sql.types import (
DecimalType,
StringType,
UserDefinedType,
- cast,
)
from pyspark.storagelevel import StorageLevel
@@ -50,7 +49,6 @@ import pyspark.sql.connect.proto as pb2
from typing import (
Any,
Callable,
- Dict,
Sequence,
List,
)
@@ -104,6 +102,7 @@ class LocalDataToArrowConversion:
elif isinstance(dataType, StructType):
field_names = dataType.fieldNames()
+ dedup_field_names = _dedup_names(dataType.names)
field_convs = [
LocalDataToArrowConversion._create_converter(field.dataType)
@@ -123,7 +122,7 @@ class LocalDataToArrowConversion:
value = value.__dict__
if isinstance(value, dict):
for i, field in enumerate(field_names):
- _dict[f"col_{i}"] =
field_convs[i](value.get(field))
+ _dict[dedup_field_names[i]] =
field_convs[i](value.get(field))
else:
if len(value) != len(field_names):
raise ValueError(
@@ -131,7 +130,7 @@ class LocalDataToArrowConversion:
f"new values have {len(value)} elements"
)
for i in range(len(field_names)):
- _dict[f"col_{i}"] = field_convs[i](value[i])
+ _dict[dedup_field_names[i]] =
field_convs[i](value[i])
return _dict
@@ -290,26 +289,16 @@ class LocalDataToArrowConversion:
for i in range(len(column_names)):
pylist[i].append(column_convs[i](item[i]))
- def normalize(dt: DataType) -> DataType:
- if isinstance(dt, StructType):
- return StructType(
- [
- StructField(f"col_{i}", normalize(field.dataType),
nullable=field.nullable)
- for i, field in enumerate(dt.fields)
- ]
- )
- elif isinstance(dt, ArrayType):
- return ArrayType(normalize(dt.elementType),
containsNull=dt.containsNull)
- elif isinstance(dt, MapType):
- return MapType(
- normalize(dt.keyType),
- normalize(dt.valueType),
- valueContainsNull=dt.valueContainsNull,
- )
- else:
- return dt
-
- pa_schema = to_arrow_schema(cast(StructType, normalize(schema)))
+ pa_schema = to_arrow_schema(
+ StructType(
+ [
+ StructField(
+ field.name, _deduplicate_field_names(field.dataType),
field.nullable
+ )
+ for field in schema.fields
+ ]
+ )
+ )
return pa.Table.from_arrays(pylist, schema=pa_schema)
@@ -355,25 +344,7 @@ class ArrowTableToRowsConversion:
elif isinstance(dataType, StructType):
field_names = dataType.names
-
- if len(set(field_names)) == len(field_names):
- dedup_field_names = field_names
- else:
- gen_new_name: Dict[str, Callable[[], str]] = {}
- for name, group in itertools.groupby(dataType.names):
- if len(list(group)) > 1:
-
- def _gen(_name: str) -> Callable[[], str]:
- _i = itertools.count()
- return lambda: f"{_name}_{next(_i)}"
-
- else:
-
- def _gen(_name: str) -> Callable[[], str]:
- return lambda: _name
-
- gen_new_name[name] = _gen(name)
- dedup_field_names = [gen_new_name[name]() for name in
dataType.names]
+ dedup_field_names = _dedup_names(field_names)
field_convs = [
ArrowTableToRowsConversion._create_converter(f.dataType) for f
in dataType.fields
@@ -510,3 +481,48 @@ def proto_to_storage_level(storage_level:
pb2.StorageLevel) -> StorageLevel:
deserialized=storage_level.deserialized,
replication=storage_level.replication,
)
+
+
+def _deduplicate_field_names(dt: DataType) -> DataType:
+ if isinstance(dt, StructType):
+ dedup_field_names = _dedup_names(dt.names)
+
+ return StructType(
+ [
+ StructField(
+ dedup_field_names[i],
+ _deduplicate_field_names(field.dataType),
+ nullable=field.nullable,
+ )
+ for i, field in enumerate(dt.fields)
+ ]
+ )
+ elif isinstance(dt, ArrayType):
+ return ArrayType(_deduplicate_field_names(dt.elementType),
containsNull=dt.containsNull)
+ elif isinstance(dt, MapType):
+ return MapType(
+ _deduplicate_field_names(dt.keyType),
+ _deduplicate_field_names(dt.valueType),
+ valueContainsNull=dt.valueContainsNull,
+ )
+ else:
+ return dt
+
+
+def _dedup_names(names: List[str]) -> List[str]:
+ if len(set(names)) == len(names):
+ return names
+ else:
+
+ def _gen_dedup(_name: str) -> Callable[[], str]:
+ _i = itertools.count()
+ return lambda: f"{_name}_{next(_i)}"
+
+ def _gen_identity(_name: str) -> Callable[[], str]:
+ return lambda: _name
+
+ gen_new_name = {
+ name: _gen_dedup(name) if len(list(group)) > 1 else
_gen_identity(name)
+ for name, group in itertools.groupby(sorted(names))
+ }
+ return [gen_new_name[name]() for name in names]
diff --git a/python/pyspark/sql/tests/test_dataframe.py
b/python/pyspark/sql/tests/test_dataframe.py
index 5716acbaabc..164b6a22a69 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -1710,7 +1710,10 @@ class DataFrameTestsMixin:
)
def test_duplicate_field_names(self):
- data = [Row(Row("a", 1), Row(2, 3, "b", 4, "c")), Row(Row("x", 6),
Row(7, 8, "y", 9, "z"))]
+ data = [
+ Row(Row("a", 1), Row(2, 3, "b", 4, "c", "d")),
+ Row(Row("w", 6), Row(7, 8, "x", 9, "y", "z")),
+ ]
schema = (
StructType()
.add("struct", StructType().add("x", StringType()).add("x",
IntegerType()))
@@ -1721,7 +1724,8 @@ class DataFrameTestsMixin:
.add("x", IntegerType())
.add("x", StringType())
.add("y", IntegerType())
- .add("y", StringType()),
+ .add("y", StringType())
+ .add("x", StringType()),
)
)
df = self.spark.createDataFrame(data, schema=schema)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]