Repository: arrow
Updated Branches:
  refs/heads/master 9d12c7c92 -> 05788d035


ARROW-1701: [Serialization] Support zero copy PyTorch Tensor serialization

This also restructures the code such that it is easier to reset the default 
serialization context back to the initial state after more handlers have been 
registered.

Author: Philipp Moritz <pcmor...@gmail.com>

Closes #1223 from pcmoritz/torch-tensor and squashes the following commits:

db09afce [Philipp Moritz] fix test
ba0856c1 [Philipp Moritz] don't run pytorch test on python 2.7
264d1992 [Philipp Moritz] remove import that is not required
882d9a56 [Philipp Moritz] small restructuring and support all PyTorch tensor 
types
c6dac9e6 [Philipp Moritz] add -q flag
23de67b6 [Philipp Moritz] add -y to torch installation
9814897f [Philipp Moritz] test torch tensor conversion
375bbfa5 [Philipp Moritz] support serializing torch tensors


Project: http://git-wip-us.apache.org/repos/asf/arrow/repo
Commit: http://git-wip-us.apache.org/repos/asf/arrow/commit/05788d03
Tree: http://git-wip-us.apache.org/repos/asf/arrow/tree/05788d03
Diff: http://git-wip-us.apache.org/repos/asf/arrow/diff/05788d03

Branch: refs/heads/master
Commit: 05788d035f4aa918d80c9db7a1bf74fe38309c60
Parents: 9d12c7c
Author: Philipp Moritz <pcmor...@gmail.com>
Authored: Sat Oct 21 12:26:08 2017 -0400
Committer: Wes McKinney <wes.mckin...@twosigma.com>
Committed: Sat Oct 21 12:26:08 2017 -0400

----------------------------------------------------------------------
 ci/travis_script_python.sh                 |   6 +
 python/pyarrow/__init__.py                 |   3 +-
 python/pyarrow/serialization.py            | 165 ++++++++++++++----------
 python/pyarrow/tests/test_serialization.py |  19 +++
 4 files changed, 124 insertions(+), 69 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/arrow/blob/05788d03/ci/travis_script_python.sh
----------------------------------------------------------------------
diff --git a/ci/travis_script_python.sh b/ci/travis_script_python.sh
index 97bde1a..603201b 100755
--- a/ci/travis_script_python.sh
+++ b/ci/travis_script_python.sh
@@ -46,6 +46,12 @@ conda install -y -q pip \
       sphinx \
       sphinx_bootstrap_theme
 
+if [ "$PYTHON_VERSION" != "2.7" ] || [ $TRAVIS_OS_NAME != "osx" ]; then
+  # Install pytorch for torch tensor conversion tests
+  # PyTorch seems to be broken on Python 2.7 on macOS so we skip it
+  conda install -y -q pytorch torchvision -c soumith
+fi
+
 # Build C++ libraries
 pushd $ARROW_CPP_BUILD_DIR
 

