This is an automated email from the ASF dual-hosted git repository.

yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 216e9e99c1 [FFI] AudoDLPack compatible with torch stream context 
(#18217)
216e9e99c1 is described below

commit 216e9e99c1a9c6a64709e4458a88871932f7a7cc
Author: Tianqi Chen <[email protected]>
AuthorDate: Tue Aug 19 20:57:41 2025 -0400

    [FFI] AudoDLPack compatible with torch stream context (#18217)
    
    This PR updates the autodlpack path to automatically update
    the env stream to be consistent with torch stream context.
    
    The change would help to make FFI functions to be
    compatible in stream based executions.
    
    We leverage torch cpp_extension load_inline to create
    an efficient query function, the first time loading
    might take more time to build the jit module and
    things should be fast after the torch jit module is cached.
---
 ffi/scripts/benchmark_dlpack.py    | 70 ++++++++++++++++++++++++++++-
 python/tvm/ffi/cython/base.pxi     |  8 ++++
 python/tvm/ffi/cython/function.pxi | 92 +++++++++++++++++++++++++++++++++++---
 3 files changed, 162 insertions(+), 8 deletions(-)

diff --git a/ffi/scripts/benchmark_dlpack.py b/ffi/scripts/benchmark_dlpack.py
index b19f566364..1453aa95a6 100644
--- a/ffi/scripts/benchmark_dlpack.py
+++ b/ffi/scripts/benchmark_dlpack.py
@@ -36,6 +36,7 @@ Summary of some takeaways:
 -
 
 """
+import os
 import torch
 import numpy as np
 from tvm import ffi as tvm_ffi
@@ -244,7 +245,7 @@ def bench_tvm_ffi_nop_autodlpack(name, x, y, z, repeat):
     print_speed(name, speed)
 
 
-def tvm_ffi_nop_autodlpack_from_torch(repeat, device="cpu"):
+def tvm_ffi_nop_autodlpack_from_torch(repeat, device="cpu", stream=False):
     """
     Measures overhead of running dlpack via auto convert by directly
     take torch.Tensor as inputs.
@@ -253,7 +254,13 @@ def tvm_ffi_nop_autodlpack_from_torch(repeat, 
device="cpu"):
     x = torch.arange(1, device=device)
     y = torch.arange(1, device=device)
     z = torch.arange(1, device=device)
-    bench_tvm_ffi_nop_autodlpack(f"tvm.ffi.nop.autodlpack(torch[{device}])", 
x, y, z, repeat)
+    if stream:
+        with torch.cuda.stream(torch.cuda.Stream()):
+            bench_tvm_ffi_nop_autodlpack(
+                f"tvm.ffi.nop.autodlpack(torch[{device}][stream])", x, y, z, 
repeat
+            )
+    else:
+        
bench_tvm_ffi_nop_autodlpack(f"tvm.ffi.nop.autodlpack(torch[{device}])", x, y, 
z, repeat)
 
 
 def tvm_ffi_nop_autodlpack_from_numpy(repeat):
@@ -308,6 +315,50 @@ def bench_torch_utils_to_dlpack(repeat):
     print_speed("torch.utils.dlpack.to_dlpack", speed)
 
 
+def torch_get_cuda_stream_native(device_id):
+    return torch.cuda.current_stream(device_id).cuda_stream
+
+
+def load_torch_get_current_cuda_stream():
+    """Create a faster get_current_cuda_stream for torch through cpp 
extension."""
+    from torch.utils import cpp_extension
+
+    source = """
+    #include <c10/cuda/CUDAStream.h>
+
+    int64_t get_current_cuda_stream(int device_id) {
+        at::cuda::CUDAStream stream = 
at::cuda::getCurrentCUDAStream(device_id);
+        // fast invariant, default stream is always 0
+        if (stream.id() == 0) return 0;
+        // convert to cudaStream_t
+        return reinterpret_cast<int64_t>(static_cast<cudaStream_t>(stream));
+    }
+    """
+    result = cpp_extension.load_inline(
+        name="get_current_cuda_stream",
+        cpp_sources=[source],
+        cuda_sources=[],
+        extra_cflags=["-O3"],
+        extra_include_paths=cpp_extension.include_paths("cuda"),
+        functions=["get_current_cuda_stream"],
+    )
+    return result.get_current_cuda_stream
+
+
+def bench_torch_get_current_stream(repeat, name, func):
+    """
+    Measures overhead of running torch.cuda.current_stream
+    """
+    x = torch.arange(1, device="cuda")
+    func(0)
+    start = time.time()
+    for i in range(repeat):
+        func(0)
+    end = time.time()
+    speed = (end - start) / repeat
+    print_speed(f"torch.cuda.current_stream[{name}]", speed)
+
+
 def main():
     repeat = 10000
     print("-----------------------------")
@@ -323,6 +374,8 @@ def main():
     tvm_ffi_nop_from_torch_utils_to_dlpack(repeat)
     tvm_ffi_nop_autodlpack_from_torch(repeat, "cpu")
     tvm_ffi_nop_autodlpack_from_torch(repeat, "cuda")
+    tvm_ffi_nop_autodlpack_from_torch(repeat, "cuda", stream=True)
+
     tvm_ffi_nop_autodlpack_from_numpy(repeat)
     print("-------------------------------")
     print("Benchmark x.__dlpack__ overhead")
@@ -339,6 +392,19 @@ def main():
     bench_to_dlpack_versioned(
         tvm_ffi.from_dlpack(torch.arange(1)), 
"tvm.__dlpack__(max_version=(1,1))", repeat
     )
+    print("---------------------------------------------------")
+    print("Benchmark torch.get_cuda_stream[default stream]")
+    print("---------------------------------------------------")
+    bench_torch_get_current_stream(repeat, "cpp-extension", 
load_torch_get_current_cuda_stream())
+    bench_torch_get_current_stream(repeat, "python", 
torch_get_cuda_stream_native)
+    print("---------------------------------------------------")
+    print("Benchmark torch.get_cuda_stream[non-default stream]")
+    print("---------------------------------------------------")
+    with torch.cuda.stream(torch.cuda.Stream()):
+        bench_torch_get_current_stream(
+            repeat, "cpp-extension", load_torch_get_current_cuda_stream()
+        )
+        bench_torch_get_current_stream(repeat, "python", 
torch_get_cuda_stream_native)
 
 
 if __name__ == "__main__":
diff --git a/python/tvm/ffi/cython/base.pxi b/python/tvm/ffi/cython/base.pxi
index 00b76e68f7..24c7290959 100644
--- a/python/tvm/ffi/cython/base.pxi
+++ b/python/tvm/ffi/cython/base.pxi
@@ -205,6 +205,14 @@ cdef extern from "tvm/ffi/c_api.h":
     DLTensor* TVMFFINDArrayGetDLTensorPtr(TVMFFIObjectHandle obj) nogil
     DLDevice TVMFFIDLDeviceFromIntPair(int32_t device_type, int32_t device_id) 
nogil
 
+cdef extern from "tvm/ffi/extra/c_env_api.h":
+    ctypedef void* TVMFFIStreamHandle
+
+    void* TVMFFIEnvGetCurrentStream(int32_t device_type, int32_t device_id) 
nogil
+    int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id,
+                           TVMFFIStreamHandle stream,
+                           TVMFFIStreamHandle* opt_out_original_stream) nogil
+
 
 cdef class ByteArrayArg:
     cdef TVMFFIByteArray cdata
diff --git a/python/tvm/ffi/cython/function.pxi 
b/python/tvm/ffi/cython/function.pxi
index 999c2e1338..3ab232e959 100644
--- a/python/tvm/ffi/cython/function.pxi
+++ b/python/tvm/ffi/cython/function.pxi
@@ -18,11 +18,51 @@ import ctypes
 from numbers import Real, Integral
 
 try:
+    # optionally import torch and setup torch related utils
     import torch
 except ImportError:
     torch = None
 
 
+def load_torch_get_current_cuda_stream():
+    """Create a faster get_current_cuda_stream for torch through cpp extension.
+    """
+    from torch.utils import cpp_extension
+
+    source = """
+    #include <c10/cuda/CUDAStream.h>
+
+    int64_t get_current_cuda_stream(int device_id) {
+        at::cuda::CUDAStream stream = 
at::cuda::getCurrentCUDAStream(device_id);
+        // fast invariant, default stream is always 0
+        if (stream.id() == 0) return 0;
+        // convert to cudaStream_t
+        return reinterpret_cast<int64_t>(static_cast<cudaStream_t>(stream));
+    }
+    """
+    def fallback_get_current_cuda_stream(device_id):
+        """Fallback with python api"""
+        return torch.cuda.current_stream(device_id).cuda_stream
+    return fallback_get_current_cuda_stream
+    try:
+        result = cpp_extension.load_inline(
+            name="get_current_cuda_stream",
+            cpp_sources=[source],
+            cuda_sources=[],
+            extra_cflags=["-O3"],
+            extra_include_paths=cpp_extension.include_paths("cuda"),
+            functions=["get_current_cuda_stream"],
+        )
+        return result.get_current_cuda_stream
+    except Exception:
+        return fallback_get_current_cuda_stream
+
+if torch is not None:
+    # when torch is available, jit compile the get_current_cuda_stream function
+    # the torch caches the extension so second loading is faster
+    torch_get_current_cuda_stream = load_torch_get_current_cuda_stream()
+
+
 cdef inline object make_ret_small_str(TVMFFIAny result):
     """convert small string to return value."""
     cdef TVMFFIByteArray bytes
@@ -76,9 +116,13 @@ cdef inline object make_ret(TVMFFIAny result):
     raise ValueError("Unhandled type index %d" % type_index)
 
 
-cdef inline int make_args(tuple py_args, TVMFFIAny* out, list temp_args) 
except -1:
+cdef inline int make_args(tuple py_args, TVMFFIAny* out, list temp_args,
+                          int* ctx_dev_type, int* ctx_dev_id, 
TVMFFIStreamHandle* ctx_stream) except -1:
     """Pack arguments into c args tvm call accept"""
-    cdef unsigned long long ptr
+    cdef unsigned long long temp_ptr
+    cdef DLTensor* temp_dltensor
+    cdef int is_cuda = 0
+
     for i, arg in enumerate(py_args):
         # clear the value to ensure zero padding on 32bit platforms
         if sizeof(void*) != 8:
@@ -96,10 +140,18 @@ cdef inline int make_args(tuple py_args, TVMFFIAny* out, 
list temp_args) except
             out[i].type_index = TVMFFIObjectGetTypeIndex((<Object>arg).chandle)
             out[i].v_ptr = (<Object>arg).chandle
         elif torch is not None and isinstance(arg, torch.Tensor):
+            is_cuda = arg.is_cuda
             arg = from_dlpack(torch.utils.dlpack.to_dlpack(arg),
                               
required_alignment=__dlpack_auto_import_required_alignment__)
             out[i].type_index = kTVMFFINDArray
             out[i].v_ptr = (<NDArray>arg).chandle
+            temp_dltensor = TVMFFINDArrayGetDLTensorPtr((<NDArray>arg).chandle)
+            # record the stream and device for torch context
+            if is_cuda and ctx_dev_type != NULL and ctx_dev_type[0] == -1:
+                ctx_dev_type[0] = temp_dltensor.device.device_type
+                ctx_dev_id[0] = temp_dltensor.device.device_id
+                temp_ptr = 
torch_get_current_cuda_stream(temp_dltensor.device.device_id)
+                ctx_stream[0] = <TVMFFIStreamHandle>temp_ptr
             temp_args.append(arg)
         elif hasattr(arg, "__dlpack__"):
             arg = from_dlpack(arg, 
required_alignment=__dlpack_auto_import_required_alignment__)
@@ -177,12 +229,27 @@ cdef inline int FuncCall3(void* chandle,
     # fast path with stack alloca for less than 3 args
     cdef TVMFFIAny[3] packed_args
     cdef int nargs = len(args)
+    cdef int ctx_dev_type = -1
+    cdef int ctx_dev_id = 0
+    cdef TVMFFIStreamHandle ctx_stream = NULL
+    cdef TVMFFIStreamHandle prev_stream = NULL
     temp_args = []
-    make_args(args, &packed_args[0], temp_args)
+    make_args(args, &packed_args[0], temp_args, &ctx_dev_type, &ctx_dev_id, 
&ctx_stream)
     with nogil:
+        if ctx_dev_type != -1:
+            # set the stream based on ctx stream
+            c_api_ret_code[0] = TVMFFIEnvSetStream(ctx_dev_type, ctx_dev_id, 
ctx_stream, &prev_stream)
+            if c_api_ret_code[0] != 0:
+                return 0
         c_api_ret_code[0] = TVMFFIFunctionCall(
             chandle, &packed_args[0], nargs, result
         )
+        # restore the original stream if it is not the same as the context 
stream
+        if ctx_dev_type != -1 and prev_stream != ctx_stream:
+            # restore the original stream
+            c_api_ret_code[0] = TVMFFIEnvSetStream(ctx_dev_type, ctx_dev_id, 
prev_stream, NULL)
+            if c_api_ret_code[0] != 0:
+                return 0
     return 0
 
 
@@ -191,6 +258,10 @@ cdef inline int FuncCall(void* chandle,
                          TVMFFIAny* result,
                          int* c_api_ret_code) except -1:
     cdef int nargs = len(args)
+    cdef int ctx_dev_type = -1
+    cdef int ctx_dev_id = 0
+    cdef TVMFFIStreamHandle ctx_stream = NULL
+    cdef TVMFFIStreamHandle prev_stream = NULL
 
     if nargs <= 3:
         FuncCall3(chandle, args, result, c_api_ret_code)
@@ -200,10 +271,19 @@ cdef inline int FuncCall(void* chandle,
     packed_args.resize(nargs)
 
     temp_args = []
-    make_args(args, &packed_args[0], temp_args)
+    make_args(args, &packed_args[0], temp_args, &ctx_dev_type, &ctx_dev_id, 
&ctx_stream)
 
     with nogil:
+        if ctx_dev_type != -1:
+            c_api_ret_code[0] = TVMFFIEnvSetStream(ctx_dev_type, ctx_dev_id, 
ctx_stream, &prev_stream)
+            if c_api_ret_code[0] != 0:
+                return 0
         c_api_ret_code[0] = TVMFFIFunctionCall(chandle, &packed_args[0], 
nargs, result)
+        # restore the original stream if it is not the same as the context 
stream
+        if ctx_dev_type != -1 and prev_stream != ctx_stream:
+            c_api_ret_code[0] = TVMFFIEnvSetStream(ctx_dev_type, ctx_dev_id, 
prev_stream, NULL)
+            if c_api_ret_code[0] != 0:
+                return 0
 
     return 0
 
@@ -274,7 +354,7 @@ cdef class FieldSetter:
         cdef void* field_ptr = (<char*>(<Object>obj).chandle) + self.offset
         cdef int nargs = 1
         temp_args = []
-        make_args((value,), &packed_args[0], temp_args)
+        make_args((value,), &packed_args[0], temp_args, NULL, NULL, NULL)
         c_api_ret_code = self.setter(field_ptr, &packed_args[0])
         # NOTE: logic is same as check_call
         # directly inline here to simplify traceback
@@ -412,7 +492,7 @@ cdef int tvm_ffi_callback(void* context,
         return -1
 
     temp_args = []
-    make_args((rv,), &temp_result, temp_args)
+    make_args((rv,), &temp_result, temp_args, NULL, NULL, NULL)
     CHECK_CALL(TVMFFIAnyViewToOwnedAny(&temp_result, result))
 
     return 0

Reply via email to