This is an automated email from the ASF dual-hosted git repository.
yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 73b6851a54 [FFI][ABI] Introduce generic stream exchange protocol
(#18295)
73b6851a54 is described below
commit 73b6851a54ecd09a6037454813374bd08652d6c5
Author: Tianqi Chen <[email protected]>
AuthorDate: Tue Sep 9 14:00:24 2025 -0400
[FFI][ABI] Introduce generic stream exchange protocol (#18295)
This PR adds a __tvm_ffi_env_stream__ protocol for generic
tensors to exchange env stream to tvm ffi.
Also renames TVMFFIEnvSetStream to TVMFFIEnvSetCurrentStream.
---
ffi/include/tvm/ffi/extra/c_env_api.h | 6 +-
ffi/python/tvm_ffi/cython/base.pxi | 91 +++++++++++++++++++------------
ffi/python/tvm_ffi/cython/function.pxi | 29 +++++++---
ffi/python/tvm_ffi/cython/tensor.pxi | 24 ++++++++
ffi/scripts/benchmark_dlpack.py | 26 +++++++--
ffi/src/ffi/extra/stream_context.cc | 4 +-
src/runtime/device_api.cc | 3 +-
src/runtime/vm/cuda/cuda_graph_builtin.cc | 7 ++-
8 files changed, 134 insertions(+), 56 deletions(-)
diff --git a/ffi/include/tvm/ffi/extra/c_env_api.h
b/ffi/include/tvm/ffi/extra/c_env_api.h
index 6f8e44bdfb..bd0d188155 100644
--- a/ffi/include/tvm/ffi/extra/c_env_api.h
+++ b/ffi/include/tvm/ffi/extra/c_env_api.h
@@ -49,9 +49,9 @@ typedef void* TVMFFIStreamHandle;
* \note The stream is a weak reference that is cached/owned by the module.
* \return 0 when success, nonzero when failure happens
*/
-TVM_FFI_DLL int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id,
- TVMFFIStreamHandle stream,
- TVMFFIStreamHandle*
opt_out_original_stream);
+TVM_FFI_DLL int TVMFFIEnvSetCurrentStream(int32_t device_type, int32_t
device_id,
+ TVMFFIStreamHandle stream,
+ TVMFFIStreamHandle*
opt_out_original_stream);
/*!
* \brief FFI function to get the current stream for a device
diff --git a/ffi/python/tvm_ffi/cython/base.pxi
b/ffi/python/tvm_ffi/cython/base.pxi
index f1cd77bc47..efb2225453 100644
--- a/ffi/python/tvm_ffi/cython/base.pxi
+++ b/ffi/python/tvm_ffi/cython/base.pxi
@@ -24,39 +24,24 @@ from cpython cimport PyErr_CheckSignals, PyGILState_Ensure,
PyGILState_Release,
from cpython cimport pycapsule, PyCapsule_Destructor
from cpython cimport PyErr_SetNone
-
-# Cython binding for TVM FFI C API
-cdef extern from "tvm/ffi/c_api.h":
- cdef enum TVMFFITypeIndex:
- kTVMFFIAny = -1
- kTVMFFINone = 0
- kTVMFFIInt = 1
- kTVMFFIBool = 2
- kTVMFFIFloat = 3
- kTVMFFIOpaquePtr = 4
- kTVMFFIDataType = 5
- kTVMFFIDevice = 6
- kTVMFFIDLTensorPtr = 7
- kTVMFFIRawStr = 8
- kTVMFFIByteArrayPtr = 9
- kTVMFFIObjectRValueRef = 10
- kTVMFFISmallStr = 11
- kTVMFFISmallBytes = 12
- kTVMFFIStaticObjectBegin = 64
- kTVMFFIObject = 64
- kTVMFFIStr = 65
- kTVMFFIBytes = 66
- kTVMFFIError = 67
- kTVMFFIFunction = 68
- kTVMFFIShape = 69
- kTVMFFITensor = 70
- kTVMFFIArray = 71
- kTVMFFIMap = 72
- kTVMFFIModule = 73
- kTVMFFIOpaquePyObject = 74
-
-
- ctypedef void* TVMFFIObjectHandle
+cdef extern from "dlpack/dlpack.h":
+ cdef enum:
+ kDLCPU = 1,
+ kDLCUDA = 2,
+ kDLCUDAHost = 3,
+ kDLOpenCL = 4,
+ kDLVulkan = 7,
+ kDLMetal = 8,
+ kDLVPI = 9,
+ kDLROCM = 10,
+ kDLROCMHost = 11,
+ kDLExtDev = 12,
+ kDLCUDAManaged = 13,
+ kDLOneAPI = 14,
+ kDLWebGPU = 15,
+ kDLHexagon = 16,
+ kDLMAIA = 17
+ kDLTrn = 18
ctypedef struct DLDataType:
uint8_t code
@@ -92,6 +77,40 @@ cdef extern from "tvm/ffi/c_api.h":
void (*deleter)(DLManagedTensorVersioned* self)
uint64_t flags
+
+# Cython binding for TVM FFI C API
+cdef extern from "tvm/ffi/c_api.h":
+ cdef enum TVMFFITypeIndex:
+ kTVMFFIAny = -1
+ kTVMFFINone = 0
+ kTVMFFIInt = 1
+ kTVMFFIBool = 2
+ kTVMFFIFloat = 3
+ kTVMFFIOpaquePtr = 4
+ kTVMFFIDataType = 5
+ kTVMFFIDevice = 6
+ kTVMFFIDLTensorPtr = 7
+ kTVMFFIRawStr = 8
+ kTVMFFIByteArrayPtr = 9
+ kTVMFFIObjectRValueRef = 10
+ kTVMFFISmallStr = 11
+ kTVMFFISmallBytes = 12
+ kTVMFFIStaticObjectBegin = 64
+ kTVMFFIObject = 64
+ kTVMFFIStr = 65
+ kTVMFFIBytes = 66
+ kTVMFFIError = 67
+ kTVMFFIFunction = 68
+ kTVMFFIShape = 69
+ kTVMFFITensor = 70
+ kTVMFFIArray = 71
+ kTVMFFIMap = 72
+ kTVMFFIModule = 73
+ kTVMFFIOpaquePyObject = 74
+
+
+ ctypedef void* TVMFFIObjectHandle
+
ctypedef struct TVMFFIObject:
int32_t type_index
int32_t ref_counter
@@ -219,9 +238,9 @@ cdef extern from "tvm/ffi/extra/c_env_api.h":
int TVMFFIEnvRegisterCAPI(const char* name, void* ptr) nogil
void* TVMFFIEnvGetCurrentStream(int32_t device_type, int32_t device_id)
nogil
- int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id,
- TVMFFIStreamHandle stream,
- TVMFFIStreamHandle* opt_out_original_stream) nogil
+ int TVMFFIEnvSetCurrentStream(int32_t device_type, int32_t device_id,
+ TVMFFIStreamHandle stream,
+ TVMFFIStreamHandle* opt_out_original_stream)
nogil
cdef class ByteArrayArg:
diff --git a/ffi/python/tvm_ffi/cython/function.pxi
b/ffi/python/tvm_ffi/cython/function.pxi
index 28d4ba5a00..71591d9526 100644
--- a/ffi/python/tvm_ffi/cython/function.pxi
+++ b/ffi/python/tvm_ffi/cython/function.pxi
@@ -122,10 +122,25 @@ cdef inline int make_args(tuple py_args, TVMFFIAny* out,
list temp_args,
ctx_stream[0] = <TVMFFIStreamHandle>temp_ptr
temp_args.append(arg)
elif hasattr(arg, "__dlpack__"):
- arg = from_dlpack(arg)
+ ffi_arg = from_dlpack(arg)
out[i].type_index = kTVMFFITensor
- out[i].v_ptr = (<Tensor>arg).chandle
- temp_args.append(arg)
+ out[i].v_ptr = (<Tensor>ffi_arg).chandle
+ # record the stream from the source framework context when possible
+ temp_dltensor =
TVMFFITensorGetDLTensorPtr((<Tensor>ffi_arg).chandle)
+ if (temp_dltensor.device.device_type != kDLCPU and
+ ctx_dev_type != NULL and
+ ctx_dev_type[0] == -1):
+ # __tvm_ffi_env_stream__ returns the expected stream that
should be set
+ # through TVMFFIEnvSetCurrentStream when calling a TVM FFI
function
+ if hasattr(arg, "__tvm_ffi_env_stream__"):
+ # Ideally projects should directly setup their stream
context API
+ # write through by also calling TVMFFIEnvSetCurrentStream
+ # so we do not need this protocol to do exchange
+ ctx_dev_type[0] = temp_dltensor.device.device_type
+ ctx_dev_id[0] = temp_dltensor.device.device_id
+ temp_ptr= arg.__tvm_ffi_env_stream__()
+ ctx_stream[0] = <TVMFFIStreamHandle>temp_ptr
+ temp_args.append(ffi_arg)
elif isinstance(arg, PyNativeObject) and arg.__tvm_ffi_object__ is not
None:
arg = arg.__tvm_ffi_object__
out[i].type_index = TVMFFIObjectGetTypeIndex((<Object>arg).chandle)
@@ -210,7 +225,7 @@ cdef inline int FuncCall3(void* chandle,
with nogil:
if ctx_dev_type != -1:
# set the stream based on ctx stream
- c_api_ret_code[0] = TVMFFIEnvSetStream(ctx_dev_type, ctx_dev_id,
ctx_stream, &prev_stream)
+ c_api_ret_code[0] = TVMFFIEnvSetCurrentStream(ctx_dev_type,
ctx_dev_id, ctx_stream, &prev_stream)
if c_api_ret_code[0] != 0:
return 0
c_api_ret_code[0] = TVMFFIFunctionCall(
@@ -219,7 +234,7 @@ cdef inline int FuncCall3(void* chandle,
# restore the original stream if it is not the same as the context
stream
if ctx_dev_type != -1 and prev_stream != ctx_stream:
# restore the original stream
- c_api_ret_code[0] = TVMFFIEnvSetStream(ctx_dev_type, ctx_dev_id,
prev_stream, NULL)
+ c_api_ret_code[0] = TVMFFIEnvSetCurrentStream(ctx_dev_type,
ctx_dev_id, prev_stream, NULL)
if c_api_ret_code[0] != 0:
return 0
return 0
@@ -247,13 +262,13 @@ cdef inline int FuncCall(void* chandle,
with nogil:
if ctx_dev_type != -1:
- c_api_ret_code[0] = TVMFFIEnvSetStream(ctx_dev_type, ctx_dev_id,
ctx_stream, &prev_stream)
+ c_api_ret_code[0] = TVMFFIEnvSetCurrentStream(ctx_dev_type,
ctx_dev_id, ctx_stream, &prev_stream)
if c_api_ret_code[0] != 0:
return 0
c_api_ret_code[0] = TVMFFIFunctionCall(chandle, &packed_args[0],
nargs, result)
# restore the original stream if it is not the same as the context
stream
if ctx_dev_type != -1 and prev_stream != ctx_stream:
- c_api_ret_code[0] = TVMFFIEnvSetStream(ctx_dev_type, ctx_dev_id,
prev_stream, NULL)
+ c_api_ret_code[0] = TVMFFIEnvSetCurrentStream(ctx_dev_type,
ctx_dev_id, prev_stream, NULL)
if c_api_ret_code[0] != 0:
return 0
diff --git a/ffi/python/tvm_ffi/cython/tensor.pxi
b/ffi/python/tvm_ffi/cython/tensor.pxi
index 4658422ca5..2072ad0567 100644
--- a/ffi/python/tvm_ffi/cython/tensor.pxi
+++ b/ffi/python/tvm_ffi/cython/tensor.pxi
@@ -260,6 +260,30 @@ _set_class_tensor(Tensor)
_register_object_by_index(kTVMFFITensor, Tensor)
+cdef class DLTensorTestWrapper:
+ """Wrapper of a Tensor that exposes DLPack protocol, only for testing
purpose.
+ """
+ cdef Tensor tensor
+ def __init__(self, tensor):
+ self.tensor = tensor
+
+ def __tvm_ffi_env_stream__(self):
+ cdef TVMFFIStreamHandle stream
+ cdef long long stream_as_int
+ cdef int c_api_ret_code
+ with nogil:
+ stream = TVMFFIEnvGetCurrentStream(
+ self.tensor.cdltensor.device.device_type,
self.tensor.cdltensor.device.device_id)
+ stream_as_int = <long long>stream
+ return stream_as_int
+
+ def __dlpack_device__(self):
+ return self.tensor.__dlpack_device__()
+
+ def __dlpack__(self, *, **kwargs):
+ return self.tensor.__dlpack__(**kwargs)
+
+
cdef inline object make_ret_dltensor(TVMFFIAny result):
cdef DLTensor* dltensor
dltensor = <DLTensor*>result.v_ptr
diff --git a/ffi/scripts/benchmark_dlpack.py b/ffi/scripts/benchmark_dlpack.py
index 73fbe0f6ac..00581eb0f3 100644
--- a/ffi/scripts/benchmark_dlpack.py
+++ b/ffi/scripts/benchmark_dlpack.py
@@ -44,11 +44,11 @@ import time
def print_speed(name, speed):
- print(f"{name:<40} {speed} sec/call")
+ print(f"{name:<60} {speed} sec/call")
def print_error(name, error):
- print(f"{name:<40} {error}")
+ print(f"{name:<60} {error}")
def baseline_torch_add(repeat):
@@ -122,7 +122,7 @@ def tvm_ffi_nop(repeat):
nop(x, y, z)
start = time.time()
for i in range(repeat):
- y = tvm_ffi.from_dlpack(x)
+ nop(x, y, z)
end = time.time()
print_speed("tvm_ffi.nop", (end - start) / repeat)
@@ -275,6 +275,22 @@ def tvm_ffi_nop_autodlpack_from_numpy(repeat):
bench_tvm_ffi_nop_autodlpack("tvm_ffi.nop.autodlpack(numpy)", x, y, z,
repeat)
+def tvm_ffi_nop_autodlpack_from_dltensor_test_wrapper(repeat, device):
+ """
+ Measures overhead of running dlpack via auto convert by directly
+ take test wrapper as inputs. This effectively measure DLPack exchange in
tvm ffi.
+ """
+ x = tvm_ffi.from_dlpack(torch.arange(1, device=device))
+ y = tvm_ffi.from_dlpack(torch.arange(1, device=device))
+ z = tvm_ffi.from_dlpack(torch.arange(1, device=device))
+ x = tvm_ffi.core.DLTensorTestWrapper(x)
+ y = tvm_ffi.core.DLTensorTestWrapper(y)
+ z = tvm_ffi.core.DLTensorTestWrapper(z)
+ bench_tvm_ffi_nop_autodlpack(
+ f"tvm_ffi.nop.autodlpack(DLTensorTestWrapper[{device}])", x, y, z,
repeat
+ )
+
+
def bench_to_dlpack(x, name, repeat):
x.__dlpack__()
start = time.time()
@@ -367,7 +383,6 @@ def main():
baseline_numpy_add(repeat)
baseline_torch_add(repeat)
baseline_cupy_add(repeat)
- tvm_ffi_nop(repeat)
tvm_ffi_nop_from_torch_dlpack(repeat)
tvm_ffi_nop_from_numpy_dlpack(repeat)
tvm_ffi_self_dlpack_nop(repeat)
@@ -377,6 +392,9 @@ def main():
tvm_ffi_nop_autodlpack_from_torch(repeat, "cuda", stream=True)
tvm_ffi_nop_autodlpack_from_numpy(repeat)
+ tvm_ffi_nop_autodlpack_from_dltensor_test_wrapper(repeat, "cpu")
+ tvm_ffi_nop_autodlpack_from_dltensor_test_wrapper(repeat, "cuda")
+ tvm_ffi_nop(repeat)
print("-------------------------------")
print("Benchmark x.__dlpack__ overhead")
print("-------------------------------")
diff --git a/ffi/src/ffi/extra/stream_context.cc
b/ffi/src/ffi/extra/stream_context.cc
index d063efdef5..5a6afad4c1 100644
--- a/ffi/src/ffi/extra/stream_context.cc
+++ b/ffi/src/ffi/extra/stream_context.cc
@@ -66,8 +66,8 @@ class StreamContext {
} // namespace ffi
} // namespace tvm
-int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id,
TVMFFIStreamHandle stream,
- TVMFFIStreamHandle* out_original_stream) {
+int TVMFFIEnvSetCurrentStream(int32_t device_type, int32_t device_id,
TVMFFIStreamHandle stream,
+ TVMFFIStreamHandle* out_original_stream) {
TVM_FFI_SAFE_CALL_BEGIN();
tvm::ffi::StreamContext::ThreadLocal()->SetStream(device_type, device_id,
stream,
out_original_stream);
diff --git a/src/runtime/device_api.cc b/src/runtime/device_api.cc
index fd7d651df2..e574ce14b0 100644
--- a/src/runtime/device_api.cc
+++ b/src/runtime/device_api.cc
@@ -165,7 +165,8 @@ TVMStreamHandle DeviceAPI::CreateStream(Device dev) {
return nullptr; }
void DeviceAPI::FreeStream(Device dev, TVMStreamHandle stream) {}
void DeviceAPI::SetStream(Device dev, TVMStreamHandle stream) {
- TVM_FFI_CHECK_SAFE_CALL(TVMFFIEnvSetStream(dev.device_type, dev.device_id,
stream, nullptr));
+ TVM_FFI_CHECK_SAFE_CALL(
+ TVMFFIEnvSetCurrentStream(dev.device_type, dev.device_id, stream,
nullptr));
}
TVMStreamHandle DeviceAPI::GetCurrentStream(Device dev) {
diff --git a/src/runtime/vm/cuda/cuda_graph_builtin.cc
b/src/runtime/vm/cuda/cuda_graph_builtin.cc
index a85ade2e1d..2528415281 100644
--- a/src/runtime/vm/cuda/cuda_graph_builtin.cc
+++ b/src/runtime/vm/cuda/cuda_graph_builtin.cc
@@ -118,13 +118,14 @@ class CUDACaptureStream {
explicit CUDACaptureStream(cudaGraph_t* graph) : output_graph_(graph) {
CUDA_CALL(cudaGetDevice(&device_id_));
TVM_FFI_CHECK_SAFE_CALL(
- TVMFFIEnvSetStream(kDLCUDA, device_id_, capture_stream_,
-
reinterpret_cast<TVMFFIStreamHandle*>(&prev_default_stream_)));
+ TVMFFIEnvSetCurrentStream(kDLCUDA, device_id_, capture_stream_,
+
reinterpret_cast<TVMFFIStreamHandle*>(&prev_default_stream_)));
CUDA_CALL(cudaStreamBeginCapture(capture_stream_,
cudaStreamCaptureModeGlobal));
}
~CUDACaptureStream() noexcept(false) {
cudaStreamEndCapture(capture_stream_, output_graph_);
- TVM_FFI_CHECK_SAFE_CALL(TVMFFIEnvSetStream(kDLCUDA, device_id_,
prev_default_stream_, nullptr));
+ TVM_FFI_CHECK_SAFE_CALL(
+ TVMFFIEnvSetCurrentStream(kDLCUDA, device_id_, prev_default_stream_,
nullptr));
}
private: