This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 7ab326c  [numpy] add dlpack functions to npx (#18342)
7ab326c is described below

commit 7ab326c588f6b6a85fa2b96ac0476b057972fea0
Author: Sheng Zha <[email protected]>
AuthorDate: Sun May 17 14:44:42 2020 -0700

    [numpy] add dlpack functions to npx (#18342)
    
    * add dlpack functions to npx
    
    * improve tests
    
    * further improve test
    
    * fix comment
---
 python/mxnet/numpy_extension/utils.py       | 125 +++++++++++++++++++++++++++-
 tests/python/unittest/test_numpy_ndarray.py |  26 ++++++
 2 files changed, 149 insertions(+), 2 deletions(-)

diff --git a/python/mxnet/numpy_extension/utils.py 
b/python/mxnet/numpy_extension/utils.py
index 23d28ea..f625439 100644
--- a/python/mxnet/numpy_extension/utils.py
+++ b/python/mxnet/numpy_extension/utils.py
@@ -21,12 +21,24 @@
 
 import ctypes
 from .. util import is_np_array, is_np_shape
-from .. base import _LIB, check_call, string_types, c_str_array
+from .. base import _LIB, check_call, string_types, c_str_array, DLPackHandle
 from .. base import c_handle_array, c_str, mx_uint, NDArrayHandle, py_str
 from ..numpy import ndarray
 
-__all__ = ['save', 'load']
+__all__ = ['save', 'load', 'to_dlpack_for_read', 'to_dlpack_for_write', 
'from_dlpack']
 
+PyCapsuleDestructor = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
+_c_str_dltensor = c_str('dltensor')
+_c_str_used_dltensor = c_str('used_dltensor')
+
+def _dlpack_deleter(pycapsule):
+    pycapsule = ctypes.c_void_p(pycapsule)
+    if ctypes.pythonapi.PyCapsule_IsValid(pycapsule, _c_str_dltensor):
+        ptr = ctypes.c_void_p(
+            ctypes.pythonapi.PyCapsule_GetPointer(pycapsule, _c_str_dltensor))
+        check_call(_LIB.MXNDArrayCallDLPackDeleter(ptr))
+
+_c_dlpack_deleter = PyCapsuleDestructor(_dlpack_deleter)
 
 def save(file, arr):
     """Saves a list of `ndarray`s or a dict of `str`->`ndarray` to file.
@@ -119,3 +131,112 @@ def load(file):
         return dict(
             (py_str(names[i]), ndarray(NDArrayHandle(handles[i])))
             for i in range(out_size.value))
+
+
+def from_dlpack(dlpack):
+    """Returns a np.ndarray backed by a dlpack tensor.
+
+    Parameters
+    ----------
+    dlpack: PyCapsule (the pointer of DLManagedTensor)
+        input data
+
+    Returns
+    -------
+    np.ndarray
+        an ndarray backed by a dlpack tensor
+
+    Examples
+    --------
+    >>> x = mx.np.ones((2,3))
+    >>> y = mx.npx.to_dlpack_for_read(x)
+    >>> type(y)
+    <class 'PyCapsule'>
+    >>> z = mx.npx.from_dlpack(y)
+    >>> type(z)
+    <class 'mxnet.numpy.ndarray'>
+    >>> z
+    array([[1., 1., 1.],
+           [1., 1., 1.]])
+
+    >>> w = mx.npx.to_dlpack_for_write(x)
+    >>> type(w)
+    <class 'PyCapsule'>
+    >>> u = mx.npx.from_dlpack(w)
+    >>> u += 1
+    >>> x
+    array([[2., 2., 2.],
+           [2., 2., 2.]])
+    """
+    handle = NDArrayHandle()
+    dlpack = ctypes.py_object(dlpack)
+    assert ctypes.pythonapi.PyCapsule_IsValid(dlpack, _c_str_dltensor), 
ValueError(
+        'Invalid DLPack Tensor. DLTensor capsules can be consumed only once.')
+    dlpack_handle = 
ctypes.c_void_p(ctypes.pythonapi.PyCapsule_GetPointer(dlpack, _c_str_dltensor))
+    check_call(_LIB.MXNDArrayFromDLPackEx(dlpack_handle, False, 
ctypes.byref(handle)))
+    # Rename PyCapsule (DLPack)
+    ctypes.pythonapi.PyCapsule_SetName(dlpack, _c_str_used_dltensor)
+    # delete the deleter of the old dlpack
+    ctypes.pythonapi.PyCapsule_SetDestructor(dlpack, None)
+    return ndarray(handle=handle)
+
+def to_dlpack_for_read(data):
+    """Returns a reference view of np.ndarray that represents as 
DLManagedTensor until
+       all previous write operations on the current array are finished.
+
+    Parameters
+    ----------
+    data: np.ndarray
+        input data.
+
+    Returns
+    -------
+    PyCapsule (the pointer of DLManagedTensor)
+        a reference view of ndarray that represents as DLManagedTensor.
+
+    Examples
+    --------
+    >>> x = mx.np.ones((2,3))
+    >>> y = mx.npx.to_dlpack_for_read(x)
+    >>> type(y)
+    <class 'PyCapsule'>
+    >>> z = mx.npx.from_dlpack(y)
+    >>> z
+    array([[1., 1., 1.],
+           [1., 1., 1.]])
+    """
+    data.wait_to_read()
+    dlpack = DLPackHandle()
+    check_call(_LIB.MXNDArrayToDLPack(data.handle, ctypes.byref(dlpack)))
+    return ctypes.pythonapi.PyCapsule_New(dlpack, _c_str_dltensor, 
_c_dlpack_deleter)
+
+def to_dlpack_for_write(data):
+    """Returns a reference view of ndarray that represents as DLManagedTensor 
until
+       all previous read/write operations on the current array are finished.
+
+    Parameters
+    ----------
+    data: np.ndarray
+        input data.
+
+    Returns
+    -------
+    PyCapsule (the pointer of DLManagedTensor)
+        a reference view of np.ndarray that represents as DLManagedTensor.
+
+    Examples
+    --------
+    >>> x = mx.np.ones((2,3))
+    >>> w = mx.npx.to_dlpack_for_write(x)
+    >>> type(w)
+    <class 'PyCapsule'>
+    >>> u = mx.npx.from_dlpack(w)
+    >>> u += 1
+    >>> x
+    array([[2., 2., 2.],
+           [2., 2., 2.]])
+    """
+    check_call(_LIB.MXNDArrayWaitToWrite(data.handle))
+    dlpack = DLPackHandle()
+    check_call(_LIB.MXNDArrayToDLPack(data.handle, ctypes.byref(dlpack)))
+    return ctypes.pythonapi.PyCapsule_New(dlpack, _c_str_dltensor, 
_c_dlpack_deleter)
diff --git a/tests/python/unittest/test_numpy_ndarray.py 
b/tests/python/unittest/test_numpy_ndarray.py
index e6db131..966b26d 100644
--- a/tests/python/unittest/test_numpy_ndarray.py
+++ b/tests/python/unittest/test_numpy_ndarray.py
@@ -1343,3 +1343,29 @@ def test_np_ndarray_pickle():
             a_load = pickle.load(f)
         same(a.asnumpy(), a_load.asnumpy())
 
[email protected]('dtype', [np.float32, np.int32])
[email protected]('size', [
+    (3, 4, 5, 6),
+    (2, 10),
+    (15,),
+    ()
+])
+@use_np
+def test_dlpack(dtype, size):
+    a = mx.np.random.uniform(size=size)
+    a_np = a.copy()
+    a += 1
+
+    pack = mx.npx.to_dlpack_for_read(a)
+    b = mx.npx.from_dlpack(pack)
+
+    a_copy = a.copy()
+    pack2 = mx.npx.to_dlpack_for_write(a_copy)
+    c = mx.npx.from_dlpack(pack2)
+    c += 1
+
+    del a, pack, pack2
+
+    same(a_np+1, b)
+    same(a_np+2, c)
+    same(a_np+2, a_copy)

Reply via email to