This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new d7f69e7003a3 [SPARK-48190][PYTHON][PS][TESTS] Introduce a helper function to drop metadata d7f69e7003a3 is described below commit d7f69e7003a3c7e7ad22a39e6aaacd183d26d326 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Wed May 8 18:48:21 2024 +0800 [SPARK-48190][PYTHON][PS][TESTS] Introduce a helper function to drop metadata ### What changes were proposed in this pull request? Introduce a helper function to drop metadata ### Why are the changes needed? existing helper function `remove_metadata` in PS doesn't support nested types, so cannot be reused in other places ### Does this PR introduce _any_ user-facing change? no, test only ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #46466 from zhengruifeng/py_drop_meta. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- python/pyspark/pandas/internal.py | 17 +++-------------- .../pyspark/sql/tests/connect/test_connect_function.py | 11 +++++++++-- python/pyspark/sql/types.py | 13 +++++++++++++ 3 files changed, 25 insertions(+), 16 deletions(-) diff --git a/python/pyspark/pandas/internal.py b/python/pyspark/pandas/internal.py index 767ec9a57f9b..8ab8d79d5686 100644 --- a/python/pyspark/pandas/internal.py +++ b/python/pyspark/pandas/internal.py @@ -33,6 +33,7 @@ from pyspark.sql import ( Window, ) from pyspark.sql.types import ( # noqa: F401 + _drop_metadata, BooleanType, DataType, LongType, @@ -761,14 +762,8 @@ class InternalFrame: # in a few tests when using Spark Connect. However, the function works properly. # Therefore, we temporarily perform Spark Connect tests by excluding metadata # until the issue is resolved. - def remove_metadata(struct_field: StructField) -> StructField: - new_struct_field = StructField( - struct_field.name, struct_field.dataType, struct_field.nullable - ) - return new_struct_field - assert all( - remove_metadata(index_field.struct_field) == remove_metadata(struct_field) + _drop_metadata(index_field.struct_field) == _drop_metadata(struct_field) for index_field, struct_field in zip(index_fields, struct_fields) ), (index_fields, struct_fields) else: @@ -795,14 +790,8 @@ class InternalFrame: # in a few tests when using Spark Connect. However, the function works properly. # Therefore, we temporarily perform Spark Connect tests by excluding metadata # until the issue is resolved. - def remove_metadata(struct_field: StructField) -> StructField: - new_struct_field = StructField( - struct_field.name, struct_field.dataType, struct_field.nullable - ) - return new_struct_field - assert all( - remove_metadata(data_field.struct_field) == remove_metadata(struct_field) + _drop_metadata(data_field.struct_field) == _drop_metadata(struct_field) for data_field, struct_field in zip(data_fields, struct_fields) ), (data_fields, struct_fields) else: diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py b/python/pyspark/sql/tests/connect/test_connect_function.py index 9d4db8cf7d15..0f0abfd4b856 100644 --- a/python/pyspark/sql/tests/connect/test_connect_function.py +++ b/python/pyspark/sql/tests/connect/test_connect_function.py @@ -21,7 +21,14 @@ from inspect import getmembers, isfunction from pyspark.util import is_remote_only from pyspark.errors import PySparkTypeError, PySparkValueError from pyspark.sql import SparkSession as PySparkSession -from pyspark.sql.types import StringType, StructType, StructField, ArrayType, IntegerType +from pyspark.sql.types import ( + _drop_metadata, + StringType, + StructType, + StructField, + ArrayType, + IntegerType, +) from pyspark.testing import assertDataFrameEqual from pyspark.testing.pandasutils import PandasOnSparkTestUtils from pyspark.testing.connectutils import ReusedConnectTestCase, should_test_connect @@ -1668,7 +1675,7 @@ class SparkConnectFunctionTests(ReusedConnectTestCase, PandasOnSparkTestUtils, S ) # TODO: 'cdf.schema' has an extra metadata '{'__autoGeneratedAlias': 'true'}' - # self.assertEqual(cdf.schema, sdf.schema) + self.assertEqual(_drop_metadata(cdf.schema), _drop_metadata(sdf.schema)) self.assertEqual(cdf.collect(), sdf.collect()) def test_csv_functions(self): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 48aa3e8e4fab..41be12620fd5 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1569,6 +1569,19 @@ _INTERVAL_DAYTIME = re.compile(r"interval (day|hour|minute|second)( to (day|hour _INTERVAL_YEARMONTH = re.compile(r"interval (year|month)( to (year|month))?") +def _drop_metadata(d: Union[DataType, StructField]) -> Union[DataType, StructField]: + assert isinstance(d, (DataType, StructField)) + if isinstance(d, StructField): + return StructField(d.name, _drop_metadata(d.dataType), d.nullable, None) + elif isinstance(d, StructType): + return StructType([cast(StructField, _drop_metadata(f)) for f in d.fields]) + elif isinstance(d, ArrayType): + return ArrayType(_drop_metadata(d.elementType), d.containsNull) + elif isinstance(d, MapType): + return MapType(_drop_metadata(d.keyType), _drop_metadata(d.valueType), d.valueContainsNull) + return d + + def _parse_datatype_string(s: str) -> DataType: """ Parses the given data type string to a :class:`DataType`. The data type string format equals --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org