[
https://issues.apache.org/jira/browse/ARROW-1654?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16215955#comment-16215955
]
ASF GitHub Bot commented on ARROW-1654:
---------------------------------------
wesm closed pull request #1238: ARROW-1654: [Python] Implement pickling for
DataType, Field, Schema
URL: https://github.com/apache/arrow/pull/1238
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/python/pyarrow/includes/libarrow.pxd
b/python/pyarrow/includes/libarrow.pxd
index 60aa4d694..0e5d4a8ed 100644
--- a/python/pyarrow/includes/libarrow.pxd
+++ b/python/pyarrow/includes/libarrow.pxd
@@ -182,6 +182,7 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:
CListType(const shared_ptr[CDataType]& value_type)
CListType(const shared_ptr[CField]& field)
shared_ptr[CDataType] value_type()
+ shared_ptr[CField] value_field()
cdef cppclass CStringType" arrow::StringType"(CDataType):
pass
diff --git a/python/pyarrow/tests/test_schema.py
b/python/pyarrow/tests/test_schema.py
index c77be9805..d6b2655b7 100644
--- a/python/pyarrow/tests/test_schema.py
+++ b/python/pyarrow/tests/test_schema.py
@@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
+import pickle
+
import pytest
import numpy as np
@@ -304,3 +306,40 @@ def test_schema_repr_with_dictionaries():
two: int32""")
assert repr(sch) == expected
+
+
+def test_type_schema_pickling():
+ cases = [
+ pa.int8(),
+ pa.string(),
+ pa.binary(),
+ pa.binary(10),
+ pa.list_(pa.string()),
+ pa.struct([
+ pa.field('a', 'int8'),
+ pa.field('b', 'string')
+ ]),
+ pa.time32('s'),
+ pa.time64('us'),
+ pa.date32(),
+ pa.date64(),
+ pa.timestamp('ms'),
+ pa.timestamp('ns'),
+ pa.decimal(12, 2),
+ pa.field('a', 'string', metadata={b'foo': b'bar'})
+ ]
+
+ for val in cases:
+ roundtripped = pickle.loads(pickle.dumps(val))
+ assert val == roundtripped
+
+ fields = []
+ for i, f in enumerate(cases):
+ if isinstance(f, pa.Field):
+ fields.append(f)
+ else:
+ fields.append(pa.field('_f{}'.format(i), f))
+
+ schema = pa.schema(fields, metadata={b'foo': b'bar'})
+ roundtripped = pickle.loads(pickle.dumps(schema))
+ assert schema == roundtripped
diff --git a/python/pyarrow/tests/test_types.py
b/python/pyarrow/tests/test_types.py
index d8eea622c..e6ff5b156 100644
--- a/python/pyarrow/tests/test_types.py
+++ b/python/pyarrow/tests/test_types.py
@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
-
import pyarrow as pa
import pyarrow.types as types
diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi
index 0bef1aa60..7b95b1563 100644
--- a/python/pyarrow/types.pxi
+++ b/python/pyarrow/types.pxi
@@ -69,6 +69,16 @@ cdef class DataType:
)
return frombytes(self.type.ToString())
+ def __reduce__(self):
+ return self.__class__, (), self.__getstate__()
+
+ def __getstate__(self):
+ return str(self),
+
+ def __setstate__(self, state):
+ cdef DataType reconstituted = type_for_alias(state[0])
+ self.init(reconstituted.sp_type)
+
def __repr__(self):
return '{0.__class__.__name__}({0})'.format(self)
@@ -117,6 +127,15 @@ cdef class ListType(DataType):
DataType.init(self, type)
self.list_type = <const CListType*> type.get()
+ def __getstate__(self):
+ cdef CField* field = self.list_type.value_field().get()
+ name = field.name()
+ return name, self.value_type
+
+ def __setstate__(self, state):
+ cdef DataType reconstituted = list_(field(state[0], state[1]))
+ self.init(reconstituted.sp_type)
+
property value_type:
def __get__(self):
@@ -128,6 +147,25 @@ cdef class StructType(DataType):
cdef void init(self, const shared_ptr[CDataType]& type):
DataType.init(self, type)
+ def __getitem__(self, i):
+ if i < 0 or i >= self.num_children:
+ raise IndexError(i)
+
+ return pyarrow_wrap_field(self.type.child(i))
+
+ property num_children:
+
+ def __get__(self):
+ return self.type.num_children()
+
+ def __getstate__(self):
+ cdef CStructType* type = <CStructType*> self.sp_type.get()
+ return [self[i] for i in range(self.num_children)]
+
+ def __setstate__(self, state):
+ cdef DataType reconstituted = struct(state)
+ self.init(reconstituted.sp_type)
+
cdef class UnionType(DataType):
@@ -196,6 +234,13 @@ cdef class FixedSizeBinaryType(DataType):
self.fixed_size_binary_type = (
<const CFixedSizeBinaryType*> type.get())
+ def __getstate__(self):
+ return self.byte_width
+
+ def __setstate__(self, state):
+ cdef DataType reconstituted = binary(state)
+ self.init(reconstituted.sp_type)
+
property byte_width:
def __get__(self):
@@ -208,6 +253,13 @@ cdef class DecimalType(FixedSizeBinaryType):
DataType.init(self, type)
self.decimal_type = <const CDecimalType*> type.get()
+ def __getstate__(self):
+ return (self.precision, self.scale)
+
+ def __setstate__(self, state):
+ cdef DataType reconstituted = decimal(*state)
+ self.init(reconstituted.sp_type)
+
property precision:
def __get__(self):
@@ -242,6 +294,24 @@ cdef class Field:
"""
return self.field.Equals(deref(other.field))
+ def __richcmp__(Field self, Field other, int op):
+ if op == cp.Py_EQ:
+ return self.equals(other)
+ elif op == cp.Py_NE:
+ return not self.equals(other)
+ else:
+ raise TypeError('Invalid comparison')
+
+ def __reduce__(self):
+ return Field, (), self.__getstate__()
+
+ def __getstate__(self):
+ return (self.name, self.type, self.metadata)
+
+ def __setstate__(self, state):
+ cdef Field reconstituted = field(state[0], state[1], metadata=state[2])
+ self.init(reconstituted.sp_field)
+
def __str__(self):
self._check_null()
return 'pyarrow.Field<{0}>'.format(frombytes(self.field.ToString()))
@@ -354,6 +424,16 @@ cdef class Schema:
self.schema = schema.get()
self.sp_schema = schema
+ def __reduce__(self):
+ return Schema, (), self.__getstate__()
+
+ def __getstate__(self):
+ return ([self[i] for i in range(len(self))], self.metadata)
+
+ def __setstate__(self, state):
+ cdef Schema reconstituted = schema(state[0], metadata=state[1])
+ self.init_schema(reconstituted.sp_schema)
+
property names:
def __get__(self):
@@ -372,6 +452,14 @@ cdef class Schema:
self.schema.metadata())
return box_metadata(metadata.get())
+ def __richcmp__(self, other, int op):
+ if op == cp.Py_EQ:
+ return self.equals(other)
+ elif op == cp.Py_NE:
+ return not self.equals(other)
+ else:
+ raise TypeError('Invalid comparison')
+
def equals(self, other):
"""
Test if this schema is equal to the other
@@ -518,7 +606,7 @@ cdef int convert_metadata(dict metadata,
return 0
-def field(name, DataType type, bint nullable=True, dict metadata=None):
+def field(name, type, bint nullable=True, dict metadata=None):
"""
Create a pyarrow.Field instance
@@ -537,17 +625,29 @@ def field(name, DataType type, bint nullable=True, dict
metadata=None):
cdef:
shared_ptr[CKeyValueMetadata] c_meta
Field result = Field()
+ DataType _type
if metadata is not None:
convert_metadata(metadata, &c_meta)
- result.sp_field.reset(new CField(tobytes(name), type.sp_type,
+ _type = _as_type(type)
+
+ result.sp_field.reset(new CField(tobytes(name), _type.sp_type,
nullable == 1, c_meta))
result.field = result.sp_field.get()
- result.type = type
+ result.type = _type
return result
+cdef _as_type(type):
+ if isinstance(type, DataType):
+ return type
+ if not isinstance(type, six.string_types):
+ raise TypeError(type)
+ return type_for_alias(type)
+
+
+
cdef set PRIMITIVE_TYPES = set([
_Type_NA, _Type_BOOL,
_Type_UINT8, _Type_INT8,
@@ -970,6 +1070,8 @@ cdef dict _type_aliases = {
'binary': binary,
'date32': date32,
'date64': date64,
+ 'date32[day]': date32,
+ 'date64[ms]': date64,
'time32[s]': time32('s'),
'time32[ms]': time32('ms'),
'time64[us]': time64('us'),
@@ -1000,19 +1102,23 @@ def type_for_alias(name):
return alias()
-def schema(fields):
+def schema(fields, dict metadata=None):
"""
Construct pyarrow.Schema from collection of fields
Parameters
----------
field : list or iterable
+ metadata : dict, default None
+ Keys and values must be coercible to bytes
Returns
-------
schema : pyarrow.Schema
"""
cdef:
+ shared_ptr[CKeyValueMetadata] c_meta
+ shared_ptr[CSchema] c_schema
Schema result
Field field
vector[shared_ptr[CField]] c_fields
@@ -1020,8 +1126,12 @@ def schema(fields):
for i, field in enumerate(fields):
c_fields.push_back(field.sp_field)
+ if metadata is not None:
+ convert_metadata(metadata, &c_meta)
+
+ c_schema.reset(new CSchema(c_fields, c_meta))
result = Schema()
- result.init(c_fields)
+ result.init_schema(c_schema)
return result
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
> [Python] pa.DataType cannot be pickled
> --------------------------------------
>
> Key: ARROW-1654
> URL: https://issues.apache.org/jira/browse/ARROW-1654
> Project: Apache Arrow
> Issue Type: Improvement
> Reporter: Li Jin
> Assignee: Wes McKinney
> Labels: pull-request-available
> Fix For: 0.8.0
>
>
> In [26]: t
> Out[26]: DataType(int64)
> In [25]: pickle.dumps(t)
> ---------------------------------------------------------------------------
> TypeError Traceback (most recent call last)
> <ipython-input-25-f90063f6658b> in <module>()
> ----> 1 pickle.dumps(t)
> /home/icexelloss/miniconda3/envs/spark-dev/lib/python3.5/site-packages/pyarrow/lib.cpython-35m-x86_64-linux-gnu.so
> in pyarrow.lib.DataType.__reduce_cython__()
> TypeError: no default __reduce__ due to non-trivial __cinit__
> This is discovered when trying to send a pa.DataType along with a udf in
> pyspark. The workaround is to send pyspark DataType and convert to
> pa.DataType. It would be nice to able to pickle pa.DataType.
--
This message was sent by Atlassian JIRA
(v6.4.14#64029)