This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 955349f6d970 [SPARK-48620][PYTHON] Fix internal raw data leak in 
`YearMonthIntervalType` and `CalendarIntervalType`
955349f6d970 is described below

commit 955349f6d970b64b496034087d2f2ea5fc0c161d
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Thu Jun 20 13:44:54 2024 +0800

    [SPARK-48620][PYTHON] Fix internal raw data leak in `YearMonthIntervalType` 
and `CalendarIntervalType`
    
    ### What changes were proposed in this pull request?
    Fix internal raw data leak in `YearMonthIntervalType/CalendarIntervalType`:
    
    PySpark Classic: it fails collection of 
`YearMonthIntervalType/CalendarIntervalType`
    
    ### Why are the changes needed?
    the raw data should not be leaked
    
    ### Does this PR introduce _any_ user-facing change?
    **PySpark Classic** (before):
    ```
    In [4]: spark.sql("SELECT INTERVAL '10-8' YEAR TO MONTH AS 
interval").first()[0]
    Out[4]: 128
    
    In [5]: spark.sql("SELECT make_interval(100, 11, 1, 1, 12, 30, 
01.001001)").first()[0]
    Out[5]: {'__class__': 'org.apache.spark.unsafe.types.CalendarInterval'}
    ```
    
    **PySpark Classic** (after):
    ```
    In [1]: spark.sql("SELECT INTERVAL '10-8' YEAR TO MONTH AS 
interval").first()
    ---------------------------------------------------------------------------
    PySparkNotImplementedError                Traceback (most recent call last)
    Cell In[1], line 1
    ----> 1 spark.sql("SELECT INTERVAL '10-8' YEAR TO MONTH AS 
interval").first()
    
    ...
    
    PySparkNotImplementedError: [NOT_IMPLEMENTED] 
YearMonthIntervalType.fromInternal is not implemented.
    
    In [2]: import os
    
    In [3]: os.environ['PYSPARK_YM_INTERVAL_LEGACY'] = "1"
    
    In [4]: spark.sql("SELECT INTERVAL '10-8' YEAR TO MONTH AS 
interval").first()
    Out[4]: Row(interval=128)
    ```
    
    ### How was this patch tested?
    Added test
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #46975 from zhengruifeng/fail_ym_interval_collect.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 .../source/migration_guide/pyspark_upgrade.rst     |  2 +-
 .../pyspark/sql/tests/connect/test_parity_types.py |  8 +++++
 python/pyspark/sql/tests/test_types.py             | 15 ++++++++
 python/pyspark/sql/types.py                        | 41 +++++++++++++++++++++-
 4 files changed, 64 insertions(+), 2 deletions(-)

diff --git a/python/docs/source/migration_guide/pyspark_upgrade.rst 
b/python/docs/source/migration_guide/pyspark_upgrade.rst
index 227293d83ada..529253042002 100644
--- a/python/docs/source/migration_guide/pyspark_upgrade.rst
+++ b/python/docs/source/migration_guide/pyspark_upgrade.rst
@@ -73,7 +73,7 @@ Upgrading from PySpark 3.5 to 4.0
 * In Spark 4.0, the aliases ``Y``, ``M``, ``H``, ``T``, ``S`` have been 
deprecated from Pandas API on Spark, use ``YE``, ``ME``, ``h``, ``min``, ``s`` 
instead respectively.
 * In Spark 4.0, the schema of a map column is inferred by merging the schemas 
of all pairs in the map. To restore the previous behavior where the schema is 
only inferred from the first non-null pair, you can set 
``spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled`` to ``true``.
 * In Spark 4.0, `compute.ops_on_diff_frames` is on by default. To restore the 
previous behavior, set `compute.ops_on_diff_frames` to `false`.
-
+* In Spark 4.0, the data type `YearMonthIntervalType` in ``DataFrame.collect`` 
no longer returns the underlying integers. To restore the previous behavior, 
set ``PYSPARK_YM_INTERVAL_LEGACY`` environment variable to ``1``.
 
 
 Upgrading from PySpark 3.3 to 3.4
diff --git a/python/pyspark/sql/tests/connect/test_parity_types.py 
b/python/pyspark/sql/tests/connect/test_parity_types.py
index fd75595b3873..6d06611def6a 100644
--- a/python/pyspark/sql/tests/connect/test_parity_types.py
+++ b/python/pyspark/sql/tests/connect/test_parity_types.py
@@ -94,6 +94,14 @@ class TypesParityTests(TypesTestsMixin, 
ReusedConnectTestCase):
     def test_schema_with_collations_json_ser_de(self):
         super().test_schema_with_collations_json_ser_de()
 
