This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 8549e725292e [SPARK-55088][PYTHON] Keep the metadata in 
to/from_arrow_type/schema
8549e725292e is described below

commit 8549e725292ea02833e0d24f0cd97af576ee1b4f
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Mon Jan 19 16:52:09 2026 +0800

    [SPARK-55088][PYTHON] Keep the metadata in to/from_arrow_type/schema
    
    ### What changes were proposed in this pull request?
    Keep the metadata in to/from_arrow_type/schema
    
    ### Why are the changes needed?
    To make it able to use metadata in the future
    
    ### Does this PR introduce _any_ user-facing change?
    No, it is a internal improvement
    
    ### How was this patch tested?
    updated tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #53848 from zhengruifeng/py_arrow_meta.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 python/pyspark/sql/pandas/types.py           | 25 +++++++++++++++++-
 python/pyspark/sql/tests/arrow/test_arrow.py | 39 +++++++++++++++++++++-------
 2 files changed, 53 insertions(+), 11 deletions(-)

diff --git a/python/pyspark/sql/pandas/types.py 
b/python/pyspark/sql/pandas/types.py
index e7a208f200ac..5ccbd37f0d24 100644
--- a/python/pyspark/sql/pandas/types.py
+++ b/python/pyspark/sql/pandas/types.py
@@ -22,8 +22,9 @@ pandas instances during the type conversion.
 import datetime
 import itertools
 import functools
+import json
 from decimal import Decimal
-from typing import Any, Callable, Iterable, List, Optional, Union, 
TYPE_CHECKING
+from typing import Any, Callable, Dict, Iterable, List, Optional, Union, 
TYPE_CHECKING
 
 from pyspark.errors import PySparkTypeError, UnsupportedOperationException, 
PySparkValueError
 from pyspark.sql.types import (
@@ -69,6 +70,24 @@ if TYPE_CHECKING:
     from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike
 
 
+# Should keep in line with org.apache.spark.sql.util.ArrowUtils.metadataKey
+metadata_key = b"SPARK::metadata::json"
+
+
+def to_arrow_metadata(metadata: Optional[Dict[str, Any]] = None) -> 
Optional[Dict[bytes, bytes]]:
+    if metadata is not None and len(metadata) > 0:
+        return {metadata_key: json.dumps(metadata).encode("utf-8")}
+    else:
+        return None
+
+
+def from_arrow_metadata(metadata: Optional[Dict[bytes, bytes]]) -> 
Optional[Dict[str, Any]]:
+    if metadata is not None and metadata_key in metadata:
+        return json.loads(metadata[metadata_key].decode("utf-8"))
+    else:
+        return None
+
+
 def to_arrow_type(
     dt: DataType,
     *,
@@ -177,6 +196,7 @@ def to_arrow_type(
                     prefers_large_types=prefers_large_types,
                 ),
                 nullable=field.nullable,
+                metadata=to_arrow_metadata(field.metadata),
             )
             for field in dt
         ]
@@ -264,6 +284,7 @@ def to_arrow_schema(
                 prefers_large_types=prefers_large_types,
             ),
             nullable=field.nullable,
+            metadata=to_arrow_metadata(field.metadata),
         )
         for field in schema
     ]
@@ -427,6 +448,7 @@ def from_arrow_type(
                     field.name,
                     from_arrow_type(field.type, prefer_timestamp_ntz),
                     nullable=field.nullable,
+                    metadata=from_arrow_metadata(field.metadata),
                 )
                 for field in at
             ]
@@ -454,6 +476,7 @@ def from_arrow_schema(
                 field.name,
                 from_arrow_type(field.type, prefer_timestamp_ntz),
                 nullable=field.nullable,
+                metadata=from_arrow_metadata(field.metadata),
             )
             for field in arrow_schema
         ]
diff --git a/python/pyspark/sql/tests/arrow/test_arrow.py 
b/python/pyspark/sql/tests/arrow/test_arrow.py
index 922fbef96215..3a1aa9a883b7 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow.py
@@ -49,6 +49,12 @@ from pyspark.sql.types import (
     DayTimeIntervalType,
     VariantType,
 )
