This is an automated email from the ASF dual-hosted git repository.
kszucs pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 7de93af ARROW-9394: [Python] Support pickling of Scalars
7de93af is described below
commit 7de93af56562ec81a8e0f29446b7cb70458a518a
Author: Krisztián Szűcs <[email protected]>
AuthorDate: Wed Jul 29 13:14:41 2020 +0200
ARROW-9394: [Python] Support pickling of Scalars
Since there are no sequence converters available for Dictionary and Union
types we cannot construct them directly thus `pa.scalar` fail as the reducer
function to reconstruct them.
We can add custom reducers for them later, so I'm leaving them as
NotImplemented for now.
Closes #7852 from kszucs/ARROW-9394
Authored-by: Krisztián Szűcs <[email protected]>
Signed-off-by: Krisztián Szűcs <[email protected]>
---
python/pyarrow/scalar.pxi | 3 ++
python/pyarrow/tests/test_scalars.py | 73 +++++++++++++++++++++++++++++++++++-
2 files changed, 75 insertions(+), 1 deletion(-)
diff --git a/python/pyarrow/scalar.pxi b/python/pyarrow/scalar.pxi
index f607070..7f35419 100644
--- a/python/pyarrow/scalar.pxi
+++ b/python/pyarrow/scalar.pxi
@@ -97,6 +97,9 @@ cdef class Scalar(_Weakrefable):
cdef CScalarHash hasher
return hasher(self.wrapped)
+ def __reduce__(self):
+ return scalar, (self.as_py(), self.type)
+
def as_py(self):
raise NotImplementedError()
diff --git a/python/pyarrow/tests/test_scalars.py
b/python/pyarrow/tests/test_scalars.py
index 3a19a1c..091ae38 100644
--- a/python/pyarrow/tests/test_scalars.py
+++ b/python/pyarrow/tests/test_scalars.py
@@ -17,6 +17,7 @@
import datetime
import decimal
+import pickle
import pytest
import weakref
@@ -41,16 +42,29 @@ import pyarrow as pa
(1.0, None, pa.DoubleScalar, pa.DoubleValue),
(np.float16(1.0), pa.float16(), pa.HalfFloatScalar, pa.HalfFloatValue),
(1.0, pa.float32(), pa.FloatScalar, pa.FloatValue),
+ (decimal.Decimal("1.123"), None, pa.Decimal128Scalar, pa.Decimal128Value),
("string", None, pa.StringScalar, pa.StringValue),
(b"bytes", None, pa.BinaryScalar, pa.BinaryValue),
+ ("largestring", pa.large_string(), pa.LargeStringScalar,
+ pa.LargeStringValue),
+ (b"largebytes", pa.large_binary(), pa.LargeBinaryScalar,
+ pa.LargeBinaryValue),
+ (b"abc", pa.binary(3), pa.FixedSizeBinaryScalar, pa.FixedSizeBinaryValue),
([1, 2, 3], None, pa.ListScalar, pa.ListValue),
([1, 2, 3, 4], pa.large_list(pa.int8()), pa.LargeListScalar,
pa.LargeListValue),
+ ([1, 2, 3, 4, 5], pa.list_(pa.int8(), 5), pa.FixedSizeListScalar,
+ pa.FixedSizeListValue),
(datetime.date.today(), None, pa.Date32Scalar, pa.Date32Value),
+ (datetime.date.today(), pa.date64(), pa.Date64Scalar, pa.Date64Value),
(datetime.datetime.now(), None, pa.TimestampScalar, pa.TimestampValue),
+ (datetime.datetime.now().time().replace(microsecond=0), pa.time32('s'),
+ pa.Time32Scalar, pa.Time32Value),
(datetime.datetime.now().time(), None, pa.Time64Scalar, pa.Time64Value),
(datetime.timedelta(days=1), None, pa.DurationScalar, pa.DurationValue),
- ({'a': 1, 'b': [1, 2]}, None, pa.StructScalar, pa.StructValue)
+ ({'a': 1, 'b': [1, 2]}, None, pa.StructScalar, pa.StructValue),
+ ([('a', 1), ('b', 2)], pa.map_(pa.string(), pa.int8()), pa.MapScalar,
+ pa.MapValue),
])
def test_basics(value, ty, klass, deprecated):
s = pa.scalar(value, type=ty)
@@ -69,6 +83,11 @@ def test_basics(value, ty, klass, deprecated):
assert s.as_py() is None
assert s != pa.scalar(value, type=ty)
+ # test pickle roundtrip
+ restored = pickle.loads(pickle.dumps(s))
+ assert s.equals(restored)
+
+ # test that scalars are weak-referenceable
wr = weakref.ref(s)
assert wr() is not None
del s
@@ -95,6 +114,11 @@ def test_nulls():
assert v is pa.NA
assert v.as_py() is None
+ # test pickle roundtrip
+ restored = pickle.loads(pickle.dumps(null))
+ assert restored.equals(null)
+
+ # test that scalars are weak-referenceable
wr = weakref.ref(null)
assert wr() is not None
del null
@@ -502,6 +526,9 @@ def test_map():
with pytest.raises(IndexError):
s[2]
+ restored = pickle.loads(pickle.dumps(s))
+ assert restored.equals(s)
+
def test_dictionary():
indices = [2, 1, 2, 0]
@@ -522,3 +549,47 @@ def test_dictionary():
assert s.index_value.as_py() == i
with pytest.warns(FutureWarning):
assert s.dictionary_value.as_py() == v
+
+ with pytest.raises(pa.ArrowNotImplementedError):
+ pickle.loads(pickle.dumps(s))
+
+
+def test_union():
+ # sparse
+ arr = pa.UnionArray.from_sparse(
+ pa.array([0, 0, 1, 1], type=pa.int8()),
+ [
+ pa.array(["a", "b", "c", "d"]),
+ pa.array([1, 2, 3, 4])
+ ]
+ )
+ for s in arr:
+ assert isinstance(s, pa.UnionScalar)
+ assert s.type.equals(arr.type)
+ assert s.is_valid is True
+ with pytest.raises(pa.ArrowNotImplementedError):
+ pickle.loads(pickle.dumps(s))
+
+ assert arr[0].as_py() == "a"
+ assert arr[1].as_py() == "b"
+ assert arr[2].as_py() == 3
+ assert arr[3].as_py() == 4
+
+ # dense
+ arr = pa.UnionArray.from_dense(
+ types=pa.array([0, 1, 0, 0, 1, 1, 0], type='int8'),
+ value_offsets=pa.array([0, 0, 2, 1, 1, 2, 3], type='int32'),
+ children=[
+ pa.array([b'a', b'b', b'c', b'd'], type='binary'),
+ pa.array([1, 2, 3], type='int64')
+ ]
+ )
+ for s in arr:
+ assert isinstance(s, pa.UnionScalar)
+ assert s.type.equals(arr.type)
+ assert s.is_valid is True
+ with pytest.raises(pa.ArrowNotImplementedError):
+ pickle.loads(pickle.dumps(s))
+
+ assert arr[0].as_py() == b'a'
+ assert arr[5].as_py() == 3