This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 347f9c64645b [SPARK-48302][PYTHON] Preserve nulls in map columns in
PyArrow Tables
347f9c64645b is described below
commit 347f9c64645b2f25816442f3e1191c97e0940537
Author: Ian Cook <[email protected]>
AuthorDate: Sun Jun 16 14:12:03 2024 +0900
[SPARK-48302][PYTHON] Preserve nulls in map columns in PyArrow Tables
### What changes were proposed in this pull request?
This is a small follow-up to #46529. It fixes a known issue affecting
PyArrow Tables passed to `spark.createDataFrame()`. After this PR, if the user
is running PyArrow 17.0.0 or higher, null values in MapArray columns containing
nested fields or timestamps will be preserved.
### Why are the changes needed?
Before this PR, null values in MapArray columns containing nested fields or
timestamps are replaced by empty lists when a PyArrow Table is passed to
`spark.createDataFrame()`.
### Does this PR introduce _any_ user-facing change?
It prevents loss of nulls in the case described above. There are no other
user-facing changes.
### How was this patch tested?
A test is included.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #46837 from ianmcook/SPARK-48302.
Authored-by: Ian Cook <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/pandas/types.py | 20 ++++++++++++++------
python/pyspark/sql/tests/test_arrow.py | 18 +++++++++++++++---
2 files changed, 29 insertions(+), 9 deletions(-)
diff --git a/python/pyspark/sql/pandas/types.py
b/python/pyspark/sql/pandas/types.py
index 2475984fbc39..27c77c9d2d7f 100644
--- a/python/pyspark/sql/pandas/types.py
+++ b/python/pyspark/sql/pandas/types.py
@@ -397,12 +397,20 @@ def _check_arrow_array_timestamps_localize(
return a
mt: MapType = cast(MapType, dt)
- # TODO(SPARK-48302): Do not replace nulls in MapArray with empty lists
- return pa.MapArray.from_arrays(
- a.offsets,
- _check_arrow_array_timestamps_localize(a.keys, mt.keyType,
truncate, timezone),
- _check_arrow_array_timestamps_localize(a.items, mt.valueType,
truncate, timezone),
- )
+
+ params = {
+ "offsets": a.offsets,
+ "keys": _check_arrow_array_timestamps_localize(a.keys, mt.keyType,
truncate, timezone),
+ "items": _check_arrow_array_timestamps_localize(
+ a.items, mt.valueType, truncate, timezone
+ ),
+ }
+ # SPARK-48302: PyArrow added support for mask argument to
pa.MapArray.from_arrays in
+ # version 17.0.0
+ if a.null_count and LooseVersion(pa.__version__) >=
LooseVersion("17.0.0"):
+ params["mask"] = a.is_null()
+
+ return pa.MapArray.from_arrays(**params)
if types.is_struct(a.type):
# Return the StructArray as-is if it contains no nested fields or
timestamps
if all(
diff --git a/python/pyspark/sql/tests/test_arrow.py
b/python/pyspark/sql/tests/test_arrow.py
index 4e2e2c51b4db..c1a69c404086 100644
--- a/python/pyspark/sql/tests/test_arrow.py
+++ b/python/pyspark/sql/tests/test_arrow.py
@@ -55,6 +55,7 @@ from pyspark.testing.sqlutils import (
ExamplePointUDT,
)
from pyspark.errors import ArithmeticException, PySparkTypeError,
UnsupportedOperationException
+from pyspark.loose_version import LooseVersion
from pyspark.util import is_remote_only
from pyspark.loose_version import LooseVersion
@@ -1556,15 +1557,26 @@ class ArrowTestsMixin:
self.assertTrue(t.equals(expected))
- @unittest.skip("SPARK-48302: Nulls are replaced with empty lists")
def test_arrow_map_timestamp_nulls_round_trip(self):
+ origin_schema = pa.schema([("map", pa.map_(pa.string(),
pa.timestamp("us", tz="UTC")))])
origin = pa.table(
[[dict(ts=datetime.datetime(2023, 1, 1, 8, 0, 0)), None]],
- schema=pa.schema([("map", pa.map_(pa.string(), pa.timestamp("us",
tz="UTC")))]),
+ schema=origin_schema,
)
df = self.spark.createDataFrame(origin)
t = df.toArrow()
- self.assertTrue(origin.equals(t))
+
+ # SPARK-48302: PyArrow versions before 17.0.0 replaced nulls with
empty lists when
+ # reconstructing MapArray columns to localize timestamps
+ if LooseVersion(pa.__version__) >= LooseVersion("17.0.0"):
+ expected = origin
+ else:
+ expected = pa.table(
+ [[dict(ts=datetime.datetime(2023, 1, 1, 8, 0, 0)), []]],
+ schema=origin_schema,
+ )
+
+ self.assertTrue(t.equals(expected))
def test_createDataFrame_udt(self):
for arrow_enabled in [True, False]:
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]