This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 65b5e905 [Feature] Support AMD HIP for cpp extension (#460)
65b5e905 is described below

commit 65b5e90576185cb6300f43bc1307158dc99afb54
Author: DarkSharpness <[email protected]>
AuthorDate: Thu Feb 19 01:54:32 2026 +0800

    [Feature] Support AMD HIP for cpp extension (#460)
    
    Related issue #458 .
    
    I'm not familiar with AMD at all, and most of the code is generated by
    claude code. I've only cleaned up a little and tried the following
    example on one AMD machine. Need some reviews from AMD experts.
    
    ```python
    import torch
    from tvm_ffi import Module
    import tvm_ffi.cpp
    
    # define the cpp source code
    cpp_source = '''
    #include <hip/hip_runtime.h>
    
    __global__ void add_one_kernel(const float* __restrict__ x,
                                   float* __restrict__ y,
                                   int64_t n) {
        int64_t i = (int64_t)blockIdx.x * blockDim.x + threadIdx.x;
        if (i < n) y[i] = x[i] + 1.0f;
    }
    
    void add_one_hip(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
        int64_t n = x.size(0);
        const float* x_ptr = static_cast<const float*>(x.data_ptr());
        float* y_ptr = static_cast<float*>(y.data_ptr());
        constexpr int threads = 256;
        int blocks = (int)((n + threads - 1) / threads);
        hipStream_t stream = 0;  // default stream; replace if your runtime 
provides a stream
        hipLaunchKernelGGL(add_one_kernel,
                           dim3(blocks), dim3(threads),
                           0, stream,
                           x_ptr, y_ptr, n);
    }
    '''
    
    # compile the cpp source code and load the module
    mod: Module = tvm_ffi.cpp.load_inline(
        name="hello",
        cuda_sources=cpp_source,
        functions="add_one_hip",
    )
    
    # use the function from the loaded module to perform
    x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32, device="cuda")
    y = torch.empty_like(x, device="cuda")
    mod.add_one_hip(x, y)
    torch.testing.assert_close(x + 1, y)
    ```
    
    ---------
    
    Co-authored-by: gemini-code-assist[bot] 
<176961590+gemini-code-assist[bot]@users.noreply.github.com>
---
 python/tvm_ffi/cpp/extension.py | 173 +++++++++++++++++++++++++++++++++++-----
 1 file changed, 155 insertions(+), 18 deletions(-)

diff --git a/python/tvm_ffi/cpp/extension.py b/python/tvm_ffi/cpp/extension.py
index b558fb09..af30348c 100644
--- a/python/tvm_ffi/cpp/extension.py
+++ b/python/tvm_ffi/cpp/extension.py
@@ -27,13 +27,39 @@ import sys
 from collections.abc import Mapping, Sequence
 from contextlib import nullcontext
 from pathlib import Path
-from typing import Any
+from typing import Any, Literal
 
 from tvm_ffi.libinfo import find_dlpack_include_path, find_include_path, 
find_libtvm_ffi
 from tvm_ffi.module import Module, load_module
 from tvm_ffi.utils import FileLock
 
 IS_WINDOWS = sys.platform == "win32"
+BACKEND_STR = Literal["cuda", "hip"]
+
+
[email protected]_cache
+def _detect_gpu_backend() -> BACKEND_STR:
+    """Auto-detect whether to use CUDA or HIP (ROCm).
+
+    Returns 'hip' if ROCm/HIP is available, 'cuda' otherwise.
+    """
+    # Check environment variable override first
+    backend = os.environ.get("TVM_FFI_GPU_BACKEND", "").lower()
+    if backend in ("cuda", "hip"):
+        return backend  # type: ignore[return-value]
+    try:
+        _find_rocm_home()
+        return "hip"
+    except RuntimeError:
+        return "cuda"
+
+
+def _resolve_gpu_backend(backend: str | None) -> BACKEND_STR:
+    if backend is not None:
+        if backend in ("cuda", "hip"):
+            return backend  # type: ignore[return-value]
+        raise ValueError(f"Invalid backend: {backend}. Supported backends are 
'cuda' and 'hip'.")
+    return _detect_gpu_backend()
 
 
 def _hash_sources(
@@ -177,6 +203,72 @@ def _get_cuda_target() -> str:
             )
 
 
[email protected]_cache
+def _find_rocm_home() -> str:
+    """Find the ROCm install path."""
+    # Guess #1: check environment variables
+    rocm_home = os.environ.get("ROCM_HOME") or os.environ.get("ROCM_PATH")
+    if rocm_home is None:
+        hipcc_path = shutil.which("hipcc")
+        # Guess #2: find hipcc in PATH and resolve ROCm home from it
+        if hipcc_path is not None:
+            rocm_home = str(Path(hipcc_path).resolve().parent.parent)
+            if Path(rocm_home).name == "hip":
+                rocm_home = str(Path(rocm_home).parent)
+        else:
+            # Guess #3: use default installation path
+            rocm_home = "/opt/rocm"
+            if not Path(rocm_home).exists():
+                raise RuntimeError(
+                    "Could not find ROCm installation. Please set ROCM_HOME 
environment variable."
+                )
+    return rocm_home
+
+
+def _get_rocm_target() -> list[str]:
+    """Get the ROCm target architecture flags (--offload-arch=gfxXXXX)."""
+    if "TVM_FFI_ROCM_ARCH_LIST" in os.environ:
+        arch_list = os.environ["TVM_FFI_ROCM_ARCH_LIST"].split()  # e.g., 
"gfx90a gfx942"
+        return [f"--offload-arch={arch}" for arch in arch_list]
+    # Try rocm_agent_enumerator
+    try:
+        agent_enum = str(Path(_find_rocm_home()) / "bin" / 
"rocm_agent_enumerator")
+        if not Path(agent_enum).exists():
+            agent_enum = "rocm_agent_enumerator"
+        status = subprocess.run(args=[agent_enum], capture_output=True, 
check=True, text=True)
+        archs = list(
+            dict.fromkeys(
+                line.strip()
+                for line in status.stdout.strip().split("\n")
+                if line.strip() and line.strip() != "gfx000"
+            )
+        )
+        if archs:
+            return [f"--offload-arch={arch}" for arch in archs]
+    except (subprocess.CalledProcessError, FileNotFoundError):
+        pass
+    # Try rocminfo
+    try:
+        status = subprocess.run(args=["rocminfo"], capture_output=True, 
check=True, text=True)
+        archs = list(
+            dict.fromkeys(
+                line.split(":")[-1].strip()
+                for line in status.stdout.split("\n")
+                if "Name:" in line
+                and "gfx" in line.lower()
+                and line.split(":")[-1].strip() != "gfx000"
+            )
+        )
+        if archs:
+            return [f"--offload-arch={arch}" for arch in archs]
+    except (subprocess.CalledProcessError, FileNotFoundError):
+        pass
+    raise RuntimeError(
+        "Could not detect ROCm GPU architecture automatically. "
+        "Please set TVM_FFI_ROCM_ARCH_LIST environment variable (e.g. 'gfx90a 
gfx942')."
+    )
+
+
 def _run_command_in_dev_prompt(
     args: list[str],
     cwd: str | os.PathLike[str],
@@ -242,7 +334,6 @@ def _run_command_in_dev_prompt(
 
 def _generate_ninja_build(  # noqa: PLR0915, PLR0912
     name: str,
-    with_cuda: bool,
     extra_cflags: Sequence[str],
     extra_cuda_cflags: Sequence[str],
     extra_ldflags: Sequence[str],
@@ -250,8 +341,13 @@ def _generate_ninja_build(  # noqa: PLR0915, PLR0912
     cpp_files: Sequence[str],
     cuda_files: Sequence[str],
     embed_cubin: Mapping[str, bytes] | None = None,
+    backend: str | None = None,
 ) -> str:
     """Generate the content of build.ninja for building the module."""
+    with_hip = backend == "hip"
+    with_cuda = backend == "cuda"
+    with_backend = with_hip or with_cuda
+
     default_include_paths = [find_include_path(), find_dlpack_include_path()]
     tvm_ffi_lib = Path(find_libtvm_ffi())
     tvm_ffi_lib_path = str(tvm_ffi_lib.parent)
@@ -280,11 +376,20 @@ def _generate_ninja_build(  # noqa: PLR0915, PLR0912
         ]
     else:
         default_cflags = ["-std=c++17", "-fPIC", "-O2"]
-        default_cuda_cflags = ["-Xcompiler", "-fPIC", "-std=c++17", "-O2"]
+        default_cuda_cflags = ["-std=c++17", "-O2"]
         default_ldflags = ["-shared", f"-L{tvm_ffi_lib_path}", "-ltvm_ffi"]
 
+        if with_hip:
+            rocm_home = _find_rocm_home()
+            default_cuda_cflags += ["-fPIC", "-D__HIP_PLATFORM_AMD__=1", 
"-fno-gpu-rdc"]
+            default_cuda_cflags += _get_rocm_target()
+            default_include_paths.append(str(Path(rocm_home) / "include"))
+            default_ldflags += [
+                f"-L{Path(rocm_home) / 'lib'!s}",
+                "-lamdhip64",
+            ]
         if with_cuda:
-            # determine the compute capability of the current GPU
+            default_cuda_cflags = ["-Xcompiler", "-fPIC", *default_cuda_cflags]
             default_cuda_cflags += [_get_cuda_target()]
             default_ldflags += [
                 "-L{}".format(str(Path(_find_cuda_home()) / "lib64")),
@@ -308,8 +413,11 @@ def _generate_ninja_build(  # noqa: PLR0915, PLR0912
     ninja.append("ninja_required_version = 1.3")
     ninja.append("cxx = {}".format(os.environ.get("CXX", "cl" if IS_WINDOWS 
else "c++")))
     ninja.append("cflags = {}".format(" ".join(cflags)))
-    if with_cuda:
-        ninja.append("nvcc = {}".format(str(Path(_find_cuda_home()) / "bin" / 
"nvcc")))
+    if with_backend:
+        if with_hip:
+            ninja.append("nvcc = {}".format(str(Path(_find_rocm_home()) / 
"bin" / "hipcc")))
+        if with_cuda:
+            ninja.append("nvcc = {}".format(str(Path(_find_cuda_home()) / 
"bin" / "nvcc")))
         ninja.append("cuda_cflags = {}".format(" ".join(cuda_cflags)))
     ninja.append("ldflags = {}".format(" ".join(ldflags)))
 
@@ -325,13 +433,16 @@ def _generate_ninja_build(  # noqa: PLR0915, PLR0912
         ninja.append("  command = $cxx -MMD -MF $out.d $cflags -c $in -o $out")
     ninja.append("")
 
-    if with_cuda:
+    if with_backend:
         ninja.append("rule compile_cuda")
         ninja.append("  depfile = $out.d")
         ninja.append("  deps = gcc")
-        ninja.append(
-            "  command = $nvcc --generate-dependencies-with-compile 
--dependency-output $out.d $cuda_cflags -c $in -o $out"
-        )
+        if with_hip:
+            ninja.append("  command = $nvcc $cuda_cflags -c $in -o $out")
+        else:
+            ninja.append(
+                "  command = $nvcc  --generate-dependencies-with-compile 
--dependency-output $out.d $cuda_cflags -c $in -o $out"
+            )
         ninja.append("")
 
     # Add rules for object merging and cubin embedding (Unix only)
@@ -482,7 +593,7 @@ def _str_seq2list(seq: Sequence[str] | str | None) -> 
list[str]:
         return list(seq)
 
 
-def _build_impl(
+def _build_impl(  # noqa: PLR0913
     name: str,
     cpp_files: Sequence[str] | str | None,
     cuda_files: Sequence[str] | str | None,
@@ -493,15 +604,17 @@ def _build_impl(
     build_directory: str | None,
     need_lock: bool = True,
     embed_cubin: Mapping[str, bytes] | None = None,
+    backend: str | None = None,
 ) -> str:
     """Real implementation of build function."""
     # need to resolve the path to make it unique
     cpp_path_list = [str(Path(p).resolve()) for p in _str_seq2list(cpp_files)]
     cuda_path_list = [str(Path(p).resolve()) for p in 
_str_seq2list(cuda_files)]
     with_cpp = bool(cpp_path_list)
-    with_cuda = bool(cuda_path_list)
-    assert with_cpp or with_cuda, "Either cpp_files or cuda_files must be 
provided."
+    with_backend = bool(cuda_path_list)
+    assert with_cpp or with_backend, "Either cpp_files or cuda_files must be 
provided."
 
+    resolved_backend = _resolve_gpu_backend(backend) if with_backend else None
     extra_ldflags_list = list(extra_ldflags) if extra_ldflags is not None else 
[]
     extra_cflags_list = list(extra_cflags) if extra_cflags is not None else []
     extra_cuda_cflags_list = list(extra_cuda_cflags) if extra_cuda_cflags is 
not None else []
@@ -541,7 +654,6 @@ def _build_impl(
     # generate build.ninja
     ninja_source = _generate_ninja_build(
         name=name,
-        with_cuda=with_cuda,
         extra_cflags=extra_cflags_list,
         extra_cuda_cflags=extra_cuda_cflags_list,
         extra_ldflags=extra_ldflags_list,
@@ -549,6 +661,7 @@ def _build_impl(
         cpp_files=cpp_path_list,
         cuda_files=cuda_path_list,
         embed_cubin=embed_cubin,
+        backend=resolved_backend,
     )
 
     # may not hold lock when build_directory is specified, prevent deadlock
@@ -562,7 +675,7 @@ def _build_impl(
         return str((build_dir / f"{name}{ext}").resolve())
 
 
-def build_inline(
+def build_inline(  # noqa: PLR0913
     name: str,
     *,
     cpp_sources: Sequence[str] | str | None = None,
@@ -574,6 +687,7 @@ def build_inline(
     extra_include_paths: Sequence[str] | None = None,
     build_directory: str | None = None,
     embed_cubin: Mapping[str, bytes] | None = None,
+    backend: str | None = None,
 ) -> str:
     """Compile and build a C++/CUDA module from inline source code.
 
@@ -648,6 +762,10 @@ def build_inline(
         the macro `TVM_FFI_EMBED_CUBIN_GET_KERNEL(name, kernel_name)` defined 
in the `tvm/ffi/extra/cuda/cubin_launcher.h` header.
         See the `examples/cubin_launcher` directory for examples how to use 
cubin launcher to launch CUBIN kernels in TVM-FFI.
 
+    backend
+        The GPU backend to use. It can be "cuda" or "hip".
+        If not specified, the backend will be automatically determined based 
on the available GPU and the provided source code.
+
     Returns
     -------
     lib_path: str
@@ -702,7 +820,7 @@ def build_inline(
 
     cuda_source_list = _str_seq2list(cuda_sources)
     cuda_source = "\n".join(cuda_source_list)
-    with_cuda = bool(cuda_source_list)
+    with_backend = bool(cuda_source_list)
     del cuda_source_list
 
     extra_ldflags_list = list(extra_ldflags) if extra_ldflags is not None else 
[]
@@ -753,13 +871,13 @@ def build_inline(
     with FileLock(str(build_dir / "lock")):
         # write source files if they do not already exist
         _maybe_write(cpp_file, cpp_source)
-        if with_cuda:
+        if with_backend:
             _maybe_write(cuda_file, cuda_source)
 
         return _build_impl(
             name=name,
             cpp_files=[cpp_file] if with_cpp else [],
-            cuda_files=[cuda_file] if with_cuda else [],
+            cuda_files=[cuda_file] if with_backend else [],
             extra_cflags=extra_cflags_list,
             extra_cuda_cflags=extra_cuda_cflags_list,
             extra_ldflags=extra_ldflags_list,
@@ -767,6 +885,7 @@ def build_inline(
             build_directory=str(build_dir),
             need_lock=False,  # already hold the lock
             embed_cubin=embed_cubin,
+            backend=backend,
         )
 
 
@@ -783,6 +902,7 @@ def load_inline(  # noqa: PLR0913
     build_directory: str | None = None,
     embed_cubin: Mapping[str, bytes] | None = None,
     keep_module_alive: bool = True,
+    backend: str | None = None,
 ) -> Module:
     """Compile, build and load a C++/CUDA module from inline source code.
 
@@ -859,6 +979,10 @@ def load_inline(  # noqa: PLR0913
         Whether to keep the module alive. If True, the module will be kept 
alive
         for the duration of the program until libtvm_ffi.so is unloaded.
 
+    backend
+        The GPU backend to use. It can be "cuda" or "hip".
+        If not specified, the backend will be automatically determined based 
on the available GPU and the provided source code.
+
     Returns
     -------
     mod: Module
@@ -919,6 +1043,7 @@ def load_inline(  # noqa: PLR0913
             extra_include_paths=extra_include_paths,
             build_directory=build_directory,
             embed_cubin=embed_cubin,
+            backend=backend,
         ),
         keep_module_alive=keep_module_alive,
     )
@@ -934,6 +1059,7 @@ def build(
     extra_ldflags: Sequence[str] | None = None,
     extra_include_paths: Sequence[str] | None = None,
     build_directory: str | None = None,
+    backend: str | None = None,
 ) -> str:
     """Compile and build a C++/CUDA module from source files.
 
@@ -997,6 +1123,10 @@ def build(
         cache directory is ``~/.cache/tvm-ffi``. You can also set the 
``TVM_FFI_CACHE_DIR`` environment variable to
         specify the cache directory.
 
+    backend
+        The GPU backend to use. It can be "cuda" or "hip".
+        If not specified, the backend will be automatically determined based 
on the available GPU and the provided source code.
+
     Returns
     -------
     lib_path: str
@@ -1060,6 +1190,7 @@ def build(
         extra_include_paths=extra_include_paths,
         build_directory=build_directory,
         need_lock=True,
+        backend=backend,
     )
 
 
@@ -1074,6 +1205,7 @@ def load(
     extra_include_paths: Sequence[str] | None = None,
     build_directory: str | None = None,
     keep_module_alive: bool = True,
+    backend: str | None = None,
 ) -> Module:
     """Compile, build and load a C++/CUDA module from source files.
 
@@ -1141,6 +1273,10 @@ def load(
         Whether to keep the module alive. If True, the module will be kept 
alive
         for the duration of the program until libtvm_ffi.so is unloaded.
 
+    backend
+        The GPU backend to use. It can be "cuda" or "hip".
+        If not specified, the backend will be automatically determined based 
on the available GPU and the provided source code.
+
     Returns
     -------
     mod: Module
@@ -1205,6 +1341,7 @@ def load(
             extra_ldflags=extra_ldflags,
             extra_include_paths=extra_include_paths,
             build_directory=build_directory,
+            backend=backend,
         ),
         keep_module_alive=keep_module_alive,
     )

Reply via email to