This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 9f14a68  [MXNET-779]Add DLPack Transformation API (#12047)
9f14a68 is described below

commit 9f14a68045a07d95cb8c86c3c614c84261507755
Author: JackieWu <[email protected]>
AuthorDate: Sat Sep 22 11:33:33 2018 +0800

    [MXNET-779]Add DLPack Transformation API (#12047)
---
 include/mxnet/c_api.h                 |  36 +++++++
 include/mxnet/ndarray.h               |  40 ++++++++
 include/mxnet/tensor_blob.h           |  65 +++++++++++-
 python/mxnet/base.py                  |   4 +
 python/mxnet/ndarray/ndarray.py       | 179 +++++++++++++++++++++++++++++++++-
 src/c_api/c_api.cc                    |  25 +++++
 src/ndarray/ndarray.cc                |  28 ++++++
 tests/python/unittest/test_ndarray.py |  31 ++++++
 8 files changed, 403 insertions(+), 5 deletions(-)

diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 0043996..a01cc6a 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -93,6 +93,8 @@ typedef void *CudaModuleHandle;
 typedef void *CudaKernelHandle;
 /*! \brief handle to a Profile object (domain, duration, counter, etc.) */
 typedef void *ProfileHandle;
+/*! \brief handle to DLManagedTensor*/
+typedef void *DLManagedTensorHandle;
 
 typedef void (*ExecutorMonitorCallback)(const char*,
                                         NDArrayHandle,
@@ -747,6 +749,40 @@ MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle,
 MXNET_DLL int MXNDArrayGetData(NDArrayHandle handle,
                                void **out_pdata);
 /*!
+* \brief Create a reference view of NDArray that
+*  represents as DLManagedTensor
+*  Notice: MXNet uses asynchronous execution. Please call MXNDArrayWaitToRead 
or
+*          MXNDArrayWaitToWrite before calling MXNDArrayToDLPack.
+* \param handle the handle to the ndarray
+* \param out_dlpack pointer holder to get pointer of DLManagedTensor
+* \return 0 when success, -1 when failure happens
+*/
+MXNET_DLL int MXNDArrayToDLPack(NDArrayHandle handle,
+                                       DLManagedTensorHandle *out_dlpack);
+
+/*!
+* \brief Create a NDArray backed by a dlpack tensor.
+*
+* This allows us to create a NDArray using the memory
+* allocated by an external deep learning framework
+* that is DLPack compatible.
+*
+* The memory is retained until the NDArray went out of scope.
+*
+* \param dlpack the pointer of the input DLManagedTensor
+* \param out_handle pointer holder to get pointer of NDArray
+* \return 0 when success, -1 when failure happens
+*/
+MXNET_DLL int MXNDArrayFromDLPack(DLManagedTensorHandle dlpack,
+                                  NDArrayHandle *out_handle);
+/*!
+ * \brief Delete a dlpack tensor
+ * \param dlpack the pointer of the input DLManagedTensor
+ * \return 0 when success, -1 when failure happens
+ */
+MXNET_DLL int MXNDArrayCallDLPackDeleter(DLManagedTensorHandle dlpack);
+
+/*!
  * \brief get the type of the data in NDArray
  * \param handle the handle to the narray
  * \param out_dtype pointer holder to get type of data
diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h
index 6141a4d..afae5dc 100644
--- a/include/mxnet/ndarray.h
+++ b/include/mxnet/ndarray.h
@@ -116,6 +116,26 @@ class NDArray {
         dtype_(data.type_flag_), storage_type_(kDefaultStorage),
         entry_({nullptr, 0, 0}) {
   }
+
+  /*!
+   * \brief constructing a static NDArray that shares data with TBlob which is 
with deleter
+   *  Use with caution: allocate ONLY ONE NDArray for each TBlob,
+   *  make sure the memory region is available through out the life of NDArray
+   * \param data the memory content of static data
+   * \param dev_id the device id this tensor sits at
+   * \param deleter the function pointer of custom deleter
+   */
+  NDArray(const TBlob &data, int dev_id, const std::function<void()>& deleter)
+      : ptr_(new Chunk(data, dev_id),
+        [deleter](Chunk *p) {
+          deleter();    // call custom deleter
+          delete p;     // delete Chunk object
+        }),
+        shape_(data.shape_),
+        dtype_(data.type_flag_), storage_type_(kDefaultStorage),
+        entry_({nullptr, 0, 0}) {
+  }
+
   /*! \brief create ndarray from shared memory */
   NDArray(int shared_pid, int shared_id, const TShape& shape, int dtype)
       : ptr_(std::make_shared<Chunk>(shared_pid, shared_id, shape, dtype)), 
shape_(shape),
@@ -524,6 +544,26 @@ class NDArray {
   }
 
   /*!
+   * \brief Create a reference view of NDArray that
+   *  represents as DLManagedTensor.
+   * \return A DLManagedTensor
+   */
+  DLManagedTensor* ToDLPack() const;
+
+  /*!
+   * \brief Create a NDArray backed by a dlpack tensor.
+   *
+   * This allows us to create a NDArray using the memory
+   * allocated by an external deep learning framework
+   * that is DLPack compatible.
+   *
+   * The memory is retained until the NDArray went out of scope.
+   *
+   * \return The created NDArray view.
+   */
+  static NDArray FromDLPack(const DLManagedTensor* tensor);
+
+  /*!
    * \brief Update ndarray chunk storage handles using existing ndarray 
storage handles
    * Also update the aux_handle, aux_shapes and aux_types.
    * This is specifically used for custom op to update the inputs and outputs 
from
diff --git a/include/mxnet/tensor_blob.h b/include/mxnet/tensor_blob.h
index 6f604a5..496e8c7 100755
--- a/include/mxnet/tensor_blob.h
+++ b/include/mxnet/tensor_blob.h
@@ -105,6 +105,39 @@ class TBlob {
     SetDLTensor(dev_mask, dev_id);
   }
   /*!
+   * \brief constructor that construct TBlob from DLTensor
+   * \param DLTensor Object
+   */
+  explicit TBlob(const DLTensor &dltensor)
+      : dptr_(dltensor.data),
+        shape_(TShape(dltensor.shape, dltensor.shape + dltensor.ndim)),
+        type_flag_(DLDataTypeTransform(dltensor.dtype)),
+        dltensor_(dltensor) {
+    // compactness check for DLTensor
+    if (dltensor.strides != nullptr) {
+      // check strides
+      const int &ndim = dltensor.ndim;
+      const int64_t *shape = dltensor.shape;
+      const int64_t *strides = dltensor.strides;
+      if (ndim >= 1) {
+        bool err = false;
+        if (strides[ndim - 1] != 1) {
+          err = true;
+        } else {
+          for (int i = ndim - 2; i >= 0; --i) {
+            if (strides[i] != shape[i + 1] * strides[i + 1]) {
+              err = true;
+              break;
+            }
+          }
+        }
+        if (err) {
+          LOG(FATAL) << "Unsupported DLPack because MXNet only support compact 
tensor now";
+        }
+      }
+    }
+  }
+  /*!
    * \brief constructor from tensor
    * \param src source tensor
    * \tparam Device which device the tensor is on
@@ -336,6 +369,36 @@ class TBlob {
       }
     }
   }
+  static int DLDataTypeTransform(DLDataType dldata_type) {
+    if (dldata_type.lanes != 1) {
+      LOG(FATAL) << "Unsupported DLDataType whose lanes != 1";
+    }
+    switch (dldata_type.code) {
+      case kDLFloat:
+        switch (dldata_type.bits) {
+          case 16: return mshadow::kFloat16;
+          case 32: return mshadow::kFloat32;
+          case 64: return mshadow::kFloat64;
+        }
+        break;
+      case kDLUInt:
+        switch (dldata_type.bits) {
+          case 8: return mshadow::kUint8;
+        }
+        break;
+      case kDLInt:
+        switch (dldata_type.bits) {
+          case 8: return mshadow::kInt8;
+          case 32: return mshadow::kInt32;
+          case 64: return mshadow::kInt64;
+        }
+        break;
+    }
+    LOG(FATAL) << "Unknown DLDataType{" << dldata_type.code
+               << ", " << dldata_type.bits
+               << ", " << dldata_type.lanes << "}";
+    return mshadow::kFloat32;
+  }
 
   inline void SetDLTensor(int dev_mask, int dev_id) {
     dltensor_.data = dptr_;
@@ -343,7 +406,7 @@ class TBlob {
     dltensor_.ndim = shape_.ndim();
     dltensor_.dtype = DTypeTransform(type_flag_);
     dltensor_.shape = shape_.data();
-    dltensor_.strides = NULL;
+    dltensor_.strides = nullptr;
     dltensor_.byte_offset = 0;
   }
 
diff --git a/python/mxnet/base.py b/python/mxnet/base.py
index 89e1c9e..84b9e58 100644
--- a/python/mxnet/base.py
+++ b/python/mxnet/base.py
@@ -232,6 +232,7 @@ RtcHandle = ctypes.c_void_p
 CudaModuleHandle = ctypes.c_void_p
 CudaKernelHandle = ctypes.c_void_p
 ProfileHandle = ctypes.c_void_p
+DLPackHandle = ctypes.c_void_p
 
 
 #----------------------------
@@ -726,3 +727,6 @@ def _generate_op_module_signature(root_namespace, 
module_name, op_code_gen_func)
     module_op_file.close()
     write_all_str(module_internal_file, module_internal_all)
     module_internal_file.close()
+
+ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object
+ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index 93f2bc4..de2ad69 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -34,8 +34,8 @@ import operator
 from functools import reduce # pylint: disable=redefined-builtin
 import numpy as np
 from ..base import _LIB, numeric_types, integer_types
-from ..base import c_array, c_array_buf, c_handle_array, mx_real_t
-from ..base import mx_uint, NDArrayHandle, check_call
+from ..base import c_str, c_array, c_array_buf, c_handle_array, mx_real_t
+from ..base import mx_uint, NDArrayHandle, check_call, DLPackHandle
 from ..base import ctypes2buffer
 from ..context import Context, current_context
 from . import _internal
@@ -46,7 +46,8 @@ __all__ = ["NDArray", "concatenate", "_DTYPE_NP_TO_MX", 
"_DTYPE_MX_TO_NP", "_GRA
            "ones", "add", "arange", "eye", "divide", "equal", "full", 
"greater", "greater_equal",
            "imdecode", "lesser", "lesser_equal", "logical_and", "logical_or", 
"logical_xor",
            "maximum", "minimum", "moveaxis", "modulo", "multiply", 
"not_equal", "onehot_encode",
-           "power", "subtract", "true_divide", "waitall", "_new_empty_handle", 
"histogram"]
+           "power", "subtract", "true_divide", "waitall", "_new_empty_handle", 
"histogram",
+           "to_dlpack_for_read", "to_dlpack_for_write", "from_dlpack"]
 
 _STORAGE_TYPE_UNDEFINED = -1
 _STORAGE_TYPE_DEFAULT = 0
@@ -178,7 +179,6 @@ fixed-size items.
     # See C++ side of definition(kTVMNDArrayTypeCode) at 
include/mxmet/tensor_blob.h
     _tvm_tcode = 19
     # pylint: disable= no-member, undefined-variable
-
     @property
     def _tvm_handle(self):
         return self.handle.value
@@ -2213,6 +2213,52 @@ fixed-size items.
         """
         return op.cast_storage(self, stype=stype)
 
+    def to_dlpack_for_read(self):
+        """Returns a reference view of NDArray that represents as 
DLManagedTensor until
+        all previous write operations on the current array are finished.
+
+        Returns
+        -------
+        PyCapsule (the pointer of DLManagedTensor)
+            a reference view of NDArray that represents as DLManagedTensor.
+
+        Examples
+        --------
+        >>> x = mx.nd.ones((2,3))
+        >>> y = mx.nd.to_dlpack_for_read(x)
+        >>> type(y)
+        <class 'PyCapsule'>
+        >>> z = mx.nd.from_dlpack(y)
+        >>> z
+        [[1. 1. 1.]
+         [1. 1. 1.]]
+        <NDArray 2x3 @cpu(0)>
+        """
+        return to_dlpack_for_read(self)
+
+    def to_dlpack_for_write(self):
+        """Returns a reference view of NDArray that represents as 
DLManagedTensor until
+        all previous read/write operations on the current array are finished.
+
+        Returns
+        -------
+        PyCapsule (the pointer of DLManagedTensor)
+            a reference view of NDArray that represents as DLManagedTensor.
+
+        Examples
+        --------
+        >>> x = mx.nd.ones((2,3))
+        >>> w = mx.nd.to_dlpack_for_write(x)
+        >>> type(w)
+        <class 'PyCapsule'>
+        >>> u = mx.nd.from_dlpack(w)
+        >>> u += 1
+        >>> x
+        [[2. 2. 2.]
+         [2. 2. 2.]]
+        <NDArray 2x3 @cpu(0)>
+        """
+        return to_dlpack_for_write(self)
 
 def _get_indexing_dispatch_code(key):
     """Returns a dispatch code for calling basic or advanced indexing 
functions."""
@@ -3859,3 +3905,128 @@ def histogram(a, bins=10, range=None):
         return _internal._histogram(data=a, bin_cnt=bins, range=range)
     raise ValueError("bins argument should be either an integer or an NDArray")
     # pylint: enable= no-member, protected-access, redefined-builtin
+
+PyCapsuleDestructor = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
+_c_str_dltensor = c_str('dltensor')
+_c_str_used_dltensor = c_str('used_dltensor')
+
+def _dlpack_deleter(pycapsule):
+    pycapsule = ctypes.c_void_p(pycapsule)
+    if ctypes.pythonapi.PyCapsule_IsValid(pycapsule, _c_str_dltensor):
+        ptr = ctypes.c_void_p(
+            ctypes.pythonapi.PyCapsule_GetPointer(pycapsule, _c_str_dltensor))
+        check_call(_LIB.MXNDArrayCallDLPackDeleter(ptr))
+
+_c_dlpack_deleter = PyCapsuleDestructor(_dlpack_deleter)
+
+def to_dlpack_for_read(data):
+    """Returns a reference view of NDArray that represents as DLManagedTensor 
until
+       all previous write operations on the current array are finished.
+
+    Parameters
+    ----------
+    data: NDArray
+        input data.
+
+    Returns
+    -------
+    PyCapsule (the pointer of DLManagedTensor)
+        a reference view of NDArray that represents as DLManagedTensor.
+
+    Examples
+    --------
+    >>> x = mx.nd.ones((2,3))
+    >>> y = mx.nd.to_dlpack_for_read(x)
+    >>> type(y)
+    <class 'PyCapsule'>
+    >>> z = mx.nd.from_dlpack(y)
+    >>> z
+    [[1. 1. 1.]
+     [1. 1. 1.]]
+    <NDArray 2x3 @cpu(0)>
+    """
+    data.wait_to_read()
+    dlpack = DLPackHandle()
+    check_call(_LIB.MXNDArrayToDLPack(data.handle, ctypes.byref(dlpack)))
+    return ctypes.pythonapi.PyCapsule_New(dlpack, _c_str_dltensor, 
_c_dlpack_deleter)
+
+def to_dlpack_for_write(data):
+    """Returns a reference view of NDArray that represents as DLManagedTensor 
until
+       all previous read/write operations on the current array are finished.
+
+    Parameters
+    ----------
+    data: NDArray
+        input data.
+
+    Returns
+    -------
+    PyCapsule (the pointer of DLManagedTensor)
+        a reference view of NDArray that represents as DLManagedTensor.
+
+    Examples
+    --------
+    >>> x = mx.nd.ones((2,3))
+    >>> w = mx.nd.to_dlpack_for_write(x)
+    >>> type(w)
+    <class 'PyCapsule'>
+    >>> u = mx.nd.from_dlpack(w)
+    >>> u += 1
+    >>> x
+    [[2. 2. 2.]
+     [2. 2. 2.]]
+    <NDArray 2x3 @cpu(0)>
+    """
+    check_call(_LIB.MXNDArrayWaitToWrite(data.handle))
+    dlpack = DLPackHandle()
+    check_call(_LIB.MXNDArrayToDLPack(data.handle, ctypes.byref(dlpack)))
+    return ctypes.pythonapi.PyCapsule_New(dlpack, _c_str_dltensor, 
_c_dlpack_deleter)
+
+def from_dlpack(dlpack):
+    """Returns a NDArray backed by a dlpack tensor.
+
+    Parameters
+    ----------
+    dlpack: PyCapsule (the pointer of DLManagedTensor)
+        input data
+
+    Returns
+    -------
+    NDArray
+        a NDArray backed by a dlpack tensor
+
+    Examples
+    --------
+    >>> x = mx.nd.ones((2,3))
+    >>> y = mx.nd.to_dlpack_for_read(x)
+    >>> type(y)
+    <class 'PyCapsule'>
+    >>> z = mx.nd.from_dlpack(y)
+    >>> type(z)
+    <class 'mxnet.ndarray.ndarray.NDArray'>
+    >>> z
+    [[ 1.  1.  1.]
+     [ 1.  1.  1.]]
+    <NDArray 2x3 @cpu(0)>
+
+    >>> w = mx.nd.to_dlpack_for_write(x)
+    >>> type(w)
+    <class 'PyCapsule'>
+    >>> u = mx.nd.from_dlpack(w)
+    >>> u += 1
+    >>> x
+    [[2. 2. 2.]
+     [2. 2. 2.]]
+    <NDArray 2x3 @cpu(0)>
+    """
+    handle = NDArrayHandle()
+    dlpack = ctypes.py_object(dlpack)
+    assert ctypes.pythonapi.PyCapsule_IsValid(dlpack, _c_str_dltensor), 
ValueError(
+        'Invalid DLPack Tensor. DLTensor capsules can be consumed only once.')
+    dlpack_handle = 
ctypes.c_void_p(ctypes.pythonapi.PyCapsule_GetPointer(dlpack, _c_str_dltensor))
+    check_call(_LIB.MXNDArrayFromDLPack(dlpack_handle, ctypes.byref(handle)))
+    # Rename PyCapsule (DLPack)
+    ctypes.pythonapi.PyCapsule_SetName(dlpack, _c_str_used_dltensor)
+    # delete the deleter of the old dlpack
+    ctypes.pythonapi.PyCapsule_SetDestructor(dlpack, None)
+    return NDArray(handle=handle)
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index 1ef3f0f..56e3180 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -500,6 +500,31 @@ int MXNDArrayGetData(NDArrayHandle handle,
   API_END();
 }
 
+int MXNDArrayToDLPack(NDArrayHandle handle,
+                      DLManagedTensorHandle *out_dlpack) {
+  API_BEGIN();
+  NDArray *arr = static_cast<NDArray*>(handle);
+  *out_dlpack = arr->ToDLPack();
+  API_END();
+}
+
+int MXNDArrayFromDLPack(DLManagedTensorHandle dlpack,
+                        NDArrayHandle *out_handle) {
+  API_BEGIN();
+  *out_handle = new NDArray(NDArray::FromDLPack(
+              static_cast<DLManagedTensor*>(dlpack)));
+  API_END();
+}
+
+int MXNDArrayCallDLPackDeleter(DLManagedTensorHandle dlpack) {
+  API_BEGIN();
+  if (dlpack != nullptr) {
+    DLManagedTensor *p_dlpack = static_cast<DLManagedTensor*>(dlpack);
+    p_dlpack->deleter(p_dlpack);
+  }
+  API_END();
+}
+
 int MXNDArrayGetDType(NDArrayHandle handle,
                      int *out_dtype) {
   API_BEGIN();
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index b443d5d..47e0c5b 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -312,6 +312,34 @@ NDArray NDArray::data_ndarray() const {
   return ret;
 }
 
+struct NDArrayDLManager {
+    NDArray handle;  // ref NDArray
+    DLManagedTensor tensor;
+};
+
+DLManagedTensor* NDArray::ToDLPack() const {
+  NDArrayDLManager* dlmanager(new NDArrayDLManager);
+  dlmanager->handle = *this;
+  if (!is_none()) {
+    dlmanager->tensor.dl_tensor = data().dltensor();
+  }
+  dlmanager->tensor.manager_ctx = dlmanager;
+  dlmanager->tensor.deleter = [](DLManagedTensor* dlmanager){
+    delete static_cast<NDArrayDLManager*>(dlmanager->manager_ctx);
+  };
+  return &(dlmanager->tensor);
+}
+
+NDArray NDArray::FromDLPack(const DLManagedTensor* tensor) {
+  const DLTensor &dl_tensor = tensor->dl_tensor;
+  auto deleter = [tensor](){
+    if (tensor->deleter != nullptr) {
+      tensor->deleter(const_cast<DLManagedTensor*>(tensor));
+    }
+  };
+  return NDArray(TBlob(dl_tensor), dl_tensor.ctx.device_id, deleter);
+}
+
 bool NDArray::fresh_out_grad() const {
   if (Imperative::AGInfo::IsNone(*this)) return false;
   Imperative::AGInfo& info = Imperative::AGInfo::Get(entry_.node);
diff --git a/tests/python/unittest/test_ndarray.py 
b/tests/python/unittest/test_ndarray.py
index 7a5c7ca..a1c178f 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -1439,6 +1439,37 @@ def test_ndarray_cpu_shared_ctx():
     res = mx.nd.zeros((1, 2, 3), ctx=ctx)
     assert(res.context == ctx)
 
+@with_seed()
+def test_dlpack():
+    for dtype in [np.float32, np.int32]:
+        for shape in [(3, 4, 5, 6), (2, 10), (15,)]:
+            a = mx.nd.random.uniform(shape = shape)
+            a_np = a.asnumpy()
+
+            pack = a.to_dlpack_for_read()
+            b = mx.nd.from_dlpack(pack)
+
+            a_copy = a.copy()
+            pack2 = a_copy.to_dlpack_for_write()
+            c = mx.nd.from_dlpack(pack2)
+
+            pack3 = mx.nd.to_dlpack_for_read(a)
+            d = mx.nd.from_dlpack(pack3)
+
+            a_copy = a.copy()
+            pack4 = mx.nd.to_dlpack_for_write(a_copy)
+            e = mx.nd.from_dlpack(pack4)
+
+            del a, pack, pack2, pack3, pack4
+
+            b_np = b.asnumpy()
+            c_np = c.asnumpy()
+            d_np = d.asnumpy()
+            e_np = e.asnumpy()
+            mx.test_utils.assert_almost_equal(a_np, b_np)
+            mx.test_utils.assert_almost_equal(a_np, c_np)
+            mx.test_utils.assert_almost_equal(a_np, d_np)
+            mx.test_utils.assert_almost_equal(a_np, e_np)
 
 if __name__ == '__main__':
     import nose

Reply via email to