http://git-wip-us.apache.org/repos/asf/arrow/blob/05788d03/python/pyarrow/__init__.py
----------------------------------------------------------------------
diff --git a/python/pyarrow/__init__.py b/python/pyarrow/__init__.py
index 0a1575f..ffc833a 100644
--- a/python/pyarrow/__init__.py
+++ b/python/pyarrow/__init__.py
@@ -120,7 +120,8 @@ from pyarrow.ipc import (Message, MessageReader,
 
 localfs = LocalFileSystem.get_instance()
 
-from pyarrow.serialization import _default_serialization_context
+from pyarrow.serialization import (_default_serialization_context,
+                                   register_default_serialization_handlers)
 
 import pyarrow.types as types
 

http://git-wip-us.apache.org/repos/asf/arrow/blob/05788d03/python/pyarrow/serialization.py
----------------------------------------------------------------------
diff --git a/python/pyarrow/serialization.py b/python/pyarrow/serialization.py
index d08ae89..248b51c 100644
--- a/python/pyarrow/serialization.py
+++ b/python/pyarrow/serialization.py
@@ -20,107 +20,136 @@ import sys
 
 import numpy as np
 
+import pyarrow as pa
 from pyarrow import serialize_pandas, deserialize_pandas
 from pyarrow.lib import _default_serialization_context
 
-# ----------------------------------------------------------------------
-# Set up serialization for primitive datatypes
+def register_default_serialization_handlers(serialization_context):
 
-# TODO(pcm): This is currently a workaround until arrow supports
-# arbitrary precision integers. This is only called on long integers,
-# see the associated case in the append method in python_to_arrow.cc
-_default_serialization_context.register_type(
-    int, "int",
-    custom_serializer=lambda obj: str(obj),
-    custom_deserializer=lambda data: int(data))
+    # ----------------------------------------------------------------------
+    # Set up serialization for primitive datatypes
 
-if (sys.version_info < (3, 0)):
-    _default_serialization_context.register_type(
-        long, "long",  # noqa: F821
+    # TODO(pcm): This is currently a workaround until arrow supports
+    # arbitrary precision integers. This is only called on long integers,
+    # see the associated case in the append method in python_to_arrow.cc
+    serialization_context.register_type(
+        int, "int",
         custom_serializer=lambda obj: str(obj),
-        custom_deserializer=lambda data: long(data))  # noqa: F821
+        custom_deserializer=lambda data: int(data))
 
+    if (sys.version_info < (3, 0)):
+        serialization_context.register_type(
+            long, "long",  # noqa: F821
+            custom_serializer=lambda obj: str(obj),
+            custom_deserializer=lambda data: long(data))  # noqa: F821
 
-def _serialize_ordered_dict(obj):
-    return list(obj.keys()), list(obj.values())
 
+    def _serialize_ordered_dict(obj):
+        return list(obj.keys()), list(obj.values())
 
-def _deserialize_ordered_dict(data):
-    return OrderedDict(zip(data[0], data[1]))
 
+    def _deserialize_ordered_dict(data):
+        return OrderedDict(zip(data[0], data[1]))
 
-_default_serialization_context.register_type(
-    OrderedDict, "OrderedDict",
-    custom_serializer=_serialize_ordered_dict,
-    custom_deserializer=_deserialize_ordered_dict)
 
+    serialization_context.register_type(
+        OrderedDict, "OrderedDict",
+        custom_serializer=_serialize_ordered_dict,
+        custom_deserializer=_deserialize_ordered_dict)
 
-def _serialize_default_dict(obj):
-    return list(obj.keys()), list(obj.values()), obj.default_factory
 
+    def _serialize_default_dict(obj):
+        return list(obj.keys()), list(obj.values()), obj.default_factory
 
-def _deserialize_default_dict(data):
-    return defaultdict(data[2], zip(data[0], data[1]))
 
+    def _deserialize_default_dict(data):
+        return defaultdict(data[2], zip(data[0], data[1]))
 
-_default_serialization_context.register_type(
-     defaultdict, "defaultdict",
-     custom_serializer=_serialize_default_dict,
-     custom_deserializer=_deserialize_default_dict)
 
+    serialization_context.register_type(
+        defaultdict, "defaultdict",
+        custom_serializer=_serialize_default_dict,
+        custom_deserializer=_deserialize_default_dict)
 
-_default_serialization_context.register_type(
-     type(lambda: 0), "function",
-     pickle=True)
 
-# ----------------------------------------------------------------------
-# Set up serialization for numpy with dtype object (primitive types are
-# handled efficiently with Arrow's Tensor facilities, see python_to_arrow.cc)
+    serialization_context.register_type(
+        type(lambda: 0), "function",
+        pickle=True)
 
+    # ----------------------------------------------------------------------
+    # Set up serialization for numpy with dtype object (primitive types are
+    # handled efficiently with Arrow's Tensor facilities, see 
python_to_arrow.cc)
 
-def _serialize_numpy_array(obj):
-    return obj.tolist(), obj.dtype.str
 
+    def _serialize_numpy_array(obj):
+        return obj.tolist(), obj.dtype.str
 
-def _deserialize_numpy_array(data):
-    return np.array(data[0], dtype=np.dtype(data[1]))
 
+    def _deserialize_numpy_array(data):
+        return np.array(data[0], dtype=np.dtype(data[1]))
 
-_default_serialization_context.register_type(
-    np.ndarray, 'np.array',
-    custom_serializer=_serialize_numpy_array,
-    custom_deserializer=_deserialize_numpy_array)
 
+    serialization_context.register_type(
+        np.ndarray, 'np.array',
+        custom_serializer=_serialize_numpy_array,
+        custom_deserializer=_deserialize_numpy_array)
 
-# ----------------------------------------------------------------------
-# Set up serialization for pandas Series and DataFrame
 
-try:
-    import pandas as pd
+    # ----------------------------------------------------------------------
+    # Set up serialization for pandas Series and DataFrame
 
-    def _serialize_pandas_series(obj):
-        # TODO: serializing Series without extra copy
-        return serialize_pandas(pd.DataFrame({obj.name: obj})).to_pybytes()
+    try:
+        import pandas as pd
 
-    def _deserialize_pandas_series(data):
-        deserialized = deserialize_pandas(data)
-        return deserialized[deserialized.columns[0]]
+        def _serialize_pandas_series(obj):
+            # TODO: serializing Series without extra copy
+            return serialize_pandas(pd.DataFrame({obj.name: obj})).to_pybytes()
 
-    def _serialize_pandas_dataframe(obj):
-        return serialize_pandas(obj).to_pybytes()
+        def _deserialize_pandas_series(data):
+            deserialized = deserialize_pandas(data)
+            return deserialized[deserialized.columns[0]]
 
-    def _deserialize_pandas_dataframe(data):
-        return deserialize_pandas(data)
+        def _serialize_pandas_dataframe(obj):
+            return serialize_pandas(obj).to_pybytes()
 
-    _default_serialization_context.register_type(
-        pd.Series, 'pd.Series',
-        custom_serializer=_serialize_pandas_series,
-        custom_deserializer=_deserialize_pandas_series)
+        def _deserialize_pandas_dataframe(data):
+            return deserialize_pandas(data)
 
-    _default_serialization_context.register_type(
-        pd.DataFrame, 'pd.DataFrame',
-        custom_serializer=_serialize_pandas_dataframe,
-        custom_deserializer=_deserialize_pandas_dataframe)
-except ImportError:
-    # no pandas
-    pass
+        serialization_context.register_type(
+            pd.Series, 'pd.Series',
+            custom_serializer=_serialize_pandas_series,
+            custom_deserializer=_deserialize_pandas_series)
+
+        serialization_context.register_type(
+            pd.DataFrame, 'pd.DataFrame',
+            custom_serializer=_serialize_pandas_dataframe,
+            custom_deserializer=_deserialize_pandas_dataframe)
+    except ImportError:
+        # no pandas
+        pass
+
+    # ----------------------------------------------------------------------
+    # Set up serialization for pytorch tensors
+
+    try:
+        import torch
+
+        def _serialize_torch_tensor(obj):
+            return obj.numpy()
+
+        def _deserialize_torch_tensor(data):
+            return torch.from_numpy(data)
+
+        for t in [torch.FloatTensor, torch.DoubleTensor, torch.HalfTensor,
+                  torch.ByteTensor, torch.CharTensor, torch.ShortTensor,
+                  torch.IntTensor, torch.LongTensor]:
+            serialization_context.register_type(
+                t, "torch." + t.__name__,
+                custom_serializer=_serialize_torch_tensor,
+                custom_deserializer=_deserialize_torch_tensor)
+    except ImportError:
+        # no torch
+        pass
+
+
+register_default_serialization_handlers(_default_serialization_context)

http://git-wip-us.apache.org/repos/asf/arrow/blob/05788d03/python/pyarrow/tests/test_serialization.py
----------------------------------------------------------------------
diff --git a/python/pyarrow/tests/test_serialization.py 
b/python/pyarrow/tests/test_serialization.py
index fea7cea..3932948 100644
--- a/python/pyarrow/tests/test_serialization.py
+++ b/python/pyarrow/tests/test_serialization.py
@@ -29,6 +29,13 @@ import numpy as np
 
 
 def assert_equal(obj1, obj2):
+    try:
+        import torch
+        if torch.is_tensor(obj1) and torch.is_tensor(obj2):
+            assert torch.equal(obj1, obj2)
+            return
+    except ImportError:
+        pass
     module_numpy = (type(obj1).__module__ == np.__name__ or
                     type(obj2).__module__ == np.__name__)
     if module_numpy:
@@ -57,6 +64,8 @@ def assert_equal(obj1, obj2):
                 return
         except:
             pass
+        if obj1.__dict__ == {}:
+            print("WARNING: Empty dict in ", obj1)
         for key in obj1.__dict__.keys():
             if key not in special_keys:
                 assert_equal(obj1.__dict__[key], obj2.__dict__[key])
@@ -285,6 +294,16 @@ def test_datetime_serialization(large_memory_map):
         for d in data:
             serialization_roundtrip(d, mmap)
 
+def test_torch_serialization(large_memory_map):
+    pytest.importorskip("torch")
+    import torch
+    with pa.memory_map(large_memory_map, mode="r+") as mmap:
+        # These are the only types that are supported for the
+        # PyTorch to NumPy conversion
+        for t in ["float32", "float64",
+                  "uint8", "int16", "int32", "int64"]:
+            obj = torch.from_numpy(np.random.randn(1000).astype(t))
+            serialization_roundtrip(obj, mmap)
 
 def test_numpy_immutable(large_memory_map):
     with pa.memory_map(large_memory_map, mode="r+") as mmap:

Reply via email to