This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 279f5bece4cb [SPARK-48322][SPARK-42965][SQL][CONNECT][PYTHON] Drop
internal metadata in `DataFrame.schema`
279f5bece4cb is described below
commit 279f5bece4cb86ae592194d0e25bc2b9319d2267
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed May 29 09:48:03 2024 +0900
[SPARK-48322][SPARK-42965][SQL][CONNECT][PYTHON] Drop internal metadata in
`DataFrame.schema`
### What changes were proposed in this pull request?
Drop internal metadata in `DataFrame.schema`
### Why are the changes needed?
Internal metadata might be leaked in both Spark Connect and Spark Classic,
e.g. in Spark Classic
```
In [9]: spark.range(10).select(sf.lit(1).alias("key"),
"id").groupBy("key").agg(sf.max("id")).schema.json()
Out[9]:
'{"fields":[{"metadata":{},"name":"key","nullable":false,"type":"integer"},{"metadata":{"__autoGeneratedAlias":"true"},"name":"max(id)","nullable":true,"type":"long"}],"type":"struct"}'
```
What make it worse is that internal metadata maybe leaked in different
cases of Spark Connect and Spark Classic, so need to add additional
`_drop_meta` in Pandas APIs to make assertions work.
### Does this PR introduce _any_ user-facing change?
yes, the internal metadata will be dropped in `DataFrame.schema`, e.g. the
`__autoGeneratedAlias` in the example query
before:
```
In [9]: spark.range(10).select(sf.lit(1).alias("key"),
"id").groupBy("key").agg(sf.max("id")).schema.json()
Out[9]:
'{"fields":[{"metadata":{},"name":"key","nullable":false,"type":"integer"},{"metadata":{"__autoGeneratedAlias":"true"},"name":"max(id)","nullable":true,"type":"long"}],"type":"struct"}'
```
after:
```
In [2]: spark.range(10).select(sf.lit(1).alias("key"),
"id").groupBy("key").agg(sf.max("id")).schema.json()
Out[2]:
'{"fields":[{"metadata":{},"name":"key","nullable":false,"type":"integer"},{"metadata":{},"name":"max(id)","nullable":true,"type":"long"}],"type":"struct"}
```
### How was this patch tested?
CI
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #46636 from zhengruifeng/connect_remove_internal_meta.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/pandas/internal.py | 37 +++++-----------------
.../sql/tests/connect/test_connect_function.py | 4 +--
python/pyspark/sql/types.py | 13 --------
.../main/scala/org/apache/spark/sql/Dataset.scala | 4 +--
.../apache/spark/sql/DataFrameAggregateSuite.scala | 5 ++-
5 files changed, 13 insertions(+), 50 deletions(-)
diff --git a/python/pyspark/pandas/internal.py
b/python/pyspark/pandas/internal.py
index 04285aa2d879..fd0f28e50b2f 100644
--- a/python/pyspark/pandas/internal.py
+++ b/python/pyspark/pandas/internal.py
@@ -33,7 +33,6 @@ from pyspark.sql import (
Window,
)
from pyspark.sql.types import ( # noqa: F401
- _drop_metadata,
BooleanType,
DataType,
LongType,
@@ -757,20 +756,10 @@ class InternalFrame:
if is_testing():
struct_fields =
spark_frame.select(index_spark_columns).schema.fields
- if is_remote():
- # TODO(SPARK-42965): For some reason, the metadata of
StructField is different
- # in a few tests when using Spark Connect. However, the
function works properly.
- # Therefore, we temporarily perform Spark Connect tests by
excluding metadata
- # until the issue is resolved.
- assert all(
- _drop_metadata(index_field.struct_field) ==
_drop_metadata(struct_field)
- for index_field, struct_field in zip(index_fields,
struct_fields)
- ), (index_fields, struct_fields)
- else:
- assert all(
- index_field.struct_field == struct_field
- for index_field, struct_field in zip(index_fields,
struct_fields)
- ), (index_fields, struct_fields)
+ assert all(
+ index_field.struct_field == struct_field
+ for index_field, struct_field in zip(index_fields,
struct_fields)
+ ), (index_fields, struct_fields)
self._index_fields: List[InternalField] = index_fields
@@ -785,20 +774,10 @@ class InternalFrame:
if is_testing():
struct_fields =
spark_frame.select(data_spark_columns).schema.fields
- if is_remote():
- # TODO(SPARK-42965): For some reason, the metadata of
StructField is different
- # in a few tests when using Spark Connect. However, the
function works properly.
- # Therefore, we temporarily perform Spark Connect tests by
excluding metadata
- # until the issue is resolved.
- assert all(
- _drop_metadata(data_field.struct_field) ==
_drop_metadata(struct_field)
- for data_field, struct_field in zip(data_fields,
struct_fields)
- ), (data_fields, struct_fields)
- else:
- assert all(
- data_field.struct_field == struct_field
- for data_field, struct_field in zip(data_fields,
struct_fields)
- ), (data_fields, struct_fields)
+ assert all(
+ data_field.struct_field == struct_field
+ for data_field, struct_field in zip(data_fields, struct_fields)
+ ), (data_fields, struct_fields)
self._data_fields: List[InternalField] = data_fields
diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py
b/python/pyspark/sql/tests/connect/test_connect_function.py
index 0f0abfd4b856..1fb0195b5203 100644
--- a/python/pyspark/sql/tests/connect/test_connect_function.py
+++ b/python/pyspark/sql/tests/connect/test_connect_function.py
@@ -22,7 +22,6 @@ from pyspark.util import is_remote_only
from pyspark.errors import PySparkTypeError, PySparkValueError
from pyspark.sql import SparkSession as PySparkSession
from pyspark.sql.types import (
- _drop_metadata,
StringType,
StructType,
StructField,
@@ -1674,8 +1673,7 @@ class SparkConnectFunctionTests(ReusedConnectTestCase,
PandasOnSparkTestUtils, S
)
)
- # TODO: 'cdf.schema' has an extra metadata '{'__autoGeneratedAlias':
'true'}'
- self.assertEqual(_drop_metadata(cdf.schema),
_drop_metadata(sdf.schema))
+ self.assertEqual(cdf.schema, sdf.schema)
self.assertEqual(cdf.collect(), sdf.collect())
def test_csv_functions(self):
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index c72ff72ce426..62f09e948792 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -1770,19 +1770,6 @@ _INTERVAL_YEARMONTH = re.compile(r"interval
(year|month)( to (year|month))?")
_COLLATIONS_METADATA_KEY = "__COLLATIONS"
-def _drop_metadata(d: Union[DataType, StructField]) -> Union[DataType,
StructField]:
- assert isinstance(d, (DataType, StructField))
- if isinstance(d, StructField):
- return StructField(d.name, _drop_metadata(d.dataType), d.nullable,
None)
- elif isinstance(d, StructType):
- return StructType([cast(StructField, _drop_metadata(f)) for f in
d.fields])
- elif isinstance(d, ArrayType):
- return ArrayType(_drop_metadata(d.elementType), d.containsNull)
- elif isinstance(d, MapType):
- return MapType(_drop_metadata(d.keyType), _drop_metadata(d.valueType),
d.valueContainsNull)
- return d
-
-
def _parse_datatype_string(s: str) -> DataType:
"""
Parses the given data type string to a :class:`DataType`. The data type
string format equals
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index c7511737b2b3..afde54fc3d11 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -49,7 +49,7 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.trees.{TreeNodeTag, TreePattern}
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
-import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, IntervalUtils}
+import org.apache.spark.sql.catalyst.util.{removeInternalMetadata,
CharVarcharUtils, IntervalUtils}
import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId
import org.apache.spark.sql.errors.{QueryCompilationErrors,
QueryExecutionErrors}
import org.apache.spark.sql.execution._
@@ -561,7 +561,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
def schema: StructType = sparkSession.withActive {
- queryExecution.analyzed.schema
+ removeInternalMetadata(queryExecution.analyzed.schema)
}
/**
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 620ee430cab2..a89cae865435 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -24,7 +24,6 @@ import scala.util.Random
import org.scalatest.matchers.must.Matchers.the
import org.apache.spark.{SparkArithmeticException, SparkRuntimeException}
-import org.apache.spark.sql.catalyst.util.AUTO_GENERATED_ALIAS
import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec,
ObjectHashAggregateExec, SortAggregateExec}
@@ -1465,7 +1464,7 @@ class DataFrameAggregateSuite extends QueryTest
Duration.ofSeconds(14)) ::
Nil)
assert(find(sumDF2.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined)
- val metadata = new MetadataBuilder().putString(AUTO_GENERATED_ALIAS,
"true").build()
+ val metadata = Metadata.empty
assert(sumDF2.schema == StructType(Seq(StructField("class", IntegerType,
false),
StructField("sum(year-month)", YearMonthIntervalType(), metadata =
metadata),
StructField("sum(year)", YearMonthIntervalType(YEAR), metadata =
metadata),
@@ -1599,7 +1598,7 @@ class DataFrameAggregateSuite extends QueryTest
Duration.ofMinutes(4).plusSeconds(20),
Duration.ofSeconds(7)) :: Nil)
assert(find(avgDF2.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined)
- val metadata = new MetadataBuilder().putString(AUTO_GENERATED_ALIAS,
"true").build()
+ val metadata = Metadata.empty
assert(avgDF2.schema == StructType(Seq(
StructField("class", IntegerType, false),
StructField("avg(year-month)", YearMonthIntervalType(), metadata =
metadata),
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]