wkcn commented on a change in pull request #12047: [MXNET-779]Add DLPack 
Transformation API
URL: https://github.com/apache/incubator-mxnet/pull/12047#discussion_r209418176
 
 

 ##########
 File path: python/mxnet/ndarray/ndarray.py
 ##########
 @@ -3851,3 +3898,117 @@ def histogram(a, bins=10, range=None):
         return _internal._histogram(data=a, bin_cnt=bins, range=range)
     raise ValueError("bins argument should be either an integer or an NDArray")
     # pylint: enable= no-member, protected-access, redefined-builtin
+
+pycapsule_dlpack_deleter = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(
+    _LIB.MXNDArrayCallDLPackCapsuleDeleter)
+
+def to_dlpack_for_read(data):
+    """Returns a reference view of NDArray that represents as DLManagedTensor 
until
+       all previous write operations on the current array are finished.
+
+    Parameters
+    ----------
+    data: NDArray
+        input data.
+
+    Returns
+    -------
+    PyCapsule (the pointer of DLManagedTensor)
+        a reference view of NDArray that represents as DLManagedTensor.
+
+    Examples
+    --------
+    >>> x = mx.nd.ones((2,3))
+    >>> y = mx.nd.to_dlpack_for_read(x)
+    >>> type(y)
+    <class 'PyCapsule'>
+    >>> z = mx.nd.from_dlpack(y)
+    >>> z
+    [[1. 1. 1.]
+     [1. 1. 1.]]
+    <NDArray 2x3 @cpu(0)>
+    """
+    data.wait_to_read()
+    dlpack = DLPackHandle()
+    check_call(_LIB.MXNDArrayToDLPack(data.handle, ctypes.byref(dlpack)))
+    return ctypes.pythonapi.PyCapsule_New(dlpack, b'dltensor', 
pycapsule_dlpack_deleter)
+
+def to_dlpack_for_write(data):
+    """Returns a reference view of NDArray that represents as DLManagedTensor 
until
+       all previous read/write operations on the current array are finished.
+
+    Parameters
+    ----------
+    data: NDArray
+        input data.
+
+    Returns
+    -------
+    PyCapsule (the pointer of DLManagedTensor)
+        a reference view of NDArray that represents as DLManagedTensor.
+
+    Examples
+    --------
+    >>> x = mx.nd.ones((2,3))
+    >>> w = mx.nd.to_dlpack_for_write(x)
+    >>> type(w)
+    <class 'PyCapsule'>
+    >>> u = mx.nd.from_dlpack(w)
+    >>> u += 1
+    >>> x
+    [[2. 2. 2.]
+     [2. 2. 2.]]
+    <NDArray 2x3 @cpu(0)>
+    """
+    check_call(_LIB.MXNDArrayWaitToWrite(data.handle))
+    dlpack = DLPackHandle()
+    check_call(_LIB.MXNDArrayToDLPack(data.handle, ctypes.byref(dlpack)))
+    return ctypes.pythonapi.PyCapsule_New(dlpack, b'dltensor', 
pycapsule_dlpack_deleter)
 
 Review comment:
   Solved it. Thank you!

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to