Repository: arrow Updated Branches: refs/heads/master 9d12c7c92 -> 05788d035
ARROW-1701: [Serialization] Support zero copy PyTorch Tensor serialization This also restructures the code such that it is easier to reset the default serialization context back to the initial state after more handlers have been registered. Author: Philipp Moritz <pcmor...@gmail.com> Closes #1223 from pcmoritz/torch-tensor and squashes the following commits: db09afce [Philipp Moritz] fix test ba0856c1 [Philipp Moritz] don't run pytorch test on python 2.7 264d1992 [Philipp Moritz] remove import that is not required 882d9a56 [Philipp Moritz] small restructuring and support all PyTorch tensor types c6dac9e6 [Philipp Moritz] add -q flag 23de67b6 [Philipp Moritz] add -y to torch installation 9814897f [Philipp Moritz] test torch tensor conversion 375bbfa5 [Philipp Moritz] support serializing torch tensors Project: http://git-wip-us.apache.org/repos/asf/arrow/repo Commit: http://git-wip-us.apache.org/repos/asf/arrow/commit/05788d03 Tree: http://git-wip-us.apache.org/repos/asf/arrow/tree/05788d03 Diff: http://git-wip-us.apache.org/repos/asf/arrow/diff/05788d03 Branch: refs/heads/master Commit: 05788d035f4aa918d80c9db7a1bf74fe38309c60 Parents: 9d12c7c Author: Philipp Moritz <pcmor...@gmail.com> Authored: Sat Oct 21 12:26:08 2017 -0400 Committer: Wes McKinney <wes.mckin...@twosigma.com> Committed: Sat Oct 21 12:26:08 2017 -0400 ---------------------------------------------------------------------- ci/travis_script_python.sh | 6 + python/pyarrow/__init__.py | 3 +- python/pyarrow/serialization.py | 165 ++++++++++++++---------- python/pyarrow/tests/test_serialization.py | 19 +++ 4 files changed, 124 insertions(+), 69 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/arrow/blob/05788d03/ci/travis_script_python.sh ---------------------------------------------------------------------- diff --git a/ci/travis_script_python.sh b/ci/travis_script_python.sh index 97bde1a..603201b 100755 --- a/ci/travis_script_python.sh +++ b/ci/travis_script_python.sh @@ -46,6 +46,12 @@ conda install -y -q pip \ sphinx \ sphinx_bootstrap_theme +if [ "$PYTHON_VERSION" != "2.7" ] || [ $TRAVIS_OS_NAME != "osx" ]; then + # Install pytorch for torch tensor conversion tests + # PyTorch seems to be broken on Python 2.7 on macOS so we skip it + conda install -y -q pytorch torchvision -c soumith +fi + # Build C++ libraries pushd $ARROW_CPP_BUILD_DIR http://git-wip-us.apache.org/repos/asf/arrow/blob/05788d03/python/pyarrow/__init__.py ---------------------------------------------------------------------- diff --git a/python/pyarrow/__init__.py b/python/pyarrow/__init__.py index 0a1575f..ffc833a 100644 --- a/python/pyarrow/__init__.py +++ b/python/pyarrow/__init__.py @@ -120,7 +120,8 @@ from pyarrow.ipc import (Message, MessageReader, localfs = LocalFileSystem.get_instance() -from pyarrow.serialization import _default_serialization_context +from pyarrow.serialization import (_default_serialization_context, + register_default_serialization_handlers) import pyarrow.types as types http://git-wip-us.apache.org/repos/asf/arrow/blob/05788d03/python/pyarrow/serialization.py ---------------------------------------------------------------------- diff --git a/python/pyarrow/serialization.py b/python/pyarrow/serialization.py index d08ae89..248b51c 100644 --- a/python/pyarrow/serialization.py +++ b/python/pyarrow/serialization.py @@ -20,107 +20,136 @@ import sys import numpy as np +import pyarrow as pa from pyarrow import serialize_pandas, deserialize_pandas from pyarrow.lib import _default_serialization_context -# ---------------------------------------------------------------------- -# Set up serialization for primitive datatypes +def register_default_serialization_handlers(serialization_context): -# TODO(pcm): This is currently a workaround until arrow supports -# arbitrary precision integers. This is only called on long integers, -# see the associated case in the append method in python_to_arrow.cc -_default_serialization_context.register_type( - int, "int", - custom_serializer=lambda obj: str(obj), - custom_deserializer=lambda data: int(data)) + # ---------------------------------------------------------------------- + # Set up serialization for primitive datatypes -if (sys.version_info < (3, 0)): - _default_serialization_context.register_type( - long, "long", # noqa: F821 + # TODO(pcm): This is currently a workaround until arrow supports + # arbitrary precision integers. This is only called on long integers, + # see the associated case in the append method in python_to_arrow.cc + serialization_context.register_type( + int, "int", custom_serializer=lambda obj: str(obj), - custom_deserializer=lambda data: long(data)) # noqa: F821 + custom_deserializer=lambda data: int(data)) + if (sys.version_info < (3, 0)): + serialization_context.register_type( + long, "long", # noqa: F821 + custom_serializer=lambda obj: str(obj), + custom_deserializer=lambda data: long(data)) # noqa: F821 -def _serialize_ordered_dict(obj): - return list(obj.keys()), list(obj.values()) + def _serialize_ordered_dict(obj): + return list(obj.keys()), list(obj.values()) -def _deserialize_ordered_dict(data): - return OrderedDict(zip(data[0], data[1])) + def _deserialize_ordered_dict(data): + return OrderedDict(zip(data[0], data[1])) -_default_serialization_context.register_type( - OrderedDict, "OrderedDict", - custom_serializer=_serialize_ordered_dict, - custom_deserializer=_deserialize_ordered_dict) + serialization_context.register_type( + OrderedDict, "OrderedDict", + custom_serializer=_serialize_ordered_dict, + custom_deserializer=_deserialize_ordered_dict) -def _serialize_default_dict(obj): - return list(obj.keys()), list(obj.values()), obj.default_factory + def _serialize_default_dict(obj): + return list(obj.keys()), list(obj.values()), obj.default_factory -def _deserialize_default_dict(data): - return defaultdict(data[2], zip(data[0], data[1])) + def _deserialize_default_dict(data): + return defaultdict(data[2], zip(data[0], data[1])) -_default_serialization_context.register_type( - defaultdict, "defaultdict", - custom_serializer=_serialize_default_dict, - custom_deserializer=_deserialize_default_dict) + serialization_context.register_type( + defaultdict, "defaultdict", + custom_serializer=_serialize_default_dict, + custom_deserializer=_deserialize_default_dict) -_default_serialization_context.register_type( - type(lambda: 0), "function", - pickle=True) -# ---------------------------------------------------------------------- -# Set up serialization for numpy with dtype object (primitive types are -# handled efficiently with Arrow's Tensor facilities, see python_to_arrow.cc) + serialization_context.register_type( + type(lambda: 0), "function", + pickle=True) + # ---------------------------------------------------------------------- + # Set up serialization for numpy with dtype object (primitive types are + # handled efficiently with Arrow's Tensor facilities, see python_to_arrow.cc) -def _serialize_numpy_array(obj): - return obj.tolist(), obj.dtype.str + def _serialize_numpy_array(obj): + return obj.tolist(), obj.dtype.str -def _deserialize_numpy_array(data): - return np.array(data[0], dtype=np.dtype(data[1])) + def _deserialize_numpy_array(data): + return np.array(data[0], dtype=np.dtype(data[1])) -_default_serialization_context.register_type( - np.ndarray, 'np.array', - custom_serializer=_serialize_numpy_array, - custom_deserializer=_deserialize_numpy_array) + serialization_context.register_type( + np.ndarray, 'np.array', + custom_serializer=_serialize_numpy_array, + custom_deserializer=_deserialize_numpy_array) -# ---------------------------------------------------------------------- -# Set up serialization for pandas Series and DataFrame -try: - import pandas as pd + # ---------------------------------------------------------------------- + # Set up serialization for pandas Series and DataFrame - def _serialize_pandas_series(obj): - # TODO: serializing Series without extra copy - return serialize_pandas(pd.DataFrame({obj.name: obj})).to_pybytes() + try: + import pandas as pd - def _deserialize_pandas_series(data): - deserialized = deserialize_pandas(data) - return deserialized[deserialized.columns[0]] + def _serialize_pandas_series(obj): + # TODO: serializing Series without extra copy + return serialize_pandas(pd.DataFrame({obj.name: obj})).to_pybytes() - def _serialize_pandas_dataframe(obj): - return serialize_pandas(obj).to_pybytes() + def _deserialize_pandas_series(data): + deserialized = deserialize_pandas(data) + return deserialized[deserialized.columns[0]] - def _deserialize_pandas_dataframe(data): - return deserialize_pandas(data) + def _serialize_pandas_dataframe(obj): + return serialize_pandas(obj).to_pybytes() - _default_serialization_context.register_type( - pd.Series, 'pd.Series', - custom_serializer=_serialize_pandas_series, - custom_deserializer=_deserialize_pandas_series) + def _deserialize_pandas_dataframe(data): + return deserialize_pandas(data) - _default_serialization_context.register_type( - pd.DataFrame, 'pd.DataFrame', - custom_serializer=_serialize_pandas_dataframe, - custom_deserializer=_deserialize_pandas_dataframe) -except ImportError: - # no pandas - pass + serialization_context.register_type( + pd.Series, 'pd.Series', + custom_serializer=_serialize_pandas_series, + custom_deserializer=_deserialize_pandas_series) + + serialization_context.register_type( + pd.DataFrame, 'pd.DataFrame', + custom_serializer=_serialize_pandas_dataframe, + custom_deserializer=_deserialize_pandas_dataframe) + except ImportError: + # no pandas + pass + + # ---------------------------------------------------------------------- + # Set up serialization for pytorch tensors + + try: + import torch + + def _serialize_torch_tensor(obj): + return obj.numpy() + + def _deserialize_torch_tensor(data): + return torch.from_numpy(data) + + for t in [torch.FloatTensor, torch.DoubleTensor, torch.HalfTensor, + torch.ByteTensor, torch.CharTensor, torch.ShortTensor, + torch.IntTensor, torch.LongTensor]: + serialization_context.register_type( + t, "torch." + t.__name__, + custom_serializer=_serialize_torch_tensor, + custom_deserializer=_deserialize_torch_tensor) + except ImportError: + # no torch + pass + + +register_default_serialization_handlers(_default_serialization_context) http://git-wip-us.apache.org/repos/asf/arrow/blob/05788d03/python/pyarrow/tests/test_serialization.py ---------------------------------------------------------------------- diff --git a/python/pyarrow/tests/test_serialization.py b/python/pyarrow/tests/test_serialization.py index fea7cea..3932948 100644 --- a/python/pyarrow/tests/test_serialization.py +++ b/python/pyarrow/tests/test_serialization.py @@ -29,6 +29,13 @@ import numpy as np def assert_equal(obj1, obj2): + try: + import torch + if torch.is_tensor(obj1) and torch.is_tensor(obj2): + assert torch.equal(obj1, obj2) + return + except ImportError: + pass module_numpy = (type(obj1).__module__ == np.__name__ or type(obj2).__module__ == np.__name__) if module_numpy: @@ -57,6 +64,8 @@ def assert_equal(obj1, obj2): return except: pass + if obj1.__dict__ == {}: + print("WARNING: Empty dict in ", obj1) for key in obj1.__dict__.keys(): if key not in special_keys: assert_equal(obj1.__dict__[key], obj2.__dict__[key]) @@ -285,6 +294,16 @@ def test_datetime_serialization(large_memory_map): for d in data: serialization_roundtrip(d, mmap) +def test_torch_serialization(large_memory_map): + pytest.importorskip("torch") + import torch + with pa.memory_map(large_memory_map, mode="r+") as mmap: + # These are the only types that are supported for the + # PyTorch to NumPy conversion + for t in ["float32", "float64", + "uint8", "int16", "int32", "int64"]: + obj = torch.from_numpy(np.random.randn(1000).astype(t)) + serialization_roundtrip(obj, mmap) def test_numpy_immutable(large_memory_map): with pa.memory_map(large_memory_map, mode="r+") as mmap: