This is an automated email from the ASF dual-hosted git repository.

wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 57e4dd8  ARROW-2265: [Python] Use CheckExact when serializing lists 
and numpy arrays.
57e4dd8 is described below

commit 57e4dd8bfb756bb3db7676c37fc05d16260bd82d
Author: Robert Nishihara <[email protected]>
AuthorDate: Mon Mar 5 21:11:32 2018 -0500

    ARROW-2265: [Python] Use CheckExact when serializing lists and numpy arrays.
    
    cc @mitar
    
    Author: Robert Nishihara <[email protected]>
    
    Closes #1704 from robertnishihara/subclassingnparray and squashes the 
following commits:
    
    459e1d95 <Robert Nishihara> Remove PySet_CheckExact because it does not 
exist.
    8b2c7dcd <Robert Nishihara> Use CheckExact when serializing lists, sets, 
numpy arrays.
---
 cpp/src/arrow/python/python_to_arrow.cc    |  4 ++--
 python/pyarrow/tests/test_serialization.py | 27 +++++++++++++++++++++++++++
 2 files changed, 29 insertions(+), 2 deletions(-)

diff --git a/cpp/src/arrow/python/python_to_arrow.cc 
b/cpp/src/arrow/python/python_to_arrow.cc
index 6d4f646..d781d9f 100644
--- a/cpp/src/arrow/python/python_to_arrow.cc
+++ b/cpp/src/arrow/python/python_to_arrow.cc
@@ -501,7 +501,7 @@ Status Append(PyObject* context, PyObject* elem, 
SequenceBuilder* builder,
       return Status::Invalid("Cannot writes bytes over 2GB");
     }
     RETURN_NOT_OK(builder->AppendString(data, static_cast<int32_t>(size)));
-  } else if (PyList_Check(elem)) {
+  } else if (PyList_CheckExact(elem)) {
     RETURN_NOT_OK(builder->AppendList(PyList_Size(elem)));
     sublists->push_back(elem);
   } else if (PyDict_CheckExact(elem)) {
@@ -515,7 +515,7 @@ Status Append(PyObject* context, PyObject* elem, 
SequenceBuilder* builder,
     subsets->push_back(elem);
   } else if (PyArray_IsScalar(elem, Generic)) {
     RETURN_NOT_OK(AppendScalar(elem, builder));
-  } else if (PyArray_Check(elem)) {
+  } else if (PyArray_CheckExact(elem)) {
     RETURN_NOT_OK(SerializeArray(context, 
reinterpret_cast<PyArrayObject*>(elem), builder,
                                  subdicts, blobs_out));
   } else if (elem == Py_None) {
diff --git a/python/pyarrow/tests/test_serialization.py 
b/python/pyarrow/tests/test_serialization.py
index 72315d2..c174084 100644
--- a/python/pyarrow/tests/test_serialization.py
+++ b/python/pyarrow/tests/test_serialization.py
@@ -410,6 +410,33 @@ def test_serialization_callback_numpy():
     pa.serialize(DummyClass(), context=context)
 
 
+def test_numpy_subclass_serialization():
+    # Check that we can properly serialize subclasses of np.ndarray.
+    class CustomNDArray(np.ndarray):
+        def __new__(cls, input_array):
+            array = np.asarray(input_array).view(cls)
+            return array
+
+    def serializer(obj):
+        return {'numpy': obj.view(np.ndarray)}
+
+    def deserializer(data):
+        array = data['numpy'].view(CustomNDArray)
+        return array
+
+    context = pa.default_serialization_context()
+
+    context.register_type(CustomNDArray, 'CustomNDArray',
+                          custom_serializer=serializer,
+                          custom_deserializer=deserializer)
+
+    x = CustomNDArray(np.zeros(3))
+    serialized = pa.serialize(x, context=context).to_buffer()
+    new_x = pa.deserialize(serialized, context=context)
+    assert type(new_x) == CustomNDArray
+    assert np.alltrue(new_x.view(np.ndarray) == np.zeros(3))
+
+
 def test_buffer_serialization():
 
     class BufferClass(object):

-- 
To stop receiving notification emails like this one, please contact
[email protected].

Reply via email to