This is an automated email from the ASF dual-hosted git repository.

yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 73b6851a54 [FFI][ABI] Introduce generic stream exchange protocol 
(#18295)
73b6851a54 is described below

commit 73b6851a54ecd09a6037454813374bd08652d6c5
Author: Tianqi Chen <[email protected]>
AuthorDate: Tue Sep 9 14:00:24 2025 -0400

    [FFI][ABI] Introduce generic stream exchange protocol (#18295)
    
    This PR adds a __tvm_ffi_env_stream__ protocol for generic
    tensors to exchange env stream to tvm ffi.
    
    Also renames TVMFFIEnvSetStream to TVMFFIEnvSetCurrentStream.
---
 ffi/include/tvm/ffi/extra/c_env_api.h     |  6 +-
 ffi/python/tvm_ffi/cython/base.pxi        | 91 +++++++++++++++++++------------
 ffi/python/tvm_ffi/cython/function.pxi    | 29 +++++++---
 ffi/python/tvm_ffi/cython/tensor.pxi      | 24 ++++++++
 ffi/scripts/benchmark_dlpack.py           | 26 +++++++--
 ffi/src/ffi/extra/stream_context.cc       |  4 +-
 src/runtime/device_api.cc                 |  3 +-
 src/runtime/vm/cuda/cuda_graph_builtin.cc |  7 ++-
 8 files changed, 134 insertions(+), 56 deletions(-)

diff --git a/ffi/include/tvm/ffi/extra/c_env_api.h 
b/ffi/include/tvm/ffi/extra/c_env_api.h
index 6f8e44bdfb..bd0d188155 100644
--- a/ffi/include/tvm/ffi/extra/c_env_api.h
+++ b/ffi/include/tvm/ffi/extra/c_env_api.h
@@ -49,9 +49,9 @@ typedef void* TVMFFIStreamHandle;
  * \note The stream is a weak reference that is cached/owned by the module.
  * \return 0 when success, nonzero when failure happens
  */
-TVM_FFI_DLL int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id,
-                                   TVMFFIStreamHandle stream,
-                                   TVMFFIStreamHandle* 
opt_out_original_stream);
+TVM_FFI_DLL int TVMFFIEnvSetCurrentStream(int32_t device_type, int32_t 
device_id,
+                                          TVMFFIStreamHandle stream,
+                                          TVMFFIStreamHandle* 
opt_out_original_stream);
 
 /*!
  * \brief FFI function to get the current stream for a device
diff --git a/ffi/python/tvm_ffi/cython/base.pxi 
b/ffi/python/tvm_ffi/cython/base.pxi
index f1cd77bc47..efb2225453 100644
--- a/ffi/python/tvm_ffi/cython/base.pxi
+++ b/ffi/python/tvm_ffi/cython/base.pxi
@@ -24,39 +24,24 @@ from cpython cimport PyErr_CheckSignals, PyGILState_Ensure, 
PyGILState_Release,
 from cpython cimport pycapsule, PyCapsule_Destructor
 from cpython cimport PyErr_SetNone
 
-
-# Cython binding for TVM FFI C API
-cdef extern from "tvm/ffi/c_api.h":
-    cdef enum TVMFFITypeIndex:
-        kTVMFFIAny = -1
-        kTVMFFINone = 0
-        kTVMFFIInt = 1
-        kTVMFFIBool = 2
-        kTVMFFIFloat = 3
-        kTVMFFIOpaquePtr = 4
-        kTVMFFIDataType = 5
-        kTVMFFIDevice = 6
-        kTVMFFIDLTensorPtr = 7
-        kTVMFFIRawStr = 8
-        kTVMFFIByteArrayPtr = 9
-        kTVMFFIObjectRValueRef = 10
-        kTVMFFISmallStr = 11
-        kTVMFFISmallBytes = 12
-        kTVMFFIStaticObjectBegin = 64
-        kTVMFFIObject = 64
-        kTVMFFIStr = 65
-        kTVMFFIBytes = 66
-        kTVMFFIError = 67
-        kTVMFFIFunction = 68
-        kTVMFFIShape = 69
-        kTVMFFITensor = 70
-        kTVMFFIArray = 71
-        kTVMFFIMap = 72
-        kTVMFFIModule = 73
-        kTVMFFIOpaquePyObject = 74
-
-
-    ctypedef void* TVMFFIObjectHandle
+cdef extern from "dlpack/dlpack.h":
+    cdef enum:
+        kDLCPU = 1,
+        kDLCUDA = 2,
+        kDLCUDAHost = 3,
+        kDLOpenCL = 4,
+        kDLVulkan = 7,
+        kDLMetal = 8,
+        kDLVPI = 9,
+        kDLROCM = 10,
+        kDLROCMHost = 11,
+        kDLExtDev = 12,
+        kDLCUDAManaged = 13,
+        kDLOneAPI = 14,
+        kDLWebGPU = 15,
+        kDLHexagon = 16,
+        kDLMAIA = 17
+        kDLTrn = 18
 
     ctypedef struct DLDataType:
         uint8_t code
@@ -92,6 +77,40 @@ cdef extern from "tvm/ffi/c_api.h":
         void (*deleter)(DLManagedTensorVersioned* self)
         uint64_t flags
 
+
+# Cython binding for TVM FFI C API
+cdef extern from "tvm/ffi/c_api.h":
+    cdef enum TVMFFITypeIndex:
+        kTVMFFIAny = -1
+        kTVMFFINone = 0
+        kTVMFFIInt = 1
+        kTVMFFIBool = 2
+        kTVMFFIFloat = 3
+        kTVMFFIOpaquePtr = 4
+        kTVMFFIDataType = 5
+        kTVMFFIDevice = 6
+        kTVMFFIDLTensorPtr = 7
+        kTVMFFIRawStr = 8
+        kTVMFFIByteArrayPtr = 9
+        kTVMFFIObjectRValueRef = 10
+        kTVMFFISmallStr = 11
+        kTVMFFISmallBytes = 12
+        kTVMFFIStaticObjectBegin = 64
+        kTVMFFIObject = 64
+        kTVMFFIStr = 65
+        kTVMFFIBytes = 66
+        kTVMFFIError = 67
+        kTVMFFIFunction = 68
+        kTVMFFIShape = 69
+        kTVMFFITensor = 70
+        kTVMFFIArray = 71
+        kTVMFFIMap = 72
+        kTVMFFIModule = 73
+        kTVMFFIOpaquePyObject = 74
+
+
+    ctypedef void* TVMFFIObjectHandle
+
     ctypedef struct TVMFFIObject:
         int32_t type_index
         int32_t ref_counter
@@ -219,9 +238,9 @@ cdef extern from "tvm/ffi/extra/c_env_api.h":
 
     int TVMFFIEnvRegisterCAPI(const char* name, void* ptr) nogil
     void* TVMFFIEnvGetCurrentStream(int32_t device_type, int32_t device_id) 
nogil
-    int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id,
-                           TVMFFIStreamHandle stream,
-                           TVMFFIStreamHandle* opt_out_original_stream) nogil
+    int TVMFFIEnvSetCurrentStream(int32_t device_type, int32_t device_id,
+                                  TVMFFIStreamHandle stream,
+                                  TVMFFIStreamHandle* opt_out_original_stream) 
nogil
 
 
 cdef class ByteArrayArg:
diff --git a/ffi/python/tvm_ffi/cython/function.pxi 
b/ffi/python/tvm_ffi/cython/function.pxi
index 28d4ba5a00..71591d9526 100644
--- a/ffi/python/tvm_ffi/cython/function.pxi
+++ b/ffi/python/tvm_ffi/cython/function.pxi
@@ -122,10 +122,25 @@ cdef inline int make_args(tuple py_args, TVMFFIAny* out, 
list temp_args,
                 ctx_stream[0] = <TVMFFIStreamHandle>temp_ptr
             temp_args.append(arg)
         elif hasattr(arg, "__dlpack__"):
-            arg = from_dlpack(arg)
+            ffi_arg = from_dlpack(arg)
             out[i].type_index = kTVMFFITensor
-            out[i].v_ptr = (<Tensor>arg).chandle
-            temp_args.append(arg)
+            out[i].v_ptr = (<Tensor>ffi_arg).chandle
+            # record the stream from the source framework context when possible
+            temp_dltensor = 
TVMFFITensorGetDLTensorPtr((<Tensor>ffi_arg).chandle)
+            if (temp_dltensor.device.device_type != kDLCPU and
+                ctx_dev_type != NULL and
+                ctx_dev_type[0] == -1):
+                # __tvm_ffi_env_stream__ returns the expected stream that 
should be set
+                # through TVMFFIEnvSetCurrentStream when calling a TVM FFI 
function
+                if hasattr(arg, "__tvm_ffi_env_stream__"):
+                    # Ideally projects should directly setup their stream 
context API
+                    # write through by also calling TVMFFIEnvSetCurrentStream
+                    # so we do not need this protocol to do exchange
+                    ctx_dev_type[0] = temp_dltensor.device.device_type
+                    ctx_dev_id[0] = temp_dltensor.device.device_id
+                    temp_ptr= arg.__tvm_ffi_env_stream__()
+                    ctx_stream[0] = <TVMFFIStreamHandle>temp_ptr
+            temp_args.append(ffi_arg)
         elif isinstance(arg, PyNativeObject) and arg.__tvm_ffi_object__ is not 
None:
             arg = arg.__tvm_ffi_object__
             out[i].type_index = TVMFFIObjectGetTypeIndex((<Object>arg).chandle)
@@ -210,7 +225,7 @@ cdef inline int FuncCall3(void* chandle,
     with nogil:
         if ctx_dev_type != -1:
             # set the stream based on ctx stream
-            c_api_ret_code[0] = TVMFFIEnvSetStream(ctx_dev_type, ctx_dev_id, 
ctx_stream, &prev_stream)
+            c_api_ret_code[0] = TVMFFIEnvSetCurrentStream(ctx_dev_type, 
ctx_dev_id, ctx_stream, &prev_stream)
             if c_api_ret_code[0] != 0:
                 return 0
         c_api_ret_code[0] = TVMFFIFunctionCall(
@@ -219,7 +234,7 @@ cdef inline int FuncCall3(void* chandle,
         # restore the original stream if it is not the same as the context 
stream
         if ctx_dev_type != -1 and prev_stream != ctx_stream:
             # restore the original stream
-            c_api_ret_code[0] = TVMFFIEnvSetStream(ctx_dev_type, ctx_dev_id, 
prev_stream, NULL)
+            c_api_ret_code[0] = TVMFFIEnvSetCurrentStream(ctx_dev_type, 
ctx_dev_id, prev_stream, NULL)
             if c_api_ret_code[0] != 0:
                 return 0
     return 0
@@ -247,13 +262,13 @@ cdef inline int FuncCall(void* chandle,
 
     with nogil:
         if ctx_dev_type != -1:
-            c_api_ret_code[0] = TVMFFIEnvSetStream(ctx_dev_type, ctx_dev_id, 
ctx_stream, &prev_stream)
+            c_api_ret_code[0] = TVMFFIEnvSetCurrentStream(ctx_dev_type, 
ctx_dev_id, ctx_stream, &prev_stream)
             if c_api_ret_code[0] != 0:
                 return 0
         c_api_ret_code[0] = TVMFFIFunctionCall(chandle, &packed_args[0], 
nargs, result)
         # restore the original stream if it is not the same as the context 
stream
         if ctx_dev_type != -1 and prev_stream != ctx_stream:
-            c_api_ret_code[0] = TVMFFIEnvSetStream(ctx_dev_type, ctx_dev_id, 
prev_stream, NULL)
+            c_api_ret_code[0] = TVMFFIEnvSetCurrentStream(ctx_dev_type, 
ctx_dev_id, prev_stream, NULL)
             if c_api_ret_code[0] != 0:
                 return 0
 
diff --git a/ffi/python/tvm_ffi/cython/tensor.pxi 
b/ffi/python/tvm_ffi/cython/tensor.pxi
index 4658422ca5..2072ad0567 100644
--- a/ffi/python/tvm_ffi/cython/tensor.pxi
+++ b/ffi/python/tvm_ffi/cython/tensor.pxi
@@ -260,6 +260,30 @@ _set_class_tensor(Tensor)
 _register_object_by_index(kTVMFFITensor, Tensor)
 
 
+cdef class DLTensorTestWrapper:
+    """Wrapper of a Tensor that exposes DLPack protocol, only for testing 
purpose.
+    """
+    cdef Tensor tensor
+    def __init__(self, tensor):
+        self.tensor = tensor
+
+    def __tvm_ffi_env_stream__(self):
+        cdef TVMFFIStreamHandle stream
+        cdef long long stream_as_int
+        cdef int c_api_ret_code
+        with nogil:
+            stream = TVMFFIEnvGetCurrentStream(
+                self.tensor.cdltensor.device.device_type, 
self.tensor.cdltensor.device.device_id)
+        stream_as_int = <long long>stream
+        return stream_as_int
+
+    def __dlpack_device__(self):
+        return self.tensor.__dlpack_device__()
+
+    def __dlpack__(self, *, **kwargs):
+        return self.tensor.__dlpack__(**kwargs)
+
+
 cdef inline object make_ret_dltensor(TVMFFIAny result):
     cdef DLTensor* dltensor
     dltensor = <DLTensor*>result.v_ptr
diff --git a/ffi/scripts/benchmark_dlpack.py b/ffi/scripts/benchmark_dlpack.py
index 73fbe0f6ac..00581eb0f3 100644
--- a/ffi/scripts/benchmark_dlpack.py
+++ b/ffi/scripts/benchmark_dlpack.py
@@ -44,11 +44,11 @@ import time
 
 
 def print_speed(name, speed):
-    print(f"{name:<40} {speed} sec/call")
+    print(f"{name:<60} {speed} sec/call")
 
 
 def print_error(name, error):
-    print(f"{name:<40} {error}")
+    print(f"{name:<60} {error}")
 
 
 def baseline_torch_add(repeat):
@@ -122,7 +122,7 @@ def tvm_ffi_nop(repeat):
     nop(x, y, z)
     start = time.time()
     for i in range(repeat):
-        y = tvm_ffi.from_dlpack(x)
+        nop(x, y, z)
     end = time.time()
     print_speed("tvm_ffi.nop", (end - start) / repeat)
 
@@ -275,6 +275,22 @@ def tvm_ffi_nop_autodlpack_from_numpy(repeat):
     bench_tvm_ffi_nop_autodlpack("tvm_ffi.nop.autodlpack(numpy)", x, y, z, 
repeat)
 
 
+def tvm_ffi_nop_autodlpack_from_dltensor_test_wrapper(repeat, device):
+    """
+    Measures overhead of running dlpack via auto convert by directly
+    take test wrapper as inputs. This effectively measure DLPack exchange in 
tvm ffi.
+    """
+    x = tvm_ffi.from_dlpack(torch.arange(1, device=device))
+    y = tvm_ffi.from_dlpack(torch.arange(1, device=device))
+    z = tvm_ffi.from_dlpack(torch.arange(1, device=device))
+    x = tvm_ffi.core.DLTensorTestWrapper(x)
+    y = tvm_ffi.core.DLTensorTestWrapper(y)
+    z = tvm_ffi.core.DLTensorTestWrapper(z)
+    bench_tvm_ffi_nop_autodlpack(
+        f"tvm_ffi.nop.autodlpack(DLTensorTestWrapper[{device}])", x, y, z, 
repeat
+    )
+
+
 def bench_to_dlpack(x, name, repeat):
     x.__dlpack__()
     start = time.time()
@@ -367,7 +383,6 @@ def main():
     baseline_numpy_add(repeat)
     baseline_torch_add(repeat)
     baseline_cupy_add(repeat)
-    tvm_ffi_nop(repeat)
     tvm_ffi_nop_from_torch_dlpack(repeat)
     tvm_ffi_nop_from_numpy_dlpack(repeat)
     tvm_ffi_self_dlpack_nop(repeat)
@@ -377,6 +392,9 @@ def main():
     tvm_ffi_nop_autodlpack_from_torch(repeat, "cuda", stream=True)
 
     tvm_ffi_nop_autodlpack_from_numpy(repeat)
+    tvm_ffi_nop_autodlpack_from_dltensor_test_wrapper(repeat, "cpu")
+    tvm_ffi_nop_autodlpack_from_dltensor_test_wrapper(repeat, "cuda")
+    tvm_ffi_nop(repeat)
     print("-------------------------------")
     print("Benchmark x.__dlpack__ overhead")
     print("-------------------------------")
diff --git a/ffi/src/ffi/extra/stream_context.cc 
b/ffi/src/ffi/extra/stream_context.cc
index d063efdef5..5a6afad4c1 100644
--- a/ffi/src/ffi/extra/stream_context.cc
+++ b/ffi/src/ffi/extra/stream_context.cc
@@ -66,8 +66,8 @@ class StreamContext {
 }  // namespace ffi
 }  // namespace tvm
 
-int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id, 
TVMFFIStreamHandle stream,
-                       TVMFFIStreamHandle* out_original_stream) {
+int TVMFFIEnvSetCurrentStream(int32_t device_type, int32_t device_id, 
TVMFFIStreamHandle stream,
+                              TVMFFIStreamHandle* out_original_stream) {
   TVM_FFI_SAFE_CALL_BEGIN();
   tvm::ffi::StreamContext::ThreadLocal()->SetStream(device_type, device_id, 
stream,
                                                     out_original_stream);
diff --git a/src/runtime/device_api.cc b/src/runtime/device_api.cc
index fd7d651df2..e574ce14b0 100644
--- a/src/runtime/device_api.cc
+++ b/src/runtime/device_api.cc
@@ -165,7 +165,8 @@ TVMStreamHandle DeviceAPI::CreateStream(Device dev) { 
return nullptr; }
 void DeviceAPI::FreeStream(Device dev, TVMStreamHandle stream) {}
 
 void DeviceAPI::SetStream(Device dev, TVMStreamHandle stream) {
-  TVM_FFI_CHECK_SAFE_CALL(TVMFFIEnvSetStream(dev.device_type, dev.device_id, 
stream, nullptr));
+  TVM_FFI_CHECK_SAFE_CALL(
+      TVMFFIEnvSetCurrentStream(dev.device_type, dev.device_id, stream, 
nullptr));
 }
 
 TVMStreamHandle DeviceAPI::GetCurrentStream(Device dev) {
diff --git a/src/runtime/vm/cuda/cuda_graph_builtin.cc 
b/src/runtime/vm/cuda/cuda_graph_builtin.cc
index a85ade2e1d..2528415281 100644
--- a/src/runtime/vm/cuda/cuda_graph_builtin.cc
+++ b/src/runtime/vm/cuda/cuda_graph_builtin.cc
@@ -118,13 +118,14 @@ class CUDACaptureStream {
   explicit CUDACaptureStream(cudaGraph_t* graph) : output_graph_(graph) {
     CUDA_CALL(cudaGetDevice(&device_id_));
     TVM_FFI_CHECK_SAFE_CALL(
-        TVMFFIEnvSetStream(kDLCUDA, device_id_, capture_stream_,
-                           
reinterpret_cast<TVMFFIStreamHandle*>(&prev_default_stream_)));
+        TVMFFIEnvSetCurrentStream(kDLCUDA, device_id_, capture_stream_,
+                                  
reinterpret_cast<TVMFFIStreamHandle*>(&prev_default_stream_)));
     CUDA_CALL(cudaStreamBeginCapture(capture_stream_, 
cudaStreamCaptureModeGlobal));
   }
   ~CUDACaptureStream() noexcept(false) {
     cudaStreamEndCapture(capture_stream_, output_graph_);
-    TVM_FFI_CHECK_SAFE_CALL(TVMFFIEnvSetStream(kDLCUDA, device_id_, 
prev_default_stream_, nullptr));
+    TVM_FFI_CHECK_SAFE_CALL(
+        TVMFFIEnvSetCurrentStream(kDLCUDA, device_id_, prev_default_stream_, 
nullptr));
   }
 
  private:

Reply via email to