This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 347f9c64645b [SPARK-48302][PYTHON] Preserve nulls in map columns in 
PyArrow Tables
347f9c64645b is described below

commit 347f9c64645b2f25816442f3e1191c97e0940537
Author: Ian Cook <[email protected]>
AuthorDate: Sun Jun 16 14:12:03 2024 +0900

    [SPARK-48302][PYTHON] Preserve nulls in map columns in PyArrow Tables
    
    ### What changes were proposed in this pull request?
    This is a small follow-up to #46529. It fixes a known issue affecting 
PyArrow Tables passed to `spark.createDataFrame()`. After this PR, if the user 
is running PyArrow 17.0.0 or higher, null values in MapArray columns containing 
nested fields or timestamps will be preserved.
    
    ### Why are the changes needed?
    Before this PR, null values in MapArray columns containing nested fields or 
timestamps are replaced by empty lists when a PyArrow Table is passed to 
`spark.createDataFrame()`.
    
    ### Does this PR introduce _any_ user-facing change?
    It prevents loss of nulls in the case described above. There are no other 
user-facing changes.
    
    ### How was this patch tested?
    A test is included.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #46837 from ianmcook/SPARK-48302.
    
    Authored-by: Ian Cook <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/sql/pandas/types.py     | 20 ++++++++++++++------
 python/pyspark/sql/tests/test_arrow.py | 18 +++++++++++++++---
 2 files changed, 29 insertions(+), 9 deletions(-)

diff --git a/python/pyspark/sql/pandas/types.py 
b/python/pyspark/sql/pandas/types.py
index 2475984fbc39..27c77c9d2d7f 100644
--- a/python/pyspark/sql/pandas/types.py
+++ b/python/pyspark/sql/pandas/types.py
@@ -397,12 +397,20 @@ def _check_arrow_array_timestamps_localize(
             return a
 
         mt: MapType = cast(MapType, dt)
-        # TODO(SPARK-48302): Do not replace nulls in MapArray with empty lists
-        return pa.MapArray.from_arrays(
-            a.offsets,
-            _check_arrow_array_timestamps_localize(a.keys, mt.keyType, 
truncate, timezone),
-            _check_arrow_array_timestamps_localize(a.items, mt.valueType, 
truncate, timezone),
-        )
+
+        params = {
+            "offsets": a.offsets,
+            "keys": _check_arrow_array_timestamps_localize(a.keys, mt.keyType, 
truncate, timezone),
+            "items": _check_arrow_array_timestamps_localize(
+                a.items, mt.valueType, truncate, timezone
+            ),
+        }
+        # SPARK-48302: PyArrow added support for mask argument to 
pa.MapArray.from_arrays in
+        # version 17.0.0
+        if a.null_count and LooseVersion(pa.__version__) >= 
LooseVersion("17.0.0"):
+            params["mask"] = a.is_null()
+
+        return pa.MapArray.from_arrays(**params)
     if types.is_struct(a.type):
         # Return the StructArray as-is if it contains no nested fields or 
timestamps
         if all(
diff --git a/python/pyspark/sql/tests/test_arrow.py 
b/python/pyspark/sql/tests/test_arrow.py
index 4e2e2c51b4db..c1a69c404086 100644
--- a/python/pyspark/sql/tests/test_arrow.py
+++ b/python/pyspark/sql/tests/test_arrow.py
@@ -55,6 +55,7 @@ from pyspark.testing.sqlutils import (
     ExamplePointUDT,
 )
 from pyspark.errors import ArithmeticException, PySparkTypeError, 
UnsupportedOperationException
+from pyspark.loose_version import LooseVersion
 from pyspark.util import is_remote_only
 from pyspark.loose_version import LooseVersion
 
@@ -1556,15 +1557,26 @@ class ArrowTestsMixin:
 
         self.assertTrue(t.equals(expected))
 
-    @unittest.skip("SPARK-48302: Nulls are replaced with empty lists")
     def test_arrow_map_timestamp_nulls_round_trip(self):
+        origin_schema = pa.schema([("map", pa.map_(pa.string(), 
pa.timestamp("us", tz="UTC")))])
         origin = pa.table(
             [[dict(ts=datetime.datetime(2023, 1, 1, 8, 0, 0)), None]],
-            schema=pa.schema([("map", pa.map_(pa.string(), pa.timestamp("us", 
tz="UTC")))]),
+            schema=origin_schema,
         )
         df = self.spark.createDataFrame(origin)
         t = df.toArrow()
-        self.assertTrue(origin.equals(t))
+
+        # SPARK-48302: PyArrow versions before 17.0.0 replaced nulls with 
empty lists when
+        # reconstructing MapArray columns to localize timestamps
+        if LooseVersion(pa.__version__) >= LooseVersion("17.0.0"):
+            expected = origin
+        else:
+            expected = pa.table(
+                [[dict(ts=datetime.datetime(2023, 1, 1, 8, 0, 0)), []]],
+                schema=origin_schema,
+            )
+
+        self.assertTrue(t.equals(expected))
 
     def test_createDataFrame_udt(self):
         for arrow_enabled in [True, False]:


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to