This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new f33a13c4b16 [SPARK-44980][PYTHON][CONNECT] Fix inherited namedtuples to work in createDataFrame f33a13c4b16 is described below commit f33a13c4b165e4ae5099703c308a2715463a479a Author: Hyukjin Kwon <gurwls...@apache.org> AuthorDate: Mon Aug 28 15:46:57 2023 +0900 [SPARK-44980][PYTHON][CONNECT] Fix inherited namedtuples to work in createDataFrame ### What changes were proposed in this pull request? This PR fixes the bug in createDataFrame with Python Spark Connect client. Now it respects inherited namedtuples as below: ```python from collections import namedtuple MyTuple = namedtuple("MyTuple", ["zz", "b", "a"]) class MyInheritedTuple(MyTuple): pass df = spark.createDataFrame([MyInheritedTuple(1, 2, 3), MyInheritedTuple(11, 22, 33)]) df.collect() ``` Before: ``` [Row(zz=None, b=None, a=None), Row(zz=None, b=None, a=None)] ``` After: ``` [Row(zz=1, b=2, a=3), Row(zz=11, b=22, a=33)] ``` ### Why are the changes needed? This is already supported without Spark Connect. We should match the behaviour for consistent API support. ### Does this PR introduce _any_ user-facing change? Yes, as described above. It fixes a bug, ### How was this patch tested? Manually tested as described above, and unittests were added. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #42693 from HyukjinKwon/SPARK-44980. Authored-by: Hyukjin Kwon <gurwls...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> (cherry picked from commit 5291c6c9274aaabd4851d70e4c1baad629e12cca) Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/sql/connect/conversion.py | 12 +++++++++-- .../pyspark/sql/tests/connect/test_parity_arrow.py | 3 +++ python/pyspark/sql/tests/test_arrow.py | 24 ++++++++++++++++++++++ 3 files changed, 37 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/connect/conversion.py b/python/pyspark/sql/connect/conversion.py index cdbc3a1e39c..1afeb3dfd44 100644 --- a/python/pyspark/sql/connect/conversion.py +++ b/python/pyspark/sql/connect/conversion.py @@ -117,7 +117,11 @@ class LocalDataToArrowConversion: ), f"{type(value)} {value}" _dict = {} - if not isinstance(value, Row) and hasattr(value, "__dict__"): + if ( + not isinstance(value, Row) + and not isinstance(value, tuple) # inherited namedtuple + and hasattr(value, "__dict__") + ): value = value.__dict__ if isinstance(value, dict): for i, field in enumerate(field_names): @@ -274,7 +278,11 @@ class LocalDataToArrowConversion: pylist: List[List] = [[] for _ in range(len(column_names))] for item in data: - if not isinstance(item, Row) and hasattr(item, "__dict__"): + if ( + not isinstance(item, Row) + and not isinstance(item, tuple) # inherited namedtuple + and hasattr(item, "__dict__") + ): item = item.__dict__ if isinstance(item, dict): for i, col in enumerate(column_names): diff --git a/python/pyspark/sql/tests/connect/test_parity_arrow.py b/python/pyspark/sql/tests/connect/test_parity_arrow.py index 5f76cafb192..a92ef971cd2 100644 --- a/python/pyspark/sql/tests/connect/test_parity_arrow.py +++ b/python/pyspark/sql/tests/connect/test_parity_arrow.py @@ -142,6 +142,9 @@ class ArrowParityTests(ArrowTestsMixin, ReusedConnectTestCase, PandasOnSparkTest def test_toPandas_udt(self): self.check_toPandas_udt(True) + def test_create_dataframe_namedtuples(self): + self.check_create_dataframe_namedtuples(True) + if __name__ == "__main__": from pyspark.sql.tests.connect.test_parity_arrow import * # noqa: F401 diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index 1b81ed72b22..73b6067373b 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -23,6 +23,7 @@ import unittest import warnings from distutils.version import LooseVersion from typing import cast +from collections import namedtuple from pyspark import SparkContext, SparkConf from pyspark.sql import Row, SparkSession @@ -1214,6 +1215,29 @@ class ArrowTestsMixin: assert_frame_equal(pdf, expected) + def test_create_dataframe_namedtuples(self): + # SPARK-44980: Inherited namedtuples in createDataFrame + for arrow_enabled in [True, False]: + with self.subTest(arrow_enabled=arrow_enabled): + self.check_create_dataframe_namedtuples(arrow_enabled) + + def check_create_dataframe_namedtuples(self, arrow_enabled): + MyTuple = namedtuple("MyTuple", ["a", "b", "c"]) + + class MyInheritedTuple(MyTuple): + pass + + with self.sql_conf( + { + "spark.sql.execution.arrow.pyspark.enabled": arrow_enabled, + } + ): + df = self.spark.createDataFrame([MyInheritedTuple(1, 2, 3)]) + self.assertEqual(df.first(), Row(a=1, b=2, c=3)) + + df = self.spark.createDataFrame([MyInheritedTuple(1, 2, MyInheritedTuple(1, 2, 3))]) + self.assertEqual(df.first(), Row(a=1, b=2, c=Row(a=1, b=2, c=3))) + @unittest.skipIf( not have_pandas or not have_pyarrow, --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org