This is an automated email from the ASF dual-hosted git repository.
apitrou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 30cc5dd788 GH-45653: [Python] Scalar subclasses should implement
Python protocols (#45818)
30cc5dd788 is described below
commit 30cc5dd788d275dcec9ea55779803291663a863f
Author: Nic Crane <[email protected]>
AuthorDate: Tue Jun 10 09:38:34 2025 +0100
GH-45653: [Python] Scalar subclasses should implement Python protocols
(#45818)
### Rationale for this change
Implement dunder methods on Scalar objects
### What changes are included in this PR?
* integer scalars implement `__int__`
* floating-point scalars implement `__float__`
* binary scalars implement
[`__bytes__`](https://docs.python.org/3.13/reference/datamodel.html#object.__bytes__)
* binary scalars implement the [buffer
protocol](https://docs.python.org/3.13/reference/datamodel.html#object.__buffer__)
* we explicitly test that Struct scalars implement Sequences
* Map scalar implement mapping
### Are these changes tested?
Yes
### Are there any user-facing changes?
Yes
* GitHub Issue: #45653
Lead-authored-by: Nic Crane <[email protected]>
Co-authored-by: Alenka Frim <[email protected]>
Co-authored-by: Antoine Pitrou <[email protected]>
Signed-off-by: Antoine Pitrou <[email protected]>
---
docs/source/python/compute.rst | 15 ++++++
python/pyarrow/scalar.pxi | 90 +++++++++++++++++++++++++++++++++---
python/pyarrow/tests/test_scalars.py | 48 +++++++++++++++++--
3 files changed, 144 insertions(+), 9 deletions(-)
diff --git a/docs/source/python/compute.rst b/docs/source/python/compute.rst
index c2b46c8f3f..397af9d2c5 100644
--- a/docs/source/python/compute.rst
+++ b/docs/source/python/compute.rst
@@ -63,6 +63,21 @@ Below are a few simple examples::
>>> pc.multiply(x, y)
<pyarrow.DoubleScalar: 72.54>
+If you are using a compute function which returns more than one value, results
+will be returned as a ``StructScalar``. You can extract the individual values
by
+calling the :meth:`pyarrow.StructScalar.values` method::
+
+ >>> import pyarrow as pa
+ >>> import pyarrow.compute as pc
+ >>> a = pa.array([1, 1, 2, 3])
+ >>> pc.min_max(a)
+ <pyarrow.StructScalar: [('min', 1), ('max', 3)]>
+ >>> a, b = pc.min_max(a).values()
+ >>> a
+ <pyarrow.Int64Scalar: 1>
+ >>> b
+ <pyarrow.Int64Scalar: 3>
+
These functions can do more than just element-by-element operations.
Here is an example of sorting a table::
diff --git a/python/pyarrow/scalar.pxi b/python/pyarrow/scalar.pxi
index a9cdcff4e4..5934a7aa8c 100644
--- a/python/pyarrow/scalar.pxi
+++ b/python/pyarrow/scalar.pxi
@@ -18,6 +18,7 @@
import collections
import warnings
from uuid import UUID
+from collections.abc import Sequence, Mapping
cdef class Scalar(_Weakrefable):
@@ -219,6 +220,8 @@ cdef class BooleanScalar(Scalar):
cdef CBooleanScalar* sp = <CBooleanScalar*> self.wrapped.get()
return sp.value if sp.is_valid else None
+ def __bool__(self):
+ return self.as_py() or False
cdef class UInt8Scalar(Scalar):
"""
@@ -238,6 +241,9 @@ cdef class UInt8Scalar(Scalar):
cdef CUInt8Scalar* sp = <CUInt8Scalar*> self.wrapped.get()
return sp.value if sp.is_valid else None
+ def __index__(self):
+ return self.as_py()
+
cdef class Int8Scalar(Scalar):
"""
@@ -257,6 +263,9 @@ cdef class Int8Scalar(Scalar):
cdef CInt8Scalar* sp = <CInt8Scalar*> self.wrapped.get()
return sp.value if sp.is_valid else None
+ def __index__(self):
+ return self.as_py()
+
cdef class UInt16Scalar(Scalar):
"""
@@ -276,6 +285,9 @@ cdef class UInt16Scalar(Scalar):
cdef CUInt16Scalar* sp = <CUInt16Scalar*> self.wrapped.get()
return sp.value if sp.is_valid else None
+ def __index__(self):
+ return self.as_py()
+
cdef class Int16Scalar(Scalar):
"""
@@ -295,6 +307,9 @@ cdef class Int16Scalar(Scalar):
cdef CInt16Scalar* sp = <CInt16Scalar*> self.wrapped.get()
return sp.value if sp.is_valid else None
+ def __index__(self):
+ return self.as_py()
+
cdef class UInt32Scalar(Scalar):
"""
@@ -314,6 +329,9 @@ cdef class UInt32Scalar(Scalar):
cdef CUInt32Scalar* sp = <CUInt32Scalar*> self.wrapped.get()
return sp.value if sp.is_valid else None
+ def __index__(self):
+ return self.as_py()
+
cdef class Int32Scalar(Scalar):
"""
@@ -333,6 +351,9 @@ cdef class Int32Scalar(Scalar):
cdef CInt32Scalar* sp = <CInt32Scalar*> self.wrapped.get()
return sp.value if sp.is_valid else None
+ def __index__(self):
+ return self.as_py()
+
cdef class UInt64Scalar(Scalar):
"""
@@ -352,6 +373,9 @@ cdef class UInt64Scalar(Scalar):
cdef CUInt64Scalar* sp = <CUInt64Scalar*> self.wrapped.get()
return sp.value if sp.is_valid else None
+ def __index__(self):
+ return self.as_py()
+
cdef class Int64Scalar(Scalar):
"""
@@ -371,6 +395,9 @@ cdef class Int64Scalar(Scalar):
cdef CInt64Scalar* sp = <CInt64Scalar*> self.wrapped.get()
return sp.value if sp.is_valid else None
+ def __index__(self):
+ return self.as_py()
+
cdef class HalfFloatScalar(Scalar):
"""
@@ -390,6 +417,12 @@ cdef class HalfFloatScalar(Scalar):
cdef CHalfFloatScalar* sp = <CHalfFloatScalar*> self.wrapped.get()
return PyFloat_FromHalf(sp.value) if sp.is_valid else None
+ def __float__(self):
+ return self.as_py()
+
+ def __int__(self):
+ return int(self.as_py())
+
cdef class FloatScalar(Scalar):
"""
@@ -409,6 +442,12 @@ cdef class FloatScalar(Scalar):
cdef CFloatScalar* sp = <CFloatScalar*> self.wrapped.get()
return sp.value if sp.is_valid else None
+ def __float__(self):
+ return self.as_py()
+
+ def __int__(self):
+ return int(float(self))
+
cdef class DoubleScalar(Scalar):
"""
@@ -428,6 +467,12 @@ cdef class DoubleScalar(Scalar):
cdef CDoubleScalar* sp = <CDoubleScalar*> self.wrapped.get()
return sp.value if sp.is_valid else None
+ def __float__(self):
+ return self.as_py()
+
+ def __int__(self):
+ return int(float(self))
+
cdef class Decimal32Scalar(Scalar):
"""
@@ -843,6 +888,15 @@ cdef class BinaryScalar(Scalar):
buffer = self.as_buffer()
return None if buffer is None else buffer.to_pybytes()
+ def __bytes__(self):
+ return self.as_py()
+
+ def __getbuffer__(self, cp.Py_buffer* buffer, int flags):
+ buf = self.as_buffer()
+ if buf is None:
+ raise ValueError("Cannot export buffer from null Arrow Scalar")
+ cp.PyObject_GetBuffer(buf, buffer, flags)
+
cdef class LargeBinaryScalar(BinaryScalar):
pass
@@ -883,7 +937,7 @@ cdef class StringViewScalar(StringScalar):
pass
-cdef class ListScalar(Scalar):
+cdef class ListScalar(Scalar, Sequence):
"""
Concrete class for list-like scalars.
"""
@@ -952,7 +1006,7 @@ cdef class LargeListViewScalar(ListScalar):
pass
-cdef class StructScalar(Scalar, collections.abc.Mapping):
+cdef class StructScalar(Scalar, Mapping):
"""
Concrete class for struct scalars.
"""
@@ -1051,20 +1105,34 @@ cdef class StructScalar(Scalar,
collections.abc.Mapping):
return str(self._as_py_tuple())
-cdef class MapScalar(ListScalar):
+cdef class MapScalar(ListScalar, Mapping):
"""
Concrete class for map scalars.
"""
def __getitem__(self, i):
"""
- Return the value at the given index.
+ Return the value at the given index or key.
"""
+
arr = self.values
if arr is None:
- raise IndexError(i)
+ raise IndexError(i) if isinstance(i, int) else KeyError(i)
+
+ key_field = self.type.key_field.name
+ item_field = self.type.item_field.name
+
+ if isinstance(i, (bytes, str)):
+ try:
+ key_index = list(self.keys()).index(i)
+ except ValueError:
+ raise KeyError(i)
+
+ dct = arr[_normalize_index(key_index, len(arr))]
+ return dct[item_field]
+
dct = arr[_normalize_index(i, len(arr))]
- return (dct[self.type.key_field.name], dct[self.type.item_field.name])
+ return (dct[key_field], dct[item_field])
def __iter__(self):
"""
@@ -1118,6 +1186,16 @@ cdef class MapScalar(ListScalar):
result_dict[key] = value
return result_dict
+ def keys(self):
+ """
+ Return the keys of the map as a list.
+ """
+ arr = self.values
+ if arr is None:
+ return []
+ key_field = self.type.key_field.name
+ return [k.as_py() for k in arr.field(key_field)]
+
cdef class DictionaryScalar(Scalar):
"""
diff --git a/python/pyarrow/tests/test_scalars.py
b/python/pyarrow/tests/test_scalars.py
index 14f6ccef62..0f62dd98f8 100644
--- a/python/pyarrow/tests/test_scalars.py
+++ b/python/pyarrow/tests/test_scalars.py
@@ -19,6 +19,7 @@ import datetime
import decimal
import pytest
import weakref
+from collections.abc import Sequence, Mapping
try:
import numpy as np
@@ -208,17 +209,26 @@ def test_timestamp_scalar():
def test_bool():
false = pa.scalar(False)
true = pa.scalar(True)
+ null = pa.scalar(None, type=pa.bool_())
assert isinstance(false, pa.BooleanScalar)
assert isinstance(true, pa.BooleanScalar)
+ assert isinstance(null, pa.BooleanScalar)
assert repr(true) == "<pyarrow.BooleanScalar: True>"
assert str(true) == "True"
assert repr(false) == "<pyarrow.BooleanScalar: False>"
assert str(false) == "False"
+ assert repr(null) == "<pyarrow.BooleanScalar: None>"
+ assert str(null) == "None"
assert true.as_py() is True
assert false.as_py() is False
+ assert null.as_py() is None
+
+ assert bool(true) is True
+ assert bool(false) is False
+ assert bool(null) is False
def test_numerics():
@@ -228,6 +238,7 @@ def test_numerics():
assert repr(s) == "<pyarrow.Int64Scalar: 1>"
assert str(s) == "1"
assert s.as_py() == 1
+ assert int(s) == 1
with pytest.raises(OverflowError):
pa.scalar(-1, type='uint8')
@@ -238,6 +249,8 @@ def test_numerics():
assert repr(s) == "<pyarrow.DoubleScalar: 1.5>"
assert str(s) == "1.5"
assert s.as_py() == 1.5
+ assert float(s) == 1.5
+ assert int(s) == 1
# float16
s = pa.scalar(0.5, type='float16')
@@ -245,6 +258,8 @@ def test_numerics():
assert repr(s) == "<pyarrow.HalfFloatScalar: 0.5>"
assert str(s) == "0.5"
assert s.as_py() == 0.5
+ assert float(s) == 0.5
+ assert int(s) == 0
def test_decimal128():
@@ -540,7 +555,7 @@ def test_string(value, ty, scalar_typ):
assert buf.to_pybytes() == value.encode()
[email protected]('value', [b'foo', b'bar'])
[email protected]('value', [b'foo', b'bar', b'', None])
@pytest.mark.parametrize(('ty', 'scalar_typ'), [
(pa.binary(), pa.BinaryScalar),
(pa.large_binary(), pa.LargeBinaryScalar),
@@ -556,14 +571,30 @@ def test_binary(value, ty, scalar_typ):
assert s != b'xxxxx'
buf = s.as_buffer()
- assert isinstance(buf, pa.Buffer)
- assert buf.to_pybytes() == value
+
+ if value is None:
+ assert buf is None
+ with pytest.raises(ValueError):
+ memoryview(s)
+ else:
+ assert buf.to_pybytes() == value
+ assert isinstance(buf, pa.Buffer)
+ assert bytes(s) == value
+
+ memview = memoryview(s)
+ assert memview.tobytes() == value
+ assert memview.format == 'b'
+ assert memview.itemsize == 1
+ assert memview.ndim == 1
+ assert memview.shape == (len(value),)
+ assert memview.strides == (1,)
def test_fixed_size_binary():
s = pa.scalar(b'foof', type=pa.binary(4))
assert isinstance(s, pa.FixedSizeBinaryScalar)
assert s.as_py() == b'foof'
+ assert bytes(s) == b'foof'
with pytest.raises(pa.ArrowInvalid):
pa.scalar(b'foof5', type=pa.binary(4))
@@ -593,6 +624,7 @@ def test_list(ty, klass):
s[-3]
with pytest.raises(IndexError):
s[2]
+ assert isinstance(s, Sequence)
@pytest.mark.numpy
@@ -666,6 +698,7 @@ def test_struct():
v = {'x': 2, 'y': 3.5}
s = pa.scalar(v, type=ty)
assert list(s) == list(s.keys()) == ['x', 'y']
+
assert list(s.values()) == [
pa.scalar(2, type=pa.int16()),
pa.scalar(3.5, type=pa.float32())
@@ -687,6 +720,7 @@ def test_struct():
assert isinstance(s['y'], pa.FloatScalar)
assert s['x'].as_py() == 2
assert s['y'].as_py() == 3.5
+ assert isinstance(s, Mapping)
with pytest.raises(KeyError):
s['nonexistent']
@@ -698,10 +732,13 @@ def test_struct():
assert 'y' in s
assert isinstance(s['x'], pa.Int16Scalar)
assert isinstance(s['y'], pa.FloatScalar)
+ assert isinstance(s[0], pa.Int16Scalar)
+ assert isinstance(s[1], pa.FloatScalar)
assert s['x'].is_valid is False
assert s['y'].is_valid is False
assert s['x'].as_py() is None
assert s['y'].as_py() is None
+ assert isinstance(s, Mapping)
def test_struct_duplicate_fields():
@@ -776,16 +813,21 @@ def test_map(pickle_module):
)
assert s[-1] == s[1]
assert s[-2] == s[0]
+ assert s['b'] == pa.scalar(2, type=pa.int8())
with pytest.raises(IndexError):
s[-3]
with pytest.raises(IndexError):
s[2]
+ with pytest.raises(KeyError):
+ s['fake_key']
restored = pickle_module.loads(pickle_module.dumps(s))
assert restored.equals(s)
assert s.as_py(maps_as_pydicts="strict") == {'a': 1, 'b': 2}
+ assert isinstance(s, Mapping)
+
def test_map_duplicate_fields():
ty = pa.map_(pa.string(), pa.int8())