This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 8549e725292e [SPARK-55088][PYTHON] Keep the metadata in
to/from_arrow_type/schema
8549e725292e is described below
commit 8549e725292ea02833e0d24f0cd97af576ee1b4f
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Mon Jan 19 16:52:09 2026 +0800
[SPARK-55088][PYTHON] Keep the metadata in to/from_arrow_type/schema
### What changes were proposed in this pull request?
Keep the metadata in to/from_arrow_type/schema
### Why are the changes needed?
To make it able to use metadata in the future
### Does this PR introduce _any_ user-facing change?
No, it is a internal improvement
### How was this patch tested?
updated tests
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #53848 from zhengruifeng/py_arrow_meta.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/sql/pandas/types.py | 25 +++++++++++++++++-
python/pyspark/sql/tests/arrow/test_arrow.py | 39 +++++++++++++++++++++-------
2 files changed, 53 insertions(+), 11 deletions(-)
diff --git a/python/pyspark/sql/pandas/types.py
b/python/pyspark/sql/pandas/types.py
index e7a208f200ac..5ccbd37f0d24 100644
--- a/python/pyspark/sql/pandas/types.py
+++ b/python/pyspark/sql/pandas/types.py
@@ -22,8 +22,9 @@ pandas instances during the type conversion.
import datetime
import itertools
import functools
+import json
from decimal import Decimal
-from typing import Any, Callable, Iterable, List, Optional, Union,
TYPE_CHECKING
+from typing import Any, Callable, Dict, Iterable, List, Optional, Union,
TYPE_CHECKING
from pyspark.errors import PySparkTypeError, UnsupportedOperationException,
PySparkValueError
from pyspark.sql.types import (
@@ -69,6 +70,24 @@ if TYPE_CHECKING:
from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike
+# Should keep in line with org.apache.spark.sql.util.ArrowUtils.metadataKey
+metadata_key = b"SPARK::metadata::json"
+
+
+def to_arrow_metadata(metadata: Optional[Dict[str, Any]] = None) ->
Optional[Dict[bytes, bytes]]:
+ if metadata is not None and len(metadata) > 0:
+ return {metadata_key: json.dumps(metadata).encode("utf-8")}
+ else:
+ return None
+
+
+def from_arrow_metadata(metadata: Optional[Dict[bytes, bytes]]) ->
Optional[Dict[str, Any]]:
+ if metadata is not None and metadata_key in metadata:
+ return json.loads(metadata[metadata_key].decode("utf-8"))
+ else:
+ return None
+
+
def to_arrow_type(
dt: DataType,
*,
@@ -177,6 +196,7 @@ def to_arrow_type(
prefers_large_types=prefers_large_types,
),
nullable=field.nullable,
+ metadata=to_arrow_metadata(field.metadata),
)
for field in dt
]
@@ -264,6 +284,7 @@ def to_arrow_schema(
prefers_large_types=prefers_large_types,
),
nullable=field.nullable,
+ metadata=to_arrow_metadata(field.metadata),
)
for field in schema
]
@@ -427,6 +448,7 @@ def from_arrow_type(
field.name,
from_arrow_type(field.type, prefer_timestamp_ntz),
nullable=field.nullable,
+ metadata=from_arrow_metadata(field.metadata),
)
for field in at
]
@@ -454,6 +476,7 @@ def from_arrow_schema(
field.name,
from_arrow_type(field.type, prefer_timestamp_ntz),
nullable=field.nullable,
+ metadata=from_arrow_metadata(field.metadata),
)
for field in arrow_schema
]
diff --git a/python/pyspark/sql/tests/arrow/test_arrow.py
b/python/pyspark/sql/tests/arrow/test_arrow.py
index 922fbef96215..3a1aa9a883b7 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow.py
@@ -49,6 +49,12 @@ from pyspark.sql.types import (
DayTimeIntervalType,
VariantType,
)
+from pyspark.sql.pandas.types import (
+ from_arrow_type,
+ to_arrow_type,
+ from_arrow_schema,
+ to_arrow_schema,
+)
from pyspark.testing.objects import ExamplePoint, ExamplePointUDT
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
@@ -754,8 +760,6 @@ class ArrowTestsMixin:
self.assertTrue(t_out.equals(expected))
def test_schema_conversion_roundtrip(self):
- from pyspark.sql.pandas.types import from_arrow_schema, to_arrow_schema
-
arrow_schema = to_arrow_schema(self.schema, timezone="UTC",
prefers_large_types=False)
schema_rt = from_arrow_schema(arrow_schema, prefer_timestamp_ntz=True)
self.assertEqual(self.schema, schema_rt)
@@ -765,8 +769,6 @@ class ArrowTestsMixin:
self.assertEqual(self.schema, schema_rt)
def test_type_conversion_round_trip(self):
- from pyspark.sql.pandas.types import from_arrow_type, to_arrow_type
-
for t in [
NullType(),
BinaryType(),
@@ -795,24 +797,32 @@ class ArrowTestsMixin:
VariantType(),
StructType(
[
- StructField("1_str_t", StringType(), True),
+ StructField(
+ "1_str_t", StringType(), True, {"is_int": False,
"is_float": "false"}
+ ),
StructField("2_int_t", IntegerType(), True),
StructField("3_long_t", LongType(), True),
StructField("4_float_t", FloatType(), True),
- StructField("5_double_t", DoubleType(), True),
+ StructField(
+ "5_double_t", DoubleType(), True, {"is_int": False,
"is_float": "true"}
+ ),
StructField("6_decimal_t", DecimalType(38, 18), True),
StructField("7_date_t", DateType(), True),
- StructField("8_timestamp_t", TimestampType(), True),
+ StructField(
+ "8_timestamp_t", TimestampType(), True, {"is_ts":
True, "ntz": "false"}
+ ),
StructField("9_binary_t", BinaryType(), True),
- StructField("10_var", VariantType(), True),
+ StructField("10_var", VariantType(), True, {"is_ts":
False, "is_var": "true"}),
StructField("11_arr", ArrayType(ArrayType(StringType(),
True), False), True),
StructField("12_map", MapType(StringType(), IntegerType(),
True), True),
StructField(
"13_struct",
StructType(
[
- StructField("13_1_str_t", StringType(), True),
- StructField("13_2_int_t", IntegerType(), True),
+ StructField(
+ "13_1_str_t", StringType(), True,
{"in_nested": "true"}
+ ),
+ StructField("13_2_int_t", IntegerType(), True,
{"in_nested": None}),
StructField("13_3_long_t", LongType(), True),
]
),
@@ -830,6 +840,15 @@ class ArrowTestsMixin:
t3 = from_arrow_type(at2)
self.assertEqual(t, t3)
+ if isinstance(t, StructType):
+ pa_schema = to_arrow_schema(t, timezone="UTC")
+ schema2 = from_arrow_schema(pa_schema)
+ self.assertEqual(t, schema2)
+
+ pa_schema2 = to_arrow_schema(t, timezone="UTC",
prefers_large_types=True)
+ schema3 = from_arrow_schema(pa_schema2)
+ self.assertEqual(t, schema3)
+
def test_createDataFrame_with_ndarray(self):
for arrow_enabled in [True, False]:
with self.subTest(arrow_enabled=arrow_enabled):
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]