This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new f33a13c4b16 [SPARK-44980][PYTHON][CONNECT] Fix inherited namedtuples 
to work in createDataFrame
f33a13c4b16 is described below

commit f33a13c4b165e4ae5099703c308a2715463a479a
Author: Hyukjin Kwon <gurwls...@apache.org>
AuthorDate: Mon Aug 28 15:46:57 2023 +0900

    [SPARK-44980][PYTHON][CONNECT] Fix inherited namedtuples to work in 
createDataFrame
    
    ### What changes were proposed in this pull request?
    
    This PR fixes the bug in createDataFrame with Python Spark Connect client. 
Now it respects inherited namedtuples as below:
    
    ```python
    from collections import namedtuple
    MyTuple = namedtuple("MyTuple", ["zz", "b", "a"])
    
    class MyInheritedTuple(MyTuple):
        pass
    
    df = spark.createDataFrame([MyInheritedTuple(1, 2, 3), MyInheritedTuple(11, 
22, 33)])
    df.collect()
    ```
    
    Before:
    
    ```
    [Row(zz=None, b=None, a=None), Row(zz=None, b=None, a=None)]
    ```
    
    After:
    
    ```
    [Row(zz=1, b=2, a=3), Row(zz=11, b=22, a=33)]
    ```
    
    ### Why are the changes needed?
    
    This is already supported without Spark Connect. We should match the 
behaviour for consistent API support.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, as described above. It fixes a bug,
    
    ### How was this patch tested?
    
    Manually tested as described above, and unittests were added.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #42693 from HyukjinKwon/SPARK-44980.
    
    Authored-by: Hyukjin Kwon <gurwls...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
    (cherry picked from commit 5291c6c9274aaabd4851d70e4c1baad629e12cca)
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/sql/connect/conversion.py           | 12 +++++++++--
 .../pyspark/sql/tests/connect/test_parity_arrow.py |  3 +++
 python/pyspark/sql/tests/test_arrow.py             | 24 ++++++++++++++++++++++
 3 files changed, 37 insertions(+), 2 deletions(-)

diff --git a/python/pyspark/sql/connect/conversion.py 
b/python/pyspark/sql/connect/conversion.py
index cdbc3a1e39c..1afeb3dfd44 100644
--- a/python/pyspark/sql/connect/conversion.py
+++ b/python/pyspark/sql/connect/conversion.py
@@ -117,7 +117,11 @@ class LocalDataToArrowConversion:
                     ), f"{type(value)} {value}"
 
                     _dict = {}
-                    if not isinstance(value, Row) and hasattr(value, 
"__dict__"):
+                    if (
+                        not isinstance(value, Row)
+                        and not isinstance(value, tuple)  # inherited 
namedtuple
+                        and hasattr(value, "__dict__")
+                    ):
                         value = value.__dict__
                     if isinstance(value, dict):
                         for i, field in enumerate(field_names):
@@ -274,7 +278,11 @@ class LocalDataToArrowConversion:
         pylist: List[List] = [[] for _ in range(len(column_names))]
 
         for item in data:
-            if not isinstance(item, Row) and hasattr(item, "__dict__"):
+            if (
+                not isinstance(item, Row)
+                and not isinstance(item, tuple)  # inherited namedtuple
+                and hasattr(item, "__dict__")
+            ):
                 item = item.__dict__
             if isinstance(item, dict):
                 for i, col in enumerate(column_names):
diff --git a/python/pyspark/sql/tests/connect/test_parity_arrow.py 
b/python/pyspark/sql/tests/connect/test_parity_arrow.py
index 5f76cafb192..a92ef971cd2 100644
--- a/python/pyspark/sql/tests/connect/test_parity_arrow.py
+++ b/python/pyspark/sql/tests/connect/test_parity_arrow.py
@@ -142,6 +142,9 @@ class ArrowParityTests(ArrowTestsMixin, 
ReusedConnectTestCase, PandasOnSparkTest
     def test_toPandas_udt(self):
         self.check_toPandas_udt(True)
 
+    def test_create_dataframe_namedtuples(self):
+        self.check_create_dataframe_namedtuples(True)
+
 
 if __name__ == "__main__":
     from pyspark.sql.tests.connect.test_parity_arrow import *  # noqa: F401
diff --git a/python/pyspark/sql/tests/test_arrow.py 
b/python/pyspark/sql/tests/test_arrow.py
index 1b81ed72b22..73b6067373b 100644
--- a/python/pyspark/sql/tests/test_arrow.py
+++ b/python/pyspark/sql/tests/test_arrow.py
@@ -23,6 +23,7 @@ import unittest
 import warnings
 from distutils.version import LooseVersion
 from typing import cast
+from collections import namedtuple
 
 from pyspark import SparkContext, SparkConf
 from pyspark.sql import Row, SparkSession
@@ -1214,6 +1215,29 @@ class ArrowTestsMixin:
 
         assert_frame_equal(pdf, expected)
 
+    def test_create_dataframe_namedtuples(self):
+        # SPARK-44980: Inherited namedtuples in createDataFrame
+        for arrow_enabled in [True, False]:
+            with self.subTest(arrow_enabled=arrow_enabled):
+                self.check_create_dataframe_namedtuples(arrow_enabled)
+
+    def check_create_dataframe_namedtuples(self, arrow_enabled):
+        MyTuple = namedtuple("MyTuple", ["a", "b", "c"])
+
+        class MyInheritedTuple(MyTuple):
+            pass
+
+        with self.sql_conf(
+            {
+                "spark.sql.execution.arrow.pyspark.enabled": arrow_enabled,
+            }
+        ):
+            df = self.spark.createDataFrame([MyInheritedTuple(1, 2, 3)])
+            self.assertEqual(df.first(), Row(a=1, b=2, c=3))
+
+            df = self.spark.createDataFrame([MyInheritedTuple(1, 2, 
MyInheritedTuple(1, 2, 3))])
+            self.assertEqual(df.first(), Row(a=1, b=2, c=Row(a=1, b=2, c=3)))
+
 
 @unittest.skipIf(
     not have_pandas or not have_pyarrow,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to