This is an automated email from the ASF dual-hosted git repository.

zhengda pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 90b66b8  [MXNET-1406] [BUG] Fix DLManagedTensor deleter (#15016)
90b66b8 is described below

commit 90b66b8f654fd0446e57a92eab37d32c73c7e53e
Author: Junru Shao <[email protected]>
AuthorDate: Tue May 21 18:14:12 2019 -0700

    [MXNET-1406] [BUG] Fix DLManagedTensor deleter (#15016)
    
    * Fix
    
    * Fix
    
    * Retrigger
---
 include/mxnet/c_api.h           |  2 ++
 include/mxnet/ndarray.h         |  2 +-
 python/mxnet/ndarray/ndarray.py |  6 ++----
 src/c_api/c_api.cc              |  4 +++-
 src/ndarray/ndarray.cc          | 17 +++++++++++------
 5 files changed, 19 insertions(+), 12 deletions(-)

diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 511bff2..335154c 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -818,10 +818,12 @@ MXNET_DLL int MXNDArrayToDLPack(NDArrayHandle handle,
 * The memory is retained until the NDArray went out of scope.
 *
 * \param dlpack the pointer of the input DLManagedTensor
+* \param transient_handle whether the handle will be destructed before calling 
the deleter
 * \param out_handle pointer holder to get pointer of NDArray
 * \return 0 when success, -1 when failure happens
 */
 MXNET_DLL int MXNDArrayFromDLPack(DLManagedTensorHandle dlpack,
+                                  const bool transient_handle,
                                   NDArrayHandle *out_handle);
 
 /*!
diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h
index 340c380..e694573 100644
--- a/include/mxnet/ndarray.h
+++ b/include/mxnet/ndarray.h
@@ -587,7 +587,7 @@ class NDArray {
    *
    * \return The created NDArray view.
    */
-  static NDArray FromDLPack(const DLManagedTensor* tensor);
+  static NDArray FromDLPack(const DLManagedTensor* tensor, bool 
transient_handle);
 
   /*!
    * \brief Update ndarray chunk storage handles using existing ndarray 
storage handles
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index 2325890..4b717e2 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -4156,7 +4156,7 @@ def from_dlpack(dlpack):
     assert ctypes.pythonapi.PyCapsule_IsValid(dlpack, _c_str_dltensor), 
ValueError(
         'Invalid DLPack Tensor. DLTensor capsules can be consumed only once.')
     dlpack_handle = 
ctypes.c_void_p(ctypes.pythonapi.PyCapsule_GetPointer(dlpack, _c_str_dltensor))
-    check_call(_LIB.MXNDArrayFromDLPack(dlpack_handle, ctypes.byref(handle)))
+    check_call(_LIB.MXNDArrayFromDLPack(dlpack_handle, False, 
ctypes.byref(handle)))
     # Rename PyCapsule (DLPack)
     ctypes.pythonapi.PyCapsule_SetName(dlpack, _c_str_used_dltensor)
     # delete the deleter of the old dlpack
@@ -4268,8 +4268,6 @@ def from_numpy(ndarray, zero_copy=True):
         raise ValueError("Only c-contiguous arrays are supported for 
zero-copy")
     ndarray.flags['WRITEABLE'] = False
     c_obj = _make_dl_managed_tensor(ndarray)
-    address = ctypes.addressof(c_obj)
-    address = ctypes.cast(address, ctypes.c_void_p)
     handle = NDArrayHandle()
-    check_call(_LIB.MXNDArrayFromDLPack(address, ctypes.byref(handle)))
+    check_call(_LIB.MXNDArrayFromDLPack(ctypes.byref(c_obj), True, 
ctypes.byref(handle)))
     return NDArray(handle=handle)
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index f549ddd..536c535 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -562,10 +562,12 @@ int MXNDArrayToDLPack(NDArrayHandle handle,
 }
 
 int MXNDArrayFromDLPack(DLManagedTensorHandle dlpack,
+                        const bool transient_handle,
                         NDArrayHandle *out_handle) {
   API_BEGIN();
   *out_handle = new NDArray(NDArray::FromDLPack(
-              static_cast<DLManagedTensor*>(dlpack)));
+              static_cast<DLManagedTensor*>(dlpack),
+              transient_handle));
   API_END();
 }
 
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index 0bfca8c..60de62d 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -355,14 +355,19 @@ DLManagedTensor* NDArray::ToDLPack() const {
   return &(dlmanager->tensor);
 }
 
-NDArray NDArray::FromDLPack(const DLManagedTensor* tensor) {
-  DLManagedTensor tensor_copy = *tensor;
-  auto deleter = [tensor_copy](){
-    if (tensor_copy.deleter != nullptr) {
-      tensor_copy.deleter(const_cast<DLManagedTensor*>(&tensor_copy));
+NDArray NDArray::FromDLPack(const DLManagedTensor* tensor, bool 
transient_handle) {
+  DLManagedTensor *tensor_copy = transient_handle
+                               ? new DLManagedTensor(*tensor)
+                               : const_cast<DLManagedTensor*>(tensor);
+  auto deleter = [tensor_copy, transient_handle](){
+    if (tensor_copy->deleter != nullptr) {
+      tensor_copy->deleter(tensor_copy);
+    }
+    if (transient_handle) {
+      delete tensor_copy;
     }
   };
-  return NDArray(TBlob(tensor_copy.dl_tensor), 
tensor_copy.dl_tensor.ctx.device_id, deleter);
+  return NDArray(TBlob(tensor_copy->dl_tensor), 
tensor_copy->dl_tensor.ctx.device_id, deleter);
 }
 
 bool NDArray::fresh_out_grad() const {

Reply via email to