This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 76f45a6892 ARROW-15365: [Python] Expose full cast options in the
pyarrow.compute.cast function (#13109)
76f45a6892 is described below
commit 76f45a6892b13391fdede4c72934f75f6d56143c
Author: Jabari Booker <[email protected]>
AuthorDate: Tue Jun 28 16:44:12 2022 -0400
ARROW-15365: [Python] Expose full cast options in the pyarrow.compute.cast
function (#13109)
Added CastOptions to pc.cast and related relate cast functions
Authored-by: JabariBooker <[email protected]>
Signed-off-by: David Li <[email protected]>
---
python/pyarrow/_compute.pyx | 19 ++++++++++++++++---
python/pyarrow/array.pxi | 8 +++++---
python/pyarrow/compute.py | 22 +++++++++++++++-------
python/pyarrow/table.pxi | 14 +++++++++-----
python/pyarrow/tests/test_array.py | 2 +-
python/pyarrow/tests/test_compute.py | 20 ++++++++++++++++++++
6 files changed, 66 insertions(+), 19 deletions(-)
diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx
index 04b57859ad..b9594d90e8 100644
--- a/python/pyarrow/_compute.pyx
+++ b/python/pyarrow/_compute.pyx
@@ -2320,7 +2320,7 @@ cdef class Expression(_Weakrefable):
options = NullOptions(nan_is_null=nan_is_null)
return Expression._call("is_null", [self], options)
- def cast(self, type, bint safe=True):
+ def cast(self, type=None, safe=None, options=None):
"""
Explicitly set or change the expression's data type.
@@ -2329,16 +2329,29 @@ cdef class Expression(_Weakrefable):
Parameters
----------
- type : DataType
+ type : DataType, default None
Type to cast array to.
safe : boolean, default True
Whether to check for conversion errors such as overflow.
+ options : CastOptions, default None
+ Additional checks pass by CastOptions
Returns
-------
cast : Expression
"""
- options = CastOptions.safe(ensure_type(type))
+ safe_vars_passed = (safe is not None) or (type is not None)
+
+ if safe_vars_passed and (options is not None):
+ raise ValueError("Must either pass values for 'type' and 'safe' or
pass a "
+ "value for 'options'")
+
+ if options is None:
+ type = ensure_type(type, allow_none=False)
+ if safe is False:
+ options = CastOptions.unsafe(type)
+ else:
+ options = CastOptions.safe(type)
return Expression._call("cast", [self], options)
def isin(self, values):
diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi
index 67a28e629a..7a4f040061 100644
--- a/python/pyarrow/array.pxi
+++ b/python/pyarrow/array.pxi
@@ -897,7 +897,7 @@ cdef class Array(_PandasConvertible):
result = self.ap.Diff(deref(other.ap))
return frombytes(result, safe=True)
- def cast(self, object target_type, safe=True):
+ def cast(self, object target_type=None, safe=None, options=None):
"""
Cast array values to another data type
@@ -905,16 +905,18 @@ cdef class Array(_PandasConvertible):
Parameters
----------
- target_type : DataType
+ target_type : DataType, default None
Type to cast array to.
safe : boolean, default True
Whether to check for conversion errors such as overflow.
+ options : CastOptions, default None
+ Additional checks pass by CastOptions
Returns
-------
cast : Array
"""
- return _pc().cast(self, target_type, safe=safe)
+ return _pc().cast(self, target_type, safe=safe, options=options)
def view(self, object target_type):
"""
diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py
index 526f0e4f7b..5873571c5a 100644
--- a/python/pyarrow/compute.py
+++ b/python/pyarrow/compute.py
@@ -323,7 +323,7 @@ def _make_global_functions():
_make_global_functions()
-def cast(arr, target_type, safe=True):
+def cast(arr, target_type=None, safe=None, options=None):
"""
Cast array values to another data type. Can also be invoked as an array
instance method.
@@ -335,6 +335,8 @@ def cast(arr, target_type, safe=True):
Type to cast to
safe : bool, default True
Check for overflows or other unsafe conversions
+ options : CastOptions, default None
+ Additional checks pass by CastOptions
Examples
--------
@@ -372,12 +374,18 @@ def cast(arr, target_type, safe=True):
-------
casted : Array
"""
- if target_type is None:
- raise ValueError("Cast target type must not be None")
- if safe:
- options = CastOptions.safe(target_type)
- else:
- options = CastOptions.unsafe(target_type)
+ safe_vars_passed = (safe is not None) or (target_type is not None)
+
+ if safe_vars_passed and (options is not None):
+ raise ValueError("Must either pass values for 'target_type' and 'safe'"
+ " or pass a value for 'options'")
+
+ if options is None:
+ target_type = pa.types.lib.ensure_type(target_type)
+ if safe is False:
+ options = CastOptions.unsafe(target_type)
+ else:
+ options = CastOptions.safe(target_type)
return call_function("cast", [arr], options)
diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi
index 3af10b46cc..17ea6d3558 100644
--- a/python/pyarrow/table.pxi
+++ b/python/pyarrow/table.pxi
@@ -488,7 +488,7 @@ cdef class ChunkedArray(_PandasConvertible):
return values
return values.astype(dtype)
- def cast(self, object target_type, safe=True):
+ def cast(self, object target_type=None, safe=None, options=None):
"""
Cast array values to another data type
@@ -496,10 +496,12 @@ cdef class ChunkedArray(_PandasConvertible):
Parameters
----------
- target_type : DataType
+ target_type : DataType, None
Type to cast array to.
safe : boolean, default True
Whether to check for conversion errors such as overflow.
+ options : CastOptions, default None
+ Additional checks pass by CastOptions
Returns
-------
@@ -518,7 +520,7 @@ cdef class ChunkedArray(_PandasConvertible):
>>> n_legs_seconds.type
DurationType(duration[s])
"""
- return _pc().cast(self, target_type, safe=safe)
+ return _pc().cast(self, target_type, safe=safe, options=options)
def dictionary_encode(self, null_encoding='mask'):
"""
@@ -3352,7 +3354,7 @@ cdef class Table(_PandasConvertible):
return result
- def cast(self, Schema target_schema, bint safe=True):
+ def cast(self, Schema target_schema, safe=None, options=None):
"""
Cast table values to another schema.
@@ -3362,6 +3364,8 @@ cdef class Table(_PandasConvertible):
Schema to cast to, the names and order of fields must match.
safe : bool, default True
Check for overflows or other unsafe conversions.
+ options : CastOptions, default None
+ Additional checks pass by CastOptions
Returns
-------
@@ -3405,7 +3409,7 @@ cdef class Table(_PandasConvertible):
.format(self.schema.names, target_schema.names))
for column, field in zip(self.itercolumns(), target_schema):
- casted = column.cast(field.type, safe=safe)
+ casted = column.cast(field.type, safe=safe, options=options)
newcols.append(casted)
return Table.from_arrays(newcols, schema=target_schema)
diff --git a/python/pyarrow/tests/test_array.py
b/python/pyarrow/tests/test_array.py
index 1421e60bb3..814691c92d 100644
--- a/python/pyarrow/tests/test_array.py
+++ b/python/pyarrow/tests/test_array.py
@@ -1281,7 +1281,7 @@ def test_cast_none():
# ARROW-3735: Ensure that calling cast(None) doesn't segfault.
arr = pa.array([1, 2, 3])
- with pytest.raises(ValueError):
+ with pytest.raises(TypeError):
arr.cast(None)
diff --git a/python/pyarrow/tests/test_compute.py
b/python/pyarrow/tests/test_compute.py
index 67857ed6ec..6664f2f824 100644
--- a/python/pyarrow/tests/test_compute.py
+++ b/python/pyarrow/tests/test_compute.py
@@ -1680,13 +1680,33 @@ def test_logical():
def test_cast():
+ arr = pa.array([1, 2, 3, 4], type='int64')
+ options = pc.CastOptions(pa.int8())
+
+ with pytest.raises(TypeError):
+ pc.cast(arr, target_type=None)
+
+ with pytest.raises(ValueError):
+ pc.cast(arr, 'int32', options=options)
+
+ with pytest.raises(ValueError):
+ pc.cast(arr, safe=True, options=options)
+
+ assert pc.cast(arr, options=options) == pa.array(
+ [1, 2, 3, 4], type='int8')
+
arr = pa.array([2 ** 63 - 1], type='int64')
+ allow_overflow_options = pc.CastOptions(
+ pa.int32(), allow_int_overflow=True)
with pytest.raises(pa.ArrowInvalid):
pc.cast(arr, 'int32')
assert pc.cast(arr, 'int32', safe=False) == pa.array([-1], type='int32')
+ assert pc.cast(arr, options=allow_overflow_options) == pa.array(
+ [-1], type='int32')
+
arr = pa.array([datetime(2010, 1, 1), datetime(2015, 1, 1)])
expected = pa.array([1262304000000, 1420070400000], type='timestamp[ms]')
assert pc.cast(arr, 'timestamp[ms]') == expected