tqchen commented on code in PR #96:
URL: https://github.com/apache/tvm-ffi/pull/96#discussion_r2422377406
##########
python/tvm_ffi/cython/tensor.pxi:
##########
@@ -275,33 +275,72 @@ _set_class_tensor(Tensor)
_register_object_by_index(kTVMFFITensor, Tensor)
-cdef int _dltensor_test_wrapper_c_dlpack_from_pyobject(
- void* obj, DLManagedTensorVersioned** out, TVMFFIStreamHandle* env_stream
+cdef int _dltensor_test_wrapper_from_pyobject(
+ void* obj, DLManagedTensorVersioned** out
) except -1:
+ """DLPackExchangeAPI: managed_tensor_from_py_object_no_sync"""
cdef PyObject* py_obj = <PyObject*>obj
cdef DLTensorTestWrapper wrapper = <DLTensorTestWrapper>py_obj
- cdef TVMFFIStreamHandle current_stream
- cdef DLManagedTensorVersioned* temp_managed_tensor
- if env_stream != NULL:
- env_stream[0] = TVMFFIEnvGetStream(
- wrapper.tensor.cdltensor.device.device_type,
- wrapper.tensor.cdltensor.device.device_id
- )
-
return TVMFFITensorToDLPackVersioned(wrapper.tensor.chandle, out)
-def _dltensor_test_wrapper_c_dlpack_from_pyobject_as_intptr():
- cdef DLPackFromPyObject converter_func =
_dltensor_test_wrapper_c_dlpack_from_pyobject
- cdef void* temp_ptr = <void*>converter_func
- cdef long long temp_int_ptr = <long long>temp_ptr
- return temp_int_ptr
+cdef int _dltensor_test_wrapper_to_pyobject(
+ DLManagedTensorVersioned* tensor, void** out_py_object
+) except -1:
+ """DLPackExchangeAPI: managed_tensor_to_py_object_no_sync"""
+ cdef TVMFFIObjectHandle temp_chandle
+ if TVMFFITensorFromDLPackVersioned(tensor, 0, 0, &temp_chandle) != 0:
+ return -1
+ py_tensor = make_tensor_from_chandle(temp_chandle)
+ Py_INCREF(py_tensor)
+ out_py_object[0] = <void*>(<PyObject*>py_tensor)
+ return 0
+
+
+cdef int _dltensor_test_wrapper_current_work_stream(
+ int device_type, int32_t device_id, void** out_stream
+) except -1:
+ """DLPackExchangeAPI: current_work_stream"""
+ if device_type != kDLCPU:
+ out_stream[0] = <void*>TVMFFIEnvGetStream(device_type, device_id)
+ return 0
+
+
+# Module-level static DLPackExchangeAPI for DLTensorTestWrapper
+cdef DLPackExchangeAPI _dltensor_test_wrapper_static_api
+cdef bint _dltensor_test_wrapper_api_initialized = False
+
+cdef const DLPackExchangeAPI* _dltensor_test_wrapper_get_exchange_api()
noexcept:
+ """Get the static DLPackExchangeAPI instance for DLTensorTestWrapper."""
+ global _dltensor_test_wrapper_static_api,
_dltensor_test_wrapper_api_initialized
+
+ if not _dltensor_test_wrapper_api_initialized:
+ # Initialize header using macros from dlpack.h
Review Comment:
specifically
```python
class DLTensorTestWrapper:
# this is static member init and only called once
__dlpack_exchange__api = xx
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]