This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d7f69e7003a3 [SPARK-48190][PYTHON][PS][TESTS] Introduce a helper 
function to drop metadata
d7f69e7003a3 is described below

commit d7f69e7003a3c7e7ad22a39e6aaacd183d26d326
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Wed May 8 18:48:21 2024 +0800

    [SPARK-48190][PYTHON][PS][TESTS] Introduce a helper function to drop 
metadata
    
    ### What changes were proposed in this pull request?
    Introduce a helper function to drop metadata
    
    ### Why are the changes needed?
    existing helper function `remove_metadata` in PS doesn't support nested 
types, so cannot be reused in other places
    
    ### Does this PR introduce _any_ user-facing change?
    no, test only
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #46466 from zhengruifeng/py_drop_meta.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 python/pyspark/pandas/internal.py                       | 17 +++--------------
 .../pyspark/sql/tests/connect/test_connect_function.py  | 11 +++++++++--
 python/pyspark/sql/types.py                             | 13 +++++++++++++
 3 files changed, 25 insertions(+), 16 deletions(-)

diff --git a/python/pyspark/pandas/internal.py 
b/python/pyspark/pandas/internal.py
index 767ec9a57f9b..8ab8d79d5686 100644
--- a/python/pyspark/pandas/internal.py
+++ b/python/pyspark/pandas/internal.py
@@ -33,6 +33,7 @@ from pyspark.sql import (
     Window,
 )
 from pyspark.sql.types import (  # noqa: F401
+    _drop_metadata,
     BooleanType,
     DataType,
     LongType,
@@ -761,14 +762,8 @@ class InternalFrame:
                 # in a few tests when using Spark Connect. However, the 
function works properly.
                 # Therefore, we temporarily perform Spark Connect tests by 
excluding metadata
                 # until the issue is resolved.
-                def remove_metadata(struct_field: StructField) -> StructField:
-                    new_struct_field = StructField(
-                        struct_field.name, struct_field.dataType, 
struct_field.nullable
-                    )
-                    return new_struct_field
-
                 assert all(
-                    remove_metadata(index_field.struct_field) == 
remove_metadata(struct_field)
+                    _drop_metadata(index_field.struct_field) == 
_drop_metadata(struct_field)
                     for index_field, struct_field in zip(index_fields, 
struct_fields)
                 ), (index_fields, struct_fields)
             else:
@@ -795,14 +790,8 @@ class InternalFrame:
                 # in a few tests when using Spark Connect. However, the 
function works properly.
                 # Therefore, we temporarily perform Spark Connect tests by 
excluding metadata
                 # until the issue is resolved.
-                def remove_metadata(struct_field: StructField) -> StructField:
-                    new_struct_field = StructField(
-                        struct_field.name, struct_field.dataType, 
struct_field.nullable
-                    )
-                    return new_struct_field
-
                 assert all(
-                    remove_metadata(data_field.struct_field) == 
remove_metadata(struct_field)
+                    _drop_metadata(data_field.struct_field) == 
_drop_metadata(struct_field)
                     for data_field, struct_field in zip(data_fields, 
struct_fields)
                 ), (data_fields, struct_fields)
             else:
diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py 
b/python/pyspark/sql/tests/connect/test_connect_function.py
index 9d4db8cf7d15..0f0abfd4b856 100644
--- a/python/pyspark/sql/tests/connect/test_connect_function.py
+++ b/python/pyspark/sql/tests/connect/test_connect_function.py
@@ -21,7 +21,14 @@ from inspect import getmembers, isfunction
 from pyspark.util import is_remote_only
 from pyspark.errors import PySparkTypeError, PySparkValueError
 from pyspark.sql import SparkSession as PySparkSession
-from pyspark.sql.types import StringType, StructType, StructField, ArrayType, 
IntegerType
+from pyspark.sql.types import (
+    _drop_metadata,
+    StringType,
+    StructType,
+    StructField,
+    ArrayType,
+    IntegerType,
+)
 from pyspark.testing import assertDataFrameEqual
 from pyspark.testing.pandasutils import PandasOnSparkTestUtils
 from pyspark.testing.connectutils import ReusedConnectTestCase, 
should_test_connect
@@ -1668,7 +1675,7 @@ class SparkConnectFunctionTests(ReusedConnectTestCase, 
PandasOnSparkTestUtils, S
         )
 
         # TODO: 'cdf.schema' has an extra metadata '{'__autoGeneratedAlias': 
'true'}'
-        # self.assertEqual(cdf.schema, sdf.schema)
+        self.assertEqual(_drop_metadata(cdf.schema), 
_drop_metadata(sdf.schema))
         self.assertEqual(cdf.collect(), sdf.collect())
 
     def test_csv_functions(self):
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 48aa3e8e4fab..41be12620fd5 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -1569,6 +1569,19 @@ _INTERVAL_DAYTIME = re.compile(r"interval 
(day|hour|minute|second)( to (day|hour
 _INTERVAL_YEARMONTH = re.compile(r"interval (year|month)( to (year|month))?")
 
 
+def _drop_metadata(d: Union[DataType, StructField]) -> Union[DataType, 
StructField]:
+    assert isinstance(d, (DataType, StructField))
+    if isinstance(d, StructField):
+        return StructField(d.name, _drop_metadata(d.dataType), d.nullable, 
None)
+    elif isinstance(d, StructType):
+        return StructType([cast(StructField, _drop_metadata(f)) for f in 
d.fields])
+    elif isinstance(d, ArrayType):
+        return ArrayType(_drop_metadata(d.elementType), d.containsNull)
+    elif isinstance(d, MapType):
+        return MapType(_drop_metadata(d.keyType), _drop_metadata(d.valueType), 
d.valueContainsNull)
+    return d
+
+
 def _parse_datatype_string(s: str) -> DataType:
     """
     Parses the given data type string to a :class:`DataType`. The data type 
string format equals


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to