This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 76f45a6892 ARROW-15365: [Python] Expose full cast options in the 
pyarrow.compute.cast function (#13109)
76f45a6892 is described below

commit 76f45a6892b13391fdede4c72934f75f6d56143c
Author: Jabari Booker <[email protected]>
AuthorDate: Tue Jun 28 16:44:12 2022 -0400

    ARROW-15365: [Python] Expose full cast options in the pyarrow.compute.cast 
function (#13109)
    
    Added CastOptions to pc.cast and related relate cast functions
    
    Authored-by: JabariBooker <[email protected]>
    Signed-off-by: David Li <[email protected]>
---
 python/pyarrow/_compute.pyx          | 19 ++++++++++++++++---
 python/pyarrow/array.pxi             |  8 +++++---
 python/pyarrow/compute.py            | 22 +++++++++++++++-------
 python/pyarrow/table.pxi             | 14 +++++++++-----
 python/pyarrow/tests/test_array.py   |  2 +-
 python/pyarrow/tests/test_compute.py | 20 ++++++++++++++++++++
 6 files changed, 66 insertions(+), 19 deletions(-)

diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx
index 04b57859ad..b9594d90e8 100644
--- a/python/pyarrow/_compute.pyx
+++ b/python/pyarrow/_compute.pyx
@@ -2320,7 +2320,7 @@ cdef class Expression(_Weakrefable):
         options = NullOptions(nan_is_null=nan_is_null)
         return Expression._call("is_null", [self], options)
 
-    def cast(self, type, bint safe=True):
+    def cast(self, type=None, safe=None, options=None):
         """
         Explicitly set or change the expression's data type.
 
@@ -2329,16 +2329,29 @@ cdef class Expression(_Weakrefable):
 
         Parameters
         ----------
-        type : DataType
+        type : DataType, default None
             Type to cast array to.
         safe : boolean, default True
             Whether to check for conversion errors such as overflow.
+        options : CastOptions, default None
+            Additional checks pass by CastOptions
 
         Returns
         -------
         cast : Expression
         """
-        options = CastOptions.safe(ensure_type(type))
+        safe_vars_passed = (safe is not None) or (type is not None)
+
+        if safe_vars_passed and (options is not None):
+            raise ValueError("Must either pass values for 'type' and 'safe' or 
pass a "
+                             "value for 'options'")
+
+        if options is None:
+            type = ensure_type(type, allow_none=False)
+            if safe is False:
+                options = CastOptions.unsafe(type)
+            else:
+                options = CastOptions.safe(type)
         return Expression._call("cast", [self], options)
 
     def isin(self, values):
diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi
index 67a28e629a..7a4f040061 100644
--- a/python/pyarrow/array.pxi
+++ b/python/pyarrow/array.pxi
@@ -897,7 +897,7 @@ cdef class Array(_PandasConvertible):
             result = self.ap.Diff(deref(other.ap))
         return frombytes(result, safe=True)
 
-    def cast(self, object target_type, safe=True):
+    def cast(self, object target_type=None, safe=None, options=None):
         """
         Cast array values to another data type
 
@@ -905,16 +905,18 @@ cdef class Array(_PandasConvertible):
 
         Parameters
         ----------
-        target_type : DataType
+        target_type : DataType, default None
             Type to cast array to.
         safe : boolean, default True
             Whether to check for conversion errors such as overflow.
+        options : CastOptions, default None
+            Additional checks pass by CastOptions
 
         Returns
         -------
         cast : Array
         """
-        return _pc().cast(self, target_type, safe=safe)
+        return _pc().cast(self, target_type, safe=safe, options=options)
 
     def view(self, object target_type):
         """
diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py
index 526f0e4f7b..5873571c5a 100644
--- a/python/pyarrow/compute.py
+++ b/python/pyarrow/compute.py
@@ -323,7 +323,7 @@ def _make_global_functions():
 _make_global_functions()
 
 
-def cast(arr, target_type, safe=True):
+def cast(arr, target_type=None, safe=None, options=None):
     """
     Cast array values to another data type. Can also be invoked as an array
     instance method.
@@ -335,6 +335,8 @@ def cast(arr, target_type, safe=True):
         Type to cast to
     safe : bool, default True
         Check for overflows or other unsafe conversions
+    options : CastOptions, default None
+        Additional checks pass by CastOptions
 
     Examples
     --------
@@ -372,12 +374,18 @@ def cast(arr, target_type, safe=True):
     -------
     casted : Array
     """
-    if target_type is None:
-        raise ValueError("Cast target type must not be None")
-    if safe:
-        options = CastOptions.safe(target_type)
-    else:
-        options = CastOptions.unsafe(target_type)
+    safe_vars_passed = (safe is not None) or (target_type is not None)
+
+    if safe_vars_passed and (options is not None):
+        raise ValueError("Must either pass values for 'target_type' and 'safe'"
+                         " or pass a value for 'options'")
+
+    if options is None:
+        target_type = pa.types.lib.ensure_type(target_type)
+        if safe is False:
+            options = CastOptions.unsafe(target_type)
+        else:
+            options = CastOptions.safe(target_type)
     return call_function("cast", [arr], options)
 
 
diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi
index 3af10b46cc..17ea6d3558 100644
--- a/python/pyarrow/table.pxi
+++ b/python/pyarrow/table.pxi
@@ -488,7 +488,7 @@ cdef class ChunkedArray(_PandasConvertible):
             return values
         return values.astype(dtype)
 
-    def cast(self, object target_type, safe=True):
+    def cast(self, object target_type=None, safe=None, options=None):
         """
         Cast array values to another data type
 
@@ -496,10 +496,12 @@ cdef class ChunkedArray(_PandasConvertible):
 
         Parameters
         ----------
-        target_type : DataType
+        target_type : DataType, None
             Type to cast array to.
         safe : boolean, default True
             Whether to check for conversion errors such as overflow.
+        options : CastOptions, default None
+            Additional checks pass by CastOptions
 
         Returns
         -------
@@ -518,7 +520,7 @@ cdef class ChunkedArray(_PandasConvertible):
         >>> n_legs_seconds.type
         DurationType(duration[s])
         """
-        return _pc().cast(self, target_type, safe=safe)
+        return _pc().cast(self, target_type, safe=safe, options=options)
 
     def dictionary_encode(self, null_encoding='mask'):
         """
@@ -3352,7 +3354,7 @@ cdef class Table(_PandasConvertible):
 
         return result
 
-    def cast(self, Schema target_schema, bint safe=True):
+    def cast(self, Schema target_schema, safe=None, options=None):
         """
         Cast table values to another schema.
 
@@ -3362,6 +3364,8 @@ cdef class Table(_PandasConvertible):
             Schema to cast to, the names and order of fields must match.
         safe : bool, default True
             Check for overflows or other unsafe conversions.
+        options : CastOptions, default None
+            Additional checks pass by CastOptions
 
         Returns
         -------
@@ -3405,7 +3409,7 @@ cdef class Table(_PandasConvertible):
                              .format(self.schema.names, target_schema.names))
 
         for column, field in zip(self.itercolumns(), target_schema):
-            casted = column.cast(field.type, safe=safe)
+            casted = column.cast(field.type, safe=safe, options=options)
             newcols.append(casted)
 
         return Table.from_arrays(newcols, schema=target_schema)
diff --git a/python/pyarrow/tests/test_array.py 
b/python/pyarrow/tests/test_array.py
index 1421e60bb3..814691c92d 100644
--- a/python/pyarrow/tests/test_array.py
+++ b/python/pyarrow/tests/test_array.py
@@ -1281,7 +1281,7 @@ def test_cast_none():
     # ARROW-3735: Ensure that calling cast(None) doesn't segfault.
     arr = pa.array([1, 2, 3])
 
-    with pytest.raises(ValueError):
+    with pytest.raises(TypeError):
         arr.cast(None)
 
 
diff --git a/python/pyarrow/tests/test_compute.py 
b/python/pyarrow/tests/test_compute.py
index 67857ed6ec..6664f2f824 100644
--- a/python/pyarrow/tests/test_compute.py
+++ b/python/pyarrow/tests/test_compute.py
@@ -1680,13 +1680,33 @@ def test_logical():
 
 
 def test_cast():
+    arr = pa.array([1, 2, 3, 4], type='int64')
+    options = pc.CastOptions(pa.int8())
+
+    with pytest.raises(TypeError):
+        pc.cast(arr, target_type=None)
+
+    with pytest.raises(ValueError):
+        pc.cast(arr, 'int32', options=options)
+
+    with pytest.raises(ValueError):
+        pc.cast(arr, safe=True, options=options)
+
+    assert pc.cast(arr, options=options) == pa.array(
+        [1, 2, 3, 4], type='int8')
+
     arr = pa.array([2 ** 63 - 1], type='int64')
+    allow_overflow_options = pc.CastOptions(
+        pa.int32(), allow_int_overflow=True)
 
     with pytest.raises(pa.ArrowInvalid):
         pc.cast(arr, 'int32')
 
     assert pc.cast(arr, 'int32', safe=False) == pa.array([-1], type='int32')
 
+    assert pc.cast(arr, options=allow_overflow_options) == pa.array(
+        [-1], type='int32')
+
     arr = pa.array([datetime(2010, 1, 1), datetime(2015, 1, 1)])
     expected = pa.array([1262304000000, 1420070400000], type='timestamp[ms]')
     assert pc.cast(arr, 'timestamp[ms]') == expected

Reply via email to