This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 955349f6d970 [SPARK-48620][PYTHON] Fix internal raw data leak in
`YearMonthIntervalType` and `CalendarIntervalType`
955349f6d970 is described below
commit 955349f6d970b64b496034087d2f2ea5fc0c161d
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Thu Jun 20 13:44:54 2024 +0800
[SPARK-48620][PYTHON] Fix internal raw data leak in `YearMonthIntervalType`
and `CalendarIntervalType`
### What changes were proposed in this pull request?
Fix internal raw data leak in `YearMonthIntervalType/CalendarIntervalType`:
PySpark Classic: it fails collection of
`YearMonthIntervalType/CalendarIntervalType`
### Why are the changes needed?
the raw data should not be leaked
### Does this PR introduce _any_ user-facing change?
**PySpark Classic** (before):
```
In [4]: spark.sql("SELECT INTERVAL '10-8' YEAR TO MONTH AS
interval").first()[0]
Out[4]: 128
In [5]: spark.sql("SELECT make_interval(100, 11, 1, 1, 12, 30,
01.001001)").first()[0]
Out[5]: {'__class__': 'org.apache.spark.unsafe.types.CalendarInterval'}
```
**PySpark Classic** (after):
```
In [1]: spark.sql("SELECT INTERVAL '10-8' YEAR TO MONTH AS
interval").first()
---------------------------------------------------------------------------
PySparkNotImplementedError Traceback (most recent call last)
Cell In[1], line 1
----> 1 spark.sql("SELECT INTERVAL '10-8' YEAR TO MONTH AS
interval").first()
...
PySparkNotImplementedError: [NOT_IMPLEMENTED]
YearMonthIntervalType.fromInternal is not implemented.
In [2]: import os
In [3]: os.environ['PYSPARK_YM_INTERVAL_LEGACY'] = "1"
In [4]: spark.sql("SELECT INTERVAL '10-8' YEAR TO MONTH AS
interval").first()
Out[4]: Row(interval=128)
```
### How was this patch tested?
Added test
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #46975 from zhengruifeng/fail_ym_interval_collect.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../source/migration_guide/pyspark_upgrade.rst | 2 +-
.../pyspark/sql/tests/connect/test_parity_types.py | 8 +++++
python/pyspark/sql/tests/test_types.py | 15 ++++++++
python/pyspark/sql/types.py | 41 +++++++++++++++++++++-
4 files changed, 64 insertions(+), 2 deletions(-)
diff --git a/python/docs/source/migration_guide/pyspark_upgrade.rst
b/python/docs/source/migration_guide/pyspark_upgrade.rst
index 227293d83ada..529253042002 100644
--- a/python/docs/source/migration_guide/pyspark_upgrade.rst
+++ b/python/docs/source/migration_guide/pyspark_upgrade.rst
@@ -73,7 +73,7 @@ Upgrading from PySpark 3.5 to 4.0
* In Spark 4.0, the aliases ``Y``, ``M``, ``H``, ``T``, ``S`` have been
deprecated from Pandas API on Spark, use ``YE``, ``ME``, ``h``, ``min``, ``s``
instead respectively.
* In Spark 4.0, the schema of a map column is inferred by merging the schemas
of all pairs in the map. To restore the previous behavior where the schema is
only inferred from the first non-null pair, you can set
``spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled`` to ``true``.
* In Spark 4.0, `compute.ops_on_diff_frames` is on by default. To restore the
previous behavior, set `compute.ops_on_diff_frames` to `false`.
-
+* In Spark 4.0, the data type `YearMonthIntervalType` in ``DataFrame.collect``
no longer returns the underlying integers. To restore the previous behavior,
set ``PYSPARK_YM_INTERVAL_LEGACY`` environment variable to ``1``.
Upgrading from PySpark 3.3 to 3.4
diff --git a/python/pyspark/sql/tests/connect/test_parity_types.py
b/python/pyspark/sql/tests/connect/test_parity_types.py
index fd75595b3873..6d06611def6a 100644
--- a/python/pyspark/sql/tests/connect/test_parity_types.py
+++ b/python/pyspark/sql/tests/connect/test_parity_types.py
@@ -94,6 +94,14 @@ class TypesParityTests(TypesTestsMixin,
ReusedConnectTestCase):
def test_schema_with_collations_json_ser_de(self):
super().test_schema_with_collations_json_ser_de()
+ @unittest.skip("This test is dedicated for PySpark Classic.")
+ def test_ym_interval_in_collect(self):
+ super().test_ym_interval_in_collect()
+
+ @unittest.skip("This test is dedicated for PySpark Classic.")
+ def test_cal_interval_in_collect(self):
+ super().test_cal_interval_in_collect()
+
if __name__ == "__main__":
import unittest
diff --git a/python/pyspark/sql/tests/test_types.py
b/python/pyspark/sql/tests/test_types.py
index 13c64b4bdc28..4810cf40e231 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -32,6 +32,7 @@ from pyspark.errors import (
PySparkTypeError,
PySparkValueError,
PySparkRuntimeError,
+ PySparkNotImplementedError,
)
from pyspark.sql.types import (
DataType,
@@ -2241,6 +2242,20 @@ class TypesTestsMixin:
self.spark.createDataFrame([[[[1,
1.0]]]]).schema.fields[0].dataType,
)
+ def test_ym_interval_in_collect(self):
+ with self.assertRaises(PySparkNotImplementedError):
+ self.spark.sql("SELECT INTERVAL '10-8' YEAR TO MONTH AS
interval").first()
+
+ with self.temp_env({"PYSPARK_YM_INTERVAL_LEGACY": "1"}):
+ self.assertEqual(
+ self.spark.sql("SELECT INTERVAL '10-8' YEAR TO MONTH AS
interval").first(),
+ Row(interval=128),
+ )
+
+ def test_cal_interval_in_collect(self):
+ with self.assertRaises(PySparkNotImplementedError):
+ self.spark.sql("SELECT make_interval(100, 11, 1, 1, 12, 30,
01.001001)").first()[0]
+
class DataTypeTests(unittest.TestCase):
# regression test for SPARK-6055
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index c0f60f839356..b7b0a977ec08 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -15,6 +15,7 @@
# limitations under the License.
#
+import os
import sys
import decimal
import time
@@ -586,7 +587,12 @@ class DayTimeIntervalType(AnsiIntervalType):
class YearMonthIntervalType(AnsiIntervalType):
- """YearMonthIntervalType, represents year-month intervals of the SQL
standard"""
+ """YearMonthIntervalType, represents year-month intervals of the SQL
standard
+
+ Notes
+ -----
+ This data type doesn't support collection: df.collect/take/head.
+ """
YEAR = 0
MONTH = 1
@@ -628,6 +634,24 @@ class YearMonthIntervalType(AnsiIntervalType):
jsonValue = _str_repr
+ def needConversion(self) -> bool:
+ # If PYSPARK_YM_INTERVAL_LEGACY is not set, needConversion is true,
+ # 'df.collect' fails with PySparkNotImplementedError;
+ # otherwise, no conversion is needed, and 'df.collect' returns the
internal integers.
+ return not os.environ.get("PYSPARK_YM_INTERVAL_LEGACY") == "1"
+
+ def toInternal(self, obj: Any) -> Any:
+ raise PySparkNotImplementedError(
+ error_class="NOT_IMPLEMENTED",
+ message_parameters={"feature": "YearMonthIntervalType.toInternal"},
+ )
+
+ def fromInternal(self, obj: Any) -> Any:
+ raise PySparkNotImplementedError(
+ error_class="NOT_IMPLEMENTED",
+ message_parameters={"feature":
"YearMonthIntervalType.fromInternal"},
+ )
+
def __repr__(self) -> str:
return "%s(%d, %d)" % (type(self).__name__, self.startField,
self.endField)
@@ -645,6 +669,21 @@ class CalendarIntervalType(DataType,
metaclass=DataTypeSingleton):
def typeName(cls) -> str:
return "interval"
+ def needConversion(self) -> bool:
+ return True
+
+ def toInternal(self, obj: Any) -> Any:
+ raise PySparkNotImplementedError(
+ error_class="NOT_IMPLEMENTED",
+ message_parameters={"feature": "YearMonthIntervalType.toInternal"},
+ )
+
+ def fromInternal(self, obj: Any) -> Any:
+ raise PySparkNotImplementedError(
+ error_class="NOT_IMPLEMENTED",
+ message_parameters={"feature":
"YearMonthIntervalType.fromInternal"},
+ )
+
class ArrayType(DataType):
"""Array data type.
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]