This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 7b28787819 [FFI] Update torch stream getter to use native torch c api 
(#18266)
7b28787819 is described below

commit 7b28787819e5ffbc7fe4234c49a7eac64a2398a5
Author: Tianqi Chen <[email protected]>
AuthorDate: Thu Sep 4 07:22:38 2025 -0400

    [FFI] Update torch stream getter to use native torch c api (#18266)
    
    This PR updates the torch stream getter to use  _cuda_getCurrentRawStream
    in the torch C API that is also used by dynamo, saves us from load_inline
    the custom module.
---
 ffi/pyproject.toml                     |  2 +-
 ffi/python/tvm_ffi/cython/function.pxi | 40 ++--------------------------------
 2 files changed, 3 insertions(+), 39 deletions(-)

diff --git a/ffi/pyproject.toml b/ffi/pyproject.toml
index 083a60fc36..ab2a7f84df 100644
--- a/ffi/pyproject.toml
+++ b/ffi/pyproject.toml
@@ -17,7 +17,7 @@
 
 [project]
 name = "apache-tvm-ffi"
-version = "0.1.0a6"
+version = "0.1.0a7"
 description = "tvm ffi"
 
 authors = [{ name = "TVM FFI team" }]
diff --git a/ffi/python/tvm_ffi/cython/function.pxi 
b/ffi/python/tvm_ffi/cython/function.pxi
index a223da90cb..064473e134 100644
--- a/ffi/python/tvm_ffi/cython/function.pxi
+++ b/ffi/python/tvm_ffi/cython/function.pxi
@@ -24,41 +24,6 @@ except ImportError:
     torch = None
 
 
-def load_torch_get_current_cuda_stream():
-    """Create a faster get_current_cuda_stream for torch through cpp extension.
-    """
-    source = """
-    #include <c10/cuda/CUDAStream.h>
-
-    int64_t get_current_cuda_stream(int device_id) {
-        at::cuda::CUDAStream stream = 
at::cuda::getCurrentCUDAStream(device_id);
-        // fast invariant, default stream is always 0
-        if (stream.id() == 0) return 0;
-        // convert to cudaStream_t
-        return reinterpret_cast<int64_t>(static_cast<cudaStream_t>(stream));
-    }
-    """
-    def fallback_get_current_cuda_stream(device_id):
-        """Fallback with python api"""
-        return torch.cuda.current_stream(device_id).cuda_stream
-    try:
-        from torch.utils import cpp_extension
-        result = cpp_extension.load_inline(
-            name="get_current_cuda_stream",
-            cpp_sources=[source],
-            cuda_sources=[],
-            extra_cflags=["-O3"],
-            extra_include_paths=cpp_extension.include_paths("cuda"),
-            functions=["get_current_cuda_stream"],
-        )
-        return result.get_current_cuda_stream
-    except Exception:
-        return fallback_get_current_cuda_stream
-
-
-torch_get_current_cuda_stream = None
-
-
 cdef inline object make_ret_small_str(TVMFFIAny result):
     """convert small string to return value."""
     cdef TVMFFIByteArray bytes
@@ -146,9 +111,8 @@ cdef inline int make_args(tuple py_args, TVMFFIAny* out, 
list temp_args,
             if is_cuda and ctx_dev_type != NULL and ctx_dev_type[0] == -1:
                 ctx_dev_type[0] = temp_dltensor.device.device_type
                 ctx_dev_id[0] = temp_dltensor.device.device_id
-                if torch_get_current_cuda_stream is None:
-                    torch_get_current_cuda_stream = 
load_torch_get_current_cuda_stream()
-                temp_ptr = 
torch_get_current_cuda_stream(temp_dltensor.device.device_id)
+                # This is an API that dynamo and other uses to get the raw 
stream from torch
+                temp_ptr = 
torch._C._cuda_getCurrentRawStream(temp_dltensor.device.device_id)
                 ctx_stream[0] = <TVMFFIStreamHandle>temp_ptr
             temp_args.append(arg)
         elif hasattr(arg, "__dlpack__"):

Reply via email to