Yicong-Huang commented on code in PR #53778:
URL: https://github.com/apache/spark/pull/53778#discussion_r2692140694


##########
python/pyspark/tests/upstream/pyarrow/test_pyarrow_array_cast.py:
##########
@@ -0,0 +1,856 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+Tests for PyArrow's pa.Array.cast() method with default arguments.
+
+This test suite is part of SPARK-54936 to monitor upstream PyArrow behavior.
+It tests all combinations of source type -> target type to ensure PySpark's
+assumptions about PyArrow's casting behavior remain valid across versions.
+
+## Type Conversion Matrix (pa.Array.cast with default safe=True)
+
+### Comprehensive Type Coverage:
+- **Integers**: int8, int16, int32, int64, uint8, uint16, uint32, uint64
+- **Floats**: float16, float32, float64
+- **Strings**: string, large_string
+- **Binary**: binary, large_binary
+- **Decimals**: decimal128, decimal256
+- **Dates**: date32, date64
+- **Timestamps**: timestamp[s/ms/us/ns]
+- **Durations**: duration[s/ms/us/ns]
+- **Times**: time32[s/ms], time64[us/ns]
+- **Lists**: list, large_list, fixed_size_list
+- **Complex**: struct, map
+- **NumPy**: np.int8-64, np.uint8-64, np.float16/32/64
+- **Pandas**: pd.Int64Dtype(), pd.Float64Dtype(), pd.ArrowDtype()
+
+### Conversion Matrix:
+
+| From \\ To         | int8-64 | uint8-64 | float16-64 | bool | string/large | 
binary/large | decimal128/256 | date32/64 | timestamp | duration | time | 
list/large | struct | map |
+|-------------------|---------|----------|------------|------|--------------|--------------|----------------|-----------|-----------|----------|------|------------|--------|-----|
+| **int8-64**       | ✓       | ✓        | ✓          | ✓    | ✓            | 
✗            | ✓              | ✗         | ✗         | ✗        | ✗    | ✗     
     | ✗      | ✗   |
+| **uint8-64**      | ✓       | ✓        | ✓          | ✓    | ✓            | 
✗            | ✓              | ✗         | ✗         | ✗        | ✗    | ✗     
     | ✗      | ✗   |
+| **float16-64**    | ✓ᴺᵀ     | ✓ᴺᵀ      | ✓          | ✓    | ✓            | 
✗            | ✓              | ✗         | ✗         | ✗        | ✗    | ✗     
     | ✗      | ✗   |
+| **bool**          | ✓       | ✓        | ✓⁽ᶠ³²⁺⁾    | ✓    | ✓            | 
✗            | ✗              | ✗         | ✗         | ✗        | ✗    | ✗     
     | ✗      | ✗   |
+| **string/large**  | ✓       | ✓        | ✓          | ✓    | ✓            | 
✓            | ✓              | ✓         | ✓         | ✗        | ✗    | ✗     
     | ✗      | ✗   |
+| **binary/large**  | ✗       | ✗        | ✗          | ✗    | ✓            | 
✓            | ✗              | ✗         | ✗         | ✗        | ✗    | ✗     
     | ✗      | ✗   |
+| **decimal128/256**| ✓       | ✓        | ✓          | ✗    | ✓            | 
✗            | ✓              | ✗         | ✗         | ✗        | ✗    | ✗     
     | ✗      | ✗   |
+| **date32/64**     | ✗       | ✗        | ✗          | ✗    | ✓            | 
✗            | ✗              | ✓         | ✓         | ✗        | ✗    | ✗     
     | ✗      | ✗   |
+| **timestamp**     | ✗       | ✗        | ✗          | ✗    | ✓            | 
✗            | ✗              | ✓         | ✓ᵁᴾ       | ✗        | ✗    | ✗     
     | ✗      | ✗   |
+| **duration**      | ✗       | ✗        | ✗          | ✗    | ✓            | 
✗            | ✗              | ✗         | ✗         | ✓ᵁᴾ      | ✗    | ✗     
     | ✗      | ✗   |
+| **time32/64**     | ✗       | ✗        | ✗          | ✗    | ✓            | 
✗            | ✗              | ✗         | ✗         | ✗        | ✓ᵁᴾ  | ✗     
     | ✗      | ✗   |
