This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 7b28787819 [FFI] Update torch stream getter to use native torch c api
(#18266)
7b28787819 is described below
commit 7b28787819e5ffbc7fe4234c49a7eac64a2398a5
Author: Tianqi Chen <[email protected]>
AuthorDate: Thu Sep 4 07:22:38 2025 -0400
[FFI] Update torch stream getter to use native torch c api (#18266)
This PR updates the torch stream getter to use _cuda_getCurrentRawStream
in the torch C API that is also used by dynamo, saves us from load_inline
the custom module.
---
ffi/pyproject.toml | 2 +-
ffi/python/tvm_ffi/cython/function.pxi | 40 ++--------------------------------
2 files changed, 3 insertions(+), 39 deletions(-)
diff --git a/ffi/pyproject.toml b/ffi/pyproject.toml
index 083a60fc36..ab2a7f84df 100644
--- a/ffi/pyproject.toml
+++ b/ffi/pyproject.toml
@@ -17,7 +17,7 @@
[project]
name = "apache-tvm-ffi"
-version = "0.1.0a6"
+version = "0.1.0a7"
description = "tvm ffi"
authors = [{ name = "TVM FFI team" }]
diff --git a/ffi/python/tvm_ffi/cython/function.pxi
b/ffi/python/tvm_ffi/cython/function.pxi
index a223da90cb..064473e134 100644
--- a/ffi/python/tvm_ffi/cython/function.pxi
+++ b/ffi/python/tvm_ffi/cython/function.pxi
@@ -24,41 +24,6 @@ except ImportError:
torch = None
-def load_torch_get_current_cuda_stream():
- """Create a faster get_current_cuda_stream for torch through cpp extension.
- """
- source = """
- #include <c10/cuda/CUDAStream.h>
-
- int64_t get_current_cuda_stream(int device_id) {
- at::cuda::CUDAStream stream =
at::cuda::getCurrentCUDAStream(device_id);
- // fast invariant, default stream is always 0
- if (stream.id() == 0) return 0;
- // convert to cudaStream_t
- return reinterpret_cast<int64_t>(static_cast<cudaStream_t>(stream));
- }
- """
- def fallback_get_current_cuda_stream(device_id):
- """Fallback with python api"""
- return torch.cuda.current_stream(device_id).cuda_stream
- try:
- from torch.utils import cpp_extension
- result = cpp_extension.load_inline(
- name="get_current_cuda_stream",
- cpp_sources=[source],
- cuda_sources=[],
- extra_cflags=["-O3"],
- extra_include_paths=cpp_extension.include_paths("cuda"),
- functions=["get_current_cuda_stream"],
- )
- return result.get_current_cuda_stream
- except Exception:
- return fallback_get_current_cuda_stream
-
-
-torch_get_current_cuda_stream = None
-
-
cdef inline object make_ret_small_str(TVMFFIAny result):
"""convert small string to return value."""
cdef TVMFFIByteArray bytes
@@ -146,9 +111,8 @@ cdef inline int make_args(tuple py_args, TVMFFIAny* out,
list temp_args,
if is_cuda and ctx_dev_type != NULL and ctx_dev_type[0] == -1:
ctx_dev_type[0] = temp_dltensor.device.device_type
ctx_dev_id[0] = temp_dltensor.device.device_id
- if torch_get_current_cuda_stream is None:
- torch_get_current_cuda_stream =
load_torch_get_current_cuda_stream()
- temp_ptr =
torch_get_current_cuda_stream(temp_dltensor.device.device_id)
+ # This is an API that dynamo and other uses to get the raw
stream from torch
+ temp_ptr =
torch._C._cuda_getCurrentRawStream(temp_dltensor.device.device_id)
ctx_stream[0] = <TVMFFIStreamHandle>temp_ptr
temp_args.append(arg)
elif hasattr(arg, "__dlpack__"):