This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/dev by this push:
new 3197cd0 Add stream context (#5)
3197cd0 is described below
commit 3197cd0949ae9bb9a41eb5a429a4b1519d061db4
Author: Yaxing Cai <[email protected]>
AuthorDate: Mon Sep 15 04:32:19 2025 -0700
Add stream context (#5)
This PR adds the stream context into ffi, so that ffi env stream can be
updated. The `tvm_ffi.use_torch_stream` is for wrapping the torch stream/graph
context. And lower-level `tvm_ffi.use_raw_stream` is for creating context with
existing stream handle.
Example for `tvm_ffi.use_torch_stream`:
case with torch stream:
```python
stream = torch.cuda.Stream()
stream_context = torch.cuda.stream(stream)
with tvm_ffi.use_torch_stream(stream_context):
...
```
case with torch cuda graph
```python
graph = torch.cuda.CUDAGraph()
graph_context = torch.cuda.graph(graph)
with tvm_ffi.use_torch_stream(graph_context):
...
```
case with current stream by default
```python
stream = torch.cuda.Stream()
stream_context = torch.cuda.stream(stream)
with torch.cuda.stream(stream):
with tvm_ffi.use_torch_stream():
...
```
Eaxmple for `tvm_ffi.use_raw_stream`:
```python
device = tvm_ffi.device(...)
stream_handle = ...
with tvm_ffi.use_raw_stream(device, stream_handle):
...
```
---
docs/reference/python/index.rst | 10 +++
python/tvm_ffi/__init__.py | 1 +
python/tvm_ffi/cython/base.pxi | 9 +++
python/tvm_ffi/stream.py | 148 ++++++++++++++++++++++++++++++++++++++++
tests/python/test_stream.py | 115 +++++++++++++++++++++++++++++++
5 files changed, 283 insertions(+)
diff --git a/docs/reference/python/index.rst b/docs/reference/python/index.rst
index 482c19d..756af4c 100644
--- a/docs/reference/python/index.rst
+++ b/docs/reference/python/index.rst
@@ -68,6 +68,16 @@ Containers
Map
+Stream Context
+--------------
+.. autosummary::
+ :toctree: generated/
+
+ StreamContext
+ use_torch_stream
+ use_raw_stream
+
+
Utility
-------
diff --git a/python/tvm_ffi/__init__.py b/python/tvm_ffi/__init__.py
index 16f035e..9bafe2b 100644
--- a/python/tvm_ffi/__init__.py
+++ b/python/tvm_ffi/__init__.py
@@ -37,6 +37,7 @@ from ._tensor import Device, device, DLDeviceType
from ._tensor import from_dlpack, Tensor, Shape
from .container import Array, Map
from .module import Module, system_lib, load_module
+from .stream import StreamContext, use_raw_stream, use_torch_stream
from . import serialization
from . import access_path
from . import testing
diff --git a/python/tvm_ffi/cython/base.pxi b/python/tvm_ffi/cython/base.pxi
index 08f7df2..77c9c7e 100644
--- a/python/tvm_ffi/cython/base.pxi
+++ b/python/tvm_ffi/cython/base.pxi
@@ -245,6 +245,15 @@ cdef extern from "tvm/ffi/extra/c_env_api.h":
TVMFFIStreamHandle stream,
TVMFFIStreamHandle* opt_out_original_stream)
nogil
+def _env_set_current_stream(int device_type, int device_id, uint64_t stream):
+ cdef TVMFFIStreamHandle prev_stream = NULL
+ CHECK_CALL(TVMFFIEnvSetStream(
+ device_type,
+ device_id,
+ <void*>stream,
+ &prev_stream))
+ return <uint64_t>prev_stream
+
cdef extern from "tvm_ffi_python_helpers.h":
# no need to expose fields of the call context
diff --git a/python/tvm_ffi/stream.py b/python/tvm_ffi/stream.py
new file mode 100644
index 0000000..598afca
--- /dev/null
+++ b/python/tvm_ffi/stream.py
@@ -0,0 +1,148 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Stream context."""
+from ctypes import c_void_p
+from typing import Any, Optional, Union
+
+from . import core
+from ._tensor import device
+
+
+class StreamContext:
+ """StreamContext represents a stream context in the ffi system.
+ StreamContext helps setup ffi environment stream by python `with`
statement.
+ When entering `with` scope, it caches the current environment stream and
+ setup the given new stream.
+ When exiting `with` scope, it recovers the stream to the cached
environment stream.
+
+ Parameters
+ ----------
+ device : Device
+ The device to which the stream belongs.
+
+ stream : Union[int, c_void_p]
+ The stream handle.
+
+ See Also
+ --------
+ :py:func:`tvm_ffi.use_raw_stream`, :py:func:`tvm_ffi.use_torch_stream`
+ """
+
+ def __init__(self, device: core.Device, stream: Union[int, c_void_p]):
+ self.device_type = device.dlpack_device_type()
+ self.device_id = device.index
+ self.stream = stream
+
+ def __enter__(self):
+ self.prev_stream = core._env_set_current_stream(
+ self.device_type, self.device_id, self.stream
+ )
+
+ def __exit__(self, *args):
+ self.prev_stream = core._env_set_current_stream(
+ self.device_type, self.device_id, self.prev_stream
+ )
+
+
+try:
+ import torch
+
+ class TorchStreamContext:
+ def __init__(self, context: Optional[Any]):
+ self.torch_context = context
+
+ def __enter__(self):
+ if self.torch_context:
+ self.torch_context.__enter__()
+ current_stream = torch.cuda.current_stream()
+ self.ffi_context = StreamContext(
+ device(str(current_stream.device)), current_stream.cuda_stream
+ )
+ self.ffi_context.__enter__()
+
+ def __exit__(self, *args):
+ if self.torch_context:
+ self.torch_context.__exit__(*args)
+ self.ffi_context.__exit__(*args)
+
+ def use_torch_stream(context: Optional[Any] = None):
+ """
+ Create a ffi stream context with given torch stream,
+ cuda graph or current stream if `None` provided.
+
+ Parameters
+ ----------
+ context : Optional[Any]
+ The wrapped torch stream or cuda graph.
+
+ Returns
+ -------
+ context : tvm_ffi.TorchStreamContext
+ The ffi stream context wrapping torch stream context.
+
+ Examples
+ --------
+ .. code-block:: python
+
+ s = torch.cuda.Stream()
+ with tvm_ffi.use_torch_stream(torch.cuda.stream(s)):
+ ...
+
+ g = torch.cuda.CUDAGraph()
+ with tvm_ffi.use_torch_stream(torch.cuda.graph(g)):
+ ...
+
+ Note
+ ----
+ When working with raw cudaStream_t handle, using
:py:func:`tvm_ffi.use_raw_stream` instead.
+ """
+ return TorchStreamContext(context)
+
+except ImportError:
+
+ def use_torch_stream(context: Optional[Any] = None):
+ raise ImportError("Cannot import torch")
+
+
+def use_raw_stream(device: core.Device, stream: Union[int, c_void_p]):
+ """
+ Create a ffi stream context with given device and stream handle.
+
+ Parameters
+ ----------
+ device : tvm_ffi.Device
+ The device to which the stream belongs.
+
+ stream : Union[int, c_void_p]
+ The stream handle.
+
+ Returns
+ -------
+ context : tvm_ffi.StreamContext
+ The ffi stream context.
+
+ Note
+ ----
+ When working with torch stram or cuda graph, using
:py:func:`tvm_ffi.use_torch_stream` instead.
+ """
+ if not isinstance(stream, (int, c_void_p)):
+ raise ValueError(
+ "use_raw_stream only accepts int or c_void_p as stram input, "
+ "try use_torch_stream when using torch.cuda.Stream or
torch.cuda.graph"
+ )
+ return StreamContext(device, stream)
diff --git a/tests/python/test_stream.py b/tests/python/test_stream.py
new file mode 100644
index 0000000..c7b81a8
--- /dev/null
+++ b/tests/python/test_stream.py
@@ -0,0 +1,115 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+
+import tvm_ffi
+import tvm_ffi.cpp
+
+try:
+ import torch
+except ImportError:
+ torch = None
+
+
+def gen_check_stream_mod():
+ return tvm_ffi.cpp.load_inline(
+ name="check_stream",
+ cpp_sources="""
+ void check_stream(int device_type, int device_id, uint64_t stream) {
+ uint64_t cur_stream =
reinterpret_cast<uint64_t>(TVMFFIEnvGetStream(device_type, device_id));
+ TVM_FFI_ICHECK_EQ(cur_stream, stream);
+ }
+ """,
+ functions=["check_stream"],
+ )
+
+
+def test_raw_stream():
+ mod = gen_check_stream_mod()
+ device = tvm_ffi.device("cuda:0")
+ stream_1 = 123456789
+ stream_2 = 987654321
+ with tvm_ffi.use_raw_stream(device, stream_1):
+ mod.check_stream(device.dlpack_device_type(), device.index, stream_1)
+
+ with tvm_ffi.use_raw_stream(device, stream_2):
+ mod.check_stream(device.dlpack_device_type(), device.index,
stream_2)
+
+ mod.check_stream(device.dlpack_device_type(), device.index, stream_1)
+
+
[email protected](
+ torch is None or not torch.cuda.is_available(), reason="Requires torch and
CUDA"
+)
+def test_torch_stream():
+ mod = gen_check_stream_mod()
+ device_id = torch.cuda.current_device()
+ device = tvm_ffi.device("cuda", device_id)
+ device_type = device.dlpack_device_type()
+ stream_1 = torch.cuda.Stream(device_id)
+ stream_2 = torch.cuda.Stream(device_id)
+ with tvm_ffi.use_torch_stream(torch.cuda.stream(stream_1)):
+ assert torch.cuda.current_stream() == stream_1
+ mod.check_stream(device_type, device_id, stream_1.cuda_stream)
+
+ with tvm_ffi.use_torch_stream(torch.cuda.stream(stream_2)):
+ assert torch.cuda.current_stream() == stream_2
+ mod.check_stream(device_type, device_id, stream_2.cuda_stream)
+
+ assert torch.cuda.current_stream() == stream_1
+ mod.check_stream(device_type, device_id, stream_1.cuda_stream)
+
+
[email protected](
+ torch is None or not torch.cuda.is_available(), reason="Requires torch and
CUDA"
+)
+def test_torch_current_stream():
+ mod = gen_check_stream_mod()
+ device_id = torch.cuda.current_device()
+ device = tvm_ffi.device("cuda", device_id)
+ device_type = device.dlpack_device_type()
+ stream_1 = torch.cuda.Stream(device_id)
+ stream_2 = torch.cuda.Stream(device_id)
+ with torch.cuda.stream(stream_1):
+ assert torch.cuda.current_stream() == stream_1
+ with tvm_ffi.use_torch_stream():
+ mod.check_stream(device_type, device_id, stream_1.cuda_stream)
+
+ with torch.cuda.stream(stream_2):
+ assert torch.cuda.current_stream() == stream_2
+ with tvm_ffi.use_torch_stream():
+ mod.check_stream(device_type, device_id, stream_2.cuda_stream)
+
+ assert torch.cuda.current_stream() == stream_1
+ with tvm_ffi.use_torch_stream():
+ mod.check_stream(device_type, device_id, stream_1.cuda_stream)
+
+
[email protected](
+ torch is None or not torch.cuda.is_available(), reason="Requires torch and
CUDA"
+)
+def test_torch_graph():
+ mod = gen_check_stream_mod()
+ device_id = torch.cuda.current_device()
+ device = tvm_ffi.device("cuda", device_id)
+ device_type = device.dlpack_device_type()
+ graph = torch.cuda.CUDAGraph()
+ stream = torch.cuda.Stream(device_id)
+ with tvm_ffi.use_torch_stream(torch.cuda.graph(graph, stream=stream)):
+ assert torch.cuda.current_stream() == stream
+ mod.check_stream(device_type, device_id, stream.cuda_stream)