+| **list/large**    | ✗       | ✗        | ✗          | ✗    | ✗            | 
✗            | ✗              | ✗         | ✗         | ✗        | ✗    | ✓ᴱᴸ   
     | ✗      | ✗   |
+| **struct**        | ✗       | ✗        | ✗          | ✗    | ✗            | 
✗            | ✗              | ✗         | ✗         | ✗        | ✗    | ✗     
     | ✓ᴱᴸ    | ✗   |
+| **map**           | ✗       | ✗        | ✗          | ✗    | ✗            | 
✗            | ✗              | ✗         | ✗         | ✗        | ✗    | ✗     
     | ✗      | ✓ᴱᴸ |
+
+Legend:
+- ✓      = Allowed (no precision loss with safe=True)
+- ✓ᴺᵀ    = No Truncation: only if no truncation (e.g., 1.0→1 ok, 1.5→1 fails)
+- ✓ᵁᴾ    = Upcast only: converting to higher precision unit (s→ms ok, ms→s may 
fail)
+- ✓ᴱᴸ    = Element-wise: element/field types must be castable
+- ✓⁽ᶠ³²⁺⁾ = float32 and above (bool->float16 not supported)
+- ✗      = Not allowed / raises ArrowInvalid
+
+Notes:
+1. With default safe=True, PyArrow prevents precision loss
+2. Float→Int requires whole numbers (1.0 ok, 1.5 fails)
+3. Timestamp/Duration/Time conversions to lower precision may lose data
+4. Large int64 values may exceed float64 safe range (±2^53)
+5. Nested type casts recursively cast element types
+6. string/large_string and binary/large_binary are interchangeable
+7. decimal128 and decimal256 require sufficient precision for int64 (≥21 
digits)
+"""
+
+import unittest
+import math
+from datetime import datetime, date
+
+from pyspark.testing.utils import (
+    have_pandas,
+    have_numpy,
+    have_pyarrow,
+    pandas_requirement_message,
+    numpy_requirement_message,
+    pyarrow_requirement_message,
+)
+
+
[email protected](not have_pyarrow, pyarrow_requirement_message)
+class PyArrowArrayCastTests(unittest.TestCase):
+    """Test pa.Array.cast() with default arguments for all type 
combinations."""
+
+    def test_all_integer_type_casts(self):
+        """Test casting between all integer types (int8-64, uint8-64)."""
+        import pyarrow as pa
+
+        # All integer types with test values
+        # (type, test_values, type_name)
+        int_types = [
+            (pa.int8(), [1, 2, 3, None], "int8"),
+            (pa.int16(), [1, 2, 3, None], "int16"),
+            (pa.int32(), [1, 2, 3, None], "int32"),
+            (pa.int64(), [1, 2, 3, None], "int64"),
+            (pa.uint8(), [1, 2, 3, None], "uint8"),
+            (pa.uint16(), [1, 2, 3, None], "uint16"),
+            (pa.uint32(), [1, 2, 3, None], "uint32"),
+            (pa.uint64(), [1, 2, 3, None], "uint64"),
+        ]
+
+        # Test all int -> int conversions
+        for source_type, source_values, source_name in int_types:
+            source_arr = pa.array(source_values, type=source_type)
+            for target_type, _, target_name in int_types:
+                result = source_arr.cast(target_type)
+                self.assertEqual(result.type, target_type)
+                self.assertEqual(result[0].as_py(), 1)
+                self.assertEqual(result[1].as_py(), 2)
+                self.assertEqual(result[2].as_py(), 3)
+                self.assertIsNone(result[3].as_py())
+
+    def test_all_float_type_casts(self):
+        """Test casting between all float types (float16, float32, float64)."""
+        import pyarrow as pa
+
+        # All float types with test values
+        float_types = [
+            (pa.float16(), [1.5, 2.5, 3.5, None], "float16"),
+            (pa.float32(), [1.5, 2.5, 3.5, None], "float32"),
+            (pa.float64(), [1.5, 2.5, 3.5, None], "float64"),
+        ]
+
+        # Test all float -> float conversions
+        for source_type, source_values, source_name in float_types:
+            source_arr = pa.array(source_values, type=source_type)
+            for target_type, _, target_name in float_types:
+                result = source_arr.cast(target_type)
+                self.assertEqual(result.type, target_type)
+                # float16 has lower precision
+                if source_name == "float16" or target_name == "float16":
+                    self.assertAlmostEqual(result[0].as_py(), 1.5, places=2)
+                    self.assertAlmostEqual(result[1].as_py(), 2.5, places=2)
+                    self.assertAlmostEqual(result[2].as_py(), 3.5, places=2)
+                else:
+                    self.assertAlmostEqual(result[0].as_py(), 1.5, places=5)
+                    self.assertAlmostEqual(result[1].as_py(), 2.5, places=5)
+                    self.assertAlmostEqual(result[2].as_py(), 3.5, places=5)
+                self.assertIsNone(result[3].as_py())
+
+    def test_integer_to_float_casts(self):
+        """Test casting all integer types to all float types."""
+        import pyarrow as pa
+
+        int_types = [
+            (pa.int8(), [1, 2, 3, None], "int8"),
+            (pa.int16(), [1, 2, 3, None], "int16"),
+            (pa.int32(), [1, 2, 3, None], "int32"),
+            (pa.int64(), [1, 2, 3, None], "int64"),
+            (pa.uint8(), [1, 2, 3, None], "uint8"),
+            (pa.uint16(), [1, 2, 3, None], "uint16"),
+            (pa.uint32(), [1, 2, 3, None], "uint32"),
+            (pa.uint64(), [1, 2, 3, None], "uint64"),
+        ]
+
+        float_types = [
+            (pa.float16(), "float16"),
+            (pa.float32(), "float32"),
+            (pa.float64(), "float64"),
+        ]
+
+        for source_type, source_values, source_name in int_types:
+            source_arr = pa.array(source_values, type=source_type)
+            for target_type, target_name in float_types:
+                result = source_arr.cast(target_type)
+                self.assertEqual(result.type, target_type)
+                self.assertAlmostEqual(result[0].as_py(), 1.0, places=2)
+                self.assertAlmostEqual(result[1].as_py(), 2.0, places=2)
+                self.assertAlmostEqual(result[2].as_py(), 3.0, places=2)
+                self.assertIsNone(result[3].as_py())
+
+    def test_float_to_integer_casts(self):
+        """Test casting float types to integer types (only whole numbers with 
safe=True)."""
+        import pyarrow as pa
+
+        # Use whole numbers for safe casting
+        float_types = [
+            (pa.float16(), [1.0, 2.0, 3.0, None], "float16"),
+            (pa.float32(), [1.0, 2.0, 3.0, None], "float32"),
+            (pa.float64(), [1.0, 2.0, 3.0, None], "float64"),
+        ]
+
+        int_types = [
+            (pa.int8(), "int8"),
+            (pa.int16(), "int16"),
+            (pa.int32(), "int32"),
+            (pa.int64(), "int64"),
+            (pa.uint8(), "uint8"),
+            (pa.uint16(), "uint16"),
+            (pa.uint32(), "uint32"),
+            (pa.uint64(), "uint64"),
+        ]
+
+        for source_type, source_values, source_name in float_types:
+            source_arr = pa.array(source_values, type=source_type)
+            for target_type, target_name in int_types:
+                result = source_arr.cast(target_type)
+                self.assertEqual(result.type, target_type)
+                self.assertEqual(result[0].as_py(), 1)
+                self.assertEqual(result[1].as_py(), 2)
+                self.assertEqual(result[2].as_py(), 3)
+                self.assertIsNone(result[3].as_py())
+
+    def test_numeric_to_bool_casts(self):
+        """Test casting all numeric types to boolean."""
+        import pyarrow as pa
+
+        numeric_types = [
+            (pa.int8(), [0, 1, 2, None], "int8"),
+            (pa.int32(), [0, 1, 2, None], "int32"),
+            (pa.int64(), [0, 1, 2, None], "int64"),
+            (pa.float32(), [0.0, 1.0, 2.0, None], "float32"),
+            (pa.float64(), [0.0, 1.0, 2.0, None], "float64"),
+        ]
+
+        for source_type, source_values, source_name in numeric_types:
+            arr = pa.array(source_values, type=source_type)
+            result = arr.cast(pa.bool_())
+            self.assertEqual(result.type, pa.bool_())
+            self.assertEqual(result[0].as_py(), False)
+            self.assertEqual(result[1].as_py(), True)
+            self.assertEqual(result[2].as_py(), True)
+            self.assertIsNone(result[3].as_py())
+
+    def test_bool_to_numeric_casts(self):
+        """Test casting boolean to all numeric types."""
+        import pyarrow as pa
+
+        arr_bool = pa.array([True, False, True, None], type=pa.bool_())
+
+        numeric_types = [
+            (pa.int8(), "int8"),
+            (pa.int16(), "int16"),
+            (pa.int32(), "int32"),
+            (pa.int64(), "int64"),
+            (pa.uint8(), "uint8"),
+            (pa.uint16(), "uint16"),
+            (pa.uint32(), "uint32"),
+            (pa.uint64(), "uint64"),
+            # float16 not supported for bool->float cast
+            (pa.float32(), "float32"),
+            (pa.float64(), "float64"),
+        ]
+
+        for target_type, target_name in numeric_types:
+            result = arr_bool.cast(target_type)
+            self.assertEqual(result.type, target_type)
+            self.assertEqual(result[0].as_py(), 1)
+            self.assertEqual(result[1].as_py(), 0)
+            self.assertEqual(result[2].as_py(), 1)
+            self.assertIsNone(result[3].as_py())
+
+    def test_numeric_to_string_casts(self):
+        """Test casting all numeric types to string."""
+        import pyarrow as pa
+
+        test_cases = [
+            (pa.int8(), [1, 2, 3, None], ["1", "2", "3", None]),
+            (pa.int32(), [1, 2, 3, None], ["1", "2", "3", None]),
+            (pa.int64(), [1, 2, 3, None], ["1", "2", "3", None]),
+            (pa.uint32(), [1, 2, 3, None], ["1", "2", "3", None]),
+            (pa.float32(), [1.5, 2.5, 3.5, None], ["1.5", "2.5", "3.5", None]),
+            (pa.float64(), [1.5, 2.5, 3.5, None], ["1.5", "2.5", "3.5", None]),
+            (pa.bool_(), [True, False, None], ["true", "false", None]),
+        ]
+
+        for source_type, source_values, expected_values in test_cases:
+            arr = pa.array(source_values, type=source_type)
+            result = arr.cast(pa.string())
+            self.assertEqual(result.type, pa.string())
+            for i in range(len(expected_values) - 1):
+                self.assertEqual(result[i].as_py(), expected_values[i])
+            self.assertIsNone(result[len(expected_values) - 1].as_py())
+
+    def test_string_to_numeric_casts(self):
+        """Test casting string to all numeric types."""
+        import pyarrow as pa
+
+        test_cases = [
+            (["1", "2", "3", None], pa.int8(), [1, 2, 3, None]),
+            (["1", "2", "3", None], pa.int32(), [1, 2, 3, None]),
+            (["1", "2", "3", None], pa.int64(), [1, 2, 3, None]),
+            (["1", "2", "3", None], pa.uint32(), [1, 2, 3, None]),
+            (["1.5", "2.5", "3.5", None], pa.float32(), [1.5, 2.5, 3.5, None]),
+            (["1.5", "2.5", "3.5", None], pa.float64(), [1.5, 2.5, 3.5, None]),
+            (["true", "false", "1", "0", None], pa.bool_(), [True, False, 
True, False, None]),
+        ]
+
+        for source_values, target_type, expected_values in test_cases:
+            arr = pa.array(source_values, type=pa.string())
+            result = arr.cast(target_type)
+            self.assertEqual(result.type, target_type)
+            for i in range(len(expected_values) - 1):
+                if isinstance(expected_values[i], float):
+                    self.assertAlmostEqual(result[i].as_py(), 
expected_values[i], places=5)
+                else:
+                    self.assertEqual(result[i].as_py(), expected_values[i])
+            self.assertIsNone(result[len(expected_values) - 1].as_py())
+
+    def test_string_binary_casts(self):

Review Comment:
   moved to next PR. this PR focuses on numerical source types.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to