+from pyspark.sql.pandas.types import (
+    from_arrow_type,
+    to_arrow_type,
+    from_arrow_schema,
+    to_arrow_schema,
+)
 from pyspark.testing.objects import ExamplePoint, ExamplePointUDT
 from pyspark.testing.sqlutils import (
     ReusedSQLTestCase,
@@ -754,8 +760,6 @@ class ArrowTestsMixin:
         self.assertTrue(t_out.equals(expected))
 
     def test_schema_conversion_roundtrip(self):
-        from pyspark.sql.pandas.types import from_arrow_schema, to_arrow_schema
-
         arrow_schema = to_arrow_schema(self.schema, timezone="UTC", 
prefers_large_types=False)
         schema_rt = from_arrow_schema(arrow_schema, prefer_timestamp_ntz=True)
         self.assertEqual(self.schema, schema_rt)
@@ -765,8 +769,6 @@ class ArrowTestsMixin:
         self.assertEqual(self.schema, schema_rt)
 
     def test_type_conversion_round_trip(self):
-        from pyspark.sql.pandas.types import from_arrow_type, to_arrow_type
-
         for t in [
             NullType(),
             BinaryType(),
@@ -795,24 +797,32 @@ class ArrowTestsMixin:
             VariantType(),
             StructType(
                 [
-                    StructField("1_str_t", StringType(), True),
+                    StructField(
+                        "1_str_t", StringType(), True, {"is_int": False, 
"is_float": "false"}
+                    ),
                     StructField("2_int_t", IntegerType(), True),
                     StructField("3_long_t", LongType(), True),
                     StructField("4_float_t", FloatType(), True),
-                    StructField("5_double_t", DoubleType(), True),
+                    StructField(
+                        "5_double_t", DoubleType(), True, {"is_int": False, 
"is_float": "true"}
+                    ),
                     StructField("6_decimal_t", DecimalType(38, 18), True),
                     StructField("7_date_t", DateType(), True),
-                    StructField("8_timestamp_t", TimestampType(), True),
+                    StructField(
+                        "8_timestamp_t", TimestampType(), True, {"is_ts": 
True, "ntz": "false"}
+                    ),
                     StructField("9_binary_t", BinaryType(), True),
-                    StructField("10_var", VariantType(), True),
+                    StructField("10_var", VariantType(), True, {"is_ts": 
False, "is_var": "true"}),
                     StructField("11_arr", ArrayType(ArrayType(StringType(), 
True), False), True),
                     StructField("12_map", MapType(StringType(), IntegerType(), 
True), True),
                     StructField(
                         "13_struct",
                         StructType(
                             [
-                                StructField("13_1_str_t", StringType(), True),
-                                StructField("13_2_int_t", IntegerType(), True),
+                                StructField(
+                                    "13_1_str_t", StringType(), True, 
{"in_nested": "true"}
+                                ),
+                                StructField("13_2_int_t", IntegerType(), True, 
{"in_nested": None}),
                                 StructField("13_3_long_t", LongType(), True),
                             ]
                         ),
@@ -830,6 +840,15 @@ class ArrowTestsMixin:
                 t3 = from_arrow_type(at2)
                 self.assertEqual(t, t3)
 
+                if isinstance(t, StructType):
+                    pa_schema = to_arrow_schema(t, timezone="UTC")
+                    schema2 = from_arrow_schema(pa_schema)
+                    self.assertEqual(t, schema2)
+
+                    pa_schema2 = to_arrow_schema(t, timezone="UTC", 
prefers_large_types=True)
+                    schema3 = from_arrow_schema(pa_schema2)
+                    self.assertEqual(t, schema3)
+
     def test_createDataFrame_with_ndarray(self):
         for arrow_enabled in [True, False]:
             with self.subTest(arrow_enabled=arrow_enabled):


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to