This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 7a355c7 [TORCH] Remove ninja dep on non-windows when JIT optional
torch-c-dlpack (#272)
7a355c7 is described below
commit 7a355c77b7e74a9ac5a1269c3c55497e88fb76a4
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon Nov 17 10:16:52 2025 -0500
[TORCH] Remove ninja dep on non-windows when JIT optional torch-c-dlpack
(#272)
---
python/tvm_ffi/_optional_torch_c_dlpack.py | 25 +++-
.../utils/_build_optional_torch_c_dlpack.py | 144 ++++++++++++++-------
2 files changed, 114 insertions(+), 55 deletions(-)
diff --git a/python/tvm_ffi/_optional_torch_c_dlpack.py
b/python/tvm_ffi/_optional_torch_c_dlpack.py
index 7f259d1..949d79f 100644
--- a/python/tvm_ffi/_optional_torch_c_dlpack.py
+++ b/python/tvm_ffi/_optional_torch_c_dlpack.py
@@ -33,6 +33,7 @@ subsequent calls will be much faster.
from __future__ import annotations
import ctypes
+import logging
import os
import subprocess
import sys
@@ -40,8 +41,10 @@ import warnings
from pathlib import Path
from typing import Any
+logger = logging.getLogger(__name__) # type: ignore
-def load_torch_c_dlpack_extension() -> Any: # noqa: PLR0912
+
+def load_torch_c_dlpack_extension() -> Any: # noqa: PLR0912, PLR0915
try:
import torch # noqa: PLC0415
@@ -82,6 +85,7 @@ def load_torch_c_dlpack_extension() -> Any: # noqa: PLR0912
libname =
f"libtorch_c_dlpack_addon_torch{major}{minor}-{device}{suffix}"
lib_path = addon_output_dir / libname
if not lib_path.exists():
+ logger.info("JIT-compiling torch-c-dlpack-ext to cache...")
build_script_path = (
Path(__file__).parent / "utils" /
"_build_optional_torch_c_dlpack.py"
)
@@ -97,11 +101,18 @@ def load_torch_c_dlpack_extension() -> Any: # noqa:
PLR0912
args.append("--build-with-cuda")
elif device == "rocm":
args.append("--build-with-rocm")
- subprocess.run(
- args,
- check=True,
- )
- assert lib_path.exists(), "Failed to build torch c dlpack addon."
+
+ # use capture_output to reduce noise when building the torch c
dlpack addon
+ result = subprocess.run(args, check=False, capture_output=True)
+ if result.returncode != 0:
+ msg = [f"Build failed with status {result.returncode}"]
+ if result.stdout:
+ msg.append(f"stdout:\n{result.stdout.decode('utf-8')}")
+ if result.stderr:
+ msg.append(f"stderr:\n{result.stderr.decode('utf-8')}")
+ raise RuntimeError("\n".join(msg))
+ if not lib_path.exists():
+ raise RuntimeError("Failed to build torch c dlpack addon.")
lib = ctypes.CDLL(str(lib_path))
func = lib.TorchDLPackExchangeAPIPtr
@@ -117,7 +128,7 @@ def load_torch_c_dlpack_extension() -> Any: # noqa: PLR0912
except Exception:
warnings.warn(
"Failed to JIT torch c dlpack extension, EnvTensorAllocator will
not be enabled.\n"
- "You may try AOT-module via `pip install torch-c-dlpack-ext`"
+ "We recommend installing via `pip install torch-c-dlpack-ext`"
)
return None
diff --git a/python/tvm_ffi/utils/_build_optional_torch_c_dlpack.py
b/python/tvm_ffi/utils/_build_optional_torch_c_dlpack.py
index 294e340..984a632 100644
--- a/python/tvm_ffi/utils/_build_optional_torch_c_dlpack.py
+++ b/python/tvm_ffi/utils/_build_optional_torch_c_dlpack.py
@@ -21,6 +21,7 @@ from __future__ import annotations
import argparse
import os
import shutil
+import subprocess
import sys
import sysconfig
import tempfile
@@ -587,7 +588,7 @@ def parse_env_flags(env_var_name: str) -> list[str]:
return []
-def _generate_ninja_build(
+def _run_build_on_linux_like(
build_dir: Path,
libname: str,
source_path: Path,
@@ -595,28 +596,16 @@ def _generate_ninja_build(
extra_ldflags: Sequence[str],
extra_include_paths: Sequence[str],
) -> None:
- """Generate the content of build.ninja for building the module."""
+ """Build the module directly by invoking compiler commands (non-Windows
only)."""
from tvm_ffi.libinfo import find_dlpack_include_path # noqa: PLC0415
- if IS_WINDOWS:
- default_cflags = [
- "/std:c++17",
- "/MD",
- "/wd4819",
- "/wd4251",
- "/wd4244",
- "/wd4267",
- "/wd4275",
- "/wd4018",
- "/wd4190",
- "/wd4624",
- "/wd4067",
- "/wd4068",
- "/EHsc",
- ]
- default_ldflags = ["/DLL"]
+ default_cflags = ["-std=c++17", "-fPIC", "-O3"]
+ # Platform-specific linker flags
+ if IS_DARWIN:
+ # macOS doesn't support --no-as-needed and uses @loader_path instead
of $ORIGIN
+ default_ldflags = ["-shared", "-Wl,-rpath,@loader_path"]
else:
- default_cflags = ["-std=c++17", "-fPIC", "-O3"]
+ # Linux uses $ORIGIN and supports --no-as-needed
default_ldflags = ["-shared", "-Wl,-rpath,$ORIGIN",
"-Wl,--no-as-needed"]
cflags = default_cflags + [flag.strip() for flag in extra_cflags]
@@ -625,43 +614,91 @@ def _generate_ninja_build(
str(Path(path).resolve()) for path in extra_include_paths
]
+ # append include paths
+ for path in include_paths:
+ cflags.extend(["-I", str(path)])
+
+ # Get compiler and build paths
+ cxx = os.environ.get("CXX", "c++")
+ source_path_resolved = source_path.resolve()
+ lib_path = build_dir / libname
+
+ # Build command: compile and link in one step
+ build_cmd = [cxx, *cflags, str(source_path_resolved), *ldflags, "-o",
str(lib_path)]
+
+ # Run build command
+ status = subprocess.run(build_cmd, cwd=str(build_dir),
capture_output=True, check=False)
+ if status.returncode != 0:
+ msg = [f"Build failed with status {status.returncode}"]
+ if status.stdout:
+ msg.append(f"stdout:\n{status.stdout.decode('utf-8')}")
+ if status.stderr:
+ msg.append(f"stderr:\n{status.stderr.decode('utf-8')}")
+ raise RuntimeError("\n".join(msg))
+
+
+def _generate_ninja_build_windows(
+ build_dir: Path,
+ libname: str,
+ source_path: Path,
+ extra_cflags: Sequence[str],
+ extra_ldflags: Sequence[str],
+ extra_include_paths: Sequence[str],
+) -> None:
+ """Generate the content of build.ninja for building the module on
Windows."""
+ from tvm_ffi.libinfo import find_dlpack_include_path # noqa: PLC0415
+
+ default_cflags = [
+ "/std:c++17",
+ "/MD",
+ "/wd4819",
+ "/wd4251",
+ "/wd4244",
+ "/wd4267",
+ "/wd4275",
+ "/wd4018",
+ "/wd4190",
+ "/wd4624",
+ "/wd4067",
+ "/wd4068",
+ "/EHsc",
+ ]
+ default_ldflags = ["/DLL"]
+
+ cflags = default_cflags + [flag.strip() for flag in extra_cflags]
+ ldflags = default_ldflags + [flag.strip() for flag in extra_ldflags]
+ include_paths = [find_dlpack_include_path()] + [
+ str(Path(path).resolve()) for path in extra_include_paths
+ ]
+
# append include paths
for path in include_paths:
path_str = str(path)
if " " in path_str:
path_str = f'"{path_str}"'
- if IS_WINDOWS:
- path_str = path_str.replace(":", "$:")
+ path_str = path_str.replace(":", "$:")
cflags.append(f"-I{path_str}")
# flags
ninja = []
ninja.append("ninja_required_version = 1.3")
- ninja.append("cxx = {}".format(os.environ.get("CXX", "cl" if IS_WINDOWS
else "c++")))
+ ninja.append("cxx = {}".format(os.environ.get("CXX", "cl")))
ninja.append("cflags = {}".format(" ".join(cflags)))
ninja.append("ldflags = {}".format(" ".join(ldflags)))
# rules
ninja.append("")
ninja.append("rule compile")
- if IS_WINDOWS:
- ninja.append(" command = $cxx /showIncludes $cflags -c $in /Fo$out")
- ninja.append(" deps = msvc")
- else:
- ninja.append(" depfile = $out.d")
- ninja.append(" deps = gcc")
- ninja.append(" command = $cxx -MMD -MF $out.d $cflags -c $in -o $out")
+ ninja.append(" command = $cxx /showIncludes $cflags -c $in /Fo$out")
+ ninja.append(" deps = msvc")
ninja.append("")
ninja.append("rule link")
- if IS_WINDOWS:
- ninja.append(" command = $cxx $in /link $ldflags /out:$out")
- else:
- ninja.append(" command = $cxx $in $ldflags -o $out")
+ ninja.append(" command = $cxx $in /link $ldflags /out:$out")
ninja.append("")
# build targets
- obj_name = "main.obj" if IS_WINDOWS else "main.o"
+ obj_name = "main.obj"
ninja.append(
"build {}: compile {}".format(obj_name,
str(source_path.resolve()).replace(":", "$:"))
)
@@ -692,7 +729,6 @@ def main() -> None: # noqa: PLR0912, PLR0915
"""Build the torch c dlpack extension."""
# we need to set the following env to avoid tvm_ffi to build the torch
c-dlpack addon during importing
os.environ["TVM_FFI_DISABLE_TORCH_C_DLPACK"] = "1"
- from tvm_ffi.cpp.extension import build_ninja # noqa: PLC0415
from tvm_ffi.utils.lockfile import FileLock # noqa: PLC0415
parser = argparse.ArgumentParser(
@@ -794,7 +830,7 @@ def main() -> None: # noqa: PLR0912, PLR0915
if IS_WINDOWS:
ldflags.append(f"/LIBPATH:{lib_dir}")
else:
- ldflags.append(f"-L{lib_dir}")
+ ldflags.extend(["-L", str(lib_dir)])
# Add all required PyTorch libraries
if IS_WINDOWS:
@@ -842,18 +878,30 @@ def main() -> None: # noqa: PLR0912, PLR0915
if env_cflags:
cflags.extend(env_cflags)
- # generate ninja build file
- _generate_ninja_build(
- build_dir=build_dir,
- libname=tmp_libname,
- source_path=source_path,
- extra_cflags=cflags,
- extra_ldflags=ldflags,
- extra_include_paths=include_paths,
- )
-
# build the shared library
- build_ninja(build_dir=str(build_dir))
+ if IS_WINDOWS:
+ # Use ninja on Windows
+ _generate_ninja_build_windows(
+ build_dir=build_dir,
+ libname=tmp_libname,
+ source_path=source_path,
+ extra_cflags=cflags,
+ extra_ldflags=ldflags,
+ extra_include_paths=include_paths,
+ )
+ from tvm_ffi.cpp.extension import build_ninja # noqa: PLC0415
+
+ build_ninja(build_dir=str(build_dir))
+ else:
+ # Use direct command on non-Windows
+ _run_build_on_linux_like(
+ build_dir=build_dir,
+ libname=tmp_libname,
+ source_path=source_path,
+ extra_cflags=cflags,
+ extra_ldflags=ldflags,
+ extra_include_paths=include_paths,
+ )
# rename the tmp file to final libname
shutil.move(str(build_dir / tmp_libname), str(output_dir / libname))