+    @unittest.skip("This test is dedicated for PySpark Classic.")
+    def test_ym_interval_in_collect(self):
+        super().test_ym_interval_in_collect()
+
+    @unittest.skip("This test is dedicated for PySpark Classic.")
+    def test_cal_interval_in_collect(self):
+        super().test_cal_interval_in_collect()
+
 
 if __name__ == "__main__":
     import unittest
diff --git a/python/pyspark/sql/tests/test_types.py 
b/python/pyspark/sql/tests/test_types.py
index 13c64b4bdc28..4810cf40e231 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -32,6 +32,7 @@ from pyspark.errors import (
     PySparkTypeError,
     PySparkValueError,
     PySparkRuntimeError,
+    PySparkNotImplementedError,
 )
 from pyspark.sql.types import (
     DataType,
@@ -2241,6 +2242,20 @@ class TypesTestsMixin:
                 self.spark.createDataFrame([[[[1, 
1.0]]]]).schema.fields[0].dataType,
             )
 
+    def test_ym_interval_in_collect(self):
+        with self.assertRaises(PySparkNotImplementedError):
+            self.spark.sql("SELECT INTERVAL '10-8' YEAR TO MONTH AS 
interval").first()
+
+        with self.temp_env({"PYSPARK_YM_INTERVAL_LEGACY": "1"}):
+            self.assertEqual(
+                self.spark.sql("SELECT INTERVAL '10-8' YEAR TO MONTH AS 
interval").first(),
+                Row(interval=128),
+            )
+
+    def test_cal_interval_in_collect(self):
+        with self.assertRaises(PySparkNotImplementedError):
+            self.spark.sql("SELECT make_interval(100, 11, 1, 1, 12, 30, 
01.001001)").first()[0]
+
 
 class DataTypeTests(unittest.TestCase):
     # regression test for SPARK-6055
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index c0f60f839356..b7b0a977ec08 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 #
 
+import os
 import sys
 import decimal
 import time
@@ -586,7 +587,12 @@ class DayTimeIntervalType(AnsiIntervalType):
 
 
 class YearMonthIntervalType(AnsiIntervalType):
-    """YearMonthIntervalType, represents year-month intervals of the SQL 
standard"""
+    """YearMonthIntervalType, represents year-month intervals of the SQL 
standard
+
+    Notes
+    -----
+    This data type doesn't support collection: df.collect/take/head.
+    """
 
     YEAR = 0
     MONTH = 1
@@ -628,6 +634,24 @@ class YearMonthIntervalType(AnsiIntervalType):
 
     jsonValue = _str_repr
 
+    def needConversion(self) -> bool:
+        # If PYSPARK_YM_INTERVAL_LEGACY is not set, needConversion is true,
+        # 'df.collect' fails with PySparkNotImplementedError;
+        # otherwise, no conversion is needed, and 'df.collect' returns the 
internal integers.
+        return not os.environ.get("PYSPARK_YM_INTERVAL_LEGACY") == "1"
+
+    def toInternal(self, obj: Any) -> Any:
+        raise PySparkNotImplementedError(
+            error_class="NOT_IMPLEMENTED",
+            message_parameters={"feature": "YearMonthIntervalType.toInternal"},
+        )
+
+    def fromInternal(self, obj: Any) -> Any:
+        raise PySparkNotImplementedError(
+            error_class="NOT_IMPLEMENTED",
+            message_parameters={"feature": 
"YearMonthIntervalType.fromInternal"},
+        )
+
     def __repr__(self) -> str:
         return "%s(%d, %d)" % (type(self).__name__, self.startField, 
self.endField)
 
@@ -645,6 +669,21 @@ class CalendarIntervalType(DataType, 
metaclass=DataTypeSingleton):
     def typeName(cls) -> str:
         return "interval"
 
+    def needConversion(self) -> bool:
+        return True
+
+    def toInternal(self, obj: Any) -> Any:
+        raise PySparkNotImplementedError(
+            error_class="NOT_IMPLEMENTED",
+            message_parameters={"feature": "YearMonthIntervalType.toInternal"},
+        )
+
+    def fromInternal(self, obj: Any) -> Any:
+        raise PySparkNotImplementedError(
+            error_class="NOT_IMPLEMENTED",
+            message_parameters={"feature": 
"YearMonthIntervalType.fromInternal"},
+        )
+
 
 class ArrayType(DataType):
     """Array data type.


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to