tqchen commented on code in PR #96:
URL: https://github.com/apache/tvm-ffi/pull/96#discussion_r2422319302
##########
python/tvm_ffi/cython/tensor.pxi:
##########
@@ -275,33 +275,72 @@ _set_class_tensor(Tensor)
_register_object_by_index(kTVMFFITensor, Tensor)
-cdef int _dltensor_test_wrapper_c_dlpack_from_pyobject(
- void* obj, DLManagedTensorVersioned** out, TVMFFIStreamHandle* env_stream
+cdef int _dltensor_test_wrapper_from_pyobject(
+ void* obj, DLManagedTensorVersioned** out
) except -1:
+ """DLPackExchangeAPI: managed_tensor_from_py_object_no_sync"""
cdef PyObject* py_obj = <PyObject*>obj
cdef DLTensorTestWrapper wrapper = <DLTensorTestWrapper>py_obj
- cdef TVMFFIStreamHandle current_stream
- cdef DLManagedTensorVersioned* temp_managed_tensor
- if env_stream != NULL:
- env_stream[0] = TVMFFIEnvGetStream(
- wrapper.tensor.cdltensor.device.device_type,
- wrapper.tensor.cdltensor.device.device_id
- )
-
return TVMFFITensorToDLPackVersioned(wrapper.tensor.chandle, out)
-def _dltensor_test_wrapper_c_dlpack_from_pyobject_as_intptr():
- cdef DLPackFromPyObject converter_func =
_dltensor_test_wrapper_c_dlpack_from_pyobject
- cdef void* temp_ptr = <void*>converter_func
- cdef long long temp_int_ptr = <long long>temp_ptr
- return temp_int_ptr
+cdef int _dltensor_test_wrapper_to_pyobject(
+ DLManagedTensorVersioned* tensor, void** out_py_object
+) except -1:
+ """DLPackExchangeAPI: managed_tensor_to_py_object_no_sync"""
+ cdef TVMFFIObjectHandle temp_chandle
+ if TVMFFITensorFromDLPackVersioned(tensor, 0, 0, &temp_chandle) != 0:
+ return -1
+ py_tensor = make_tensor_from_chandle(temp_chandle)
+ Py_INCREF(py_tensor)
+ out_py_object[0] = <void*>(<PyObject*>py_tensor)
+ return 0
+
+
+cdef int _dltensor_test_wrapper_current_work_stream(
+ int device_type, int32_t device_id, void** out_stream
+) except -1:
+ """DLPackExchangeAPI: current_work_stream"""
+ if device_type != kDLCPU:
+ out_stream[0] = <void*>TVMFFIEnvGetStream(device_type, device_id)
+ return 0
+
+
+# Module-level static DLPackExchangeAPI for DLTensorTestWrapper
+cdef DLPackExchangeAPI _dltensor_test_wrapper_static_api
+cdef bint _dltensor_test_wrapper_api_initialized = False
+
+cdef const DLPackExchangeAPI* _dltensor_test_wrapper_get_exchange_api()
noexcept:
+ """Get the static DLPackExchangeAPI instance for DLTensorTestWrapper."""
+ global _dltensor_test_wrapper_static_api,
_dltensor_test_wrapper_api_initialized
+
+ if not _dltensor_test_wrapper_api_initialized:
+ # Initialize header using macros from dlpack.h
Review Comment:
given we call it once, no need to do the intialization check, just do the
assignment
##########
python/tvm_ffi/cython/function.pxi:
##########
@@ -142,35 +142,42 @@ cdef int TVMFFIPyArgSetterObject_(
return 0
-cdef int TVMFFIPyArgSetterDLPackCExporter_(
+cdef int TVMFFIPyArgSetterDLPackExchangeAPI_(
TVMFFIPyArgSetter* this, TVMFFIPyCallContext* ctx,
PyObject* arg, TVMFFIAny* out
) except -1:
cdef DLManagedTensorVersioned* temp_managed_tensor
cdef TVMFFIObjectHandle temp_chandle
- cdef TVMFFIStreamHandle env_stream = NULL
+ cdef void* current_stream = NULL
+ cdef const DLPackExchangeAPI* api = this.c_dlpack_exchange_api
- if this.c_dlpack_to_pyobject != NULL:
- ctx.c_dlpack_to_pyobject = this.c_dlpack_to_pyobject
- if this.c_dlpack_tensor_allocator != NULL:
- ctx.c_dlpack_tensor_allocator = this.c_dlpack_tensor_allocator
+ # Set allocator and ToPyObject converter in context if available
+ if api.managed_tensor_allocator != NULL:
+ ctx.c_dlpack_tensor_allocator =
<DLPackTensorAllocator>api.managed_tensor_allocator
Review Comment:
instead, change to ctx.dlpack_exchange_api =
change ctx field
##########
python/tvm_ffi/_optional_torch_c_dlpack.py:
##########
@@ -513,16 +508,66 @@ def load_torch_c_dlpack_extension() -> Any:
}
}
-int64_t TorchDLPackFromPyObjectPtr() {
- return reinterpret_cast<int64_t>(TorchDLPackFromPyObject);
+int TorchDLTensorFromPyObject(void* py_obj, DLTensor* out) {
+ try {
+ // Use handle (non-owning) to avoid unnecessary refcount operations
+ py::handle handle(static_cast<PyObject*>(py_obj));
+ at::Tensor tensor = handle.cast<at::Tensor>();
+
+ // Fill in the pre-allocated DLTensor struct with direct pointers
+ // This is a non-owning conversion - the original PyObject owns the data
+ // and is kept alive by the caller for the duration of this call
+ out->data = tensor.data_ptr();
+ out->device = at::torchDeviceToDLDeviceForDLPackv1(tensor.device());
+ out->ndim = static_cast<int32_t>(tensor.dim());
Review Comment:
move the implementation to aten:: namespace, mainly because likely we need
such split in upstream.
at::toDLPackNonOwning(const at::Tensor&, DLTensor*)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]