This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/dev by this push:
     new 3197cd0  Add stream context (#5)
3197cd0 is described below

commit 3197cd0949ae9bb9a41eb5a429a4b1519d061db4
Author: Yaxing Cai <[email protected]>
AuthorDate: Mon Sep 15 04:32:19 2025 -0700

    Add stream context (#5)
    
    This PR adds the stream context into ffi, so that ffi env stream can be 
updated. The `tvm_ffi.use_torch_stream` is for wrapping the torch stream/graph 
context. And lower-level `tvm_ffi.use_raw_stream` is for creating context with 
existing stream handle.
    
    Example for `tvm_ffi.use_torch_stream`:
    
    case with torch stream:
    ```python
    stream = torch.cuda.Stream()
    stream_context = torch.cuda.stream(stream)
    with tvm_ffi.use_torch_stream(stream_context):
      ...
    ```
    
    case with torch cuda graph
    ```python
    graph = torch.cuda.CUDAGraph()
    graph_context = torch.cuda.graph(graph)
    with tvm_ffi.use_torch_stream(graph_context):
      ...
    ```
    
    case with current stream by default
    ```python
    stream = torch.cuda.Stream()
    stream_context = torch.cuda.stream(stream)
    with torch.cuda.stream(stream):
      with tvm_ffi.use_torch_stream():
        ...
    ```
    
    Eaxmple for `tvm_ffi.use_raw_stream`:
    
    ```python
    device = tvm_ffi.device(...)
    stream_handle = ...
    with tvm_ffi.use_raw_stream(device, stream_handle):
      ...
    ```
---
 docs/reference/python/index.rst |  10 +++
 python/tvm_ffi/__init__.py      |   1 +
 python/tvm_ffi/cython/base.pxi  |   9 +++
 python/tvm_ffi/stream.py        | 148 ++++++++++++++++++++++++++++++++++++++++
 tests/python/test_stream.py     | 115 +++++++++++++++++++++++++++++++
 5 files changed, 283 insertions(+)

diff --git a/docs/reference/python/index.rst b/docs/reference/python/index.rst
index 482c19d..756af4c 100644
--- a/docs/reference/python/index.rst
+++ b/docs/reference/python/index.rst
@@ -68,6 +68,16 @@ Containers
   Map
 
 
+Stream Context
+--------------
+.. autosummary::
+  :toctree: generated/
+
+  StreamContext
+  use_torch_stream
+  use_raw_stream
+
+
 Utility
 -------
 
diff --git a/python/tvm_ffi/__init__.py b/python/tvm_ffi/__init__.py
index 16f035e..9bafe2b 100644
--- a/python/tvm_ffi/__init__.py
+++ b/python/tvm_ffi/__init__.py
@@ -37,6 +37,7 @@ from ._tensor import Device, device, DLDeviceType
 from ._tensor import from_dlpack, Tensor, Shape
 from .container import Array, Map
 from .module import Module, system_lib, load_module
+from .stream import StreamContext, use_raw_stream, use_torch_stream
 from . import serialization
 from . import access_path
 from . import testing
diff --git a/python/tvm_ffi/cython/base.pxi b/python/tvm_ffi/cython/base.pxi
index 08f7df2..77c9c7e 100644
--- a/python/tvm_ffi/cython/base.pxi
+++ b/python/tvm_ffi/cython/base.pxi
@@ -245,6 +245,15 @@ cdef extern from "tvm/ffi/extra/c_env_api.h":
                                   TVMFFIStreamHandle stream,
                                   TVMFFIStreamHandle* opt_out_original_stream) 
nogil
 
+def _env_set_current_stream(int device_type, int device_id, uint64_t stream):
+    cdef TVMFFIStreamHandle prev_stream = NULL
+    CHECK_CALL(TVMFFIEnvSetStream(
+        device_type,
+        device_id,
+        <void*>stream,
+        &prev_stream))
+    return <uint64_t>prev_stream
+
 
 cdef extern from "tvm_ffi_python_helpers.h":
     # no need to expose fields of the call context
diff --git a/python/tvm_ffi/stream.py b/python/tvm_ffi/stream.py
new file mode 100644
index 0000000..598afca
--- /dev/null
+++ b/python/tvm_ffi/stream.py
@@ -0,0 +1,148 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Stream context."""
+from ctypes import c_void_p
+from typing import Any, Optional, Union
+
+from . import core
+from ._tensor import device
+
+
+class StreamContext:
+    """StreamContext represents a stream context in the ffi system.
+    StreamContext helps setup ffi environment stream by python `with` 
statement.
+    When entering `with` scope, it caches the current environment stream and
+    setup the given new stream.
+    When exiting `with` scope, it recovers the stream to the cached 
environment stream.
+
+    Parameters
+    ----------
+    device : Device
+        The device to which the stream belongs.
+
+    stream : Union[int, c_void_p]
+        The stream handle.
+
+    See Also
+    --------
+    :py:func:`tvm_ffi.use_raw_stream`, :py:func:`tvm_ffi.use_torch_stream`
+    """
+
+    def __init__(self, device: core.Device, stream: Union[int, c_void_p]):
+        self.device_type = device.dlpack_device_type()
+        self.device_id = device.index
+        self.stream = stream
+
+    def __enter__(self):
+        self.prev_stream = core._env_set_current_stream(
+            self.device_type, self.device_id, self.stream
+        )
+
+    def __exit__(self, *args):
+        self.prev_stream = core._env_set_current_stream(
+            self.device_type, self.device_id, self.prev_stream
+        )
+
+
+try:
+    import torch
+
+    class TorchStreamContext:
+        def __init__(self, context: Optional[Any]):
+            self.torch_context = context
+
+        def __enter__(self):
+            if self.torch_context:
+                self.torch_context.__enter__()
+            current_stream = torch.cuda.current_stream()
+            self.ffi_context = StreamContext(
+                device(str(current_stream.device)), current_stream.cuda_stream
+            )
+            self.ffi_context.__enter__()
+
+        def __exit__(self, *args):
+            if self.torch_context:
+                self.torch_context.__exit__(*args)
+            self.ffi_context.__exit__(*args)
+
+    def use_torch_stream(context: Optional[Any] = None):
+        """
+        Create a ffi stream context with given torch stream,
+        cuda graph or current stream if `None` provided.
+
+        Parameters
+        ----------
+        context : Optional[Any]
+            The wrapped torch stream or cuda graph.
+
+        Returns
+        -------
+        context : tvm_ffi.TorchStreamContext
+            The ffi stream context wrapping torch stream context.
+
+        Examples
+        --------
+        .. code-block:: python
+
+            s = torch.cuda.Stream()
+            with tvm_ffi.use_torch_stream(torch.cuda.stream(s)):
+                ...
+
+            g = torch.cuda.CUDAGraph()
+            with tvm_ffi.use_torch_stream(torch.cuda.graph(g)):
+                ...
+
+        Note
+        ----
+        When working with raw cudaStream_t handle, using 
:py:func:`tvm_ffi.use_raw_stream` instead.
+        """
+        return TorchStreamContext(context)
+
+except ImportError:
+
+    def use_torch_stream(context: Optional[Any] = None):
+        raise ImportError("Cannot import torch")
+
+
+def use_raw_stream(device: core.Device, stream: Union[int, c_void_p]):
+    """
+    Create a ffi stream context with given device and stream handle.
+
+    Parameters
+    ----------
+    device : tvm_ffi.Device
+        The device to which the stream belongs.
+
+    stream : Union[int, c_void_p]
+        The stream handle.
+
+    Returns
+    -------
+    context : tvm_ffi.StreamContext
+        The ffi stream context.
+
+    Note
+    ----
+    When working with torch stram or cuda graph, using 
:py:func:`tvm_ffi.use_torch_stream` instead.
+    """
+    if not isinstance(stream, (int, c_void_p)):
+        raise ValueError(
+            "use_raw_stream only accepts int or c_void_p as stram input, "
+            "try use_torch_stream when using torch.cuda.Stream or 
torch.cuda.graph"
+        )
+    return StreamContext(device, stream)
diff --git a/tests/python/test_stream.py b/tests/python/test_stream.py
new file mode 100644
index 0000000..c7b81a8
--- /dev/null
+++ b/tests/python/test_stream.py
@@ -0,0 +1,115 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+
+import tvm_ffi
+import tvm_ffi.cpp
+
+try:
+    import torch
+except ImportError:
+    torch = None
+
+
+def gen_check_stream_mod():
+    return tvm_ffi.cpp.load_inline(
+        name="check_stream",
+        cpp_sources="""
+        void check_stream(int device_type, int device_id, uint64_t stream) {
+            uint64_t cur_stream = 
reinterpret_cast<uint64_t>(TVMFFIEnvGetStream(device_type, device_id));
+            TVM_FFI_ICHECK_EQ(cur_stream, stream);
+        }
+    """,
+        functions=["check_stream"],
+    )
+
+
+def test_raw_stream():
+    mod = gen_check_stream_mod()
+    device = tvm_ffi.device("cuda:0")
+    stream_1 = 123456789
+    stream_2 = 987654321
+    with tvm_ffi.use_raw_stream(device, stream_1):
+        mod.check_stream(device.dlpack_device_type(), device.index, stream_1)
+
+        with tvm_ffi.use_raw_stream(device, stream_2):
+            mod.check_stream(device.dlpack_device_type(), device.index, 
stream_2)
+
+        mod.check_stream(device.dlpack_device_type(), device.index, stream_1)
+
+
[email protected](
+    torch is None or not torch.cuda.is_available(), reason="Requires torch and 
CUDA"
+)
+def test_torch_stream():
+    mod = gen_check_stream_mod()
+    device_id = torch.cuda.current_device()
+    device = tvm_ffi.device("cuda", device_id)
+    device_type = device.dlpack_device_type()
+    stream_1 = torch.cuda.Stream(device_id)
+    stream_2 = torch.cuda.Stream(device_id)
+    with tvm_ffi.use_torch_stream(torch.cuda.stream(stream_1)):
+        assert torch.cuda.current_stream() == stream_1
+        mod.check_stream(device_type, device_id, stream_1.cuda_stream)
+
+        with tvm_ffi.use_torch_stream(torch.cuda.stream(stream_2)):
+            assert torch.cuda.current_stream() == stream_2
+            mod.check_stream(device_type, device_id, stream_2.cuda_stream)
+
+        assert torch.cuda.current_stream() == stream_1
+        mod.check_stream(device_type, device_id, stream_1.cuda_stream)
+
+
[email protected](
+    torch is None or not torch.cuda.is_available(), reason="Requires torch and 
CUDA"
+)
+def test_torch_current_stream():
+    mod = gen_check_stream_mod()
+    device_id = torch.cuda.current_device()
+    device = tvm_ffi.device("cuda", device_id)
+    device_type = device.dlpack_device_type()
+    stream_1 = torch.cuda.Stream(device_id)
+    stream_2 = torch.cuda.Stream(device_id)
+    with torch.cuda.stream(stream_1):
+        assert torch.cuda.current_stream() == stream_1
+        with tvm_ffi.use_torch_stream():
+            mod.check_stream(device_type, device_id, stream_1.cuda_stream)
+
+        with torch.cuda.stream(stream_2):
+            assert torch.cuda.current_stream() == stream_2
+            with tvm_ffi.use_torch_stream():
+                mod.check_stream(device_type, device_id, stream_2.cuda_stream)
+
+        assert torch.cuda.current_stream() == stream_1
+        with tvm_ffi.use_torch_stream():
+            mod.check_stream(device_type, device_id, stream_1.cuda_stream)
+
+
[email protected](
+    torch is None or not torch.cuda.is_available(), reason="Requires torch and 
CUDA"
+)
+def test_torch_graph():
+    mod = gen_check_stream_mod()
+    device_id = torch.cuda.current_device()
+    device = tvm_ffi.device("cuda", device_id)
+    device_type = device.dlpack_device_type()
+    graph = torch.cuda.CUDAGraph()
+    stream = torch.cuda.Stream(device_id)
+    with tvm_ffi.use_torch_stream(torch.cuda.graph(graph, stream=stream)):
+        assert torch.cuda.current_stream() == stream
+        mod.check_stream(device_type, device_id, stream.cuda_stream)

Reply via email to