This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push:
new f33a13c4b16 [SPARK-44980][PYTHON][CONNECT] Fix inherited namedtuples
to work in createDataFrame
f33a13c4b16 is described below
commit f33a13c4b165e4ae5099703c308a2715463a479a
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Mon Aug 28 15:46:57 2023 +0900
[SPARK-44980][PYTHON][CONNECT] Fix inherited namedtuples to work in
createDataFrame
### What changes were proposed in this pull request?
This PR fixes the bug in createDataFrame with Python Spark Connect client.
Now it respects inherited namedtuples as below:
```python
from collections import namedtuple
MyTuple = namedtuple("MyTuple", ["zz", "b", "a"])
class MyInheritedTuple(MyTuple):
pass
df = spark.createDataFrame([MyInheritedTuple(1, 2, 3), MyInheritedTuple(11,
22, 33)])
df.collect()
```
Before:
```
[Row(zz=None, b=None, a=None), Row(zz=None, b=None, a=None)]
```
After:
```
[Row(zz=1, b=2, a=3), Row(zz=11, b=22, a=33)]
```
### Why are the changes needed?
This is already supported without Spark Connect. We should match the
behaviour for consistent API support.
### Does this PR introduce _any_ user-facing change?
Yes, as described above. It fixes a bug,
### How was this patch tested?
Manually tested as described above, and unittests were added.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #42693 from HyukjinKwon/SPARK-44980.
Authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
(cherry picked from commit 5291c6c9274aaabd4851d70e4c1baad629e12cca)
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/connect/conversion.py | 12 +++++++++--
.../pyspark/sql/tests/connect/test_parity_arrow.py | 3 +++
python/pyspark/sql/tests/test_arrow.py | 24 ++++++++++++++++++++++
3 files changed, 37 insertions(+), 2 deletions(-)
diff --git a/python/pyspark/sql/connect/conversion.py
b/python/pyspark/sql/connect/conversion.py
index cdbc3a1e39c..1afeb3dfd44 100644
--- a/python/pyspark/sql/connect/conversion.py
+++ b/python/pyspark/sql/connect/conversion.py
@@ -117,7 +117,11 @@ class LocalDataToArrowConversion:
), f"{type(value)} {value}"
_dict = {}
- if not isinstance(value, Row) and hasattr(value,
"__dict__"):
+ if (
+ not isinstance(value, Row)
+ and not isinstance(value, tuple) # inherited
namedtuple
+ and hasattr(value, "__dict__")
+ ):
value = value.__dict__
if isinstance(value, dict):
for i, field in enumerate(field_names):
@@ -274,7 +278,11 @@ class LocalDataToArrowConversion:
pylist: List[List] = [[] for _ in range(len(column_names))]
for item in data:
- if not isinstance(item, Row) and hasattr(item, "__dict__"):
+ if (
+ not isinstance(item, Row)
+ and not isinstance(item, tuple) # inherited namedtuple
+ and hasattr(item, "__dict__")
+ ):
item = item.__dict__
if isinstance(item, dict):
for i, col in enumerate(column_names):
diff --git a/python/pyspark/sql/tests/connect/test_parity_arrow.py
b/python/pyspark/sql/tests/connect/test_parity_arrow.py
index 5f76cafb192..a92ef971cd2 100644
--- a/python/pyspark/sql/tests/connect/test_parity_arrow.py
+++ b/python/pyspark/sql/tests/connect/test_parity_arrow.py
@@ -142,6 +142,9 @@ class ArrowParityTests(ArrowTestsMixin,
ReusedConnectTestCase, PandasOnSparkTest
def test_toPandas_udt(self):
self.check_toPandas_udt(True)
+ def test_create_dataframe_namedtuples(self):
+ self.check_create_dataframe_namedtuples(True)
+
if __name__ == "__main__":
from pyspark.sql.tests.connect.test_parity_arrow import * # noqa: F401
diff --git a/python/pyspark/sql/tests/test_arrow.py
b/python/pyspark/sql/tests/test_arrow.py
index 1b81ed72b22..73b6067373b 100644
--- a/python/pyspark/sql/tests/test_arrow.py
+++ b/python/pyspark/sql/tests/test_arrow.py
@@ -23,6 +23,7 @@ import unittest
import warnings
from distutils.version import LooseVersion
from typing import cast
+from collections import namedtuple
from pyspark import SparkContext, SparkConf
from pyspark.sql import Row, SparkSession
@@ -1214,6 +1215,29 @@ class ArrowTestsMixin:
assert_frame_equal(pdf, expected)
+ def test_create_dataframe_namedtuples(self):
+ # SPARK-44980: Inherited namedtuples in createDataFrame
+ for arrow_enabled in [True, False]:
+ with self.subTest(arrow_enabled=arrow_enabled):
+ self.check_create_dataframe_namedtuples(arrow_enabled)
+
+ def check_create_dataframe_namedtuples(self, arrow_enabled):
+ MyTuple = namedtuple("MyTuple", ["a", "b", "c"])
+
+ class MyInheritedTuple(MyTuple):
+ pass
+
+ with self.sql_conf(
+ {
+ "spark.sql.execution.arrow.pyspark.enabled": arrow_enabled,
+ }
+ ):
+ df = self.spark.createDataFrame([MyInheritedTuple(1, 2, 3)])
+ self.assertEqual(df.first(), Row(a=1, b=2, c=3))
+
+ df = self.spark.createDataFrame([MyInheritedTuple(1, 2,
MyInheritedTuple(1, 2, 3))])
+ self.assertEqual(df.first(), Row(a=1, b=2, c=Row(a=1, b=2, c=3)))
+
@unittest.skipIf(
not have_pandas or not have_pyarrow,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]